Port envd from e2b with internalized shared packages and Connect RPC
- Copy envd source from e2b-dev/infra, internalize shared dependencies
into envd/internal/shared/ (keys, filesystem, id, smap, utils)
- Switch from gRPC to Connect RPC for all envd services
- Update module paths to git.omukk.dev/wrenn/{sandbox,sandbox/envd}
- Add proto specs (process, filesystem) with buf-based code generation
- Implement full envd: process exec, filesystem ops, port forwarding,
cgroup management, MMDS integration, and HTTP API
- Update main module dependencies (firecracker SDK, pgx, goose, etc.)
- Remove placeholder .gitkeep files replaced by real implementations
This commit is contained in:
@ -1,17 +1,62 @@
|
||||
LDFLAGS := -s -w
|
||||
BUILD := $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
|
||||
LDFLAGS := -s -w -X=main.commitSHA=$(BUILD)
|
||||
BUILDS := ../builds
|
||||
|
||||
.PHONY: build clean fmt vet
|
||||
# ═══════════════════════════════════════════════════
|
||||
# Build
|
||||
# ═══════════════════════════════════════════════════
|
||||
.PHONY: build build-debug
|
||||
|
||||
build:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="$(LDFLAGS)" -o envd .
|
||||
@file envd | grep -q "statically linked" || \
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="$(LDFLAGS)" -o $(BUILDS)/envd .
|
||||
@file $(BUILDS)/envd | grep -q "statically linked" || \
|
||||
(echo "ERROR: envd is not statically linked!" && exit 1)
|
||||
|
||||
clean:
|
||||
rm -f envd
|
||||
build-debug:
|
||||
CGO_ENABLED=1 go build -race -gcflags=all="-N -l" -ldflags="-X=main.commitSHA=$(BUILD)" -o $(BUILDS)/debug/envd .
|
||||
|
||||
# ═══════════════════════════════════════════════════
|
||||
# Run (debug mode, not inside a VM)
|
||||
# ═══════════════════════════════════════════════════
|
||||
.PHONY: run-debug
|
||||
|
||||
run-debug: build-debug
|
||||
$(BUILDS)/debug/envd -isnotfc -port 49983
|
||||
|
||||
# ═══════════════════════════════════════════════════
|
||||
# Code Generation
|
||||
# ═══════════════════════════════════════════════════
|
||||
.PHONY: generate proto openapi
|
||||
|
||||
generate: proto openapi
|
||||
|
||||
proto:
|
||||
cd spec && buf generate --template buf.gen.yaml
|
||||
|
||||
openapi:
|
||||
go generate ./internal/api/...
|
||||
|
||||
# ═══════════════════════════════════════════════════
|
||||
# Quality
|
||||
# ═══════════════════════════════════════════════════
|
||||
.PHONY: fmt vet test tidy
|
||||
|
||||
fmt:
|
||||
gofmt -w .
|
||||
|
||||
vet:
|
||||
go vet ./...
|
||||
|
||||
test:
|
||||
go test -race -v ./...
|
||||
|
||||
tidy:
|
||||
go mod tidy
|
||||
|
||||
# ═══════════════════════════════════════════════════
|
||||
# Clean
|
||||
# ═══════════════════════════════════════════════════
|
||||
.PHONY: clean
|
||||
|
||||
clean:
|
||||
rm -f $(BUILDS)/envd $(BUILDS)/debug/envd
|
||||
|
||||
43
envd/go.mod
43
envd/go.mod
@ -1,9 +1,42 @@
|
||||
module github.com/wrenn-dev/envd
|
||||
module git.omukk.dev/wrenn/sandbox/envd
|
||||
|
||||
go 1.23.0
|
||||
go 1.25.5
|
||||
|
||||
require (
|
||||
github.com/mdlayher/vsock v1.2.1
|
||||
google.golang.org/grpc v1.71.0
|
||||
google.golang.org/protobuf v1.36.5
|
||||
connectrpc.com/authn v0.1.0
|
||||
connectrpc.com/connect v1.19.1
|
||||
connectrpc.com/cors v0.1.0
|
||||
github.com/awnumar/memguard v0.23.0
|
||||
github.com/creack/pty v1.1.24
|
||||
github.com/dchest/uniuri v1.2.0
|
||||
github.com/e2b-dev/fsnotify v0.0.1
|
||||
github.com/go-chi/chi/v5 v5.2.5
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/oapi-codegen/runtime v1.2.0
|
||||
github.com/orcaman/concurrent-map/v2 v2.0.1
|
||||
github.com/rs/cors v1.11.1
|
||||
github.com/rs/zerolog v1.34.0
|
||||
github.com/shirou/gopsutil/v4 v4.26.2
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/txn2/txeh v1.8.0
|
||||
golang.org/x/sys v0.42.0
|
||||
google.golang.org/protobuf v1.36.11
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
|
||||
github.com/awnumar/memcall v0.4.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/ebitengine/purego v0.10.0 // indirect
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.16 // indirect
|
||||
github.com/tklauser/numcpus v0.11.0 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
golang.org/x/crypto v0.41.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
92
envd/go.sum
Normal file
92
envd/go.sum
Normal file
@ -0,0 +1,92 @@
|
||||
connectrpc.com/authn v0.1.0 h1:m5weACjLWwgwcjttvUDyTPICJKw74+p2obBVrf8hT9E=
|
||||
connectrpc.com/authn v0.1.0/go.mod h1:AwNZK/KYbqaJzRYadTuAaoz6sYQSPdORPqh1TOPIkgY=
|
||||
connectrpc.com/connect v1.19.1 h1:R5M57z05+90EfEvCY1b7hBxDVOUl45PrtXtAV2fOC14=
|
||||
connectrpc.com/connect v1.19.1/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w=
|
||||
connectrpc.com/cors v0.1.0 h1:f3gTXJyDZPrDIZCQ567jxfD9PAIpopHiRDnJRt3QuOQ=
|
||||
connectrpc.com/cors v0.1.0/go.mod h1:v8SJZCPfHtGH1zsm+Ttajpozd4cYIUryl4dFB6QEpfg=
|
||||
github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk=
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ=
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk=
|
||||
github.com/awnumar/memcall v0.4.0 h1:B7hgZYdfH6Ot1Goaz8jGne/7i8xD4taZie/PNSFZ29g=
|
||||
github.com/awnumar/memcall v0.4.0/go.mod h1:8xOx1YbfyuCg3Fy6TO8DK0kZUua3V42/goA5Ru47E8w=
|
||||
github.com/awnumar/memguard v0.23.0 h1:sJ3a1/SWlcuKIQ7MV+R9p0Pvo9CWsMbGZvcZQtmc68A=
|
||||
github.com/awnumar/memguard v0.23.0/go.mod h1:olVofBrsPdITtJ2HgxQKrEYEMyIBAIciVG4wNnZhW9M=
|
||||
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
|
||||
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dchest/uniuri v1.2.0 h1:koIcOUdrTIivZgSLhHQvKgqdWZq5d7KdMEWF1Ud6+5g=
|
||||
github.com/dchest/uniuri v1.2.0/go.mod h1:fSzm4SLHzNZvWLvWJew423PhAzkpNQYq+uNLq4kxhkY=
|
||||
github.com/e2b-dev/fsnotify v0.0.1 h1:7j0I98HD6VehAuK/bcslvW4QDynAULtOuMZtImihjVk=
|
||||
github.com/e2b-dev/fsnotify v0.0.1/go.mod h1:jAuDjregRrUixKneTRQwPI847nNuPFg3+n5QM/ku/JM=
|
||||
github.com/ebitengine/purego v0.10.0 h1:QIw4xfpWT6GWTzaW5XEKy3HXoqrJGx1ijYHzTF0/ISU=
|
||||
github.com/ebitengine/purego v0.10.0/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
|
||||
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/oapi-codegen/runtime v1.2.0 h1:RvKc1CVS1QeKSNzO97FBQbSMZyQ8s6rZd+LpmzwHMP4=
|
||||
github.com/oapi-codegen/runtime v1.2.0/go.mod h1:Y7ZhmmlE8ikZOmuHRRndiIm7nf3xcVv+YMweKgG1DT0=
|
||||
github.com/orcaman/concurrent-map/v2 v2.0.1 h1:jOJ5Pg2w1oeB6PeDurIYf6k9PQ+aTITr/6lP/L/zp6c=
|
||||
github.com/orcaman/concurrent-map/v2 v2.0.1/go.mod h1:9Eq3TG2oBe5FirmYWQfYO5iH1q0Jv47PLaNK++uCdOM=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA=
|
||||
github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
|
||||
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||
github.com/shirou/gopsutil/v4 v4.26.2 h1:X8i6sicvUFih4BmYIGT1m2wwgw2VG9YgrDTi7cIRGUI=
|
||||
github.com/shirou/gopsutil/v4 v4.26.2/go.mod h1:LZ6ewCSkBqUpvSOf+LsTGnRinC6iaNUNMGBtDkJBaLQ=
|
||||
github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYICU0nA=
|
||||
github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI=
|
||||
github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw=
|
||||
github.com/tklauser/numcpus v0.11.0/go.mod h1:z+LwcLq54uWZTX0u/bGobaV34u6V7KNlTZejzM6/3MQ=
|
||||
github.com/txn2/txeh v1.8.0 h1:G1vZgom6+P/xWwU53AMOpcZgC5ni382ukcPP1TDVYHk=
|
||||
github.com/txn2/txeh v1.8.0/go.mod h1:rRI3Egi3+AFmEXQjft051YdYbxeCT3nFmBLsNCZZaxM=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
|
||||
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
pgregory.net/rapid v1.2.0 h1:keKAYRcjm+e1F0oAuU5F5+YPAWcyxNNRK2wud503Gnk=
|
||||
pgregory.net/rapid v1.2.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04=
|
||||
568
envd/internal/api/api.gen.go
Normal file
568
envd/internal/api/api.gen.go
Normal file
@ -0,0 +1,568 @@
|
||||
// Package api provides primitives to interact with the openapi HTTP API.
|
||||
//
|
||||
// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.5.1 DO NOT EDIT.
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/oapi-codegen/runtime"
|
||||
openapi_types "github.com/oapi-codegen/runtime/types"
|
||||
)
|
||||
|
||||
const (
|
||||
AccessTokenAuthScopes = "AccessTokenAuth.Scopes"
|
||||
)
|
||||
|
||||
// Defines values for EntryInfoType.
|
||||
const (
|
||||
File EntryInfoType = "file"
|
||||
)
|
||||
|
||||
// EntryInfo defines model for EntryInfo.
|
||||
type EntryInfo struct {
|
||||
// Name Name of the file
|
||||
Name string `json:"name"`
|
||||
|
||||
// Path Path to the file
|
||||
Path string `json:"path"`
|
||||
|
||||
// Type Type of the file
|
||||
Type EntryInfoType `json:"type"`
|
||||
}
|
||||
|
||||
// EntryInfoType Type of the file
|
||||
type EntryInfoType string
|
||||
|
||||
// EnvVars Environment variables to set
|
||||
type EnvVars map[string]string
|
||||
|
||||
// Error defines model for Error.
|
||||
type Error struct {
|
||||
// Code Error code
|
||||
Code int `json:"code"`
|
||||
|
||||
// Message Error message
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Metrics Resource usage metrics
|
||||
type Metrics struct {
|
||||
// CpuCount Number of CPU cores
|
||||
CpuCount *int `json:"cpu_count,omitempty"`
|
||||
|
||||
// CpuUsedPct CPU usage percentage
|
||||
CpuUsedPct *float32 `json:"cpu_used_pct,omitempty"`
|
||||
|
||||
// DiskTotal Total disk space in bytes
|
||||
DiskTotal *int `json:"disk_total,omitempty"`
|
||||
|
||||
// DiskUsed Used disk space in bytes
|
||||
DiskUsed *int `json:"disk_used,omitempty"`
|
||||
|
||||
// MemTotal Total virtual memory in bytes
|
||||
MemTotal *int `json:"mem_total,omitempty"`
|
||||
|
||||
// MemUsed Used virtual memory in bytes
|
||||
MemUsed *int `json:"mem_used,omitempty"`
|
||||
|
||||
// Ts Unix timestamp in UTC for current sandbox time
|
||||
Ts *int64 `json:"ts,omitempty"`
|
||||
}
|
||||
|
||||
// VolumeMount Volume
|
||||
type VolumeMount struct {
|
||||
NfsTarget string `json:"nfs_target"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// FilePath defines model for FilePath.
|
||||
type FilePath = string
|
||||
|
||||
// Signature defines model for Signature.
|
||||
type Signature = string
|
||||
|
||||
// SignatureExpiration defines model for SignatureExpiration.
|
||||
type SignatureExpiration = int
|
||||
|
||||
// User defines model for User.
|
||||
type User = string
|
||||
|
||||
// FileNotFound defines model for FileNotFound.
|
||||
type FileNotFound = Error
|
||||
|
||||
// InternalServerError defines model for InternalServerError.
|
||||
type InternalServerError = Error
|
||||
|
||||
// InvalidPath defines model for InvalidPath.
|
||||
type InvalidPath = Error
|
||||
|
||||
// InvalidUser defines model for InvalidUser.
|
||||
type InvalidUser = Error
|
||||
|
||||
// NotEnoughDiskSpace defines model for NotEnoughDiskSpace.
|
||||
type NotEnoughDiskSpace = Error
|
||||
|
||||
// UploadSuccess defines model for UploadSuccess.
|
||||
type UploadSuccess = []EntryInfo
|
||||
|
||||
// GetFilesParams defines parameters for GetFiles.
|
||||
type GetFilesParams struct {
|
||||
// Path Path to the file, URL encoded. Can be relative to user's home directory.
|
||||
Path *FilePath `form:"path,omitempty" json:"path,omitempty"`
|
||||
|
||||
// Username User used for setting the owner, or resolving relative paths.
|
||||
Username *User `form:"username,omitempty" json:"username,omitempty"`
|
||||
|
||||
// Signature Signature used for file access permission verification.
|
||||
Signature *Signature `form:"signature,omitempty" json:"signature,omitempty"`
|
||||
|
||||
// SignatureExpiration Signature expiration used for defining the expiration time of the signature.
|
||||
SignatureExpiration *SignatureExpiration `form:"signature_expiration,omitempty" json:"signature_expiration,omitempty"`
|
||||
}
|
||||
|
||||
// PostFilesMultipartBody defines parameters for PostFiles.
|
||||
type PostFilesMultipartBody struct {
|
||||
File *openapi_types.File `json:"file,omitempty"`
|
||||
}
|
||||
|
||||
// PostFilesParams defines parameters for PostFiles.
|
||||
type PostFilesParams struct {
|
||||
// Path Path to the file, URL encoded. Can be relative to user's home directory.
|
||||
Path *FilePath `form:"path,omitempty" json:"path,omitempty"`
|
||||
|
||||
// Username User used for setting the owner, or resolving relative paths.
|
||||
Username *User `form:"username,omitempty" json:"username,omitempty"`
|
||||
|
||||
// Signature Signature used for file access permission verification.
|
||||
Signature *Signature `form:"signature,omitempty" json:"signature,omitempty"`
|
||||
|
||||
// SignatureExpiration Signature expiration used for defining the expiration time of the signature.
|
||||
SignatureExpiration *SignatureExpiration `form:"signature_expiration,omitempty" json:"signature_expiration,omitempty"`
|
||||
}
|
||||
|
||||
// PostInitJSONBody defines parameters for PostInit.
|
||||
type PostInitJSONBody struct {
|
||||
// AccessToken Access token for secure access to envd service
|
||||
AccessToken *SecureToken `json:"accessToken,omitempty"`
|
||||
|
||||
// DefaultUser The default user to use for operations
|
||||
DefaultUser *string `json:"defaultUser,omitempty"`
|
||||
|
||||
// DefaultWorkdir The default working directory to use for operations
|
||||
DefaultWorkdir *string `json:"defaultWorkdir,omitempty"`
|
||||
|
||||
// EnvVars Environment variables to set
|
||||
EnvVars *EnvVars `json:"envVars,omitempty"`
|
||||
|
||||
// HyperloopIP IP address of the hyperloop server to connect to
|
||||
HyperloopIP *string `json:"hyperloopIP,omitempty"`
|
||||
|
||||
// Timestamp The current timestamp in RFC3339 format
|
||||
Timestamp *time.Time `json:"timestamp,omitempty"`
|
||||
VolumeMounts *[]VolumeMount `json:"volumeMounts,omitempty"`
|
||||
}
|
||||
|
||||
// PostFilesMultipartRequestBody defines body for PostFiles for multipart/form-data ContentType.
|
||||
type PostFilesMultipartRequestBody PostFilesMultipartBody
|
||||
|
||||
// PostInitJSONRequestBody defines body for PostInit for application/json ContentType.
|
||||
type PostInitJSONRequestBody PostInitJSONBody
|
||||
|
||||
// ServerInterface represents all server handlers.
|
||||
type ServerInterface interface {
|
||||
// Get the environment variables
|
||||
// (GET /envs)
|
||||
GetEnvs(w http.ResponseWriter, r *http.Request)
|
||||
// Download a file
|
||||
// (GET /files)
|
||||
GetFiles(w http.ResponseWriter, r *http.Request, params GetFilesParams)
|
||||
// Upload a file and ensure the parent directories exist. If the file exists, it will be overwritten.
|
||||
// (POST /files)
|
||||
PostFiles(w http.ResponseWriter, r *http.Request, params PostFilesParams)
|
||||
// Check the health of the service
|
||||
// (GET /health)
|
||||
GetHealth(w http.ResponseWriter, r *http.Request)
|
||||
// Set initial vars, ensure the time and metadata is synced with the host
|
||||
// (POST /init)
|
||||
PostInit(w http.ResponseWriter, r *http.Request)
|
||||
// Get the stats of the service
|
||||
// (GET /metrics)
|
||||
GetMetrics(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
// Unimplemented server implementation that returns http.StatusNotImplemented for each endpoint.
|
||||
|
||||
type Unimplemented struct{}
|
||||
|
||||
// Get the environment variables
|
||||
// (GET /envs)
|
||||
func (_ Unimplemented) GetEnvs(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// Download a file
|
||||
// (GET /files)
|
||||
func (_ Unimplemented) GetFiles(w http.ResponseWriter, r *http.Request, params GetFilesParams) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// Upload a file and ensure the parent directories exist. If the file exists, it will be overwritten.
|
||||
// (POST /files)
|
||||
func (_ Unimplemented) PostFiles(w http.ResponseWriter, r *http.Request, params PostFilesParams) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// Check the health of the service
|
||||
// (GET /health)
|
||||
func (_ Unimplemented) GetHealth(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// Set initial vars, ensure the time and metadata is synced with the host
|
||||
// (POST /init)
|
||||
func (_ Unimplemented) PostInit(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// Get the stats of the service
|
||||
// (GET /metrics)
|
||||
func (_ Unimplemented) GetMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// ServerInterfaceWrapper converts contexts to parameters.
|
||||
type ServerInterfaceWrapper struct {
|
||||
Handler ServerInterface
|
||||
HandlerMiddlewares []MiddlewareFunc
|
||||
ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error)
|
||||
}
|
||||
|
||||
type MiddlewareFunc func(http.Handler) http.Handler
|
||||
|
||||
// GetEnvs operation middleware
|
||||
func (siw *ServerInterfaceWrapper) GetEnvs(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.GetEnvs(w, r)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// GetFiles operation middleware
|
||||
func (siw *ServerInterfaceWrapper) GetFiles(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var err error
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
// Parameter object where we will unmarshal all parameters from the context
|
||||
var params GetFilesParams
|
||||
|
||||
// ------------- Optional query parameter "path" -------------
|
||||
|
||||
err = runtime.BindQueryParameter("form", true, false, "path", r.URL.Query(), ¶ms.Path)
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "path", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "username" -------------
|
||||
|
||||
err = runtime.BindQueryParameter("form", true, false, "username", r.URL.Query(), ¶ms.Username)
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "username", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "signature" -------------
|
||||
|
||||
err = runtime.BindQueryParameter("form", true, false, "signature", r.URL.Query(), ¶ms.Signature)
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "signature_expiration" -------------
|
||||
|
||||
err = runtime.BindQueryParameter("form", true, false, "signature_expiration", r.URL.Query(), ¶ms.SignatureExpiration)
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature_expiration", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.GetFiles(w, r, params)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// PostFiles operation middleware
|
||||
func (siw *ServerInterfaceWrapper) PostFiles(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var err error
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
// Parameter object where we will unmarshal all parameters from the context
|
||||
var params PostFilesParams
|
||||
|
||||
// ------------- Optional query parameter "path" -------------
|
||||
|
||||
err = runtime.BindQueryParameter("form", true, false, "path", r.URL.Query(), ¶ms.Path)
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "path", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "username" -------------
|
||||
|
||||
err = runtime.BindQueryParameter("form", true, false, "username", r.URL.Query(), ¶ms.Username)
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "username", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "signature" -------------
|
||||
|
||||
err = runtime.BindQueryParameter("form", true, false, "signature", r.URL.Query(), ¶ms.Signature)
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "signature_expiration" -------------
|
||||
|
||||
err = runtime.BindQueryParameter("form", true, false, "signature_expiration", r.URL.Query(), ¶ms.SignatureExpiration)
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature_expiration", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.PostFiles(w, r, params)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// GetHealth operation middleware
|
||||
func (siw *ServerInterfaceWrapper) GetHealth(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.GetHealth(w, r)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// PostInit operation middleware
|
||||
func (siw *ServerInterfaceWrapper) PostInit(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.PostInit(w, r)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// GetMetrics operation middleware
|
||||
func (siw *ServerInterfaceWrapper) GetMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.GetMetrics(w, r)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
type UnescapedCookieParamError struct {
|
||||
ParamName string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *UnescapedCookieParamError) Error() string {
|
||||
return fmt.Sprintf("error unescaping cookie parameter '%s'", e.ParamName)
|
||||
}
|
||||
|
||||
func (e *UnescapedCookieParamError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type UnmarshalingParamError struct {
|
||||
ParamName string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *UnmarshalingParamError) Error() string {
|
||||
return fmt.Sprintf("Error unmarshaling parameter %s as JSON: %s", e.ParamName, e.Err.Error())
|
||||
}
|
||||
|
||||
func (e *UnmarshalingParamError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type RequiredParamError struct {
|
||||
ParamName string
|
||||
}
|
||||
|
||||
func (e *RequiredParamError) Error() string {
|
||||
return fmt.Sprintf("Query argument %s is required, but not found", e.ParamName)
|
||||
}
|
||||
|
||||
type RequiredHeaderError struct {
|
||||
ParamName string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *RequiredHeaderError) Error() string {
|
||||
return fmt.Sprintf("Header parameter %s is required, but not found", e.ParamName)
|
||||
}
|
||||
|
||||
func (e *RequiredHeaderError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type InvalidParamFormatError struct {
|
||||
ParamName string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *InvalidParamFormatError) Error() string {
|
||||
return fmt.Sprintf("Invalid format for parameter %s: %s", e.ParamName, e.Err.Error())
|
||||
}
|
||||
|
||||
func (e *InvalidParamFormatError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type TooManyValuesForParamError struct {
|
||||
ParamName string
|
||||
Count int
|
||||
}
|
||||
|
||||
func (e *TooManyValuesForParamError) Error() string {
|
||||
return fmt.Sprintf("Expected one value for %s, got %d", e.ParamName, e.Count)
|
||||
}
|
||||
|
||||
// Handler creates http.Handler with routing matching OpenAPI spec.
|
||||
func Handler(si ServerInterface) http.Handler {
|
||||
return HandlerWithOptions(si, ChiServerOptions{})
|
||||
}
|
||||
|
||||
type ChiServerOptions struct {
|
||||
BaseURL string
|
||||
BaseRouter chi.Router
|
||||
Middlewares []MiddlewareFunc
|
||||
ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error)
|
||||
}
|
||||
|
||||
// HandlerFromMux creates http.Handler with routing matching OpenAPI spec based on the provided mux.
|
||||
func HandlerFromMux(si ServerInterface, r chi.Router) http.Handler {
|
||||
return HandlerWithOptions(si, ChiServerOptions{
|
||||
BaseRouter: r,
|
||||
})
|
||||
}
|
||||
|
||||
func HandlerFromMuxWithBaseURL(si ServerInterface, r chi.Router, baseURL string) http.Handler {
|
||||
return HandlerWithOptions(si, ChiServerOptions{
|
||||
BaseURL: baseURL,
|
||||
BaseRouter: r,
|
||||
})
|
||||
}
|
||||
|
||||
// HandlerWithOptions creates http.Handler with additional options
|
||||
func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handler {
|
||||
r := options.BaseRouter
|
||||
|
||||
if r == nil {
|
||||
r = chi.NewRouter()
|
||||
}
|
||||
if options.ErrorHandlerFunc == nil {
|
||||
options.ErrorHandlerFunc = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
wrapper := ServerInterfaceWrapper{
|
||||
Handler: si,
|
||||
HandlerMiddlewares: options.Middlewares,
|
||||
ErrorHandlerFunc: options.ErrorHandlerFunc,
|
||||
}
|
||||
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Get(options.BaseURL+"/envs", wrapper.GetEnvs)
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Get(options.BaseURL+"/files", wrapper.GetFiles)
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Post(options.BaseURL+"/files", wrapper.PostFiles)
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Get(options.BaseURL+"/health", wrapper.GetHealth)
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Post(options.BaseURL+"/init", wrapper.PostInit)
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Get(options.BaseURL+"/metrics", wrapper.GetMetrics)
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
129
envd/internal/api/auth.go
Normal file
129
envd/internal/api/auth.go
Normal file
@ -0,0 +1,129 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
|
||||
)
|
||||
|
||||
const (
|
||||
SigningReadOperation = "read"
|
||||
SigningWriteOperation = "write"
|
||||
|
||||
accessTokenHeader = "X-Access-Token"
|
||||
)
|
||||
|
||||
// paths that are always allowed without general authentication
|
||||
// POST/init is secured via MMDS hash validation instead
|
||||
var authExcludedPaths = []string{
|
||||
"GET/health",
|
||||
"GET/files",
|
||||
"POST/files",
|
||||
"POST/init",
|
||||
}
|
||||
|
||||
func (a *API) WithAuthorization(handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
if a.accessToken.IsSet() {
|
||||
authHeader := req.Header.Get(accessTokenHeader)
|
||||
|
||||
// check if this path is allowed without authentication (e.g., health check, endpoints supporting signing)
|
||||
allowedPath := slices.Contains(authExcludedPaths, req.Method+req.URL.Path)
|
||||
|
||||
if !a.accessToken.Equals(authHeader) && !allowedPath {
|
||||
a.logger.Error().Msg("Trying to access secured envd without correct access token")
|
||||
|
||||
err := fmt.Errorf("unauthorized access, please provide a valid access token or method signing if supported")
|
||||
jsonError(w, http.StatusUnauthorized, err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
})
|
||||
}
|
||||
|
||||
func (a *API) generateSignature(path string, username string, operation string, signatureExpiration *int64) (string, error) {
|
||||
tokenBytes, err := a.accessToken.Bytes()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("access token is not set: %w", err)
|
||||
}
|
||||
defer memguard.WipeBytes(tokenBytes)
|
||||
|
||||
var signature string
|
||||
hasher := keys.NewSHA256Hashing()
|
||||
|
||||
if signatureExpiration == nil {
|
||||
signature = strings.Join([]string{path, operation, username, string(tokenBytes)}, ":")
|
||||
} else {
|
||||
signature = strings.Join([]string{path, operation, username, string(tokenBytes), strconv.FormatInt(*signatureExpiration, 10)}, ":")
|
||||
}
|
||||
|
||||
return fmt.Sprintf("v1_%s", hasher.HashWithoutPrefix([]byte(signature))), nil
|
||||
}
|
||||
|
||||
func (a *API) validateSigning(r *http.Request, signature *string, signatureExpiration *int, username *string, path string, operation string) (err error) {
|
||||
var expectedSignature string
|
||||
|
||||
// no need to validate signing key if access token is not set
|
||||
if !a.accessToken.IsSet() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// check if access token is sent in the header
|
||||
tokenFromHeader := r.Header.Get(accessTokenHeader)
|
||||
if tokenFromHeader != "" {
|
||||
if !a.accessToken.Equals(tokenFromHeader) {
|
||||
return fmt.Errorf("access token present in header but does not match")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if signature == nil {
|
||||
return fmt.Errorf("missing signature query parameter")
|
||||
}
|
||||
|
||||
// Empty string is used when no username is provided and the default user should be used
|
||||
signatureUsername := ""
|
||||
if username != nil {
|
||||
signatureUsername = *username
|
||||
}
|
||||
|
||||
if signatureExpiration == nil {
|
||||
expectedSignature, err = a.generateSignature(path, signatureUsername, operation, nil)
|
||||
} else {
|
||||
exp := int64(*signatureExpiration)
|
||||
expectedSignature, err = a.generateSignature(path, signatureUsername, operation, &exp)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Msg("error generating signing key")
|
||||
|
||||
return errors.New("invalid signature")
|
||||
}
|
||||
|
||||
// signature validation
|
||||
if expectedSignature != *signature {
|
||||
return fmt.Errorf("invalid signature")
|
||||
}
|
||||
|
||||
// signature expiration
|
||||
if signatureExpiration != nil {
|
||||
exp := int64(*signatureExpiration)
|
||||
if exp < time.Now().Unix() {
|
||||
return fmt.Errorf("signature is already expired")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
62
envd/internal/api/auth_test.go
Normal file
62
envd/internal/api/auth_test.go
Normal file
@ -0,0 +1,62 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
|
||||
)
|
||||
|
||||
func TestKeyGenerationAlgorithmIsStable(t *testing.T) {
|
||||
t.Parallel()
|
||||
apiToken := "secret-access-token"
|
||||
secureToken := &SecureToken{}
|
||||
err := secureToken.Set([]byte(apiToken))
|
||||
require.NoError(t, err)
|
||||
api := &API{accessToken: secureToken}
|
||||
|
||||
path := "/path/to/demo.txt"
|
||||
username := "root"
|
||||
operation := "write"
|
||||
timestamp := time.Now().Unix()
|
||||
|
||||
signature, err := api.generateSignature(path, username, operation, ×tamp)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, signature)
|
||||
|
||||
// locally generated signature
|
||||
hasher := keys.NewSHA256Hashing()
|
||||
localSignatureTmp := fmt.Sprintf("%s:%s:%s:%s:%s", path, operation, username, apiToken, strconv.FormatInt(timestamp, 10))
|
||||
localSignature := fmt.Sprintf("v1_%s", hasher.HashWithoutPrefix([]byte(localSignatureTmp)))
|
||||
|
||||
assert.Equal(t, localSignature, signature)
|
||||
}
|
||||
|
||||
func TestKeyGenerationAlgorithmWithoutExpirationIsStable(t *testing.T) {
|
||||
t.Parallel()
|
||||
apiToken := "secret-access-token"
|
||||
secureToken := &SecureToken{}
|
||||
err := secureToken.Set([]byte(apiToken))
|
||||
require.NoError(t, err)
|
||||
api := &API{accessToken: secureToken}
|
||||
|
||||
path := "/path/to/resource.txt"
|
||||
username := "user"
|
||||
operation := "read"
|
||||
|
||||
signature, err := api.generateSignature(path, username, operation, nil)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, signature)
|
||||
|
||||
// locally generated signature
|
||||
hasher := keys.NewSHA256Hashing()
|
||||
localSignatureTmp := fmt.Sprintf("%s:%s:%s:%s", path, operation, username, apiToken)
|
||||
localSignature := fmt.Sprintf("v1_%s", hasher.HashWithoutPrefix([]byte(localSignatureTmp)))
|
||||
|
||||
assert.Equal(t, localSignature, signature)
|
||||
}
|
||||
8
envd/internal/api/cfg.yaml
Normal file
8
envd/internal/api/cfg.yaml
Normal file
@ -0,0 +1,8 @@
|
||||
# yaml-language-server: $schema=https://raw.githubusercontent.com/deepmap/oapi-codegen/HEAD/configuration-schema.json
|
||||
|
||||
package: api
|
||||
output: api.gen.go
|
||||
generate:
|
||||
models: true
|
||||
chi-server: true
|
||||
client: false
|
||||
173
envd/internal/api/download.go
Normal file
173
envd/internal/api/download.go
Normal file
@ -0,0 +1,173 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||
)
|
||||
|
||||
func (a *API) GetFiles(w http.ResponseWriter, r *http.Request, params GetFilesParams) {
|
||||
defer r.Body.Close()
|
||||
|
||||
var errorCode int
|
||||
var errMsg error
|
||||
|
||||
var path string
|
||||
if params.Path != nil {
|
||||
path = *params.Path
|
||||
}
|
||||
|
||||
operationID := logs.AssignOperationID()
|
||||
|
||||
// signing authorization if needed
|
||||
err := a.validateSigning(r, params.Signature, params.SignatureExpiration, params.Username, path, SigningReadOperation)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("error during auth validation")
|
||||
jsonError(w, http.StatusUnauthorized, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
username, err := execcontext.ResolveDefaultUsername(params.Username, a.defaults.User)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("no user specified")
|
||||
jsonError(w, http.StatusBadRequest, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
l := a.logger.
|
||||
Err(errMsg).
|
||||
Str("method", r.Method+" "+r.URL.Path).
|
||||
Str(string(logs.OperationIDKey), operationID).
|
||||
Str("path", path).
|
||||
Str("username", username)
|
||||
|
||||
if errMsg != nil {
|
||||
l = l.Int("error_code", errorCode)
|
||||
}
|
||||
|
||||
l.Msg("File read")
|
||||
}()
|
||||
|
||||
u, err := user.Lookup(username)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error looking up user '%s': %w", username, err)
|
||||
errorCode = http.StatusUnauthorized
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resolvedPath, err := permissions.ExpandAndResolve(path, u, a.defaults.Workdir)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error expanding and resolving path '%s': %w", path, err)
|
||||
errorCode = http.StatusBadRequest
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
stat, err := os.Stat(resolvedPath)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
errMsg = fmt.Errorf("path '%s' does not exist", resolvedPath)
|
||||
errorCode = http.StatusNotFound
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
errMsg = fmt.Errorf("error checking if path exists '%s': %w", resolvedPath, err)
|
||||
errorCode = http.StatusInternalServerError
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if stat.IsDir() {
|
||||
errMsg = fmt.Errorf("path '%s' is a directory", resolvedPath)
|
||||
errorCode = http.StatusBadRequest
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Validate Accept-Encoding header
|
||||
encoding, err := parseAcceptEncoding(r)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error parsing Accept-Encoding: %w", err)
|
||||
errorCode = http.StatusNotAcceptable
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Tell caches to store separate variants for different Accept-Encoding values
|
||||
w.Header().Set("Vary", "Accept-Encoding")
|
||||
|
||||
// Fall back to identity for Range or conditional requests to preserve http.ServeContent
|
||||
// behavior (206 Partial Content, 304 Not Modified). However, we must check if identity
|
||||
// is acceptable per the Accept-Encoding header.
|
||||
hasRangeOrConditional := r.Header.Get("Range") != "" ||
|
||||
r.Header.Get("If-Modified-Since") != "" ||
|
||||
r.Header.Get("If-None-Match") != "" ||
|
||||
r.Header.Get("If-Range") != ""
|
||||
if hasRangeOrConditional {
|
||||
if !isIdentityAcceptable(r) {
|
||||
errMsg = fmt.Errorf("identity encoding not acceptable for Range or conditional request")
|
||||
errorCode = http.StatusNotAcceptable
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
encoding = EncodingIdentity
|
||||
}
|
||||
|
||||
file, err := os.Open(resolvedPath)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error opening file '%s': %w", resolvedPath, err)
|
||||
errorCode = http.StatusInternalServerError
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
w.Header().Set("Content-Disposition", mime.FormatMediaType("inline", map[string]string{"filename": filepath.Base(resolvedPath)}))
|
||||
|
||||
// Serve with gzip encoding if requested.
|
||||
if encoding == EncodingGzip {
|
||||
w.Header().Set("Content-Encoding", EncodingGzip)
|
||||
|
||||
// Set Content-Type based on file extension, preserving the original type
|
||||
contentType := mime.TypeByExtension(filepath.Ext(path))
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
|
||||
gw := gzip.NewWriter(w)
|
||||
defer gw.Close()
|
||||
|
||||
_, err = io.Copy(gw, file)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("error writing gzip response")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
http.ServeContent(w, r, path, stat.ModTime(), file)
|
||||
}
|
||||
401
envd/internal/api/download_test.go
Normal file
401
envd/internal/api/download_test.go
Normal file
@ -0,0 +1,401 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
func TestGetFilesContentDisposition(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filename string
|
||||
expectedHeader string
|
||||
}{
|
||||
{
|
||||
name: "simple filename",
|
||||
filename: "test.txt",
|
||||
expectedHeader: `inline; filename=test.txt`,
|
||||
},
|
||||
{
|
||||
name: "filename with extension",
|
||||
filename: "presentation.pptx",
|
||||
expectedHeader: `inline; filename=presentation.pptx`,
|
||||
},
|
||||
{
|
||||
name: "filename with multiple dots",
|
||||
filename: "archive.tar.gz",
|
||||
expectedHeader: `inline; filename=archive.tar.gz`,
|
||||
},
|
||||
{
|
||||
name: "filename with spaces",
|
||||
filename: "my document.pdf",
|
||||
expectedHeader: `inline; filename="my document.pdf"`,
|
||||
},
|
||||
{
|
||||
name: "filename with quotes",
|
||||
filename: `file"name.txt`,
|
||||
expectedHeader: `inline; filename="file\"name.txt"`,
|
||||
},
|
||||
{
|
||||
name: "filename with backslash",
|
||||
filename: `file\name.txt`,
|
||||
expectedHeader: `inline; filename="file\\name.txt"`,
|
||||
},
|
||||
{
|
||||
name: "unicode filename",
|
||||
filename: "\u6587\u6863.pdf", // 文档.pdf in Chinese
|
||||
expectedHeader: "inline; filename*=utf-8''%E6%96%87%E6%A1%A3.pdf",
|
||||
},
|
||||
{
|
||||
name: "dotfile preserved",
|
||||
filename: ".env",
|
||||
expectedHeader: `inline; filename=.env`,
|
||||
},
|
||||
{
|
||||
name: "dotfile with extension preserved",
|
||||
filename: ".gitignore",
|
||||
expectedHeader: `inline; filename=.gitignore`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a temp directory and file
|
||||
tempDir := t.TempDir()
|
||||
tempFile := filepath.Join(tempDir, tt.filename)
|
||||
err := os.WriteFile(tempFile, []byte("test content"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test API
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
User: currentUser.Username,
|
||||
}
|
||||
api := New(&logger, defaults, nil, false)
|
||||
|
||||
// Create request and response recorder
|
||||
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call the handler
|
||||
params := GetFilesParams{
|
||||
Path: &tempFile,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.GetFiles(w, req, params)
|
||||
|
||||
// Check response
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Verify Content-Disposition header
|
||||
contentDisposition := resp.Header.Get("Content-Disposition")
|
||||
assert.Equal(t, tt.expectedHeader, contentDisposition, "Content-Disposition header should be set with correct filename")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFilesContentDispositionWithNestedPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a temp directory with nested structure
|
||||
tempDir := t.TempDir()
|
||||
nestedDir := filepath.Join(tempDir, "subdir", "another")
|
||||
err = os.MkdirAll(nestedDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
filename := "document.pdf"
|
||||
tempFile := filepath.Join(nestedDir, filename)
|
||||
err = os.WriteFile(tempFile, []byte("test content"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test API
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
User: currentUser.Username,
|
||||
}
|
||||
api := New(&logger, defaults, nil, false)
|
||||
|
||||
// Create request and response recorder
|
||||
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call the handler
|
||||
params := GetFilesParams{
|
||||
Path: &tempFile,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.GetFiles(w, req, params)
|
||||
|
||||
// Check response
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Verify Content-Disposition header uses only the base filename, not the full path
|
||||
contentDisposition := resp.Header.Get("Content-Disposition")
|
||||
assert.Equal(t, `inline; filename=document.pdf`, contentDisposition, "Content-Disposition should contain only the filename, not the path")
|
||||
}
|
||||
|
||||
func TestGetFiles_GzipEncoding_ExplicitIdentityOffWithRange(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a temp directory with a test file
|
||||
tempDir := t.TempDir()
|
||||
filename := "document.pdf"
|
||||
tempFile := filepath.Join(tempDir, filename)
|
||||
err = os.WriteFile(tempFile, []byte("test content"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test API
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
User: currentUser.Username,
|
||||
}
|
||||
api := New(&logger, defaults, nil, false)
|
||||
|
||||
// Create request and response recorder
|
||||
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip; q=1,*; q=0")
|
||||
req.Header.Set("Range", "bytes=0-4") // Request first 5 bytes
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call the handler
|
||||
params := GetFilesParams{
|
||||
Path: &tempFile,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.GetFiles(w, req, params)
|
||||
|
||||
// Check response
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusNotAcceptable, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestGetFiles_GzipDownload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
originalContent := []byte("hello world, this is a test file for gzip compression")
|
||||
|
||||
// Create a temp file with known content
|
||||
tempDir := t.TempDir()
|
||||
tempFile := filepath.Join(tempDir, "test.txt")
|
||||
err = os.WriteFile(tempFile, originalContent, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
User: currentUser.Username,
|
||||
}
|
||||
api := New(&logger, defaults, nil, false)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
params := GetFilesParams{
|
||||
Path: &tempFile,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.GetFiles(w, req, params)
|
||||
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, "gzip", resp.Header.Get("Content-Encoding"))
|
||||
assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type"))
|
||||
|
||||
// Decompress the gzip response body
|
||||
gzReader, err := gzip.NewReader(resp.Body)
|
||||
require.NoError(t, err)
|
||||
defer gzReader.Close()
|
||||
|
||||
decompressed, err := io.ReadAll(gzReader)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, originalContent, decompressed)
|
||||
}
|
||||
|
||||
func TestPostFiles_GzipUpload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
originalContent := []byte("hello world, this is a test file uploaded with gzip")
|
||||
|
||||
// Build a multipart body
|
||||
var multipartBuf bytes.Buffer
|
||||
mpWriter := multipart.NewWriter(&multipartBuf)
|
||||
part, err := mpWriter.CreateFormFile("file", "uploaded.txt")
|
||||
require.NoError(t, err)
|
||||
_, err = part.Write(originalContent)
|
||||
require.NoError(t, err)
|
||||
err = mpWriter.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Gzip-compress the entire multipart body
|
||||
var gzBuf bytes.Buffer
|
||||
gzWriter := gzip.NewWriter(&gzBuf)
|
||||
_, err = gzWriter.Write(multipartBuf.Bytes())
|
||||
require.NoError(t, err)
|
||||
err = gzWriter.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test API
|
||||
tempDir := t.TempDir()
|
||||
destPath := filepath.Join(tempDir, "uploaded.txt")
|
||||
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
User: currentUser.Username,
|
||||
}
|
||||
api := New(&logger, defaults, nil, false)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/files?path="+url.QueryEscape(destPath), &gzBuf)
|
||||
req.Header.Set("Content-Type", mpWriter.FormDataContentType())
|
||||
req.Header.Set("Content-Encoding", "gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
params := PostFilesParams{
|
||||
Path: &destPath,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.PostFiles(w, req, params)
|
||||
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Verify the file was written with the original (decompressed) content
|
||||
data, err := os.ReadFile(destPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, originalContent, data)
|
||||
}
|
||||
|
||||
func TestGzipUploadThenGzipDownload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
originalContent := []byte("round-trip gzip test: upload compressed, download compressed, verify match")
|
||||
|
||||
// --- Upload with gzip ---
|
||||
|
||||
// Build a multipart body
|
||||
var multipartBuf bytes.Buffer
|
||||
mpWriter := multipart.NewWriter(&multipartBuf)
|
||||
part, err := mpWriter.CreateFormFile("file", "roundtrip.txt")
|
||||
require.NoError(t, err)
|
||||
_, err = part.Write(originalContent)
|
||||
require.NoError(t, err)
|
||||
err = mpWriter.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Gzip-compress the entire multipart body
|
||||
var gzBuf bytes.Buffer
|
||||
gzWriter := gzip.NewWriter(&gzBuf)
|
||||
_, err = gzWriter.Write(multipartBuf.Bytes())
|
||||
require.NoError(t, err)
|
||||
err = gzWriter.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
destPath := filepath.Join(tempDir, "roundtrip.txt")
|
||||
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
User: currentUser.Username,
|
||||
}
|
||||
api := New(&logger, defaults, nil, false)
|
||||
|
||||
uploadReq := httptest.NewRequest(http.MethodPost, "/files?path="+url.QueryEscape(destPath), &gzBuf)
|
||||
uploadReq.Header.Set("Content-Type", mpWriter.FormDataContentType())
|
||||
uploadReq.Header.Set("Content-Encoding", "gzip")
|
||||
uploadW := httptest.NewRecorder()
|
||||
|
||||
uploadParams := PostFilesParams{
|
||||
Path: &destPath,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.PostFiles(uploadW, uploadReq, uploadParams)
|
||||
|
||||
uploadResp := uploadW.Result()
|
||||
defer uploadResp.Body.Close()
|
||||
|
||||
require.Equal(t, http.StatusOK, uploadResp.StatusCode)
|
||||
|
||||
// --- Download with gzip ---
|
||||
|
||||
downloadReq := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(destPath), nil)
|
||||
downloadReq.Header.Set("Accept-Encoding", "gzip")
|
||||
downloadW := httptest.NewRecorder()
|
||||
|
||||
downloadParams := GetFilesParams{
|
||||
Path: &destPath,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.GetFiles(downloadW, downloadReq, downloadParams)
|
||||
|
||||
downloadResp := downloadW.Result()
|
||||
defer downloadResp.Body.Close()
|
||||
|
||||
require.Equal(t, http.StatusOK, downloadResp.StatusCode)
|
||||
assert.Equal(t, "gzip", downloadResp.Header.Get("Content-Encoding"))
|
||||
|
||||
// Decompress and verify content matches original
|
||||
gzReader, err := gzip.NewReader(downloadResp.Body)
|
||||
require.NoError(t, err)
|
||||
defer gzReader.Close()
|
||||
|
||||
decompressed, err := io.ReadAll(gzReader)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, originalContent, decompressed)
|
||||
}
|
||||
227
envd/internal/api/encoding.go
Normal file
227
envd/internal/api/encoding.go
Normal file
@ -0,0 +1,227 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// EncodingGzip is the gzip content encoding.
|
||||
EncodingGzip = "gzip"
|
||||
// EncodingIdentity means no encoding (passthrough).
|
||||
EncodingIdentity = "identity"
|
||||
// EncodingWildcard means any encoding is acceptable.
|
||||
EncodingWildcard = "*"
|
||||
)
|
||||
|
||||
// SupportedEncodings lists the content encodings supported for file transfer.
|
||||
// The order matters - encodings are checked in order of preference.
|
||||
var SupportedEncodings = []string{
|
||||
EncodingGzip,
|
||||
}
|
||||
|
||||
// encodingWithQuality holds an encoding name and its quality value.
|
||||
type encodingWithQuality struct {
|
||||
encoding string
|
||||
quality float64
|
||||
}
|
||||
|
||||
// isSupportedEncoding checks if the given encoding is in the supported list.
|
||||
// Per RFC 7231, content-coding values are case-insensitive.
|
||||
func isSupportedEncoding(encoding string) bool {
|
||||
return slices.Contains(SupportedEncodings, strings.ToLower(encoding))
|
||||
}
|
||||
|
||||
// parseEncodingWithQuality parses an encoding value and extracts the quality.
|
||||
// Returns the encoding name (lowercased) and quality value (default 1.0 if not specified).
|
||||
// Per RFC 7231, content-coding values are case-insensitive.
|
||||
func parseEncodingWithQuality(value string) encodingWithQuality {
|
||||
value = strings.TrimSpace(value)
|
||||
quality := 1.0
|
||||
|
||||
if idx := strings.Index(value, ";"); idx != -1 {
|
||||
params := value[idx+1:]
|
||||
value = strings.TrimSpace(value[:idx])
|
||||
|
||||
// Parse q=X.X parameter
|
||||
for param := range strings.SplitSeq(params, ";") {
|
||||
param = strings.TrimSpace(param)
|
||||
if strings.HasPrefix(strings.ToLower(param), "q=") {
|
||||
if q, err := strconv.ParseFloat(param[2:], 64); err == nil {
|
||||
quality = q
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize encoding to lowercase per RFC 7231
|
||||
return encodingWithQuality{encoding: strings.ToLower(value), quality: quality}
|
||||
}
|
||||
|
||||
// parseEncoding extracts the encoding name from a header value, stripping quality.
|
||||
func parseEncoding(value string) string {
|
||||
return parseEncodingWithQuality(value).encoding
|
||||
}
|
||||
|
||||
// parseContentEncoding parses the Content-Encoding header and returns the encoding.
|
||||
// Returns an error if an unsupported encoding is specified.
|
||||
// If no Content-Encoding header is present, returns empty string.
|
||||
func parseContentEncoding(r *http.Request) (string, error) {
|
||||
header := r.Header.Get("Content-Encoding")
|
||||
if header == "" {
|
||||
return EncodingIdentity, nil
|
||||
}
|
||||
|
||||
encoding := parseEncoding(header)
|
||||
|
||||
if encoding == EncodingIdentity {
|
||||
return EncodingIdentity, nil
|
||||
}
|
||||
|
||||
if !isSupportedEncoding(encoding) {
|
||||
return "", fmt.Errorf("unsupported Content-Encoding: %s, supported: %v", header, SupportedEncodings)
|
||||
}
|
||||
|
||||
return encoding, nil
|
||||
}
|
||||
|
||||
// parseAcceptEncodingHeader parses the Accept-Encoding header and returns
|
||||
// the parsed encodings along with the identity rejection state.
|
||||
// Per RFC 7231 Section 5.3.4, identity is acceptable unless excluded by
|
||||
// "identity;q=0" or "*;q=0" without a more specific entry for identity with q>0.
|
||||
func parseAcceptEncodingHeader(header string) ([]encodingWithQuality, bool) {
|
||||
if header == "" {
|
||||
return nil, false // identity not rejected when header is empty
|
||||
}
|
||||
|
||||
// Parse all encodings with their quality values
|
||||
var encodings []encodingWithQuality
|
||||
for value := range strings.SplitSeq(header, ",") {
|
||||
eq := parseEncodingWithQuality(value)
|
||||
encodings = append(encodings, eq)
|
||||
}
|
||||
|
||||
// Check if identity is rejected per RFC 7231 Section 5.3.4:
|
||||
// identity is acceptable unless excluded by "identity;q=0" or "*;q=0"
|
||||
// without a more specific entry for identity with q>0.
|
||||
identityRejected := false
|
||||
identityExplicitlyAccepted := false
|
||||
wildcardRejected := false
|
||||
|
||||
for _, eq := range encodings {
|
||||
switch eq.encoding {
|
||||
case EncodingIdentity:
|
||||
if eq.quality == 0 {
|
||||
identityRejected = true
|
||||
} else {
|
||||
identityExplicitlyAccepted = true
|
||||
}
|
||||
case EncodingWildcard:
|
||||
if eq.quality == 0 {
|
||||
wildcardRejected = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if wildcardRejected && !identityExplicitlyAccepted {
|
||||
identityRejected = true
|
||||
}
|
||||
|
||||
return encodings, identityRejected
|
||||
}
|
||||
|
||||
// isIdentityAcceptable checks if identity encoding is acceptable based on the
|
||||
// Accept-Encoding header. Per RFC 7231 section 5.3.4, identity is always
|
||||
// implicitly acceptable unless explicitly rejected with q=0.
|
||||
func isIdentityAcceptable(r *http.Request) bool {
|
||||
header := r.Header.Get("Accept-Encoding")
|
||||
_, identityRejected := parseAcceptEncodingHeader(header)
|
||||
|
||||
return !identityRejected
|
||||
}
|
||||
|
||||
// parseAcceptEncoding parses the Accept-Encoding header and returns the best
|
||||
// supported encoding based on quality values. Per RFC 7231 section 5.3.4,
|
||||
// identity is always implicitly acceptable unless explicitly rejected with q=0.
|
||||
// If no Accept-Encoding header is present, returns empty string (identity).
|
||||
func parseAcceptEncoding(r *http.Request) (string, error) {
|
||||
header := r.Header.Get("Accept-Encoding")
|
||||
if header == "" {
|
||||
return EncodingIdentity, nil
|
||||
}
|
||||
|
||||
encodings, identityRejected := parseAcceptEncodingHeader(header)
|
||||
|
||||
// Sort by quality value (highest first)
|
||||
sort.Slice(encodings, func(i, j int) bool {
|
||||
return encodings[i].quality > encodings[j].quality
|
||||
})
|
||||
|
||||
// Find the best supported encoding
|
||||
for _, eq := range encodings {
|
||||
// Skip encodings with q=0 (explicitly rejected)
|
||||
if eq.quality == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if eq.encoding == EncodingIdentity {
|
||||
return EncodingIdentity, nil
|
||||
}
|
||||
|
||||
// Wildcard means any encoding is acceptable - return a supported encoding if identity is rejected
|
||||
if eq.encoding == EncodingWildcard {
|
||||
if identityRejected && len(SupportedEncodings) > 0 {
|
||||
return SupportedEncodings[0], nil
|
||||
}
|
||||
|
||||
return EncodingIdentity, nil
|
||||
}
|
||||
|
||||
if isSupportedEncoding(eq.encoding) {
|
||||
return eq.encoding, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Per RFC 7231, identity is implicitly acceptable unless rejected
|
||||
if !identityRejected {
|
||||
return EncodingIdentity, nil
|
||||
}
|
||||
|
||||
// Identity rejected and no supported encodings found
|
||||
return "", fmt.Errorf("no acceptable encoding found, supported: %v", SupportedEncodings)
|
||||
}
|
||||
|
||||
// getDecompressedBody returns a reader that decompresses the request body based on
|
||||
// Content-Encoding header. Returns the original body if no encoding is specified.
|
||||
// Returns an error if an unsupported encoding is specified.
|
||||
// The caller is responsible for closing both the returned ReadCloser and the
|
||||
// original request body (r.Body) separately.
|
||||
func getDecompressedBody(r *http.Request) (io.ReadCloser, error) {
|
||||
encoding, err := parseContentEncoding(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if encoding == EncodingIdentity {
|
||||
return r.Body, nil
|
||||
}
|
||||
|
||||
switch encoding {
|
||||
case EncodingGzip:
|
||||
gzReader, err := gzip.NewReader(r.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
|
||||
}
|
||||
|
||||
return gzReader, nil
|
||||
default:
|
||||
// This shouldn't happen if isSupportedEncoding is correct
|
||||
return nil, fmt.Errorf("encoding %s is supported but not implemented", encoding)
|
||||
}
|
||||
}
|
||||
494
envd/internal/api/encoding_test.go
Normal file
494
envd/internal/api/encoding_test.go
Normal file
@ -0,0 +1,494 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsSupportedEncoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("gzip is supported", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.True(t, isSupportedEncoding("gzip"))
|
||||
})
|
||||
|
||||
t.Run("GZIP is supported (case-insensitive)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.True(t, isSupportedEncoding("GZIP"))
|
||||
})
|
||||
|
||||
t.Run("Gzip is supported (case-insensitive)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.True(t, isSupportedEncoding("Gzip"))
|
||||
})
|
||||
|
||||
t.Run("br is not supported", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.False(t, isSupportedEncoding("br"))
|
||||
})
|
||||
|
||||
t.Run("deflate is not supported", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.False(t, isSupportedEncoding("deflate"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseEncodingWithQuality(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns encoding with default quality 1.0", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("gzip")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 1.0, eq.quality, 0.001)
|
||||
})
|
||||
|
||||
t.Run("parses quality value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("gzip;q=0.5")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 0.5, eq.quality, 0.001)
|
||||
})
|
||||
|
||||
t.Run("parses quality value with whitespace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("gzip ; q=0.8")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 0.8, eq.quality, 0.001)
|
||||
})
|
||||
|
||||
t.Run("handles q=0", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("gzip;q=0")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 0.0, eq.quality, 0.001)
|
||||
})
|
||||
|
||||
t.Run("handles invalid quality value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("gzip;q=invalid")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 1.0, eq.quality, 0.001) // defaults to 1.0 on parse error
|
||||
})
|
||||
|
||||
t.Run("trims whitespace from encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality(" gzip ")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 1.0, eq.quality, 0.001)
|
||||
})
|
||||
|
||||
t.Run("normalizes encoding to lowercase", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("GZIP")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
})
|
||||
|
||||
t.Run("normalizes mixed case encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("Gzip;q=0.5")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 0.5, eq.quality, 0.001)
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseEncoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns encoding as-is", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Equal(t, "gzip", parseEncoding("gzip"))
|
||||
})
|
||||
|
||||
t.Run("trims whitespace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Equal(t, "gzip", parseEncoding(" gzip "))
|
||||
})
|
||||
|
||||
t.Run("strips quality value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Equal(t, "gzip", parseEncoding("gzip;q=1.0"))
|
||||
})
|
||||
|
||||
t.Run("strips quality value with whitespace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Equal(t, "gzip", parseEncoding("gzip ; q=0.5"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseContentEncoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns identity when no header", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
|
||||
encoding, err := parseContentEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when Content-Encoding is gzip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Content-Encoding", "gzip")
|
||||
|
||||
encoding, err := parseContentEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when Content-Encoding is GZIP (case-insensitive)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Content-Encoding", "GZIP")
|
||||
|
||||
encoding, err := parseContentEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when Content-Encoding is Gzip (case-insensitive)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Content-Encoding", "Gzip")
|
||||
|
||||
encoding, err := parseContentEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns identity for identity encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Content-Encoding", "identity")
|
||||
|
||||
encoding, err := parseContentEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("returns error for unsupported encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Content-Encoding", "br")
|
||||
|
||||
_, err := parseContentEncoding(req)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported Content-Encoding")
|
||||
assert.Contains(t, err.Error(), "supported: [gzip]")
|
||||
})
|
||||
|
||||
t.Run("handles gzip with quality value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Content-Encoding", "gzip;q=1.0")
|
||||
|
||||
encoding, err := parseContentEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseAcceptEncoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns identity when no header", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when Accept-Encoding is gzip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when Accept-Encoding is GZIP (case-insensitive)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "GZIP")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when gzip is among multiple encodings", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "deflate, gzip, br")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip with quality value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip;q=1.0")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns identity for identity encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "identity")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("returns identity for wildcard encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "*")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("falls back to identity for unsupported encoding only", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "br")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("falls back to identity when only unsupported encodings", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "deflate, br")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("selects gzip when it has highest quality", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "br;q=0.5, gzip;q=1.0, deflate;q=0.8")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("selects gzip even with lower quality when others unsupported", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "br;q=1.0, gzip;q=0.5")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns identity when it has higher quality than gzip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip;q=0.5, identity;q=1.0")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("skips encoding with q=0", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip;q=0, identity")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("falls back to identity when gzip rejected and no other supported", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip;q=0, br")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("returns error when identity explicitly rejected and no supported encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "br, identity;q=0")
|
||||
|
||||
_, err := parseAcceptEncoding(req)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no acceptable encoding found")
|
||||
})
|
||||
|
||||
t.Run("returns gzip for wildcard when identity rejected", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "*, identity;q=0")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding) // wildcard with identity rejected returns supported encoding
|
||||
})
|
||||
|
||||
t.Run("returns error when wildcard rejected and no explicit identity", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "*;q=0")
|
||||
|
||||
_, err := parseAcceptEncoding(req)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no acceptable encoding found")
|
||||
})
|
||||
|
||||
t.Run("returns identity when wildcard rejected but identity explicitly accepted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "*;q=0, identity")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when wildcard rejected but gzip explicitly accepted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "*;q=0, gzip")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingGzip, encoding)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetDecompressedBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns original body when no Content-Encoding header", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
content := []byte("test content")
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(content))
|
||||
|
||||
body, err := getDecompressedBody(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, req.Body, body, "should return original body")
|
||||
|
||||
data, err := io.ReadAll(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, data)
|
||||
})
|
||||
|
||||
t.Run("decompresses gzip body when Content-Encoding is gzip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
originalContent := []byte("test content to compress")
|
||||
|
||||
var compressed bytes.Buffer
|
||||
gw := gzip.NewWriter(&compressed)
|
||||
_, err := gw.Write(originalContent)
|
||||
require.NoError(t, err)
|
||||
err = gw.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(compressed.Bytes()))
|
||||
req.Header.Set("Content-Encoding", "gzip")
|
||||
|
||||
body, err := getDecompressedBody(req)
|
||||
require.NoError(t, err)
|
||||
defer body.Close()
|
||||
|
||||
assert.NotEqual(t, req.Body, body, "should return a new gzip reader")
|
||||
|
||||
data, err := io.ReadAll(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, originalContent, data)
|
||||
})
|
||||
|
||||
t.Run("returns error for invalid gzip data", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
invalidGzip := []byte("this is not gzip data")
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(invalidGzip))
|
||||
req.Header.Set("Content-Encoding", "gzip")
|
||||
|
||||
_, err := getDecompressedBody(req)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to create gzip reader")
|
||||
})
|
||||
|
||||
t.Run("returns original body for identity encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
content := []byte("test content")
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(content))
|
||||
req.Header.Set("Content-Encoding", "identity")
|
||||
|
||||
body, err := getDecompressedBody(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, req.Body, body, "should return original body")
|
||||
|
||||
data, err := io.ReadAll(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, data)
|
||||
})
|
||||
|
||||
t.Run("returns error for unsupported encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
content := []byte("test content")
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(content))
|
||||
req.Header.Set("Content-Encoding", "br")
|
||||
|
||||
_, err := getDecompressedBody(req)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported Content-Encoding")
|
||||
})
|
||||
|
||||
t.Run("handles gzip with quality value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
originalContent := []byte("test content to compress")
|
||||
|
||||
var compressed bytes.Buffer
|
||||
gw := gzip.NewWriter(&compressed)
|
||||
_, err := gw.Write(originalContent)
|
||||
require.NoError(t, err)
|
||||
err = gw.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(compressed.Bytes()))
|
||||
req.Header.Set("Content-Encoding", "gzip;q=1.0")
|
||||
|
||||
body, err := getDecompressedBody(req)
|
||||
require.NoError(t, err)
|
||||
defer body.Close()
|
||||
|
||||
data, err := io.ReadAll(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, originalContent, data)
|
||||
})
|
||||
}
|
||||
29
envd/internal/api/envs.go
Normal file
29
envd/internal/api/envs.go
Normal file
@ -0,0 +1,29 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
)
|
||||
|
||||
func (a *API) GetEnvs(w http.ResponseWriter, _ *http.Request) {
|
||||
operationID := logs.AssignOperationID()
|
||||
|
||||
a.logger.Debug().Str(string(logs.OperationIDKey), operationID).Msg("Getting env vars")
|
||||
|
||||
envs := make(EnvVars)
|
||||
a.defaults.EnvVars.Range(func(key, value string) bool {
|
||||
envs[key] = value
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(envs); err != nil {
|
||||
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("Failed to encode env vars")
|
||||
}
|
||||
}
|
||||
21
envd/internal/api/error.go
Normal file
21
envd/internal/api/error.go
Normal file
@ -0,0 +1,21 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func jsonError(w http.ResponseWriter, code int, err error) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
|
||||
w.WriteHeader(code)
|
||||
encodeErr := json.NewEncoder(w).Encode(Error{
|
||||
Code: code,
|
||||
Message: err.Error(),
|
||||
})
|
||||
if encodeErr != nil {
|
||||
http.Error(w, errors.Join(encodeErr, err).Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
3
envd/internal/api/generate.go
Normal file
3
envd/internal/api/generate.go
Normal file
@ -0,0 +1,3 @@
|
||||
package api
|
||||
|
||||
//go:generate go run github.com/oapi-codegen/oapi-codegen/v2/cmd/oapi-codegen -config cfg.yaml ../../spec/envd.yaml
|
||||
314
envd/internal/api/init.go
Normal file
314
envd/internal/api/init.go
Normal file
@ -0,0 +1,314 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/txn2/txeh"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAccessTokenMismatch = errors.New("access token validation failed")
|
||||
ErrAccessTokenResetNotAuthorized = errors.New("access token reset not authorized")
|
||||
)
|
||||
|
||||
const (
|
||||
maxTimeInPast = 50 * time.Millisecond
|
||||
maxTimeInFuture = 5 * time.Second
|
||||
)
|
||||
|
||||
// validateInitAccessToken validates the access token for /init requests.
|
||||
// Token is valid if it matches the existing token OR the MMDS hash.
|
||||
// If neither exists, first-time setup is allowed.
|
||||
func (a *API) validateInitAccessToken(ctx context.Context, requestToken *SecureToken) error {
|
||||
requestTokenSet := requestToken.IsSet()
|
||||
|
||||
// Fast path: token matches existing
|
||||
if a.accessToken.IsSet() && requestTokenSet && a.accessToken.EqualsSecure(requestToken) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check MMDS only if token didn't match existing
|
||||
matchesMMDS, mmdsExists := a.checkMMDSHash(ctx, requestToken)
|
||||
|
||||
switch {
|
||||
case matchesMMDS:
|
||||
return nil
|
||||
case !a.accessToken.IsSet() && !mmdsExists:
|
||||
return nil // first-time setup
|
||||
case !requestTokenSet:
|
||||
return ErrAccessTokenResetNotAuthorized
|
||||
default:
|
||||
return ErrAccessTokenMismatch
|
||||
}
|
||||
}
|
||||
|
||||
// checkMMDSHash checks if the request token matches the MMDS hash.
|
||||
// Returns (matches, mmdsExists).
|
||||
//
|
||||
// The MMDS hash is set by the orchestrator during Resume:
|
||||
// - hash(token): requires this specific token
|
||||
// - hash(""): explicitly allows nil token (token reset authorized)
|
||||
// - "": MMDS not properly configured, no authorization granted
|
||||
func (a *API) checkMMDSHash(ctx context.Context, requestToken *SecureToken) (bool, bool) {
|
||||
if a.isNotFC {
|
||||
return false, false
|
||||
}
|
||||
|
||||
mmdsHash, err := a.mmdsClient.GetAccessTokenHash(ctx)
|
||||
if err != nil {
|
||||
return false, false
|
||||
}
|
||||
|
||||
if mmdsHash == "" {
|
||||
return false, false
|
||||
}
|
||||
|
||||
if !requestToken.IsSet() {
|
||||
return mmdsHash == keys.HashAccessToken(""), true
|
||||
}
|
||||
|
||||
tokenBytes, err := requestToken.Bytes()
|
||||
if err != nil {
|
||||
return false, true
|
||||
}
|
||||
defer memguard.WipeBytes(tokenBytes)
|
||||
|
||||
return keys.HashAccessTokenBytes(tokenBytes) == mmdsHash, true
|
||||
}
|
||||
|
||||
func (a *API) PostInit(w http.ResponseWriter, r *http.Request) {
|
||||
defer r.Body.Close()
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
operationID := logs.AssignOperationID()
|
||||
logger := a.logger.With().Str(string(logs.OperationIDKey), operationID).Logger()
|
||||
|
||||
if r.Body != nil {
|
||||
// Read raw body so we can wipe it after parsing
|
||||
body, err := io.ReadAll(r.Body)
|
||||
// Ensure body is wiped after we're done
|
||||
defer memguard.WipeBytes(body)
|
||||
if err != nil {
|
||||
logger.Error().Msgf("Failed to read request body: %v", err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var initRequest PostInitJSONBody
|
||||
if len(body) > 0 {
|
||||
err = json.Unmarshal(body, &initRequest)
|
||||
if err != nil {
|
||||
logger.Error().Msgf("Failed to decode request: %v", err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure request token is destroyed if not transferred via TakeFrom.
|
||||
// This handles: validation failures, timestamp-based skips, and any early returns.
|
||||
// Safe because Destroy() is nil-safe and TakeFrom clears the source.
|
||||
defer initRequest.AccessToken.Destroy()
|
||||
|
||||
a.initLock.Lock()
|
||||
defer a.initLock.Unlock()
|
||||
|
||||
// Update data only if the request is newer or if there's no timestamp at all
|
||||
if initRequest.Timestamp == nil || a.lastSetTime.SetToGreater(initRequest.Timestamp.UnixNano()) {
|
||||
err = a.SetData(ctx, logger, initRequest)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ErrAccessTokenMismatch), errors.Is(err, ErrAccessTokenResetNotAuthorized):
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
default:
|
||||
logger.Error().Msgf("Failed to set data: %v", err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}
|
||||
w.Write([]byte(err.Error()))
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
go func() { //nolint:contextcheck // TODO: fix this later
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
host.PollForMMDSOpts(ctx, a.mmdsChan, a.defaults.EnvVars)
|
||||
}()
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.Header().Set("Content-Type", "")
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (a *API) SetData(ctx context.Context, logger zerolog.Logger, data PostInitJSONBody) error {
|
||||
// Validate access token before proceeding with any action
|
||||
// The request must provide a token that is either:
|
||||
// 1. Matches the existing access token (if set), OR
|
||||
// 2. Matches the MMDS hash (for token change during resume)
|
||||
if err := a.validateInitAccessToken(ctx, data.AccessToken); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if data.Timestamp != nil {
|
||||
// Check if current time differs significantly from the received timestamp
|
||||
if shouldSetSystemTime(time.Now(), *data.Timestamp) {
|
||||
logger.Debug().Msgf("Setting sandbox start time to: %v", *data.Timestamp)
|
||||
ts := unix.NsecToTimespec(data.Timestamp.UnixNano())
|
||||
err := unix.ClockSettime(unix.CLOCK_REALTIME, &ts)
|
||||
if err != nil {
|
||||
logger.Error().Msgf("Failed to set system time: %v", err)
|
||||
}
|
||||
} else {
|
||||
logger.Debug().Msgf("Current time is within acceptable range of timestamp %v, not setting system time", *data.Timestamp)
|
||||
}
|
||||
}
|
||||
|
||||
if data.EnvVars != nil {
|
||||
logger.Debug().Msg(fmt.Sprintf("Setting %d env vars", len(*data.EnvVars)))
|
||||
|
||||
for key, value := range *data.EnvVars {
|
||||
logger.Debug().Msgf("Setting env var for %s", key)
|
||||
a.defaults.EnvVars.Store(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
if data.AccessToken.IsSet() {
|
||||
logger.Debug().Msg("Setting access token")
|
||||
a.accessToken.TakeFrom(data.AccessToken)
|
||||
} else if a.accessToken.IsSet() {
|
||||
logger.Debug().Msg("Clearing access token")
|
||||
a.accessToken.Destroy()
|
||||
}
|
||||
|
||||
if data.HyperloopIP != nil {
|
||||
go a.SetupHyperloop(*data.HyperloopIP)
|
||||
}
|
||||
|
||||
if data.DefaultUser != nil && *data.DefaultUser != "" {
|
||||
logger.Debug().Msgf("Setting default user to: %s", *data.DefaultUser)
|
||||
a.defaults.User = *data.DefaultUser
|
||||
}
|
||||
|
||||
if data.DefaultWorkdir != nil && *data.DefaultWorkdir != "" {
|
||||
logger.Debug().Msgf("Setting default workdir to: %s", *data.DefaultWorkdir)
|
||||
a.defaults.Workdir = data.DefaultWorkdir
|
||||
}
|
||||
|
||||
if data.VolumeMounts != nil {
|
||||
for _, volume := range *data.VolumeMounts {
|
||||
logger.Debug().Msgf("Mounting %s at %q", volume.NfsTarget, volume.Path)
|
||||
|
||||
go a.setupNfs(context.WithoutCancel(ctx), volume.NfsTarget, volume.Path)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *API) setupNfs(ctx context.Context, nfsTarget, path string) {
|
||||
commands := [][]string{
|
||||
{"mkdir", "-p", path},
|
||||
{"mount", "-v", "-t", "nfs", "-o", "mountproto=tcp,mountport=2049,proto=tcp,port=2049,nfsvers=3,noacl", nfsTarget, path},
|
||||
}
|
||||
|
||||
for _, command := range commands {
|
||||
data, err := exec.CommandContext(ctx, command[0], command[1:]...).CombinedOutput()
|
||||
|
||||
logger := a.getLogger(err)
|
||||
|
||||
logger.
|
||||
Strs("command", command).
|
||||
Str("output", string(data)).
|
||||
Msg("Mount NFS")
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *API) SetupHyperloop(address string) {
|
||||
a.hyperloopLock.Lock()
|
||||
defer a.hyperloopLock.Unlock()
|
||||
|
||||
if err := rewriteHostsFile(address, "/etc/hosts"); err != nil {
|
||||
a.logger.Error().Err(err).Msg("failed to modify hosts file")
|
||||
} else {
|
||||
a.defaults.EnvVars.Store("WRENN_EVENTS_ADDRESS", fmt.Sprintf("http://%s", address))
|
||||
}
|
||||
}
|
||||
|
||||
const eventsHost = "events.wrenn.local"
|
||||
|
||||
func rewriteHostsFile(address, path string) error {
|
||||
hosts, err := txeh.NewHosts(&txeh.HostsConfig{
|
||||
ReadFilePath: path,
|
||||
WriteFilePath: path,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create hosts: %w", err)
|
||||
}
|
||||
|
||||
// Update /etc/hosts to point events.wrenn.local to the hyperloop IP
|
||||
// This will remove any existing entries for events.wrenn.local first
|
||||
ipFamily, err := getIPFamily(address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get ip family: %w", err)
|
||||
}
|
||||
|
||||
if ok, current, _ := hosts.HostAddressLookup(eventsHost, ipFamily); ok && current == address {
|
||||
return nil // nothing to be done
|
||||
}
|
||||
|
||||
hosts.AddHost(address, eventsHost)
|
||||
|
||||
return hosts.Save()
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidAddress = errors.New("invalid IP address")
|
||||
ErrUnknownAddressFormat = errors.New("unknown IP address format")
|
||||
)
|
||||
|
||||
func getIPFamily(address string) (txeh.IPFamily, error) {
|
||||
addressIP, err := netip.ParseAddr(address)
|
||||
if err != nil {
|
||||
return txeh.IPFamilyV4, fmt.Errorf("failed to parse IP address: %w", err)
|
||||
}
|
||||
|
||||
switch {
|
||||
case addressIP.Is4():
|
||||
return txeh.IPFamilyV4, nil
|
||||
case addressIP.Is6():
|
||||
return txeh.IPFamilyV6, nil
|
||||
default:
|
||||
return txeh.IPFamilyV4, fmt.Errorf("%w: %s", ErrUnknownAddressFormat, address)
|
||||
}
|
||||
}
|
||||
|
||||
// shouldSetSystemTime returns true if the current time differs significantly from the received timestamp,
|
||||
// indicating the system clock should be adjusted. Returns true when the sandboxTime is more than
|
||||
// maxTimeInPast before the hostTime or more than maxTimeInFuture after the hostTime.
|
||||
func shouldSetSystemTime(sandboxTime, hostTime time.Time) bool {
|
||||
return sandboxTime.Before(hostTime.Add(-maxTimeInPast)) || sandboxTime.After(hostTime.Add(maxTimeInFuture))
|
||||
}
|
||||
587
envd/internal/api/init_test.go
Normal file
587
envd/internal/api/init_test.go
Normal file
@ -0,0 +1,587 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
|
||||
utilsShared "git.omukk.dev/wrenn/sandbox/envd/internal/shared/utils"
|
||||
)
|
||||
|
||||
func TestSimpleCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
testCases := map[string]func(string) string{
|
||||
"both newlines": func(s string) string { return s },
|
||||
"no newline prefix": func(s string) string { return strings.TrimPrefix(s, "\n") },
|
||||
"no newline suffix": func(s string) string { return strings.TrimSuffix(s, "\n") },
|
||||
"no newline prefix or suffix": strings.TrimSpace,
|
||||
}
|
||||
|
||||
for name, preprocessor := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tempDir := t.TempDir()
|
||||
|
||||
value := `
|
||||
# comment
|
||||
127.0.0.1 one.host
|
||||
127.0.0.2 two.host
|
||||
`
|
||||
value = preprocessor(value)
|
||||
inputPath := filepath.Join(tempDir, "hosts")
|
||||
err := os.WriteFile(inputPath, []byte(value), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = rewriteHostsFile("127.0.0.3", inputPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
data, err := os.ReadFile(inputPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, `# comment
|
||||
127.0.0.1 one.host
|
||||
127.0.0.2 two.host
|
||||
127.0.0.3 events.wrenn.local`, strings.TrimSpace(string(data)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSetSystemTime(t *testing.T) {
|
||||
t.Parallel()
|
||||
sandboxTime := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
hostTime time.Time
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "sandbox time far ahead of host time (should set)",
|
||||
hostTime: sandboxTime.Add(-10 * time.Second),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "sandbox time at maxTimeInPast boundary ahead of host time (should not set)",
|
||||
hostTime: sandboxTime.Add(-50 * time.Millisecond),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "sandbox time just within maxTimeInPast ahead of host time (should not set)",
|
||||
hostTime: sandboxTime.Add(-40 * time.Millisecond),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "sandbox time slightly ahead of host time (should not set)",
|
||||
hostTime: sandboxTime.Add(-10 * time.Millisecond),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "sandbox time equals host time (should not set)",
|
||||
hostTime: sandboxTime,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "sandbox time slightly behind host time (should not set)",
|
||||
hostTime: sandboxTime.Add(1 * time.Second),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "sandbox time just within maxTimeInFuture behind host time (should not set)",
|
||||
hostTime: sandboxTime.Add(4 * time.Second),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "sandbox time at maxTimeInFuture boundary behind host time (should not set)",
|
||||
hostTime: sandboxTime.Add(5 * time.Second),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "sandbox time far behind host time (should set)",
|
||||
hostTime: sandboxTime.Add(1 * time.Minute),
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := shouldSetSystemTime(tt.hostTime, sandboxTime)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func secureTokenPtr(s string) *SecureToken {
|
||||
token := &SecureToken{}
|
||||
_ = token.Set([]byte(s))
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
type mockMMDSClient struct {
|
||||
hash string
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockMMDSClient) GetAccessTokenHash(_ context.Context) (string, error) {
|
||||
return m.hash, m.err
|
||||
}
|
||||
|
||||
func newTestAPI(accessToken *SecureToken, mmdsClient MMDSClient) *API {
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
}
|
||||
api := New(&logger, defaults, nil, false)
|
||||
if accessToken != nil {
|
||||
api.accessToken.TakeFrom(accessToken)
|
||||
}
|
||||
api.mmdsClient = mmdsClient
|
||||
|
||||
return api
|
||||
}
|
||||
|
||||
func TestValidateInitAccessToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := t.Context()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accessToken *SecureToken
|
||||
requestToken *SecureToken
|
||||
mmdsHash string
|
||||
mmdsErr error
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "fast path: token matches existing",
|
||||
accessToken: secureTokenPtr("secret-token"),
|
||||
requestToken: secureTokenPtr("secret-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: nil,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "MMDS match: token hash matches MMDS hash",
|
||||
accessToken: secureTokenPtr("old-token"),
|
||||
requestToken: secureTokenPtr("new-token"),
|
||||
mmdsHash: keys.HashAccessToken("new-token"),
|
||||
mmdsErr: nil,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "first-time setup: no existing token, MMDS error",
|
||||
accessToken: nil,
|
||||
requestToken: secureTokenPtr("new-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "first-time setup: no existing token, empty MMDS hash",
|
||||
accessToken: nil,
|
||||
requestToken: secureTokenPtr("new-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: nil,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "first-time setup: both tokens nil, no MMDS",
|
||||
accessToken: nil,
|
||||
requestToken: nil,
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "mismatch: existing token differs from request, no MMDS",
|
||||
accessToken: secureTokenPtr("existing-token"),
|
||||
requestToken: secureTokenPtr("wrong-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: ErrAccessTokenMismatch,
|
||||
},
|
||||
{
|
||||
name: "mismatch: existing token differs from request, MMDS hash mismatch",
|
||||
accessToken: secureTokenPtr("existing-token"),
|
||||
requestToken: secureTokenPtr("wrong-token"),
|
||||
mmdsHash: keys.HashAccessToken("different-token"),
|
||||
mmdsErr: nil,
|
||||
wantErr: ErrAccessTokenMismatch,
|
||||
},
|
||||
{
|
||||
name: "conflict: existing token, nil request, MMDS exists",
|
||||
accessToken: secureTokenPtr("existing-token"),
|
||||
requestToken: nil,
|
||||
mmdsHash: keys.HashAccessToken("some-token"),
|
||||
mmdsErr: nil,
|
||||
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||
},
|
||||
{
|
||||
name: "conflict: existing token, nil request, no MMDS",
|
||||
accessToken: secureTokenPtr("existing-token"),
|
||||
requestToken: nil,
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: tt.mmdsHash, err: tt.mmdsErr}
|
||||
api := newTestAPI(tt.accessToken, mmdsClient)
|
||||
|
||||
err := api.validateInitAccessToken(ctx, tt.requestToken)
|
||||
|
||||
if tt.wantErr != nil {
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckMMDSHash(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := t.Context()
|
||||
|
||||
t.Run("returns match when token hash equals MMDS hash", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
token := "my-secret-token"
|
||||
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken(token), err: nil}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr(token))
|
||||
|
||||
assert.True(t, matches)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("returns no match when token hash differs from MMDS hash", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken("different-token"), err: nil}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("my-token"))
|
||||
|
||||
assert.False(t, matches)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("returns exists but no match when request token is nil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken("some-token"), err: nil}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, nil)
|
||||
|
||||
assert.False(t, matches)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("returns false, false when MMDS returns error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("any-token"))
|
||||
|
||||
assert.False(t, matches)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("returns false, false when MMDS returns empty hash with non-nil request", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: nil}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("any-token"))
|
||||
|
||||
assert.False(t, matches)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("returns false, false when MMDS returns empty hash with nil request", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: nil}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, nil)
|
||||
|
||||
assert.False(t, matches)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("returns true, true when MMDS returns hash of empty string with nil request (explicit reset)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken(""), err: nil}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, nil)
|
||||
|
||||
assert.True(t, matches)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetData(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := zerolog.Nop()
|
||||
|
||||
t.Run("access token updates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
existingToken *SecureToken
|
||||
requestToken *SecureToken
|
||||
mmdsHash string
|
||||
mmdsErr error
|
||||
wantErr error
|
||||
wantFinalToken *SecureToken
|
||||
}{
|
||||
{
|
||||
name: "first-time setup: sets initial token",
|
||||
existingToken: nil,
|
||||
requestToken: secureTokenPtr("initial-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: nil,
|
||||
wantFinalToken: secureTokenPtr("initial-token"),
|
||||
},
|
||||
{
|
||||
name: "first-time setup: nil request token leaves token unset",
|
||||
existingToken: nil,
|
||||
requestToken: nil,
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: nil,
|
||||
wantFinalToken: nil,
|
||||
},
|
||||
{
|
||||
name: "re-init with same token: token unchanged",
|
||||
existingToken: secureTokenPtr("same-token"),
|
||||
requestToken: secureTokenPtr("same-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: nil,
|
||||
wantFinalToken: secureTokenPtr("same-token"),
|
||||
},
|
||||
{
|
||||
name: "resume with MMDS: updates token when hash matches",
|
||||
existingToken: secureTokenPtr("old-token"),
|
||||
requestToken: secureTokenPtr("new-token"),
|
||||
mmdsHash: keys.HashAccessToken("new-token"),
|
||||
mmdsErr: nil,
|
||||
wantErr: nil,
|
||||
wantFinalToken: secureTokenPtr("new-token"),
|
||||
},
|
||||
{
|
||||
name: "resume with MMDS: fails when hash doesn't match",
|
||||
existingToken: secureTokenPtr("old-token"),
|
||||
requestToken: secureTokenPtr("new-token"),
|
||||
mmdsHash: keys.HashAccessToken("different-token"),
|
||||
mmdsErr: nil,
|
||||
wantErr: ErrAccessTokenMismatch,
|
||||
wantFinalToken: secureTokenPtr("old-token"),
|
||||
},
|
||||
{
|
||||
name: "fails when existing token and request token mismatch without MMDS",
|
||||
existingToken: secureTokenPtr("existing-token"),
|
||||
requestToken: secureTokenPtr("wrong-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: ErrAccessTokenMismatch,
|
||||
wantFinalToken: secureTokenPtr("existing-token"),
|
||||
},
|
||||
{
|
||||
name: "conflict when existing token but nil request token",
|
||||
existingToken: secureTokenPtr("existing-token"),
|
||||
requestToken: nil,
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||
wantFinalToken: secureTokenPtr("existing-token"),
|
||||
},
|
||||
{
|
||||
name: "conflict when existing token but nil request with MMDS present",
|
||||
existingToken: secureTokenPtr("existing-token"),
|
||||
requestToken: nil,
|
||||
mmdsHash: keys.HashAccessToken("some-token"),
|
||||
mmdsErr: nil,
|
||||
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||
wantFinalToken: secureTokenPtr("existing-token"),
|
||||
},
|
||||
{
|
||||
name: "conflict when MMDS returns empty hash and request is nil (prevents unauthorized reset)",
|
||||
existingToken: secureTokenPtr("existing-token"),
|
||||
requestToken: nil,
|
||||
mmdsHash: "",
|
||||
mmdsErr: nil,
|
||||
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||
wantFinalToken: secureTokenPtr("existing-token"),
|
||||
},
|
||||
{
|
||||
name: "resets token when MMDS returns hash of empty string and request is nil (explicit reset)",
|
||||
existingToken: secureTokenPtr("existing-token"),
|
||||
requestToken: nil,
|
||||
mmdsHash: keys.HashAccessToken(""),
|
||||
mmdsErr: nil,
|
||||
wantErr: nil,
|
||||
wantFinalToken: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: tt.mmdsHash, err: tt.mmdsErr}
|
||||
api := newTestAPI(tt.existingToken, mmdsClient)
|
||||
|
||||
data := PostInitJSONBody{
|
||||
AccessToken: tt.requestToken,
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
if tt.wantFinalToken == nil {
|
||||
assert.False(t, api.accessToken.IsSet(), "expected token to not be set")
|
||||
} else {
|
||||
require.True(t, api.accessToken.IsSet(), "expected token to be set")
|
||||
assert.True(t, api.accessToken.EqualsSecure(tt.wantFinalToken), "expected token to match")
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sets environment variables", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
envVars := EnvVars{"FOO": "bar", "BAZ": "qux"}
|
||||
data := PostInitJSONBody{
|
||||
EnvVars: &envVars,
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
require.NoError(t, err)
|
||||
val, ok := api.defaults.EnvVars.Load("FOO")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "bar", val)
|
||||
val, ok = api.defaults.EnvVars.Load("BAZ")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "qux", val)
|
||||
})
|
||||
|
||||
t.Run("sets default user", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
data := PostInitJSONBody{
|
||||
DefaultUser: utilsShared.ToPtr("testuser"),
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "testuser", api.defaults.User)
|
||||
})
|
||||
|
||||
t.Run("does not set default user when empty", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
api.defaults.User = "original"
|
||||
|
||||
data := PostInitJSONBody{
|
||||
DefaultUser: utilsShared.ToPtr(""),
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "original", api.defaults.User)
|
||||
})
|
||||
|
||||
t.Run("sets default workdir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
data := PostInitJSONBody{
|
||||
DefaultWorkdir: utilsShared.ToPtr("/home/user"),
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, api.defaults.Workdir)
|
||||
assert.Equal(t, "/home/user", *api.defaults.Workdir)
|
||||
})
|
||||
|
||||
t.Run("does not set default workdir when empty", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
originalWorkdir := "/original"
|
||||
api.defaults.Workdir = &originalWorkdir
|
||||
|
||||
data := PostInitJSONBody{
|
||||
DefaultWorkdir: utilsShared.ToPtr(""),
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, api.defaults.Workdir)
|
||||
assert.Equal(t, "/original", *api.defaults.Workdir)
|
||||
})
|
||||
|
||||
t.Run("sets multiple fields at once", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
envVars := EnvVars{"KEY": "value"}
|
||||
data := PostInitJSONBody{
|
||||
AccessToken: secureTokenPtr("token"),
|
||||
DefaultUser: utilsShared.ToPtr("user"),
|
||||
DefaultWorkdir: utilsShared.ToPtr("/workdir"),
|
||||
EnvVars: &envVars,
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, api.accessToken.Equals("token"), "expected token to match")
|
||||
assert.Equal(t, "user", api.defaults.User)
|
||||
assert.Equal(t, "/workdir", *api.defaults.Workdir)
|
||||
val, ok := api.defaults.EnvVars.Load("KEY")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "value", val)
|
||||
})
|
||||
}
|
||||
212
envd/internal/api/secure_token.go
Normal file
212
envd/internal/api/secure_token.go
Normal file
@ -0,0 +1,212 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTokenNotSet = errors.New("access token not set")
|
||||
ErrTokenEmpty = errors.New("empty token not allowed")
|
||||
)
|
||||
|
||||
// SecureToken wraps memguard for secure token storage.
|
||||
// It uses LockedBuffer which provides memory locking, guard pages,
|
||||
// and secure zeroing on destroy.
|
||||
type SecureToken struct {
|
||||
mu sync.RWMutex
|
||||
buffer *memguard.LockedBuffer
|
||||
}
|
||||
|
||||
// Set securely replaces the token, destroying the old one first.
|
||||
// The old token memory is zeroed before the new token is stored.
|
||||
// The input byte slice is wiped after copying to secure memory.
|
||||
// Returns ErrTokenEmpty if token is empty - use Destroy() to clear the token instead.
|
||||
func (s *SecureToken) Set(token []byte) error {
|
||||
if len(token) == 0 {
|
||||
return ErrTokenEmpty
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Destroy old token first (zeros memory)
|
||||
if s.buffer != nil {
|
||||
s.buffer.Destroy()
|
||||
s.buffer = nil
|
||||
}
|
||||
|
||||
// Create new LockedBuffer from bytes (source slice is wiped by memguard)
|
||||
s.buffer = memguard.NewBufferFromBytes(token)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler to securely parse a JSON string
|
||||
// directly into memguard, wiping the input bytes after copying.
|
||||
//
|
||||
// Access tokens are hex-encoded HMAC-SHA256 hashes (64 chars of [0-9a-f]),
|
||||
// so they never contain JSON escape sequences.
|
||||
func (s *SecureToken) UnmarshalJSON(data []byte) error {
|
||||
// JSON strings are quoted, so minimum valid is `""` (2 bytes).
|
||||
if len(data) < 2 || data[0] != '"' || data[len(data)-1] != '"' {
|
||||
memguard.WipeBytes(data)
|
||||
|
||||
return errors.New("invalid secure token JSON string")
|
||||
}
|
||||
|
||||
content := data[1 : len(data)-1]
|
||||
|
||||
// Access tokens are hex strings - reject if contains backslash
|
||||
if bytes.ContainsRune(content, '\\') {
|
||||
memguard.WipeBytes(data)
|
||||
|
||||
return errors.New("invalid secure token: unexpected escape sequence")
|
||||
}
|
||||
|
||||
if len(content) == 0 {
|
||||
memguard.WipeBytes(data)
|
||||
|
||||
return ErrTokenEmpty
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.buffer != nil {
|
||||
s.buffer.Destroy()
|
||||
s.buffer = nil
|
||||
}
|
||||
|
||||
// Allocate secure buffer and copy directly into it
|
||||
s.buffer = memguard.NewBuffer(len(content))
|
||||
copy(s.buffer.Bytes(), content)
|
||||
|
||||
// Wipe the input data
|
||||
memguard.WipeBytes(data)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TakeFrom transfers the token from src to this SecureToken, destroying any
|
||||
// existing token. The source token is cleared after transfer.
|
||||
// This avoids copying the underlying bytes.
|
||||
func (s *SecureToken) TakeFrom(src *SecureToken) {
|
||||
if src == nil || s == src {
|
||||
return
|
||||
}
|
||||
|
||||
// Extract buffer from source
|
||||
src.mu.Lock()
|
||||
buffer := src.buffer
|
||||
src.buffer = nil
|
||||
src.mu.Unlock()
|
||||
|
||||
// Install buffer in destination
|
||||
s.mu.Lock()
|
||||
if s.buffer != nil {
|
||||
s.buffer.Destroy()
|
||||
}
|
||||
s.buffer = buffer
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// Equals checks if token matches using constant-time comparison.
|
||||
// Returns false if the receiver is nil.
|
||||
func (s *SecureToken) Equals(token string) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.buffer == nil || !s.buffer.IsAlive() {
|
||||
return false
|
||||
}
|
||||
|
||||
return s.buffer.EqualTo([]byte(token))
|
||||
}
|
||||
|
||||
// EqualsSecure compares this token with another SecureToken using constant-time comparison.
|
||||
// Returns false if either receiver or other is nil.
|
||||
func (s *SecureToken) EqualsSecure(other *SecureToken) bool {
|
||||
if s == nil || other == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if s == other {
|
||||
return s.IsSet()
|
||||
}
|
||||
|
||||
// Get a copy of other's bytes (avoids holding two locks simultaneously)
|
||||
otherBytes, err := other.Bytes()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer memguard.WipeBytes(otherBytes)
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.buffer == nil || !s.buffer.IsAlive() {
|
||||
return false
|
||||
}
|
||||
|
||||
return s.buffer.EqualTo(otherBytes)
|
||||
}
|
||||
|
||||
// IsSet returns true if a token is stored.
|
||||
// Returns false if the receiver is nil.
|
||||
func (s *SecureToken) IsSet() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
return s.buffer != nil && s.buffer.IsAlive()
|
||||
}
|
||||
|
||||
// Bytes returns a copy of the token bytes (for signature generation).
|
||||
// The caller should zero the returned slice after use.
|
||||
// Returns ErrTokenNotSet if the receiver is nil.
|
||||
func (s *SecureToken) Bytes() ([]byte, error) {
|
||||
if s == nil {
|
||||
return nil, ErrTokenNotSet
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.buffer == nil || !s.buffer.IsAlive() {
|
||||
return nil, ErrTokenNotSet
|
||||
}
|
||||
|
||||
// Return a copy (unavoidable for signature generation)
|
||||
src := s.buffer.Bytes()
|
||||
result := make([]byte, len(src))
|
||||
copy(result, src)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Destroy securely wipes the token from memory.
|
||||
// No-op if the receiver is nil.
|
||||
func (s *SecureToken) Destroy() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.buffer != nil {
|
||||
s.buffer.Destroy()
|
||||
s.buffer = nil
|
||||
}
|
||||
}
|
||||
461
envd/internal/api/secure_token_test.go
Normal file
461
envd/internal/api/secure_token_test.go
Normal file
@ -0,0 +1,461 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSecureTokenSetAndEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
|
||||
// Initially not set
|
||||
assert.False(t, st.IsSet(), "token should not be set initially")
|
||||
assert.False(t, st.Equals("any-token"), "equals should return false when not set")
|
||||
|
||||
// Set token
|
||||
err := st.Set([]byte("test-token"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.IsSet(), "token should be set after Set()")
|
||||
assert.True(t, st.Equals("test-token"), "equals should return true for correct token")
|
||||
assert.False(t, st.Equals("wrong-token"), "equals should return false for wrong token")
|
||||
assert.False(t, st.Equals(""), "equals should return false for empty token")
|
||||
}
|
||||
|
||||
func TestSecureTokenReplace(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
|
||||
// Set initial token
|
||||
err := st.Set([]byte("first-token"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.Equals("first-token"))
|
||||
|
||||
// Replace with new token (old one should be destroyed)
|
||||
err = st.Set([]byte("second-token"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.Equals("second-token"), "should match new token")
|
||||
assert.False(t, st.Equals("first-token"), "should not match old token")
|
||||
}
|
||||
|
||||
func TestSecureTokenDestroy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
|
||||
// Set and then destroy
|
||||
err := st.Set([]byte("test-token"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.IsSet())
|
||||
|
||||
st.Destroy()
|
||||
assert.False(t, st.IsSet(), "token should not be set after Destroy()")
|
||||
assert.False(t, st.Equals("test-token"), "equals should return false after Destroy()")
|
||||
|
||||
// Destroy on already destroyed should be safe
|
||||
st.Destroy()
|
||||
assert.False(t, st.IsSet())
|
||||
|
||||
// Nil receiver should be safe
|
||||
var nilToken *SecureToken
|
||||
assert.False(t, nilToken.IsSet(), "nil receiver should return false for IsSet()")
|
||||
assert.False(t, nilToken.Equals("anything"), "nil receiver should return false for Equals()")
|
||||
assert.False(t, nilToken.EqualsSecure(st), "nil receiver should return false for EqualsSecure()")
|
||||
nilToken.Destroy() // should not panic
|
||||
|
||||
_, err = nilToken.Bytes()
|
||||
require.ErrorIs(t, err, ErrTokenNotSet, "nil receiver should return ErrTokenNotSet for Bytes()")
|
||||
}
|
||||
|
||||
func TestSecureTokenBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
|
||||
// Bytes should return error when not set
|
||||
_, err := st.Bytes()
|
||||
require.ErrorIs(t, err, ErrTokenNotSet)
|
||||
|
||||
// Set token and get bytes
|
||||
err = st.Set([]byte("test-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
bytes, err := st.Bytes()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("test-token"), bytes)
|
||||
|
||||
// Zero out the bytes (as caller should do)
|
||||
memguard.WipeBytes(bytes)
|
||||
|
||||
// Original should still be intact
|
||||
assert.True(t, st.Equals("test-token"), "original token should still work after zeroing copy")
|
||||
|
||||
// After destroy, bytes should fail
|
||||
st.Destroy()
|
||||
_, err = st.Bytes()
|
||||
assert.ErrorIs(t, err, ErrTokenNotSet)
|
||||
}
|
||||
|
||||
func TestSecureTokenConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
err := st.Set([]byte("initial-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const numGoroutines = 100
|
||||
|
||||
// Concurrent reads
|
||||
for range numGoroutines {
|
||||
wg.Go(func() {
|
||||
st.IsSet()
|
||||
st.Equals("initial-token")
|
||||
})
|
||||
}
|
||||
|
||||
// Concurrent writes
|
||||
for i := range 10 {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
st.Set([]byte("token-" + string(rune('a'+idx))))
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should still be in a valid state
|
||||
assert.True(t, st.IsSet())
|
||||
}
|
||||
|
||||
func TestSecureTokenEmptyToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
|
||||
// Setting empty token should return an error
|
||||
err := st.Set([]byte{})
|
||||
require.ErrorIs(t, err, ErrTokenEmpty)
|
||||
assert.False(t, st.IsSet(), "token should not be set after empty token error")
|
||||
|
||||
// Setting nil should also return an error
|
||||
err = st.Set(nil)
|
||||
require.ErrorIs(t, err, ErrTokenEmpty)
|
||||
assert.False(t, st.IsSet(), "token should not be set after nil token error")
|
||||
}
|
||||
|
||||
func TestSecureTokenEmptyTokenDoesNotClearExisting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
|
||||
// Set a valid token first
|
||||
err := st.Set([]byte("valid-token"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.IsSet())
|
||||
|
||||
// Attempting to set empty token should fail and preserve existing token
|
||||
err = st.Set([]byte{})
|
||||
require.ErrorIs(t, err, ErrTokenEmpty)
|
||||
assert.True(t, st.IsSet(), "existing token should be preserved after empty token error")
|
||||
assert.True(t, st.Equals("valid-token"), "existing token value should be unchanged")
|
||||
}
|
||||
|
||||
func TestSecureTokenUnmarshalJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("unmarshals valid JSON string", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.UnmarshalJSON([]byte(`"my-secret-token"`))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.IsSet())
|
||||
assert.True(t, st.Equals("my-secret-token"))
|
||||
})
|
||||
|
||||
t.Run("returns error for empty string", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.UnmarshalJSON([]byte(`""`))
|
||||
require.ErrorIs(t, err, ErrTokenEmpty)
|
||||
assert.False(t, st.IsSet())
|
||||
})
|
||||
|
||||
t.Run("returns error for invalid JSON", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.UnmarshalJSON([]byte(`not-valid-json`))
|
||||
require.Error(t, err)
|
||||
assert.False(t, st.IsSet())
|
||||
})
|
||||
|
||||
t.Run("replaces existing token", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.Set([]byte("old-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = st.UnmarshalJSON([]byte(`"new-token"`))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.Equals("new-token"))
|
||||
assert.False(t, st.Equals("old-token"))
|
||||
})
|
||||
|
||||
t.Run("wipes input buffer after parsing", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Create a buffer with a known token
|
||||
input := []byte(`"secret-token-12345"`)
|
||||
original := make([]byte, len(input))
|
||||
copy(original, input)
|
||||
|
||||
st := &SecureToken{}
|
||||
err := st.UnmarshalJSON(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the token was stored correctly
|
||||
assert.True(t, st.Equals("secret-token-12345"))
|
||||
|
||||
// Verify the input buffer was wiped (all zeros)
|
||||
for i, b := range input {
|
||||
assert.Equal(t, byte(0), b, "byte at position %d should be zero, got %d", i, b)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wipes input buffer on error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Create a buffer with an empty token (will error)
|
||||
input := []byte(`""`)
|
||||
|
||||
st := &SecureToken{}
|
||||
err := st.UnmarshalJSON(input)
|
||||
require.Error(t, err)
|
||||
|
||||
// Verify the input buffer was still wiped
|
||||
for i, b := range input {
|
||||
assert.Equal(t, byte(0), b, "byte at position %d should be zero, got %d", i, b)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects escape sequences", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.UnmarshalJSON([]byte(`"token\nwith\nnewlines"`))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "escape sequence")
|
||||
assert.False(t, st.IsSet())
|
||||
})
|
||||
}
|
||||
|
||||
func TestSecureTokenSetWipesInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("wipes input buffer after storing", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Create a buffer with a known token
|
||||
input := []byte("my-secret-token")
|
||||
original := make([]byte, len(input))
|
||||
copy(original, input)
|
||||
|
||||
st := &SecureToken{}
|
||||
err := st.Set(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the token was stored correctly
|
||||
assert.True(t, st.Equals("my-secret-token"))
|
||||
|
||||
// Verify the input buffer was wiped (all zeros)
|
||||
for i, b := range input {
|
||||
assert.Equal(t, byte(0), b, "byte at position %d should be zero, got %d", i, b)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSecureTokenTakeFrom(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("transfers token from source to destination", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
src := &SecureToken{}
|
||||
err := src.Set([]byte("source-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
dst := &SecureToken{}
|
||||
dst.TakeFrom(src)
|
||||
|
||||
assert.True(t, dst.IsSet())
|
||||
assert.True(t, dst.Equals("source-token"))
|
||||
assert.False(t, src.IsSet(), "source should be empty after transfer")
|
||||
})
|
||||
|
||||
t.Run("replaces existing destination token", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
src := &SecureToken{}
|
||||
err := src.Set([]byte("new-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
dst := &SecureToken{}
|
||||
err = dst.Set([]byte("old-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
dst.TakeFrom(src)
|
||||
|
||||
assert.True(t, dst.Equals("new-token"))
|
||||
assert.False(t, dst.Equals("old-token"))
|
||||
assert.False(t, src.IsSet())
|
||||
})
|
||||
|
||||
t.Run("handles nil source", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dst := &SecureToken{}
|
||||
err := dst.Set([]byte("existing-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
dst.TakeFrom(nil)
|
||||
|
||||
assert.True(t, dst.IsSet(), "destination should be unchanged with nil source")
|
||||
assert.True(t, dst.Equals("existing-token"))
|
||||
})
|
||||
|
||||
t.Run("handles empty source", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
src := &SecureToken{}
|
||||
dst := &SecureToken{}
|
||||
err := dst.Set([]byte("existing-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
dst.TakeFrom(src)
|
||||
|
||||
assert.False(t, dst.IsSet(), "destination should be cleared when source is empty")
|
||||
})
|
||||
|
||||
t.Run("self-transfer is no-op and does not deadlock", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.Set([]byte("token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
st.TakeFrom(st)
|
||||
|
||||
assert.True(t, st.IsSet(), "token should remain set after self-transfer")
|
||||
assert.True(t, st.Equals("token"), "token value should be unchanged")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSecureTokenEqualsSecure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns true for matching tokens", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st1 := &SecureToken{}
|
||||
err := st1.Set([]byte("same-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
st2 := &SecureToken{}
|
||||
err = st2.Set([]byte("same-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, st1.EqualsSecure(st2))
|
||||
assert.True(t, st2.EqualsSecure(st1))
|
||||
})
|
||||
|
||||
t.Run("concurrent TakeFrom and EqualsSecure do not deadlock", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// This test verifies the fix for the lock ordering deadlock bug.
|
||||
|
||||
const iterations = 100
|
||||
|
||||
for range iterations {
|
||||
a := &SecureToken{}
|
||||
err := a.Set([]byte("token-a"))
|
||||
require.NoError(t, err)
|
||||
|
||||
b := &SecureToken{}
|
||||
err = b.Set([]byte("token-b"))
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
// Goroutine 1: a.TakeFrom(b)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
a.TakeFrom(b)
|
||||
}()
|
||||
|
||||
// Goroutine 2: b.EqualsSecure(a)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
b.EqualsSecure(a)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns false for different tokens", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st1 := &SecureToken{}
|
||||
err := st1.Set([]byte("token-a"))
|
||||
require.NoError(t, err)
|
||||
|
||||
st2 := &SecureToken{}
|
||||
err = st2.Set([]byte("token-b"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, st1.EqualsSecure(st2))
|
||||
})
|
||||
|
||||
t.Run("returns false when comparing with nil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.Set([]byte("token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, st.EqualsSecure(nil))
|
||||
})
|
||||
|
||||
t.Run("returns false when other is not set", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st1 := &SecureToken{}
|
||||
err := st1.Set([]byte("token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
st2 := &SecureToken{}
|
||||
|
||||
assert.False(t, st1.EqualsSecure(st2))
|
||||
})
|
||||
|
||||
t.Run("returns false when self is not set", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st1 := &SecureToken{}
|
||||
|
||||
st2 := &SecureToken{}
|
||||
err := st2.Set([]byte("token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, st1.EqualsSecure(st2))
|
||||
})
|
||||
|
||||
t.Run("self-comparison returns true when set", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.Set([]byte("token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, st.EqualsSecure(st), "self-comparison should return true and not deadlock")
|
||||
})
|
||||
|
||||
t.Run("self-comparison returns false when not set", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
|
||||
assert.False(t, st.EqualsSecure(st), "self-comparison on unset token should return false")
|
||||
})
|
||||
}
|
||||
93
envd/internal/api/store.go
Normal file
93
envd/internal/api/store.go
Normal file
@ -0,0 +1,93 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
// MMDSClient provides access to MMDS metadata.
|
||||
type MMDSClient interface {
|
||||
GetAccessTokenHash(ctx context.Context) (string, error)
|
||||
}
|
||||
|
||||
// DefaultMMDSClient is the production implementation that calls the real MMDS endpoint.
|
||||
type DefaultMMDSClient struct{}
|
||||
|
||||
func (c *DefaultMMDSClient) GetAccessTokenHash(ctx context.Context) (string, error) {
|
||||
return host.GetAccessTokenHashFromMMDS(ctx)
|
||||
}
|
||||
|
||||
type API struct {
|
||||
isNotFC bool
|
||||
logger *zerolog.Logger
|
||||
accessToken *SecureToken
|
||||
defaults *execcontext.Defaults
|
||||
|
||||
mmdsChan chan *host.MMDSOpts
|
||||
hyperloopLock sync.Mutex
|
||||
mmdsClient MMDSClient
|
||||
|
||||
lastSetTime *utils.AtomicMax
|
||||
initLock sync.Mutex
|
||||
}
|
||||
|
||||
func New(l *zerolog.Logger, defaults *execcontext.Defaults, mmdsChan chan *host.MMDSOpts, isNotFC bool) *API {
|
||||
return &API{
|
||||
logger: l,
|
||||
defaults: defaults,
|
||||
mmdsChan: mmdsChan,
|
||||
isNotFC: isNotFC,
|
||||
mmdsClient: &DefaultMMDSClient{},
|
||||
lastSetTime: utils.NewAtomicMax(),
|
||||
accessToken: &SecureToken{},
|
||||
}
|
||||
}
|
||||
|
||||
func (a *API) GetHealth(w http.ResponseWriter, r *http.Request) {
|
||||
defer r.Body.Close()
|
||||
|
||||
a.logger.Trace().Msg("Health check")
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.Header().Set("Content-Type", "")
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (a *API) GetMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
defer r.Body.Close()
|
||||
|
||||
a.logger.Trace().Msg("Get metrics")
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
metrics, err := host.GetMetrics()
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Msg("Failed to get metrics")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(metrics); err != nil {
|
||||
a.logger.Error().Err(err).Msg("Failed to encode metrics")
|
||||
}
|
||||
}
|
||||
|
||||
func (a *API) getLogger(err error) *zerolog.Event {
|
||||
if err != nil {
|
||||
return a.logger.Error().Err(err) //nolint:zerologlint // this is only prep
|
||||
}
|
||||
|
||||
return a.logger.Info() //nolint:zerologlint // this is only prep
|
||||
}
|
||||
309
envd/internal/api/upload.go
Normal file
309
envd/internal/api/upload.go
Normal file
@ -0,0 +1,309 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
var ErrNoDiskSpace = fmt.Errorf("not enough disk space available")
|
||||
|
||||
func processFile(r *http.Request, path string, part io.Reader, uid, gid int, logger zerolog.Logger) (int, error) {
|
||||
logger.Debug().
|
||||
Str("path", path).
|
||||
Msg("File processing")
|
||||
|
||||
err := permissions.EnsureDirs(filepath.Dir(path), uid, gid)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("error ensuring directories: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
|
||||
canBePreChowned := false
|
||||
stat, err := os.Stat(path)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
errMsg := fmt.Errorf("error getting file info: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, errMsg
|
||||
} else if err == nil {
|
||||
if stat.IsDir() {
|
||||
err := fmt.Errorf("path is a directory: %s", path)
|
||||
|
||||
return http.StatusBadRequest, err
|
||||
}
|
||||
canBePreChowned = true
|
||||
}
|
||||
|
||||
hasBeenChowned := false
|
||||
if canBePreChowned {
|
||||
err = os.Chown(path, uid, gid)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
err = fmt.Errorf("error changing file ownership: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
} else {
|
||||
hasBeenChowned = true
|
||||
}
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o666)
|
||||
if err != nil {
|
||||
if errors.Is(err, syscall.ENOSPC) {
|
||||
err = fmt.Errorf("not enough inodes available: %w", err)
|
||||
|
||||
return http.StatusInsufficientStorage, err
|
||||
}
|
||||
|
||||
err := fmt.Errorf("error opening file: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
|
||||
defer file.Close()
|
||||
|
||||
if !hasBeenChowned {
|
||||
err = os.Chown(path, uid, gid)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("error changing file ownership: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
}
|
||||
|
||||
_, err = file.ReadFrom(part)
|
||||
if err != nil {
|
||||
if errors.Is(err, syscall.ENOSPC) {
|
||||
err = ErrNoDiskSpace
|
||||
if r.ContentLength > 0 {
|
||||
err = fmt.Errorf("attempted to write %d bytes: %w", r.ContentLength, err)
|
||||
}
|
||||
|
||||
return http.StatusInsufficientStorage, err
|
||||
}
|
||||
|
||||
err = fmt.Errorf("error writing file: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
|
||||
return http.StatusNoContent, nil
|
||||
}
|
||||
|
||||
func resolvePath(part *multipart.Part, paths *UploadSuccess, u *user.User, defaultPath *string, params PostFilesParams) (string, error) {
|
||||
var pathToResolve string
|
||||
|
||||
if params.Path != nil {
|
||||
pathToResolve = *params.Path
|
||||
} else {
|
||||
var err error
|
||||
customPart := utils.NewCustomPart(part)
|
||||
pathToResolve, err = customPart.FileNameWithPath()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error getting multipart custom part file name: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
filePath, err := permissions.ExpandAndResolve(pathToResolve, u, defaultPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error resolving path: %w", err)
|
||||
}
|
||||
|
||||
for _, entry := range *paths {
|
||||
if entry.Path == filePath {
|
||||
var alreadyUploaded []string
|
||||
for _, uploadedFile := range *paths {
|
||||
if uploadedFile.Path != filePath {
|
||||
alreadyUploaded = append(alreadyUploaded, uploadedFile.Path)
|
||||
}
|
||||
}
|
||||
|
||||
errMsg := fmt.Errorf("you cannot upload multiple files to the same path '%s' in one upload request, only the first specified file was uploaded", filePath)
|
||||
|
||||
if len(alreadyUploaded) > 1 {
|
||||
errMsg = fmt.Errorf("%w, also the following files were uploaded: %v", errMsg, strings.Join(alreadyUploaded, ", "))
|
||||
}
|
||||
|
||||
return "", errMsg
|
||||
}
|
||||
}
|
||||
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
func (a *API) handlePart(r *http.Request, part *multipart.Part, paths UploadSuccess, u *user.User, uid, gid int, operationID string, params PostFilesParams) (*EntryInfo, int, error) {
|
||||
defer part.Close()
|
||||
|
||||
if part.FormName() != "file" {
|
||||
return nil, http.StatusOK, nil
|
||||
}
|
||||
|
||||
filePath, err := resolvePath(part, &paths, u, a.defaults.Workdir, params)
|
||||
if err != nil {
|
||||
return nil, http.StatusBadRequest, err
|
||||
}
|
||||
|
||||
logger := a.logger.
|
||||
With().
|
||||
Str(string(logs.OperationIDKey), operationID).
|
||||
Str("event_type", "file_processing").
|
||||
Logger()
|
||||
|
||||
status, err := processFile(r, filePath, part, uid, gid, logger)
|
||||
if err != nil {
|
||||
return nil, status, err
|
||||
}
|
||||
|
||||
return &EntryInfo{
|
||||
Path: filePath,
|
||||
Name: filepath.Base(filePath),
|
||||
Type: File,
|
||||
}, http.StatusOK, nil
|
||||
}
|
||||
|
||||
func (a *API) PostFiles(w http.ResponseWriter, r *http.Request, params PostFilesParams) {
|
||||
// Capture original body to ensure it's always closed
|
||||
originalBody := r.Body
|
||||
defer originalBody.Close()
|
||||
|
||||
var errorCode int
|
||||
var errMsg error
|
||||
|
||||
var path string
|
||||
if params.Path != nil {
|
||||
path = *params.Path
|
||||
}
|
||||
|
||||
operationID := logs.AssignOperationID()
|
||||
|
||||
// signing authorization if needed
|
||||
err := a.validateSigning(r, params.Signature, params.SignatureExpiration, params.Username, path, SigningWriteOperation)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("error during auth validation")
|
||||
jsonError(w, http.StatusUnauthorized, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
username, err := execcontext.ResolveDefaultUsername(params.Username, a.defaults.User)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("no user specified")
|
||||
jsonError(w, http.StatusBadRequest, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
l := a.logger.
|
||||
Err(errMsg).
|
||||
Str("method", r.Method+" "+r.URL.Path).
|
||||
Str(string(logs.OperationIDKey), operationID).
|
||||
Str("path", path).
|
||||
Str("username", username)
|
||||
|
||||
if errMsg != nil {
|
||||
l = l.Int("error_code", errorCode)
|
||||
}
|
||||
|
||||
l.Msg("File write")
|
||||
}()
|
||||
|
||||
// Handle gzip-encoded request body
|
||||
body, err := getDecompressedBody(r)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error decompressing request body: %w", err)
|
||||
errorCode = http.StatusBadRequest
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
defer body.Close()
|
||||
r.Body = body
|
||||
|
||||
f, err := r.MultipartReader()
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error parsing multipart form: %w", err)
|
||||
errorCode = http.StatusInternalServerError
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
u, err := user.Lookup(username)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error looking up user '%s': %w", username, err)
|
||||
errorCode = http.StatusUnauthorized
|
||||
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
uid, gid, err := permissions.GetUserIdInts(u)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error getting user ids: %w", err)
|
||||
|
||||
jsonError(w, http.StatusInternalServerError, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
paths := UploadSuccess{}
|
||||
|
||||
for {
|
||||
part, partErr := f.NextPart()
|
||||
|
||||
if partErr == io.EOF {
|
||||
// We're done reading the parts.
|
||||
break
|
||||
} else if partErr != nil {
|
||||
errMsg = fmt.Errorf("error reading form: %w", partErr)
|
||||
errorCode = http.StatusInternalServerError
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
entry, status, err := a.handlePart(r, part, paths, u, uid, gid, operationID, params)
|
||||
if err != nil {
|
||||
errorCode = status
|
||||
errMsg = err
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if entry != nil {
|
||||
paths = append(paths, *entry)
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(paths)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error marshaling response: %w", err)
|
||||
errorCode = http.StatusInternalServerError
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(data)
|
||||
}
|
||||
249
envd/internal/api/upload_test.go
Normal file
249
envd/internal/api/upload_test.go
Normal file
@ -0,0 +1,249 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestProcessFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
uid := os.Getuid()
|
||||
gid := os.Getgid()
|
||||
|
||||
newRequest := func(content []byte) (*http.Request, io.Reader) {
|
||||
request := &http.Request{
|
||||
ContentLength: int64(len(content)),
|
||||
}
|
||||
buffer := bytes.NewBuffer(content)
|
||||
|
||||
return request, buffer
|
||||
}
|
||||
|
||||
var emptyReq http.Request
|
||||
var emptyPart *bytes.Buffer
|
||||
var emptyLogger zerolog.Logger
|
||||
|
||||
t.Run("failed to ensure directories", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
httpStatus, err := processFile(&emptyReq, "/proc/invalid/not-real", emptyPart, uid, gid, emptyLogger)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, http.StatusInternalServerError, httpStatus)
|
||||
assert.ErrorContains(t, err, "error ensuring directories: ")
|
||||
})
|
||||
|
||||
t.Run("attempt to replace directory with a file", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tempDir := t.TempDir()
|
||||
|
||||
httpStatus, err := processFile(&emptyReq, tempDir, emptyPart, uid, gid, emptyLogger)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, http.StatusBadRequest, httpStatus, err.Error())
|
||||
assert.ErrorContains(t, err, "path is a directory: ")
|
||||
})
|
||||
|
||||
t.Run("fail to create file", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
httpStatus, err := processFile(&emptyReq, "/proc/invalid-filename", emptyPart, uid, gid, emptyLogger)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, http.StatusInternalServerError, httpStatus)
|
||||
assert.ErrorContains(t, err, "error opening file: ")
|
||||
})
|
||||
|
||||
t.Run("out of disk space", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// make a tiny tmpfs mount
|
||||
mountSize := 1024
|
||||
tempDir := createTmpfsMount(t, mountSize)
|
||||
|
||||
// create test file
|
||||
firstFileSize := mountSize / 2
|
||||
tempFile1 := filepath.Join(tempDir, "test-file-1")
|
||||
|
||||
// fill it up
|
||||
cmd := exec.CommandContext(t.Context(),
|
||||
"dd", "if=/dev/zero", "of="+tempFile1, fmt.Sprintf("bs=%d", firstFileSize), "count=1")
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// create a new file that would fill up the
|
||||
secondFileContents := make([]byte, mountSize*2)
|
||||
for index := range secondFileContents {
|
||||
secondFileContents[index] = 'a'
|
||||
}
|
||||
|
||||
// try to replace it
|
||||
request, buffer := newRequest(secondFileContents)
|
||||
tempFile2 := filepath.Join(tempDir, "test-file-2")
|
||||
httpStatus, err := processFile(request, tempFile2, buffer, uid, gid, emptyLogger)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, http.StatusInsufficientStorage, httpStatus)
|
||||
assert.ErrorContains(t, err, "attempted to write 2048 bytes: not enough disk space")
|
||||
})
|
||||
|
||||
t.Run("happy path", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tempDir := t.TempDir()
|
||||
tempFile := filepath.Join(tempDir, "test-file")
|
||||
|
||||
content := []byte("test-file-contents")
|
||||
request, buffer := newRequest(content)
|
||||
|
||||
httpStatus, err := processFile(request, tempFile, buffer, uid, gid, emptyLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNoContent, httpStatus)
|
||||
|
||||
data, err := os.ReadFile(tempFile)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, data)
|
||||
})
|
||||
|
||||
t.Run("overwrite file on full disk", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// make a tiny tmpfs mount
|
||||
sizeInBytes := 1024
|
||||
tempDir := createTmpfsMount(t, 1024)
|
||||
|
||||
// create test file
|
||||
tempFile := filepath.Join(tempDir, "test-file")
|
||||
|
||||
// fill it up
|
||||
cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+tempFile, fmt.Sprintf("bs=%d", sizeInBytes), "count=1")
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// try to replace it
|
||||
content := []byte("test-file-contents")
|
||||
request, buffer := newRequest(content)
|
||||
httpStatus, err := processFile(request, tempFile, buffer, uid, gid, emptyLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNoContent, httpStatus)
|
||||
})
|
||||
|
||||
t.Run("write new file on full disk", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// make a tiny tmpfs mount
|
||||
sizeInBytes := 1024
|
||||
tempDir := createTmpfsMount(t, 1024)
|
||||
|
||||
// create test file
|
||||
tempFile1 := filepath.Join(tempDir, "test-file")
|
||||
|
||||
// fill it up
|
||||
cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+tempFile1, fmt.Sprintf("bs=%d", sizeInBytes), "count=1")
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// try to write a new file
|
||||
tempFile2 := filepath.Join(tempDir, "test-file-2")
|
||||
content := []byte("test-file-contents")
|
||||
request, buffer := newRequest(content)
|
||||
httpStatus, err := processFile(request, tempFile2, buffer, uid, gid, emptyLogger)
|
||||
require.ErrorContains(t, err, "not enough disk space available")
|
||||
assert.Equal(t, http.StatusInsufficientStorage, httpStatus)
|
||||
})
|
||||
|
||||
t.Run("write new file with no inodes available", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// make a tiny tmpfs mount
|
||||
tempDir := createTmpfsMountWithInodes(t, 1024, 2)
|
||||
|
||||
// create test file
|
||||
tempFile1 := filepath.Join(tempDir, "test-file")
|
||||
|
||||
// fill it up
|
||||
cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+tempFile1, fmt.Sprintf("bs=%d", 100), "count=1")
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// try to write a new file
|
||||
tempFile2 := filepath.Join(tempDir, "test-file-2")
|
||||
content := []byte("test-file-contents")
|
||||
request, buffer := newRequest(content)
|
||||
httpStatus, err := processFile(request, tempFile2, buffer, uid, gid, emptyLogger)
|
||||
require.ErrorContains(t, err, "not enough inodes available")
|
||||
assert.Equal(t, http.StatusInsufficientStorage, httpStatus)
|
||||
})
|
||||
|
||||
t.Run("update sysfs or other virtual fs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if os.Geteuid() != 0 {
|
||||
t.Skip("skipping sysfs updates: Operation not permitted with non-root user")
|
||||
}
|
||||
|
||||
filePath := "/sys/fs/cgroup/user.slice/cpu.weight"
|
||||
newContent := []byte("102\n")
|
||||
request, buffer := newRequest(newContent)
|
||||
|
||||
httpStatus, err := processFile(request, filePath, buffer, uid, gid, emptyLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNoContent, httpStatus)
|
||||
|
||||
data, err := os.ReadFile(filePath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, newContent, data)
|
||||
})
|
||||
|
||||
t.Run("replace file", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tempDir := t.TempDir()
|
||||
tempFile := filepath.Join(tempDir, "test-file")
|
||||
|
||||
err := os.WriteFile(tempFile, []byte("old-contents"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
newContent := []byte("new-file-contents")
|
||||
request, buffer := newRequest(newContent)
|
||||
|
||||
httpStatus, err := processFile(request, tempFile, buffer, uid, gid, emptyLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNoContent, httpStatus)
|
||||
|
||||
data, err := os.ReadFile(tempFile)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, newContent, data)
|
||||
})
|
||||
}
|
||||
|
||||
func createTmpfsMount(t *testing.T, sizeInBytes int) string {
|
||||
t.Helper()
|
||||
|
||||
return createTmpfsMountWithInodes(t, sizeInBytes, 5)
|
||||
}
|
||||
|
||||
func createTmpfsMountWithInodes(t *testing.T, sizeInBytes, inodesCount int) string {
|
||||
t.Helper()
|
||||
|
||||
if os.Geteuid() != 0 {
|
||||
t.Skip("skipping sysfs updates: Operation not permitted with non-root user")
|
||||
}
|
||||
|
||||
tempDir := t.TempDir()
|
||||
|
||||
cmd := exec.CommandContext(t.Context(),
|
||||
"mount",
|
||||
"tmpfs",
|
||||
tempDir,
|
||||
"-t", "tmpfs",
|
||||
"-o", fmt.Sprintf("size=%d,nr_inodes=%d", sizeInBytes, inodesCount))
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
ctx := context.WithoutCancel(t.Context())
|
||||
cmd := exec.CommandContext(ctx, "umount", tempDir)
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
return tempDir
|
||||
}
|
||||
37
envd/internal/execcontext/context.go
Normal file
37
envd/internal/execcontext/context.go
Normal file
@ -0,0 +1,37 @@
|
||||
package execcontext
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
type Defaults struct {
|
||||
EnvVars *utils.Map[string, string]
|
||||
User string
|
||||
Workdir *string
|
||||
}
|
||||
|
||||
func ResolveDefaultWorkdir(workdir string, defaultWorkdir *string) string {
|
||||
if workdir != "" {
|
||||
return workdir
|
||||
}
|
||||
|
||||
if defaultWorkdir != nil {
|
||||
return *defaultWorkdir
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func ResolveDefaultUsername(username *string, defaultUsername string) (string, error) {
|
||||
if username != nil {
|
||||
return *username, nil
|
||||
}
|
||||
|
||||
if defaultUsername != "" {
|
||||
return defaultUsername, nil
|
||||
}
|
||||
|
||||
return "", errors.New("username not provided")
|
||||
}
|
||||
93
envd/internal/host/metrics.go
Normal file
93
envd/internal/host/metrics.go
Normal file
@ -0,0 +1,93 @@
|
||||
package host
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/shirou/gopsutil/v4/cpu"
|
||||
"github.com/shirou/gopsutil/v4/mem"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type Metrics struct {
|
||||
Timestamp int64 `json:"ts"` // Unix Timestamp in UTC
|
||||
|
||||
CPUCount uint32 `json:"cpu_count"` // Total CPU cores
|
||||
CPUUsedPercent float32 `json:"cpu_used_pct"` // Percent rounded to 2 decimal places
|
||||
|
||||
// Deprecated: kept for backwards compatibility with older orchestrators.
|
||||
MemTotalMiB uint64 `json:"mem_total_mib"` // Total virtual memory in MiB
|
||||
|
||||
// Deprecated: kept for backwards compatibility with older orchestrators.
|
||||
MemUsedMiB uint64 `json:"mem_used_mib"` // Used virtual memory in MiB
|
||||
|
||||
MemTotal uint64 `json:"mem_total"` // Total virtual memory in bytes
|
||||
MemUsed uint64 `json:"mem_used"` // Used virtual memory in bytes
|
||||
|
||||
DiskUsed uint64 `json:"disk_used"` // Used disk space in bytes
|
||||
DiskTotal uint64 `json:"disk_total"` // Total disk space in bytes
|
||||
}
|
||||
|
||||
func GetMetrics() (*Metrics, error) {
|
||||
v, err := mem.VirtualMemory()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
memUsedMiB := v.Used / 1024 / 1024
|
||||
memTotalMiB := v.Total / 1024 / 1024
|
||||
|
||||
cpuTotal, err := cpu.Counts(true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cpuUsedPcts, err := cpu.Percent(0, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cpuUsedPct := cpuUsedPcts[0]
|
||||
cpuUsedPctRounded := float32(cpuUsedPct)
|
||||
if cpuUsedPct > 0 {
|
||||
cpuUsedPctRounded = float32(math.Round(cpuUsedPct*100) / 100)
|
||||
}
|
||||
|
||||
diskMetrics, err := diskStats("/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Metrics{
|
||||
Timestamp: time.Now().UTC().Unix(),
|
||||
CPUCount: uint32(cpuTotal),
|
||||
CPUUsedPercent: cpuUsedPctRounded,
|
||||
MemUsedMiB: memUsedMiB,
|
||||
MemTotalMiB: memTotalMiB,
|
||||
MemTotal: v.Total,
|
||||
MemUsed: v.Used,
|
||||
DiskUsed: diskMetrics.Total - diskMetrics.Available,
|
||||
DiskTotal: diskMetrics.Total,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type diskSpace struct {
|
||||
Total uint64
|
||||
Available uint64
|
||||
}
|
||||
|
||||
func diskStats(path string) (diskSpace, error) {
|
||||
var st unix.Statfs_t
|
||||
if err := unix.Statfs(path, &st); err != nil {
|
||||
return diskSpace{}, err
|
||||
}
|
||||
|
||||
block := uint64(st.Bsize)
|
||||
|
||||
// all data blocks
|
||||
total := st.Blocks * block
|
||||
// blocks available
|
||||
available := st.Bavail * block
|
||||
|
||||
return diskSpace{Total: total, Available: available}, nil
|
||||
}
|
||||
182
envd/internal/host/mmds.go
Normal file
182
envd/internal/host/mmds.go
Normal file
@ -0,0 +1,182 @@
|
||||
package host
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
WrennRunDir = "/run/wrenn" // store sandbox metadata files here
|
||||
|
||||
mmdsDefaultAddress = "169.254.169.254"
|
||||
mmdsTokenExpiration = 60 * time.Second
|
||||
|
||||
mmdsAccessTokenRequestClientTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
var mmdsAccessTokenClient = &http.Client{
|
||||
Timeout: mmdsAccessTokenRequestClientTimeout,
|
||||
Transport: &http.Transport{
|
||||
DisableKeepAlives: true,
|
||||
},
|
||||
}
|
||||
|
||||
type MMDSOpts struct {
|
||||
SandboxID string `json:"instanceID"`
|
||||
TemplateID string `json:"envID"`
|
||||
LogsCollectorAddress string `json:"address"`
|
||||
AccessTokenHash string `json:"accessTokenHash"`
|
||||
}
|
||||
|
||||
func (opts *MMDSOpts) Update(sandboxID, templateID, collectorAddress string) {
|
||||
opts.SandboxID = sandboxID
|
||||
opts.TemplateID = templateID
|
||||
opts.LogsCollectorAddress = collectorAddress
|
||||
}
|
||||
|
||||
func (opts *MMDSOpts) AddOptsToJSON(jsonLogs []byte) ([]byte, error) {
|
||||
parsed := make(map[string]any)
|
||||
|
||||
err := json.Unmarshal(jsonLogs, &parsed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsed["instanceID"] = opts.SandboxID
|
||||
parsed["envID"] = opts.TemplateID
|
||||
|
||||
data, err := json.Marshal(parsed)
|
||||
|
||||
return data, err
|
||||
}
|
||||
|
||||
func getMMDSToken(ctx context.Context, client *http.Client) (string, error) {
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://"+mmdsDefaultAddress+"/latest/api/token", &bytes.Buffer{})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
request.Header["X-metadata-token-ttl-seconds"] = []string{fmt.Sprint(mmdsTokenExpiration.Seconds())}
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
token := string(body)
|
||||
|
||||
if len(token) == 0 {
|
||||
return "", fmt.Errorf("mmds token is an empty string")
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func getMMDSOpts(ctx context.Context, client *http.Client, token string) (*MMDSOpts, error) {
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://"+mmdsDefaultAddress, &bytes.Buffer{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
request.Header["X-metadata-token"] = []string{token}
|
||||
request.Header["Accept"] = []string{"application/json"}
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer response.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var opts MMDSOpts
|
||||
|
||||
err = json.Unmarshal(body, &opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &opts, nil
|
||||
}
|
||||
|
||||
// GetAccessTokenHashFromMMDS reads the access token hash from MMDS.
|
||||
// This is used to validate that /init requests come from the orchestrator.
|
||||
func GetAccessTokenHashFromMMDS(ctx context.Context) (string, error) {
|
||||
token, err := getMMDSToken(ctx, mmdsAccessTokenClient)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get MMDS token: %w", err)
|
||||
}
|
||||
|
||||
opts, err := getMMDSOpts(ctx, mmdsAccessTokenClient, token)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get MMDS opts: %w", err)
|
||||
}
|
||||
|
||||
return opts.AccessTokenHash, nil
|
||||
}
|
||||
|
||||
func PollForMMDSOpts(ctx context.Context, mmdsChan chan<- *MMDSOpts, envVars *utils.Map[string, string]) {
|
||||
httpClient := &http.Client{}
|
||||
defer httpClient.CloseIdleConnections()
|
||||
|
||||
ticker := time.NewTicker(50 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "context cancelled while waiting for mmds opts")
|
||||
|
||||
return
|
||||
case <-ticker.C:
|
||||
token, err := getMMDSToken(ctx, httpClient)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error getting mmds token: %v\n", err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
mmdsOpts, err := getMMDSOpts(ctx, httpClient, token)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error getting mmds opts: %v\n", err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
envVars.Store("WRENN_SANDBOX_ID", mmdsOpts.SandboxID)
|
||||
envVars.Store("WRENN_TEMPLATE_ID", mmdsOpts.TemplateID)
|
||||
|
||||
if err := os.WriteFile(filepath.Join(WrennRunDir, ".WRENN_SANDBOX_ID"), []byte(mmdsOpts.SandboxID), 0o666); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error writing sandbox ID file: %v\n", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(WrennRunDir, ".WRENN_TEMPLATE_ID"), []byte(mmdsOpts.TemplateID), 0o666); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error writing template ID file: %v\n", err)
|
||||
}
|
||||
|
||||
if mmdsOpts.LogsCollectorAddress != "" {
|
||||
mmdsChan <- mmdsOpts
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
47
envd/internal/logs/bufferedEvents.go
Normal file
47
envd/internal/logs/bufferedEvents.go
Normal file
@ -0,0 +1,47 @@
|
||||
package logs
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMaxBufferSize = 2 << 15
|
||||
defaultTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
func LogBufferedDataEvents(dataCh <-chan []byte, logger *zerolog.Logger, eventType string) {
|
||||
timer := time.NewTicker(defaultTimeout)
|
||||
defer timer.Stop()
|
||||
|
||||
var buffer []byte
|
||||
defer func() {
|
||||
if len(buffer) > 0 {
|
||||
logger.Info().Str(eventType, string(buffer)).Msg("Streaming process event (flush)")
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
if len(buffer) > 0 {
|
||||
logger.Info().Str(eventType, string(buffer)).Msg("Streaming process event")
|
||||
buffer = nil
|
||||
}
|
||||
case data, ok := <-dataCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
buffer = append(buffer, data...)
|
||||
|
||||
if len(buffer) >= defaultMaxBufferSize {
|
||||
logger.Info().Str(eventType, string(buffer)).Msg("Streaming process event")
|
||||
buffer = nil
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
172
envd/internal/logs/exporter/exporter.go
Normal file
172
envd/internal/logs/exporter/exporter.go
Normal file
@ -0,0 +1,172 @@
|
||||
package exporter
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
||||
)
|
||||
|
||||
const ExporterTimeout = 10 * time.Second
|
||||
|
||||
type HTTPExporter struct {
|
||||
client http.Client
|
||||
logs [][]byte
|
||||
isNotFC bool
|
||||
mmdsOpts *host.MMDSOpts
|
||||
|
||||
// Concurrency coordination
|
||||
triggers chan struct{}
|
||||
logLock sync.RWMutex
|
||||
mmdsLock sync.RWMutex
|
||||
startOnce sync.Once
|
||||
}
|
||||
|
||||
func NewHTTPLogsExporter(ctx context.Context, isNotFC bool, mmdsChan <-chan *host.MMDSOpts) *HTTPExporter {
|
||||
exporter := &HTTPExporter{
|
||||
client: http.Client{
|
||||
Timeout: ExporterTimeout,
|
||||
},
|
||||
triggers: make(chan struct{}, 1),
|
||||
isNotFC: isNotFC,
|
||||
startOnce: sync.Once{},
|
||||
mmdsOpts: &host.MMDSOpts{
|
||||
SandboxID: "unknown",
|
||||
TemplateID: "unknown",
|
||||
LogsCollectorAddress: "",
|
||||
},
|
||||
}
|
||||
|
||||
go exporter.listenForMMDSOptsAndStart(ctx, mmdsChan)
|
||||
|
||||
return exporter
|
||||
}
|
||||
|
||||
func (w *HTTPExporter) sendInstanceLogs(ctx context.Context, logs []byte, address string) error {
|
||||
if address == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodPost, address, bytes.NewBuffer(logs))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
response, err := w.client.Do(request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func printLog(logs []byte) {
|
||||
fmt.Fprintf(os.Stdout, "%v", string(logs))
|
||||
}
|
||||
|
||||
func (w *HTTPExporter) listenForMMDSOptsAndStart(ctx context.Context, mmdsChan <-chan *host.MMDSOpts) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case mmdsOpts, ok := <-mmdsChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
w.mmdsLock.Lock()
|
||||
w.mmdsOpts.Update(mmdsOpts.SandboxID, mmdsOpts.TemplateID, mmdsOpts.LogsCollectorAddress)
|
||||
w.mmdsLock.Unlock()
|
||||
|
||||
w.startOnce.Do(func() {
|
||||
go w.start(ctx)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *HTTPExporter) start(ctx context.Context) {
|
||||
for range w.triggers {
|
||||
logs := w.getAllLogs()
|
||||
|
||||
if len(logs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if w.isNotFC {
|
||||
for _, log := range logs {
|
||||
fmt.Fprintf(os.Stdout, "%v", string(log))
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
for _, logLine := range logs {
|
||||
w.mmdsLock.RLock()
|
||||
logLineWithOpts, err := w.mmdsOpts.AddOptsToJSON(logLine)
|
||||
w.mmdsLock.RUnlock()
|
||||
if err != nil {
|
||||
log.Printf("error adding instance logging options (%+v) to JSON (%+v) with logs : %v\n", w.mmdsOpts, logLine, err)
|
||||
|
||||
printLog(logLine)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
err = w.sendInstanceLogs(ctx, logLineWithOpts, w.mmdsOpts.LogsCollectorAddress)
|
||||
if err != nil {
|
||||
log.Printf("error sending instance logs: %+v", err)
|
||||
|
||||
printLog(logLine)
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *HTTPExporter) resumeProcessing() {
|
||||
select {
|
||||
case w.triggers <- struct{}{}:
|
||||
default:
|
||||
// Exporter processing already triggered
|
||||
// This is expected behavior if the exporter is already processing logs
|
||||
}
|
||||
}
|
||||
|
||||
func (w *HTTPExporter) Write(logs []byte) (int, error) {
|
||||
logsCopy := make([]byte, len(logs))
|
||||
copy(logsCopy, logs)
|
||||
|
||||
go w.addLogs(logsCopy)
|
||||
|
||||
return len(logs), nil
|
||||
}
|
||||
|
||||
func (w *HTTPExporter) getAllLogs() [][]byte {
|
||||
w.logLock.Lock()
|
||||
defer w.logLock.Unlock()
|
||||
|
||||
logs := w.logs
|
||||
w.logs = nil
|
||||
|
||||
return logs
|
||||
}
|
||||
|
||||
func (w *HTTPExporter) addLogs(logs []byte) {
|
||||
w.logLock.Lock()
|
||||
defer w.logLock.Unlock()
|
||||
|
||||
w.logs = append(w.logs, logs)
|
||||
|
||||
w.resumeProcessing()
|
||||
}
|
||||
172
envd/internal/logs/interceptor.go
Normal file
172
envd/internal/logs/interceptor.go
Normal file
@ -0,0 +1,172 @@
|
||||
package logs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
type OperationID string
|
||||
|
||||
const (
|
||||
OperationIDKey OperationID = "operation_id"
|
||||
DefaultHTTPMethod string = "POST"
|
||||
)
|
||||
|
||||
var operationID = atomic.Int32{}
|
||||
|
||||
func AssignOperationID() string {
|
||||
id := operationID.Add(1)
|
||||
|
||||
return strconv.Itoa(int(id))
|
||||
}
|
||||
|
||||
func AddRequestIDToContext(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, OperationIDKey, AssignOperationID())
|
||||
}
|
||||
|
||||
func formatMethod(method string) string {
|
||||
parts := strings.Split(method, ".")
|
||||
if len(parts) < 2 {
|
||||
return method
|
||||
}
|
||||
|
||||
split := strings.Split(parts[1], "/")
|
||||
if len(split) < 2 {
|
||||
return method
|
||||
}
|
||||
|
||||
servicePart := split[0]
|
||||
servicePart = strings.ToUpper(servicePart[:1]) + servicePart[1:]
|
||||
|
||||
methodPart := split[1]
|
||||
methodPart = strings.ToLower(methodPart[:1]) + methodPart[1:]
|
||||
|
||||
return fmt.Sprintf("%s %s", servicePart, methodPart)
|
||||
}
|
||||
|
||||
func NewUnaryLogInterceptor(logger *zerolog.Logger) connect.UnaryInterceptorFunc {
|
||||
interceptor := func(next connect.UnaryFunc) connect.UnaryFunc {
|
||||
return connect.UnaryFunc(func(
|
||||
ctx context.Context,
|
||||
req connect.AnyRequest,
|
||||
) (connect.AnyResponse, error) {
|
||||
ctx = AddRequestIDToContext(ctx)
|
||||
|
||||
res, err := next(ctx, req)
|
||||
|
||||
l := logger.
|
||||
Err(err).
|
||||
Str("method", DefaultHTTPMethod+" "+req.Spec().Procedure).
|
||||
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
|
||||
|
||||
if err != nil {
|
||||
l = l.Int("error_code", int(connect.CodeOf(err)))
|
||||
}
|
||||
|
||||
if req != nil {
|
||||
l = l.Interface("request", req.Any())
|
||||
}
|
||||
|
||||
if res != nil && err == nil {
|
||||
l = l.Interface("response", res.Any())
|
||||
}
|
||||
|
||||
if res == nil && err == nil {
|
||||
l = l.Interface("response", nil)
|
||||
}
|
||||
|
||||
l.Msg(formatMethod(req.Spec().Procedure))
|
||||
|
||||
return res, err
|
||||
})
|
||||
}
|
||||
|
||||
return connect.UnaryInterceptorFunc(interceptor)
|
||||
}
|
||||
|
||||
func LogServerStreamWithoutEvents[T any, R any](
|
||||
ctx context.Context,
|
||||
logger *zerolog.Logger,
|
||||
req *connect.Request[R],
|
||||
stream *connect.ServerStream[T],
|
||||
handler func(ctx context.Context, req *connect.Request[R], stream *connect.ServerStream[T]) error,
|
||||
) error {
|
||||
ctx = AddRequestIDToContext(ctx)
|
||||
|
||||
l := logger.Debug().
|
||||
Str("method", DefaultHTTPMethod+" "+req.Spec().Procedure).
|
||||
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
|
||||
|
||||
if req != nil {
|
||||
l = l.Interface("request", req.Any())
|
||||
}
|
||||
|
||||
l.Msg(fmt.Sprintf("%s (server stream start)", formatMethod(req.Spec().Procedure)))
|
||||
|
||||
err := handler(ctx, req, stream)
|
||||
|
||||
logEvent := getErrDebugLogEvent(logger, err).
|
||||
Str("method", DefaultHTTPMethod+" "+req.Spec().Procedure).
|
||||
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
|
||||
|
||||
if err != nil {
|
||||
logEvent = logEvent.Int("error_code", int(connect.CodeOf(err)))
|
||||
} else {
|
||||
logEvent = logEvent.Interface("response", nil)
|
||||
}
|
||||
|
||||
logEvent.Msg(fmt.Sprintf("%s (server stream end)", formatMethod(req.Spec().Procedure)))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func LogClientStreamWithoutEvents[T any, R any](
|
||||
ctx context.Context,
|
||||
logger *zerolog.Logger,
|
||||
stream *connect.ClientStream[T],
|
||||
handler func(ctx context.Context, stream *connect.ClientStream[T]) (*connect.Response[R], error),
|
||||
) (*connect.Response[R], error) {
|
||||
ctx = AddRequestIDToContext(ctx)
|
||||
|
||||
logger.Debug().
|
||||
Str("method", DefaultHTTPMethod+" "+stream.Spec().Procedure).
|
||||
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string)).
|
||||
Msg(fmt.Sprintf("%s (client stream start)", formatMethod(stream.Spec().Procedure)))
|
||||
|
||||
res, err := handler(ctx, stream)
|
||||
|
||||
logEvent := getErrDebugLogEvent(logger, err).
|
||||
Str("method", DefaultHTTPMethod+" "+stream.Spec().Procedure).
|
||||
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
|
||||
|
||||
if err != nil {
|
||||
logEvent = logEvent.Int("error_code", int(connect.CodeOf(err)))
|
||||
}
|
||||
|
||||
if res != nil && err == nil {
|
||||
logEvent = logEvent.Interface("response", res.Any())
|
||||
}
|
||||
|
||||
if res == nil && err == nil {
|
||||
logEvent = logEvent.Interface("response", nil)
|
||||
}
|
||||
|
||||
logEvent.Msg(fmt.Sprintf("%s (client stream end)", formatMethod(stream.Spec().Procedure)))
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
// Return logger with error level if err is not nil, otherwise return logger with debug level
|
||||
func getErrDebugLogEvent(logger *zerolog.Logger, err error) *zerolog.Event {
|
||||
if err != nil {
|
||||
return logger.Error().Err(err) //nolint:zerologlint // this builds an event, it is not expected to return it
|
||||
}
|
||||
|
||||
return logger.Debug() //nolint:zerologlint // this builds an event, it is not expected to return it
|
||||
}
|
||||
35
envd/internal/logs/logger.go
Normal file
35
envd/internal/logs/logger.go
Normal file
@ -0,0 +1,35 @@
|
||||
package logs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs/exporter"
|
||||
)
|
||||
|
||||
func NewLogger(ctx context.Context, isNotFC bool, mmdsChan <-chan *host.MMDSOpts) *zerolog.Logger {
|
||||
zerolog.TimestampFieldName = "timestamp"
|
||||
zerolog.TimeFieldFormat = time.RFC3339Nano
|
||||
|
||||
exporters := []io.Writer{}
|
||||
|
||||
if isNotFC {
|
||||
exporters = append(exporters, os.Stdout)
|
||||
} else {
|
||||
exporters = append(exporters, exporter.NewHTTPLogsExporter(ctx, isNotFC, mmdsChan), os.Stdout)
|
||||
}
|
||||
|
||||
l := zerolog.
|
||||
New(io.MultiWriter(exporters...)).
|
||||
With().
|
||||
Timestamp().
|
||||
Logger().
|
||||
Level(zerolog.DebugLevel)
|
||||
|
||||
return &l
|
||||
}
|
||||
47
envd/internal/permissions/authenticate.go
Normal file
47
envd/internal/permissions/authenticate.go
Normal file
@ -0,0 +1,47 @@
|
||||
package permissions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/user"
|
||||
|
||||
"connectrpc.com/authn"
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
)
|
||||
|
||||
func AuthenticateUsername(_ context.Context, req authn.Request) (any, error) {
|
||||
username, _, ok := req.BasicAuth()
|
||||
if !ok {
|
||||
// When no username is provided, ignore the authentication method (not all endpoints require it)
|
||||
// Missing user is then handled in the GetAuthUser function
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
u, err := GetUser(username)
|
||||
if err != nil {
|
||||
return nil, authn.Errorf("invalid username: '%s'", username)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func GetAuthUser(ctx context.Context, defaultUser string) (*user.User, error) {
|
||||
u, ok := authn.GetInfo(ctx).(*user.User)
|
||||
if !ok {
|
||||
username, err := execcontext.ResolveDefaultUsername(nil, defaultUser)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("no user specified"))
|
||||
}
|
||||
|
||||
u, err := GetUser(username)
|
||||
if err != nil {
|
||||
return nil, authn.Errorf("invalid default user: '%s'", username)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
29
envd/internal/permissions/keepalive.go
Normal file
29
envd/internal/permissions/keepalive.go
Normal file
@ -0,0 +1,29 @@
|
||||
package permissions
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
)
|
||||
|
||||
const defaultKeepAliveInterval = 90 * time.Second
|
||||
|
||||
func GetKeepAliveTicker[T any](req *connect.Request[T]) (*time.Ticker, func()) {
|
||||
keepAliveIntervalHeader := req.Header().Get("Keepalive-Ping-Interval")
|
||||
|
||||
var interval time.Duration
|
||||
|
||||
keepAliveIntervalInt, err := strconv.Atoi(keepAliveIntervalHeader)
|
||||
if err != nil {
|
||||
interval = defaultKeepAliveInterval
|
||||
} else {
|
||||
interval = time.Duration(keepAliveIntervalInt) * time.Second
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
|
||||
return ticker, func() {
|
||||
ticker.Reset(interval)
|
||||
}
|
||||
}
|
||||
96
envd/internal/permissions/path.go
Normal file
96
envd/internal/permissions/path.go
Normal file
@ -0,0 +1,96 @@
|
||||
package permissions
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
)
|
||||
|
||||
func expand(path, homedir string) (string, error) {
|
||||
if len(path) == 0 {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
if path[0] != '~' {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
if len(path) > 1 && path[1] != '/' && path[1] != '\\' {
|
||||
return "", errors.New("cannot expand user-specific home dir")
|
||||
}
|
||||
|
||||
return filepath.Join(homedir, path[1:]), nil
|
||||
}
|
||||
|
||||
func ExpandAndResolve(path string, user *user.User, defaultPath *string) (string, error) {
|
||||
path = execcontext.ResolveDefaultWorkdir(path, defaultPath)
|
||||
|
||||
path, err := expand(path, user.HomeDir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to expand path '%s' for user '%s': %w", path, user.Username, err)
|
||||
}
|
||||
|
||||
if filepath.IsAbs(path) {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// The filepath.Abs can correctly resolve paths like /home/user/../file
|
||||
path = filepath.Join(user.HomeDir, path)
|
||||
|
||||
abs, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to resolve path '%s' for user '%s' with home dir '%s': %w", path, user.Username, user.HomeDir, err)
|
||||
}
|
||||
|
||||
return abs, nil
|
||||
}
|
||||
|
||||
func getSubpaths(path string) (subpaths []string) {
|
||||
for {
|
||||
subpaths = append(subpaths, path)
|
||||
|
||||
path = filepath.Dir(path)
|
||||
if path == "/" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
slices.Reverse(subpaths)
|
||||
|
||||
return subpaths
|
||||
}
|
||||
|
||||
func EnsureDirs(path string, uid, gid int) error {
|
||||
subpaths := getSubpaths(path)
|
||||
for _, subpath := range subpaths {
|
||||
info, err := os.Stat(subpath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to stat directory: %w", err)
|
||||
}
|
||||
|
||||
if err != nil && os.IsNotExist(err) {
|
||||
err = os.Mkdir(subpath, 0o755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
|
||||
err = os.Chown(subpath, uid, gid)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to chown directory: %w", err)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if !info.IsDir() {
|
||||
return fmt.Errorf("path is a file: %s", subpath)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
44
envd/internal/permissions/user.go
Normal file
44
envd/internal/permissions/user.go
Normal file
@ -0,0 +1,44 @@
|
||||
package permissions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/user"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func GetUserIdUints(u *user.User) (uid, gid uint32, err error) {
|
||||
newUID, err := strconv.ParseUint(u.Uid, 10, 32)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("error parsing uid '%s': %w", u.Uid, err)
|
||||
}
|
||||
|
||||
newGID, err := strconv.ParseUint(u.Gid, 10, 32)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("error parsing gid '%s': %w", u.Gid, err)
|
||||
}
|
||||
|
||||
return uint32(newUID), uint32(newGID), nil
|
||||
}
|
||||
|
||||
func GetUserIdInts(u *user.User) (uid, gid int, err error) {
|
||||
newUID, err := strconv.ParseInt(u.Uid, 10, strconv.IntSize)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("error parsing uid '%s': %w", u.Uid, err)
|
||||
}
|
||||
|
||||
newGID, err := strconv.ParseInt(u.Gid, 10, strconv.IntSize)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("error parsing gid '%s': %w", u.Gid, err)
|
||||
}
|
||||
|
||||
return int(newUID), int(newGID), nil
|
||||
}
|
||||
|
||||
func GetUser(username string) (u *user.User, err error) {
|
||||
u, err = user.Lookup(username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error looking up user '%s': %w", username, err)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
218
envd/internal/port/forward.go
Normal file
218
envd/internal/port/forward.go
Normal file
@ -0,0 +1,218 @@
|
||||
// portf (port forward) periodaically scans opened TCP ports on the 127.0.0.1 (or localhost)
|
||||
// and launches `socat` process for every such port in the background.
|
||||
// socat forward traffic from `sourceIP`:port to the 127.0.0.1:port.
|
||||
|
||||
// WARNING: portf isn't thread safe!
|
||||
|
||||
package port
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/services/cgroups"
|
||||
)
|
||||
|
||||
type PortState string
|
||||
|
||||
const (
|
||||
PortStateForward PortState = "FORWARD"
|
||||
PortStateDelete PortState = "DELETE"
|
||||
)
|
||||
|
||||
var defaultGatewayIP = net.IPv4(169, 254, 0, 21)
|
||||
|
||||
type PortToForward struct {
|
||||
socat *exec.Cmd
|
||||
// Process ID of the process that's listening on port.
|
||||
pid int32
|
||||
// family version of the ip.
|
||||
family uint32
|
||||
state PortState
|
||||
port uint32
|
||||
}
|
||||
|
||||
type Forwarder struct {
|
||||
logger *zerolog.Logger
|
||||
cgroupManager cgroups.Manager
|
||||
// Map of ports that are being currently forwarded.
|
||||
ports map[string]*PortToForward
|
||||
scannerSubscriber *ScannerSubscriber
|
||||
sourceIP net.IP
|
||||
}
|
||||
|
||||
func NewForwarder(
|
||||
logger *zerolog.Logger,
|
||||
scanner *Scanner,
|
||||
cgroupManager cgroups.Manager,
|
||||
) *Forwarder {
|
||||
scannerSub := scanner.AddSubscriber(
|
||||
logger,
|
||||
"port-forwarder",
|
||||
// We only want to forward ports that are actively listening on localhost.
|
||||
&ScannerFilter{
|
||||
IPs: []string{"127.0.0.1", "localhost", "::1"},
|
||||
State: "LISTEN",
|
||||
},
|
||||
)
|
||||
|
||||
return &Forwarder{
|
||||
logger: logger,
|
||||
sourceIP: defaultGatewayIP,
|
||||
ports: make(map[string]*PortToForward),
|
||||
scannerSubscriber: scannerSub,
|
||||
cgroupManager: cgroupManager,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Forwarder) StartForwarding(ctx context.Context) {
|
||||
if f.scannerSubscriber == nil {
|
||||
f.logger.Error().Msg("Cannot start forwarding because scanner subscriber is nil")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
// procs is an array of currently opened ports.
|
||||
if procs, ok := <-f.scannerSubscriber.Messages; ok {
|
||||
// Now we are going to refresh all ports that are being forwarded in the `ports` map. Maybe add new ones
|
||||
// and maybe remove some.
|
||||
|
||||
// Go through the ports that are currently being forwarded and set all of them
|
||||
// to the `DELETE` state. We don't know yet if they will be there after refresh.
|
||||
for _, v := range f.ports {
|
||||
v.state = PortStateDelete
|
||||
}
|
||||
|
||||
// Let's refresh our map of currently forwarded ports and mark the currently opened ones with the "FORWARD" state.
|
||||
// This will make sure we won't delete them later.
|
||||
for _, p := range procs {
|
||||
key := fmt.Sprintf("%d-%d", p.Pid, p.Laddr.Port)
|
||||
|
||||
// We check if the opened port is in our map of forwarded ports.
|
||||
val, portOk := f.ports[key]
|
||||
if portOk {
|
||||
// Just mark the port as being forwarded so we don't delete it.
|
||||
// The actual socat process that handles forwarding should be running from the last iteration.
|
||||
val.state = PortStateForward
|
||||
} else {
|
||||
f.logger.Debug().
|
||||
Str("ip", p.Laddr.IP).
|
||||
Uint32("port", p.Laddr.Port).
|
||||
Uint32("family", familyToIPVersion(p.Family)).
|
||||
Str("state", p.Status).
|
||||
Msg("Detected new opened port on localhost that is not forwarded")
|
||||
|
||||
// The opened port wasn't in the map so we create a new PortToForward and start forwarding.
|
||||
ptf := &PortToForward{
|
||||
pid: p.Pid,
|
||||
port: p.Laddr.Port,
|
||||
state: PortStateForward,
|
||||
family: familyToIPVersion(p.Family),
|
||||
}
|
||||
f.ports[key] = ptf
|
||||
f.startPortForwarding(ctx, ptf)
|
||||
}
|
||||
}
|
||||
|
||||
// We go through the ports map one more time and stop forwarding all ports
|
||||
// that stayed marked as "DELETE".
|
||||
for _, v := range f.ports {
|
||||
if v.state == PortStateDelete {
|
||||
f.stopPortForwarding(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Forwarder) startPortForwarding(ctx context.Context, p *PortToForward) {
|
||||
// https://unix.stackexchange.com/questions/311492/redirect-application-listening-on-localhost-to-listening-on-external-interface
|
||||
// socat -d -d TCP4-LISTEN:4000,bind=169.254.0.21,fork TCP4:localhost:4000
|
||||
// reuseaddr is used to fix the "Address already in use" error when restarting socat quickly.
|
||||
cmd := exec.CommandContext(ctx,
|
||||
"socat", "-d", "-d", "-d",
|
||||
fmt.Sprintf("TCP4-LISTEN:%v,bind=%s,reuseaddr,fork", p.port, f.sourceIP.To4()),
|
||||
fmt.Sprintf("TCP%d:localhost:%v", p.family, p.port),
|
||||
)
|
||||
|
||||
cgroupFD, ok := f.cgroupManager.GetFileDescriptor(cgroups.ProcessTypeSocat)
|
||||
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
Setpgid: true,
|
||||
CgroupFD: cgroupFD,
|
||||
UseCgroupFD: ok,
|
||||
}
|
||||
|
||||
f.logger.Debug().
|
||||
Str("socatCmd", cmd.String()).
|
||||
Int32("pid", p.pid).
|
||||
Uint32("family", p.family).
|
||||
IPAddr("sourceIP", f.sourceIP.To4()).
|
||||
Uint32("port", p.port).
|
||||
Msg("About to start port forwarding")
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
f.logger.
|
||||
Error().
|
||||
Str("socatCmd", cmd.String()).
|
||||
Err(err).
|
||||
Msg("Failed to start port forwarding - failed to start socat")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := cmd.Wait(); err != nil {
|
||||
f.logger.
|
||||
Debug().
|
||||
Str("socatCmd", cmd.String()).
|
||||
Err(err).
|
||||
Msg("Port forwarding socat process exited")
|
||||
}
|
||||
}()
|
||||
|
||||
p.socat = cmd
|
||||
}
|
||||
|
||||
func (f *Forwarder) stopPortForwarding(p *PortToForward) {
|
||||
if p.socat == nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() { p.socat = nil }()
|
||||
|
||||
logger := f.logger.With().
|
||||
Str("socatCmd", p.socat.String()).
|
||||
Int32("pid", p.pid).
|
||||
Uint32("family", p.family).
|
||||
IPAddr("sourceIP", f.sourceIP.To4()).
|
||||
Uint32("port", p.port).
|
||||
Logger()
|
||||
|
||||
logger.Debug().Msg("Stopping port forwarding")
|
||||
|
||||
if err := syscall.Kill(-p.socat.Process.Pid, syscall.SIGKILL); err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to kill process group")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Stopped port forwarding")
|
||||
}
|
||||
|
||||
func familyToIPVersion(family uint32) uint32 {
|
||||
switch family {
|
||||
case syscall.AF_INET:
|
||||
return 4
|
||||
case syscall.AF_INET6:
|
||||
return 6
|
||||
default:
|
||||
return 0 // Unknown or unsupported family
|
||||
}
|
||||
}
|
||||
59
envd/internal/port/scan.go
Normal file
59
envd/internal/port/scan.go
Normal file
@ -0,0 +1,59 @@
|
||||
package port
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/shirou/gopsutil/v4/net"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/smap"
|
||||
)
|
||||
|
||||
type Scanner struct {
|
||||
Processes chan net.ConnectionStat
|
||||
scanExit chan struct{}
|
||||
subs *smap.Map[*ScannerSubscriber]
|
||||
period time.Duration
|
||||
}
|
||||
|
||||
func (s *Scanner) Destroy() {
|
||||
close(s.scanExit)
|
||||
}
|
||||
|
||||
func NewScanner(period time.Duration) *Scanner {
|
||||
return &Scanner{
|
||||
period: period,
|
||||
subs: smap.New[*ScannerSubscriber](),
|
||||
scanExit: make(chan struct{}),
|
||||
Processes: make(chan net.ConnectionStat),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scanner) AddSubscriber(logger *zerolog.Logger, id string, filter *ScannerFilter) *ScannerSubscriber {
|
||||
subscriber := NewScannerSubscriber(logger, id, filter)
|
||||
s.subs.Insert(id, subscriber)
|
||||
|
||||
return subscriber
|
||||
}
|
||||
|
||||
func (s *Scanner) Unsubscribe(sub *ScannerSubscriber) {
|
||||
s.subs.Remove(sub.ID())
|
||||
sub.Destroy()
|
||||
}
|
||||
|
||||
// ScanAndBroadcast starts scanning open TCP ports and broadcasts every open port to all subscribers.
|
||||
func (s *Scanner) ScanAndBroadcast() {
|
||||
for {
|
||||
// tcp monitors both ipv4 and ipv6 connections.
|
||||
processes, _ := net.Connections("tcp")
|
||||
for _, sub := range s.subs.Items() {
|
||||
sub.Signal(processes)
|
||||
}
|
||||
select {
|
||||
case <-s.scanExit:
|
||||
return
|
||||
default:
|
||||
time.Sleep(s.period)
|
||||
}
|
||||
}
|
||||
}
|
||||
50
envd/internal/port/scanSubscriber.go
Normal file
50
envd/internal/port/scanSubscriber.go
Normal file
@ -0,0 +1,50 @@
|
||||
package port
|
||||
|
||||
import (
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/shirou/gopsutil/v4/net"
|
||||
)
|
||||
|
||||
// If we want to create a listener/subscriber pattern somewhere else we should move
|
||||
// from a concrete implementation to combination of generics and interfaces.
|
||||
|
||||
type ScannerSubscriber struct {
|
||||
logger *zerolog.Logger
|
||||
filter *ScannerFilter
|
||||
Messages chan ([]net.ConnectionStat)
|
||||
id string
|
||||
}
|
||||
|
||||
func NewScannerSubscriber(logger *zerolog.Logger, id string, filter *ScannerFilter) *ScannerSubscriber {
|
||||
return &ScannerSubscriber{
|
||||
logger: logger,
|
||||
id: id,
|
||||
filter: filter,
|
||||
Messages: make(chan []net.ConnectionStat),
|
||||
}
|
||||
}
|
||||
|
||||
func (ss *ScannerSubscriber) ID() string {
|
||||
return ss.id
|
||||
}
|
||||
|
||||
func (ss *ScannerSubscriber) Destroy() {
|
||||
close(ss.Messages)
|
||||
}
|
||||
|
||||
func (ss *ScannerSubscriber) Signal(proc []net.ConnectionStat) {
|
||||
// Filter isn't specified. Accept everything.
|
||||
if ss.filter == nil {
|
||||
ss.Messages <- proc
|
||||
} else {
|
||||
filtered := []net.ConnectionStat{}
|
||||
for i := range proc {
|
||||
// We need to access the list directly otherwise there will be implicit memory aliasing
|
||||
// If the filter matched a process, we will send it to a channel.
|
||||
if ss.filter.Match(&proc[i]) {
|
||||
filtered = append(filtered, proc[i])
|
||||
}
|
||||
}
|
||||
ss.Messages <- filtered
|
||||
}
|
||||
}
|
||||
27
envd/internal/port/scanfilter.go
Normal file
27
envd/internal/port/scanfilter.go
Normal file
@ -0,0 +1,27 @@
|
||||
package port
|
||||
|
||||
import (
|
||||
"slices"
|
||||
|
||||
"github.com/shirou/gopsutil/v4/net"
|
||||
)
|
||||
|
||||
type ScannerFilter struct {
|
||||
State string
|
||||
IPs []string
|
||||
}
|
||||
|
||||
func (sf *ScannerFilter) Match(proc *net.ConnectionStat) bool {
|
||||
// Filter is an empty struct.
|
||||
if sf.State == "" && len(sf.IPs) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
ipMatch := slices.Contains(sf.IPs, proc.Laddr.IP)
|
||||
|
||||
if ipMatch && sf.State == proc.Status {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
127
envd/internal/services/cgroups/cgroup2.go
Normal file
127
envd/internal/services/cgroups/cgroup2.go
Normal file
@ -0,0 +1,127 @@
|
||||
package cgroups
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type Cgroup2Manager struct {
|
||||
cgroupFDs map[ProcessType]int
|
||||
}
|
||||
|
||||
var _ Manager = (*Cgroup2Manager)(nil)
|
||||
|
||||
type cgroup2Config struct {
|
||||
rootPath string
|
||||
processTypes map[ProcessType]Cgroup2Config
|
||||
}
|
||||
|
||||
type Cgroup2ManagerOption func(*cgroup2Config)
|
||||
|
||||
func WithCgroup2RootSysFSPath(path string) Cgroup2ManagerOption {
|
||||
return func(config *cgroup2Config) {
|
||||
config.rootPath = path
|
||||
}
|
||||
}
|
||||
|
||||
func WithCgroup2ProcessType(processType ProcessType, path string, properties map[string]string) Cgroup2ManagerOption {
|
||||
return func(config *cgroup2Config) {
|
||||
if config.processTypes == nil {
|
||||
config.processTypes = make(map[ProcessType]Cgroup2Config)
|
||||
}
|
||||
config.processTypes[processType] = Cgroup2Config{Path: path, Properties: properties}
|
||||
}
|
||||
}
|
||||
|
||||
type Cgroup2Config struct {
|
||||
Path string
|
||||
Properties map[string]string
|
||||
}
|
||||
|
||||
func NewCgroup2Manager(opts ...Cgroup2ManagerOption) (*Cgroup2Manager, error) {
|
||||
config := cgroup2Config{
|
||||
rootPath: "/sys/fs/cgroup",
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(&config)
|
||||
}
|
||||
|
||||
cgroupFDs, err := createCgroups(config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cgroups: %w", err)
|
||||
}
|
||||
|
||||
return &Cgroup2Manager{cgroupFDs: cgroupFDs}, nil
|
||||
}
|
||||
|
||||
func createCgroups(configs cgroup2Config) (map[ProcessType]int, error) {
|
||||
var (
|
||||
results = make(map[ProcessType]int)
|
||||
errs []error
|
||||
)
|
||||
|
||||
for procType, config := range configs.processTypes {
|
||||
fullPath := filepath.Join(configs.rootPath, config.Path)
|
||||
fd, err := createCgroup(fullPath, config.Properties)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to create %s cgroup: %w", procType, err))
|
||||
|
||||
continue
|
||||
}
|
||||
results[procType] = fd
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
for procType, fd := range results {
|
||||
err := unix.Close(fd)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to close cgroup fd for %s: %w", procType, err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.Join(errs...)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func createCgroup(fullPath string, properties map[string]string) (int, error) {
|
||||
if err := os.MkdirAll(fullPath, 0o755); err != nil {
|
||||
return -1, fmt.Errorf("failed to create cgroup root: %w", err)
|
||||
}
|
||||
|
||||
var errs []error
|
||||
for name, value := range properties {
|
||||
if err := os.WriteFile(filepath.Join(fullPath, name), []byte(value), 0o644); err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to write cgroup property: %w", err))
|
||||
}
|
||||
}
|
||||
if len(errs) > 0 {
|
||||
return -1, errors.Join(errs...)
|
||||
}
|
||||
|
||||
return unix.Open(fullPath, unix.O_RDONLY, 0)
|
||||
}
|
||||
|
||||
func (c Cgroup2Manager) GetFileDescriptor(procType ProcessType) (int, bool) {
|
||||
fd, ok := c.cgroupFDs[procType]
|
||||
|
||||
return fd, ok
|
||||
}
|
||||
|
||||
func (c Cgroup2Manager) Close() error {
|
||||
var errs []error
|
||||
for procType, fd := range c.cgroupFDs {
|
||||
if err := unix.Close(fd); err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to close cgroup fd for %s: %w", procType, err))
|
||||
}
|
||||
delete(c.cgroupFDs, procType)
|
||||
}
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
185
envd/internal/services/cgroups/cgroup2_test.go
Normal file
185
envd/internal/services/cgroups/cgroup2_test.go
Normal file
@ -0,0 +1,185 @@
|
||||
package cgroups
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
oneByte = 1
|
||||
kilobyte = 1024 * oneByte
|
||||
megabyte = 1024 * kilobyte
|
||||
)
|
||||
|
||||
func TestCgroupRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if os.Geteuid() != 0 {
|
||||
t.Skip("must run as root")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
maxTimeout := time.Second * 5
|
||||
|
||||
t.Run("process does not die without cgroups", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// create manager
|
||||
m, err := NewCgroup2Manager()
|
||||
require.NoError(t, err)
|
||||
|
||||
// create new child process
|
||||
cmd := startProcess(t, m, "not-a-real-one")
|
||||
|
||||
// wait for child process to die
|
||||
err = waitForProcess(t, cmd, maxTimeout)
|
||||
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
})
|
||||
|
||||
t.Run("process dies with cgroups", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cgroupPath := createCgroupPath(t, "real-one")
|
||||
|
||||
// create manager
|
||||
m, err := NewCgroup2Manager(
|
||||
WithCgroup2ProcessType(ProcessTypePTY, cgroupPath, map[string]string{
|
||||
"memory.max": strconv.Itoa(1 * megabyte),
|
||||
}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := m.Close()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
// create new child process
|
||||
cmd := startProcess(t, m, ProcessTypePTY)
|
||||
|
||||
// wait for child process to die
|
||||
err = waitForProcess(t, cmd, maxTimeout)
|
||||
|
||||
// verify process exited correctly
|
||||
var exitErr *exec.ExitError
|
||||
require.ErrorAs(t, err, &exitErr)
|
||||
assert.Equal(t, "signal: killed", exitErr.Error())
|
||||
assert.False(t, exitErr.Exited())
|
||||
assert.False(t, exitErr.Success())
|
||||
assert.Equal(t, -1, exitErr.ExitCode())
|
||||
|
||||
// dig a little deeper
|
||||
ws, ok := exitErr.Sys().(syscall.WaitStatus)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, syscall.SIGKILL, ws.Signal())
|
||||
assert.True(t, ws.Signaled())
|
||||
assert.False(t, ws.Stopped())
|
||||
assert.False(t, ws.Continued())
|
||||
assert.False(t, ws.CoreDump())
|
||||
assert.False(t, ws.Exited())
|
||||
assert.Equal(t, -1, ws.ExitStatus())
|
||||
})
|
||||
|
||||
t.Run("process cannot be spawned because memory limit is too low", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cgroupPath := createCgroupPath(t, "real-one")
|
||||
|
||||
// create manager
|
||||
m, err := NewCgroup2Manager(
|
||||
WithCgroup2ProcessType(ProcessTypeSocat, cgroupPath, map[string]string{
|
||||
"memory.max": strconv.Itoa(1 * kilobyte),
|
||||
}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := m.Close()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
// create new child process
|
||||
cmd := startProcess(t, m, ProcessTypeSocat)
|
||||
|
||||
// wait for child process to die
|
||||
err = waitForProcess(t, cmd, maxTimeout)
|
||||
|
||||
// verify process exited correctly
|
||||
var exitErr *exec.ExitError
|
||||
require.ErrorAs(t, err, &exitErr)
|
||||
assert.Equal(t, "exit status 253", exitErr.Error())
|
||||
assert.True(t, exitErr.Exited())
|
||||
assert.False(t, exitErr.Success())
|
||||
assert.Equal(t, 253, exitErr.ExitCode())
|
||||
|
||||
// dig a little deeper
|
||||
ws, ok := exitErr.Sys().(syscall.WaitStatus)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, syscall.Signal(-1), ws.Signal())
|
||||
assert.False(t, ws.Signaled())
|
||||
assert.False(t, ws.Stopped())
|
||||
assert.False(t, ws.Continued())
|
||||
assert.False(t, ws.CoreDump())
|
||||
assert.True(t, ws.Exited())
|
||||
assert.Equal(t, 253, ws.ExitStatus())
|
||||
})
|
||||
}
|
||||
|
||||
func createCgroupPath(t *testing.T, s string) string {
|
||||
t.Helper()
|
||||
|
||||
randPart := rand.Int()
|
||||
|
||||
return fmt.Sprintf("envd-test-%s-%d", s, randPart)
|
||||
}
|
||||
|
||||
func startProcess(t *testing.T, m *Cgroup2Manager, pt ProcessType) *exec.Cmd {
|
||||
t.Helper()
|
||||
|
||||
cmdName, args := "bash", []string{"-c", `sleep 1 && tail /dev/zero`}
|
||||
cmd := exec.CommandContext(t.Context(), cmdName, args...)
|
||||
|
||||
fd, ok := m.GetFileDescriptor(pt)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
UseCgroupFD: ok,
|
||||
CgroupFD: fd,
|
||||
}
|
||||
|
||||
err := cmd.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func waitForProcess(t *testing.T, cmd *exec.Cmd, timeout time.Duration) error {
|
||||
t.Helper()
|
||||
|
||||
done := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
defer close(done)
|
||||
done <- cmd.Wait()
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), timeout)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
return err
|
||||
}
|
||||
}
|
||||
14
envd/internal/services/cgroups/iface.go
Normal file
14
envd/internal/services/cgroups/iface.go
Normal file
@ -0,0 +1,14 @@
|
||||
package cgroups
|
||||
|
||||
type ProcessType string
|
||||
|
||||
const (
|
||||
ProcessTypePTY ProcessType = "pty"
|
||||
ProcessTypeUser ProcessType = "user"
|
||||
ProcessTypeSocat ProcessType = "socat"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
GetFileDescriptor(procType ProcessType) (int, bool)
|
||||
Close() error
|
||||
}
|
||||
17
envd/internal/services/cgroups/noop.go
Normal file
17
envd/internal/services/cgroups/noop.go
Normal file
@ -0,0 +1,17 @@
|
||||
package cgroups
|
||||
|
||||
type NoopManager struct{}
|
||||
|
||||
var _ Manager = (*NoopManager)(nil)
|
||||
|
||||
func NewNoopManager() *NoopManager {
|
||||
return &NoopManager{}
|
||||
}
|
||||
|
||||
func (n NoopManager) GetFileDescriptor(ProcessType) (int, bool) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (n NoopManager) Close() error {
|
||||
return nil
|
||||
}
|
||||
184
envd/internal/services/filesystem/dir.go
Normal file
184
envd/internal/services/filesystem/dir.go
Normal file
@ -0,0 +1,184 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
|
||||
)
|
||||
|
||||
func (s Service) ListDir(ctx context.Context, req *connect.Request[rpc.ListDirRequest]) (*connect.Response[rpc.ListDirResponse], error) {
|
||||
depth := req.Msg.GetDepth()
|
||||
if depth == 0 {
|
||||
depth = 1 // default depth to current directory
|
||||
}
|
||||
|
||||
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
requestedPath := req.Msg.GetPath()
|
||||
|
||||
// Expand the path so we can return absolute paths in the response.
|
||||
requestedPath, err = permissions.ExpandAndResolve(requestedPath, u, s.defaults.Workdir)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
resolvedPath, err := followSymlink(requestedPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = checkIfDirectory(resolvedPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
entries, err := walkDir(requestedPath, resolvedPath, int(depth))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return connect.NewResponse(&rpc.ListDirResponse{
|
||||
Entries: entries,
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s Service) MakeDir(ctx context.Context, req *connect.Request[rpc.MakeDirRequest]) (*connect.Response[rpc.MakeDirResponse], error) {
|
||||
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dirPath, err := permissions.ExpandAndResolve(req.Msg.GetPath(), u, s.defaults.Workdir)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
stat, err := os.Stat(dirPath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error getting file info: %w", err))
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
if stat.IsDir() {
|
||||
return nil, connect.NewError(connect.CodeAlreadyExists, fmt.Errorf("directory already exists: %s", dirPath))
|
||||
}
|
||||
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("path already exists but it is not a directory: %s", dirPath))
|
||||
}
|
||||
|
||||
uid, gid, userErr := permissions.GetUserIdInts(u)
|
||||
if userErr != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, userErr)
|
||||
}
|
||||
|
||||
userErr = permissions.EnsureDirs(dirPath, uid, gid)
|
||||
if userErr != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, userErr)
|
||||
}
|
||||
|
||||
entry, err := entryInfo(dirPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return connect.NewResponse(&rpc.MakeDirResponse{
|
||||
Entry: entry,
|
||||
}), nil
|
||||
}
|
||||
|
||||
// followSymlink resolves a symbolic link to its target path.
|
||||
func followSymlink(path string) (string, error) {
|
||||
// Resolve symlinks
|
||||
resolvedPath, err := filepath.EvalSymlinks(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return "", connect.NewError(connect.CodeNotFound, fmt.Errorf("path not found: %w", err))
|
||||
}
|
||||
|
||||
if strings.Contains(err.Error(), "too many links") {
|
||||
return "", connect.NewError(connect.CodeFailedPrecondition, fmt.Errorf("cyclic symlink or chain >255 links at %q", path))
|
||||
}
|
||||
|
||||
return "", connect.NewError(connect.CodeInternal, fmt.Errorf("error resolving symlink: %w", err))
|
||||
}
|
||||
|
||||
return resolvedPath, nil
|
||||
}
|
||||
|
||||
// checkIfDirectory checks if the given path is a directory.
|
||||
func checkIfDirectory(path string) error {
|
||||
stat, err := os.Stat(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return connect.NewError(connect.CodeNotFound, fmt.Errorf("directory not found: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("error getting file info: %w", err))
|
||||
}
|
||||
|
||||
if !stat.IsDir() {
|
||||
return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("path is not a directory: %s", path))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// walkDir walks the directory tree starting from dirPath up to the specified depth (doesn't follow symlinks).
|
||||
func walkDir(requestedPath string, dirPath string, depth int) (entries []*rpc.EntryInfo, err error) {
|
||||
err = filepath.WalkDir(dirPath, func(path string, _ os.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip the root directory itself
|
||||
if path == dirPath {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Calculate current depth
|
||||
relPath, err := filepath.Rel(dirPath, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
currentDepth := len(strings.Split(relPath, string(os.PathSeparator)))
|
||||
|
||||
if currentDepth > depth {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
|
||||
entryInfo, err := entryInfo(path)
|
||||
if err != nil {
|
||||
var connectErr *connect.Error
|
||||
if errors.As(err, &connectErr) && connectErr.Code() == connect.CodeNotFound {
|
||||
// Skip entries that don't exist anymore
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Return the requested path as the base path instead of the symlink-resolved path
|
||||
path = filepath.Join(requestedPath, relPath)
|
||||
entryInfo.Path = path
|
||||
|
||||
entries = append(entries, entryInfo)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error reading directory %s: %w", dirPath, err))
|
||||
}
|
||||
|
||||
return entries, nil
|
||||
}
|
||||
405
envd/internal/services/filesystem/dir_test.go
Normal file
405
envd/internal/services/filesystem/dir_test.go
Normal file
@ -0,0 +1,405 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"connectrpc.com/authn"
|
||||
"connectrpc.com/connect"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
|
||||
)
|
||||
|
||||
func TestListDir(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Setup temp root and user
|
||||
root := t.TempDir()
|
||||
u, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Setup directory structure
|
||||
testFolder := filepath.Join(root, "test")
|
||||
require.NoError(t, os.MkdirAll(filepath.Join(testFolder, "test-dir", "sub-dir-1"), 0o755))
|
||||
require.NoError(t, os.MkdirAll(filepath.Join(testFolder, "test-dir", "sub-dir-2"), 0o755))
|
||||
filePath := filepath.Join(testFolder, "test-dir", "sub-dir-1", "file.txt")
|
||||
require.NoError(t, os.WriteFile(filePath, []byte("Hello, World!"), 0o644))
|
||||
|
||||
// Service instance
|
||||
svc := mockService()
|
||||
|
||||
// Helper to inject user into context
|
||||
injectUser := func(ctx context.Context, u *user.User) context.Context {
|
||||
return authn.SetInfo(ctx, u)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
depth uint32
|
||||
expectedPaths []string
|
||||
}{
|
||||
{
|
||||
name: "depth 0 lists only root directory",
|
||||
depth: 0,
|
||||
expectedPaths: []string{
|
||||
filepath.Join(testFolder, "test-dir"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "depth 1 lists root directory",
|
||||
depth: 1,
|
||||
expectedPaths: []string{
|
||||
filepath.Join(testFolder, "test-dir"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "depth 2 lists first level of subdirectories (in this case the root directory)",
|
||||
depth: 2,
|
||||
expectedPaths: []string{
|
||||
filepath.Join(testFolder, "test-dir"),
|
||||
filepath.Join(testFolder, "test-dir", "sub-dir-1"),
|
||||
filepath.Join(testFolder, "test-dir", "sub-dir-2"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "depth 3 lists all directories and files",
|
||||
depth: 3,
|
||||
expectedPaths: []string{
|
||||
filepath.Join(testFolder, "test-dir"),
|
||||
filepath.Join(testFolder, "test-dir", "sub-dir-1"),
|
||||
filepath.Join(testFolder, "test-dir", "sub-dir-2"),
|
||||
filepath.Join(testFolder, "test-dir", "sub-dir-1", "file.txt"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := injectUser(t.Context(), u)
|
||||
req := connect.NewRequest(&filesystem.ListDirRequest{
|
||||
Path: testFolder,
|
||||
Depth: tt.depth,
|
||||
})
|
||||
resp, err := svc.ListDir(ctx, req)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.Msg)
|
||||
assert.Len(t, resp.Msg.GetEntries(), len(tt.expectedPaths))
|
||||
actualPaths := make([]string, len(resp.Msg.GetEntries()))
|
||||
for i, entry := range resp.Msg.GetEntries() {
|
||||
actualPaths[i] = entry.GetPath()
|
||||
}
|
||||
assert.ElementsMatch(t, tt.expectedPaths, actualPaths)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDirNonExistingPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := mockService()
|
||||
u, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
ctx := authn.SetInfo(t.Context(), u)
|
||||
|
||||
req := connect.NewRequest(&filesystem.ListDirRequest{
|
||||
Path: "/non-existing-path",
|
||||
Depth: 1,
|
||||
})
|
||||
_, err = svc.ListDir(ctx, req)
|
||||
require.Error(t, err)
|
||||
var connectErr *connect.Error
|
||||
ok := errors.As(err, &connectErr)
|
||||
assert.True(t, ok, "expected error to be of type *connect.Error")
|
||||
assert.Equal(t, connect.CodeNotFound, connectErr.Code())
|
||||
}
|
||||
|
||||
func TestListDirRelativePath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Setup temp root and user
|
||||
u, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Setup directory structure
|
||||
testRelativePath := fmt.Sprintf("test-%s", uuid.New())
|
||||
testFolderPath := filepath.Join(u.HomeDir, testRelativePath)
|
||||
filePath := filepath.Join(testFolderPath, "file.txt")
|
||||
require.NoError(t, os.MkdirAll(testFolderPath, 0o755))
|
||||
require.NoError(t, os.WriteFile(filePath, []byte("Hello, World!"), 0o644))
|
||||
|
||||
// Service instance
|
||||
svc := mockService()
|
||||
ctx := authn.SetInfo(t.Context(), u)
|
||||
|
||||
req := connect.NewRequest(&filesystem.ListDirRequest{
|
||||
Path: testRelativePath,
|
||||
Depth: 1,
|
||||
})
|
||||
resp, err := svc.ListDir(ctx, req)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.Msg)
|
||||
|
||||
expectedPaths := []string{
|
||||
filepath.Join(testFolderPath, "file.txt"),
|
||||
}
|
||||
assert.Len(t, resp.Msg.GetEntries(), len(expectedPaths))
|
||||
|
||||
actualPaths := make([]string, len(resp.Msg.GetEntries()))
|
||||
for i, entry := range resp.Msg.GetEntries() {
|
||||
actualPaths[i] = entry.GetPath()
|
||||
}
|
||||
assert.ElementsMatch(t, expectedPaths, actualPaths)
|
||||
}
|
||||
|
||||
func TestListDir_Symlinks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
root := t.TempDir()
|
||||
u, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
ctx := authn.SetInfo(t.Context(), u)
|
||||
|
||||
symlinkRoot := filepath.Join(root, "test-symlinks")
|
||||
require.NoError(t, os.MkdirAll(symlinkRoot, 0o755))
|
||||
|
||||
// 1. Prepare a real directory + file that a symlink will point to
|
||||
realDir := filepath.Join(symlinkRoot, "real-dir")
|
||||
require.NoError(t, os.MkdirAll(realDir, 0o755))
|
||||
filePath := filepath.Join(realDir, "file.txt")
|
||||
require.NoError(t, os.WriteFile(filePath, []byte("hello via symlink"), 0o644))
|
||||
|
||||
// 2. Prepare a standalone real file (points-to-file scenario)
|
||||
realFile := filepath.Join(symlinkRoot, "real-file.txt")
|
||||
require.NoError(t, os.WriteFile(realFile, []byte("i am a plain file"), 0o644))
|
||||
|
||||
// 3. Create the three symlinks
|
||||
linkToDir := filepath.Join(symlinkRoot, "link-dir") // → directory
|
||||
linkToFile := filepath.Join(symlinkRoot, "link-file") // → file
|
||||
cyclicLink := filepath.Join(symlinkRoot, "cyclic") // → itself
|
||||
require.NoError(t, os.Symlink(realDir, linkToDir))
|
||||
require.NoError(t, os.Symlink(realFile, linkToFile))
|
||||
require.NoError(t, os.Symlink(cyclicLink, cyclicLink))
|
||||
|
||||
svc := mockService()
|
||||
|
||||
t.Run("symlink to directory behaves like directory and the content looks like inside the directory", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := connect.NewRequest(&filesystem.ListDirRequest{
|
||||
Path: linkToDir,
|
||||
Depth: 1,
|
||||
})
|
||||
resp, err := svc.ListDir(ctx, req)
|
||||
require.NoError(t, err)
|
||||
expected := []string{
|
||||
filepath.Join(linkToDir, "file.txt"),
|
||||
}
|
||||
actual := make([]string, len(resp.Msg.GetEntries()))
|
||||
for i, e := range resp.Msg.GetEntries() {
|
||||
actual[i] = e.GetPath()
|
||||
}
|
||||
assert.ElementsMatch(t, expected, actual)
|
||||
})
|
||||
|
||||
t.Run("link to file", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := connect.NewRequest(&filesystem.ListDirRequest{
|
||||
Path: linkToFile,
|
||||
Depth: 1,
|
||||
})
|
||||
_, err := svc.ListDir(ctx, req)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not a directory")
|
||||
})
|
||||
|
||||
t.Run("cyclic symlink surfaces 'too many links' → invalid-argument", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := connect.NewRequest(&filesystem.ListDirRequest{
|
||||
Path: cyclicLink,
|
||||
})
|
||||
_, err := svc.ListDir(ctx, req)
|
||||
require.Error(t, err)
|
||||
var connectErr *connect.Error
|
||||
ok := errors.As(err, &connectErr)
|
||||
assert.True(t, ok, "expected error to be of type *connect.Error")
|
||||
assert.Equal(t, connect.CodeFailedPrecondition, connectErr.Code())
|
||||
assert.Contains(t, connectErr.Error(), "cyclic symlink")
|
||||
})
|
||||
|
||||
t.Run("symlink not resolved if not root", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := connect.NewRequest(&filesystem.ListDirRequest{
|
||||
Path: symlinkRoot,
|
||||
Depth: 3,
|
||||
})
|
||||
res, err := svc.ListDir(ctx, req)
|
||||
require.NoError(t, err)
|
||||
expected := []string{
|
||||
filepath.Join(symlinkRoot, "cyclic"),
|
||||
filepath.Join(symlinkRoot, "link-dir"),
|
||||
filepath.Join(symlinkRoot, "link-file"),
|
||||
filepath.Join(symlinkRoot, "real-dir"),
|
||||
filepath.Join(symlinkRoot, "real-dir", "file.txt"),
|
||||
filepath.Join(symlinkRoot, "real-file.txt"),
|
||||
}
|
||||
actual := make([]string, len(res.Msg.GetEntries()))
|
||||
for i, e := range res.Msg.GetEntries() {
|
||||
actual[i] = e.GetPath()
|
||||
}
|
||||
assert.ElementsMatch(t, expected, actual, "symlinks should not be resolved when listing the symlink root directory")
|
||||
})
|
||||
}
|
||||
|
||||
// TestFollowSymlink_Success makes sure that followSymlink resolves symlinks,
|
||||
// while also being robust to the /var → /private/var indirection that exists on macOS.
|
||||
func TestFollowSymlink_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Base temporary directory. On macOS this lives under /var/folders/…
|
||||
// which itself is a symlink to /private/var/folders/….
|
||||
base := t.TempDir()
|
||||
|
||||
// Create a real directory that we ultimately want to resolve to.
|
||||
target := filepath.Join(base, "target")
|
||||
require.NoError(t, os.MkdirAll(target, 0o755))
|
||||
|
||||
// Create a symlink pointing at the real directory so we can verify that
|
||||
// followSymlink follows it.
|
||||
link := filepath.Join(base, "link")
|
||||
require.NoError(t, os.Symlink(target, link))
|
||||
|
||||
got, err := followSymlink(link)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Canonicalise the expected path too, so that /var → /private/var (macOS)
|
||||
// or any other benign symlink indirections don’t cause flaky tests.
|
||||
want, err := filepath.EvalSymlinks(link)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, want, got, "followSymlink should resolve and canonicalise symlinks")
|
||||
}
|
||||
|
||||
// TestFollowSymlink_MultiSymlinkChain verifies that followSymlink follows a chain
|
||||
// of several symlinks (non‑cyclic) correctly.
|
||||
func TestFollowSymlink_MultiSymlinkChain(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
base := t.TempDir()
|
||||
|
||||
// Final destination directory.
|
||||
target := filepath.Join(base, "target")
|
||||
require.NoError(t, os.MkdirAll(target, 0o755))
|
||||
|
||||
// Build a 3‑link chain: link1 → link2 → link3 → target.
|
||||
link3 := filepath.Join(base, "link3")
|
||||
require.NoError(t, os.Symlink(target, link3))
|
||||
|
||||
link2 := filepath.Join(base, "link2")
|
||||
require.NoError(t, os.Symlink(link3, link2))
|
||||
|
||||
link1 := filepath.Join(base, "link1")
|
||||
require.NoError(t, os.Symlink(link2, link1))
|
||||
|
||||
got, err := followSymlink(link1)
|
||||
require.NoError(t, err)
|
||||
|
||||
want, err := filepath.EvalSymlinks(link1)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, want, got, "followSymlink should resolve an arbitrary symlink chain")
|
||||
}
|
||||
|
||||
func TestFollowSymlink_NotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := followSymlink("/definitely/does/not/exist")
|
||||
require.Error(t, err)
|
||||
|
||||
var cerr *connect.Error
|
||||
require.ErrorAs(t, err, &cerr)
|
||||
require.Equal(t, connect.CodeNotFound, cerr.Code())
|
||||
}
|
||||
|
||||
func TestFollowSymlink_CyclicSymlink(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
a := filepath.Join(dir, "a")
|
||||
b := filepath.Join(dir, "b")
|
||||
require.NoError(t, os.MkdirAll(a, 0o755))
|
||||
require.NoError(t, os.MkdirAll(b, 0o755))
|
||||
|
||||
// Create a two‑node loop: a/loop → b/loop, b/loop → a/loop.
|
||||
require.NoError(t, os.Symlink(filepath.Join(b, "loop"), filepath.Join(a, "loop")))
|
||||
require.NoError(t, os.Symlink(filepath.Join(a, "loop"), filepath.Join(b, "loop")))
|
||||
|
||||
_, err := followSymlink(filepath.Join(a, "loop"))
|
||||
require.Error(t, err)
|
||||
|
||||
var cerr *connect.Error
|
||||
require.ErrorAs(t, err, &cerr)
|
||||
require.Equal(t, connect.CodeFailedPrecondition, cerr.Code())
|
||||
require.Contains(t, cerr.Message(), "cyclic")
|
||||
}
|
||||
|
||||
func TestCheckIfDirectory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
require.NoError(t, checkIfDirectory(dir))
|
||||
|
||||
file := filepath.Join(dir, "file.txt")
|
||||
require.NoError(t, os.WriteFile(file, []byte("hello"), 0o644))
|
||||
|
||||
err := checkIfDirectory(file)
|
||||
require.Error(t, err)
|
||||
|
||||
var cerr *connect.Error
|
||||
require.ErrorAs(t, err, &cerr)
|
||||
require.Equal(t, connect.CodeInvalidArgument, cerr.Code())
|
||||
}
|
||||
|
||||
func TestWalkDir_Depth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
root := t.TempDir()
|
||||
sub := filepath.Join(root, "sub")
|
||||
subsub := filepath.Join(sub, "subsub")
|
||||
require.NoError(t, os.MkdirAll(subsub, 0o755))
|
||||
|
||||
entries, err := walkDir(root, root, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Collect the names for easier assertions.
|
||||
names := make([]string, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
names = append(names, e.GetName())
|
||||
}
|
||||
|
||||
require.Contains(t, names, "sub")
|
||||
require.NotContains(t, names, "subsub", "entries beyond depth should be excluded")
|
||||
}
|
||||
|
||||
func TestWalkDir_Error(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := walkDir("/does/not/exist", "/does/not/exist", 1)
|
||||
require.Error(t, err)
|
||||
|
||||
var cerr *connect.Error
|
||||
require.ErrorAs(t, err, &cerr)
|
||||
require.Equal(t, connect.CodeInternal, cerr.Code())
|
||||
}
|
||||
58
envd/internal/services/filesystem/move.go
Normal file
58
envd/internal/services/filesystem/move.go
Normal file
@ -0,0 +1,58 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
|
||||
)
|
||||
|
||||
func (s Service) Move(ctx context.Context, req *connect.Request[rpc.MoveRequest]) (*connect.Response[rpc.MoveResponse], error) {
|
||||
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
source, err := permissions.ExpandAndResolve(req.Msg.GetSource(), u, s.defaults.Workdir)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
destination, err := permissions.ExpandAndResolve(req.Msg.GetDestination(), u, s.defaults.Workdir)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
uid, gid, userErr := permissions.GetUserIdInts(u)
|
||||
if userErr != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, userErr)
|
||||
}
|
||||
|
||||
userErr = permissions.EnsureDirs(filepath.Dir(destination), uid, gid)
|
||||
if userErr != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, userErr)
|
||||
}
|
||||
|
||||
err = os.Rename(source, destination)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("source file not found: %w", err))
|
||||
}
|
||||
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error renaming: %w", err))
|
||||
}
|
||||
|
||||
entry, err := entryInfo(destination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return connect.NewResponse(&rpc.MoveResponse{
|
||||
Entry: entry,
|
||||
}), nil
|
||||
}
|
||||
364
envd/internal/services/filesystem/move_test.go
Normal file
364
envd/internal/services/filesystem/move_test.go
Normal file
@ -0,0 +1,364 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"connectrpc.com/authn"
|
||||
"connectrpc.com/connect"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
|
||||
)
|
||||
|
||||
func TestMove(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Setup temp root and user
|
||||
root := t.TempDir()
|
||||
u, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Setup source and destination directories
|
||||
sourceDir := filepath.Join(root, "source")
|
||||
destDir := filepath.Join(root, "destination")
|
||||
require.NoError(t, os.MkdirAll(sourceDir, 0o755))
|
||||
require.NoError(t, os.MkdirAll(destDir, 0o755))
|
||||
|
||||
// Create a test file to move
|
||||
sourceFile := filepath.Join(sourceDir, "test-file.txt")
|
||||
testContent := []byte("Hello, World!")
|
||||
require.NoError(t, os.WriteFile(sourceFile, testContent, 0o644))
|
||||
|
||||
// Destination file path
|
||||
destFile := filepath.Join(destDir, "test-file.txt")
|
||||
|
||||
// Service instance
|
||||
svc := mockService()
|
||||
|
||||
// Call the Move function
|
||||
ctx := authn.SetInfo(t.Context(), u)
|
||||
req := connect.NewRequest(&filesystem.MoveRequest{
|
||||
Source: sourceFile,
|
||||
Destination: destFile,
|
||||
})
|
||||
resp, err := svc.Move(ctx, req)
|
||||
|
||||
// Verify the move was successful
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, destFile, resp.Msg.GetEntry().GetPath())
|
||||
|
||||
// Verify the file exists at the destination
|
||||
_, err = os.Stat(destFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the file no longer exists at the source
|
||||
_, err = os.Stat(sourceFile)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
|
||||
// Verify the content of the moved file
|
||||
content, err := os.ReadFile(destFile)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testContent, content)
|
||||
}
|
||||
|
||||
func TestMoveDirectory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Setup temp root and user
|
||||
root := t.TempDir()
|
||||
u, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Setup source and destination directories
|
||||
sourceParent := filepath.Join(root, "source-parent")
|
||||
destParent := filepath.Join(root, "dest-parent")
|
||||
require.NoError(t, os.MkdirAll(sourceParent, 0o755))
|
||||
require.NoError(t, os.MkdirAll(destParent, 0o755))
|
||||
|
||||
// Create a test directory with files to move
|
||||
sourceDir := filepath.Join(sourceParent, "test-dir")
|
||||
require.NoError(t, os.MkdirAll(filepath.Join(sourceDir, "subdir"), 0o755))
|
||||
|
||||
// Create some files in the directory
|
||||
file1 := filepath.Join(sourceDir, "file1.txt")
|
||||
file2 := filepath.Join(sourceDir, "subdir", "file2.txt")
|
||||
require.NoError(t, os.WriteFile(file1, []byte("File 1 content"), 0o644))
|
||||
require.NoError(t, os.WriteFile(file2, []byte("File 2 content"), 0o644))
|
||||
|
||||
// Destination directory path
|
||||
destDir := filepath.Join(destParent, "test-dir")
|
||||
|
||||
// Service instance
|
||||
svc := mockService()
|
||||
|
||||
// Call the Move function
|
||||
ctx := authn.SetInfo(t.Context(), u)
|
||||
req := connect.NewRequest(&filesystem.MoveRequest{
|
||||
Source: sourceDir,
|
||||
Destination: destDir,
|
||||
})
|
||||
resp, err := svc.Move(ctx, req)
|
||||
|
||||
// Verify the move was successful
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, destDir, resp.Msg.GetEntry().GetPath())
|
||||
|
||||
// Verify the directory exists at the destination
|
||||
_, err = os.Stat(destDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the files exist at the destination
|
||||
destFile1 := filepath.Join(destDir, "file1.txt")
|
||||
destFile2 := filepath.Join(destDir, "subdir", "file2.txt")
|
||||
_, err = os.Stat(destFile1)
|
||||
require.NoError(t, err)
|
||||
_, err = os.Stat(destFile2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the directory no longer exists at the source
|
||||
_, err = os.Stat(sourceDir)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
|
||||
// Verify the content of the moved files
|
||||
content1, err := os.ReadFile(destFile1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("File 1 content"), content1)
|
||||
|
||||
content2, err := os.ReadFile(destFile2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("File 2 content"), content2)
|
||||
}
|
||||
|
||||
func TestMoveNonExistingFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Setup temp root and user
|
||||
root := t.TempDir()
|
||||
u, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Setup destination directory
|
||||
destDir := filepath.Join(root, "destination")
|
||||
require.NoError(t, os.MkdirAll(destDir, 0o755))
|
||||
|
||||
// Non-existing source file
|
||||
sourceFile := filepath.Join(root, "non-existing-file.txt")
|
||||
|
||||
// Destination file path
|
||||
destFile := filepath.Join(destDir, "moved-file.txt")
|
||||
|
||||
// Service instance
|
||||
svc := mockService()
|
||||
|
||||
// Call the Move function
|
||||
ctx := authn.SetInfo(t.Context(), u)
|
||||
req := connect.NewRequest(&filesystem.MoveRequest{
|
||||
Source: sourceFile,
|
||||
Destination: destFile,
|
||||
})
|
||||
_, err = svc.Move(ctx, req)
|
||||
|
||||
// Verify the correct error is returned
|
||||
require.Error(t, err)
|
||||
|
||||
var connectErr *connect.Error
|
||||
ok := errors.As(err, &connectErr)
|
||||
assert.True(t, ok, "expected error to be of type *connect.Error")
|
||||
assert.Equal(t, connect.CodeNotFound, connectErr.Code())
|
||||
assert.Contains(t, connectErr.Message(), "source file not found")
|
||||
}
|
||||
|
||||
func TestMoveRelativePath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Setup user
|
||||
u, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Setup directory structure with unique name to avoid conflicts
|
||||
testRelativePath := fmt.Sprintf("test-move-%s", uuid.New())
|
||||
testFolderPath := filepath.Join(u.HomeDir, testRelativePath)
|
||||
require.NoError(t, os.MkdirAll(testFolderPath, 0o755))
|
||||
|
||||
// Create a test file to move
|
||||
sourceFile := filepath.Join(testFolderPath, "source-file.txt")
|
||||
testContent := []byte("Hello from relative path!")
|
||||
require.NoError(t, os.WriteFile(sourceFile, testContent, 0o644))
|
||||
|
||||
// Destination file path (also relative)
|
||||
destRelativePath := fmt.Sprintf("test-move-dest-%s", uuid.New())
|
||||
destFolderPath := filepath.Join(u.HomeDir, destRelativePath)
|
||||
require.NoError(t, os.MkdirAll(destFolderPath, 0o755))
|
||||
destFile := filepath.Join(destFolderPath, "moved-file.txt")
|
||||
|
||||
// Service instance
|
||||
svc := mockService()
|
||||
|
||||
// Call the Move function with relative paths
|
||||
ctx := authn.SetInfo(t.Context(), u)
|
||||
req := connect.NewRequest(&filesystem.MoveRequest{
|
||||
Source: filepath.Join(testRelativePath, "source-file.txt"), // Relative path
|
||||
Destination: filepath.Join(destRelativePath, "moved-file.txt"), // Relative path
|
||||
})
|
||||
resp, err := svc.Move(ctx, req)
|
||||
|
||||
// Verify the move was successful
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, destFile, resp.Msg.GetEntry().GetPath())
|
||||
|
||||
// Verify the file exists at the destination
|
||||
_, err = os.Stat(destFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the file no longer exists at the source
|
||||
_, err = os.Stat(sourceFile)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
|
||||
// Verify the content of the moved file
|
||||
content, err := os.ReadFile(destFile)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testContent, content)
|
||||
|
||||
// Clean up
|
||||
os.RemoveAll(testFolderPath)
|
||||
os.RemoveAll(destFolderPath)
|
||||
}
|
||||
|
||||
func TestMove_Symlinks(t *testing.T) { //nolint:tparallel // this test cannot be executed in parallel
|
||||
root := t.TempDir()
|
||||
u, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
ctx := authn.SetInfo(t.Context(), u)
|
||||
|
||||
// Setup source and destination directories
|
||||
sourceRoot := filepath.Join(root, "source")
|
||||
destRoot := filepath.Join(root, "destination")
|
||||
require.NoError(t, os.MkdirAll(sourceRoot, 0o755))
|
||||
require.NoError(t, os.MkdirAll(destRoot, 0o755))
|
||||
|
||||
// 1. Prepare a real directory + file that a symlink will point to
|
||||
realDir := filepath.Join(sourceRoot, "real-dir")
|
||||
require.NoError(t, os.MkdirAll(realDir, 0o755))
|
||||
filePath := filepath.Join(realDir, "file.txt")
|
||||
require.NoError(t, os.WriteFile(filePath, []byte("hello via symlink"), 0o644))
|
||||
|
||||
// 2. Prepare a standalone real file (points-to-file scenario)
|
||||
realFile := filepath.Join(sourceRoot, "real-file.txt")
|
||||
require.NoError(t, os.WriteFile(realFile, []byte("i am a plain file"), 0o644))
|
||||
|
||||
// 3. Create symlinks
|
||||
linkToDir := filepath.Join(sourceRoot, "link-dir") // → directory
|
||||
linkToFile := filepath.Join(sourceRoot, "link-file") // → file
|
||||
require.NoError(t, os.Symlink(realDir, linkToDir))
|
||||
require.NoError(t, os.Symlink(realFile, linkToFile))
|
||||
|
||||
svc := mockService()
|
||||
|
||||
t.Run("move symlink to directory", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
destPath := filepath.Join(destRoot, "moved-link-dir")
|
||||
|
||||
req := connect.NewRequest(&filesystem.MoveRequest{
|
||||
Source: linkToDir,
|
||||
Destination: destPath,
|
||||
})
|
||||
resp, err := svc.Move(ctx, req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, destPath, resp.Msg.GetEntry().GetPath())
|
||||
|
||||
// Verify the symlink was moved
|
||||
_, err = os.Stat(destPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it's still a symlink
|
||||
info, err := os.Lstat(destPath)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, 0, info.Mode()&os.ModeSymlink, "expected a symlink")
|
||||
|
||||
// Verify the symlink target is still correct
|
||||
target, err := os.Readlink(destPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, realDir, target)
|
||||
|
||||
// Verify the original symlink is gone
|
||||
_, err = os.Stat(linkToDir)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
|
||||
// Verify the real directory still exists
|
||||
_, err = os.Stat(realDir)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("move symlink to file", func(t *testing.T) { //nolint:paralleltest
|
||||
destPath := filepath.Join(destRoot, "moved-link-file")
|
||||
|
||||
req := connect.NewRequest(&filesystem.MoveRequest{
|
||||
Source: linkToFile,
|
||||
Destination: destPath,
|
||||
})
|
||||
resp, err := svc.Move(ctx, req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, destPath, resp.Msg.GetEntry().GetPath())
|
||||
|
||||
// Verify the symlink was moved
|
||||
_, err = os.Stat(destPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it's still a symlink
|
||||
info, err := os.Lstat(destPath)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, 0, info.Mode()&os.ModeSymlink, "expected a symlink")
|
||||
|
||||
// Verify the symlink target is still correct
|
||||
target, err := os.Readlink(destPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, realFile, target)
|
||||
|
||||
// Verify the original symlink is gone
|
||||
_, err = os.Stat(linkToFile)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
|
||||
// Verify the real file still exists
|
||||
_, err = os.Stat(realFile)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("move real file that is target of symlink", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Create a new symlink to the real file
|
||||
newLinkToFile := filepath.Join(sourceRoot, "new-link-file")
|
||||
require.NoError(t, os.Symlink(realFile, newLinkToFile))
|
||||
|
||||
destPath := filepath.Join(destRoot, "moved-real-file.txt")
|
||||
|
||||
req := connect.NewRequest(&filesystem.MoveRequest{
|
||||
Source: realFile,
|
||||
Destination: destPath,
|
||||
})
|
||||
resp, err := svc.Move(ctx, req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, destPath, resp.Msg.GetEntry().GetPath())
|
||||
|
||||
// Verify the real file was moved
|
||||
_, err = os.Stat(destPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the original file is gone
|
||||
_, err = os.Stat(realFile)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
|
||||
// Verify the symlink still exists but now points to a non-existent file
|
||||
_, err = os.Stat(newLinkToFile)
|
||||
require.Error(t, err, "symlink should point to non-existent file")
|
||||
})
|
||||
}
|
||||
31
envd/internal/services/filesystem/remove.go
Normal file
31
envd/internal/services/filesystem/remove.go
Normal file
@ -0,0 +1,31 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
|
||||
)
|
||||
|
||||
func (s Service) Remove(ctx context.Context, req *connect.Request[rpc.RemoveRequest]) (*connect.Response[rpc.RemoveResponse], error) {
|
||||
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
path, err := permissions.ExpandAndResolve(req.Msg.GetPath(), u, s.defaults.Workdir)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
err = os.RemoveAll(path)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error removing file or directory: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&rpc.RemoveResponse{}), nil
|
||||
}
|
||||
34
envd/internal/services/filesystem/service.go
Normal file
34
envd/internal/services/filesystem/service.go
Normal file
@ -0,0 +1,34 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
spec "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem/filesystemconnect"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
logger *zerolog.Logger
|
||||
watchers *utils.Map[string, *FileWatcher]
|
||||
defaults *execcontext.Defaults
|
||||
}
|
||||
|
||||
func Handle(server *chi.Mux, l *zerolog.Logger, defaults *execcontext.Defaults) {
|
||||
service := Service{
|
||||
logger: l,
|
||||
watchers: utils.NewMap[string, *FileWatcher](),
|
||||
defaults: defaults,
|
||||
}
|
||||
|
||||
interceptors := connect.WithInterceptors(
|
||||
logs.NewUnaryLogInterceptor(l),
|
||||
)
|
||||
|
||||
path, handler := spec.NewFilesystemHandler(service, interceptors)
|
||||
|
||||
server.Mount(path, handler)
|
||||
}
|
||||
14
envd/internal/services/filesystem/service_test.go
Normal file
14
envd/internal/services/filesystem/service_test.go
Normal file
@ -0,0 +1,14 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
func mockService() Service {
|
||||
return Service{
|
||||
defaults: &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
},
|
||||
}
|
||||
}
|
||||
29
envd/internal/services/filesystem/stat.go
Normal file
29
envd/internal/services/filesystem/stat.go
Normal file
@ -0,0 +1,29 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
|
||||
)
|
||||
|
||||
func (s Service) Stat(ctx context.Context, req *connect.Request[rpc.StatRequest]) (*connect.Response[rpc.StatResponse], error) {
|
||||
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
path, err := permissions.ExpandAndResolve(req.Msg.GetPath(), u, s.defaults.Workdir)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
entry, err := entryInfo(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return connect.NewResponse(&rpc.StatResponse{Entry: entry}), nil
|
||||
}
|
||||
114
envd/internal/services/filesystem/stat_test.go
Normal file
114
envd/internal/services/filesystem/stat_test.go
Normal file
@ -0,0 +1,114 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"connectrpc.com/authn"
|
||||
"connectrpc.com/connect"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
|
||||
)
|
||||
|
||||
func TestStat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Setup temp root and user
|
||||
root := t.TempDir()
|
||||
// Get the actual path to the temp directory (symlinks can cause issues)
|
||||
root, err := filepath.EvalSymlinks(root)
|
||||
require.NoError(t, err)
|
||||
|
||||
u, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
group, err := user.LookupGroupId(u.Gid)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Setup directory structure
|
||||
testFolder := filepath.Join(root, "test")
|
||||
err = os.MkdirAll(testFolder, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
testFile := filepath.Join(testFolder, "file.txt")
|
||||
err = os.WriteFile(testFile, []byte("Hello, World!"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
linkedFile := filepath.Join(testFolder, "linked-file.txt")
|
||||
err = os.Symlink(testFile, linkedFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Service instance
|
||||
svc := mockService()
|
||||
|
||||
// Helper to inject user into context
|
||||
injectUser := func(ctx context.Context, u *user.User) context.Context {
|
||||
return authn.SetInfo(ctx, u)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
}{
|
||||
{
|
||||
name: "Stat file directory",
|
||||
path: testFile,
|
||||
},
|
||||
{
|
||||
name: "Stat symlink to file",
|
||||
path: linkedFile,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := injectUser(t.Context(), u)
|
||||
req := connect.NewRequest(&filesystem.StatRequest{
|
||||
Path: tt.path,
|
||||
})
|
||||
resp, err := svc.Stat(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, resp.Msg)
|
||||
require.NotNil(t, resp.Msg.GetEntry())
|
||||
assert.Equal(t, tt.path, resp.Msg.GetEntry().GetPath())
|
||||
assert.Equal(t, filesystem.FileType_FILE_TYPE_FILE, resp.Msg.GetEntry().GetType())
|
||||
assert.Equal(t, u.Username, resp.Msg.GetEntry().GetOwner())
|
||||
assert.Equal(t, group.Name, resp.Msg.GetEntry().GetGroup())
|
||||
assert.Equal(t, uint32(0o644), resp.Msg.GetEntry().GetMode())
|
||||
if tt.path == linkedFile {
|
||||
require.NotNil(t, resp.Msg.GetEntry().GetSymlinkTarget())
|
||||
assert.Equal(t, testFile, resp.Msg.GetEntry().GetSymlinkTarget())
|
||||
} else {
|
||||
assert.Empty(t, resp.Msg.GetEntry().GetSymlinkTarget())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatMissingPathReturnsNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
u, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
svc := mockService()
|
||||
ctx := authn.SetInfo(t.Context(), u)
|
||||
|
||||
req := connect.NewRequest(&filesystem.StatRequest{
|
||||
Path: filepath.Join(t.TempDir(), "missing.txt"),
|
||||
})
|
||||
|
||||
_, err = svc.Stat(ctx, req)
|
||||
require.Error(t, err)
|
||||
|
||||
var connectErr *connect.Error
|
||||
require.ErrorAs(t, err, &connectErr)
|
||||
assert.Equal(t, connect.CodeNotFound, connectErr.Code())
|
||||
}
|
||||
107
envd/internal/services/filesystem/utils.go
Normal file
107
envd/internal/services/filesystem/utils.go
Normal file
@ -0,0 +1,107 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/filesystem"
|
||||
)
|
||||
|
||||
// Filesystem magic numbers from Linux kernel (include/uapi/linux/magic.h)
|
||||
const (
|
||||
nfsSuperMagic = 0x6969
|
||||
cifsMagic = 0xFF534D42
|
||||
smbSuperMagic = 0x517B
|
||||
smb2MagicNumber = 0xFE534D42
|
||||
fuseSuperMagic = 0x65735546
|
||||
)
|
||||
|
||||
// IsPathOnNetworkMount checks if the given path is on a network filesystem mount.
|
||||
// Returns true if the path is on NFS, CIFS, SMB, or FUSE filesystem.
|
||||
func IsPathOnNetworkMount(path string) (bool, error) {
|
||||
var statfs syscall.Statfs_t
|
||||
if err := syscall.Statfs(path, &statfs); err != nil {
|
||||
return false, fmt.Errorf("failed to statfs %s: %w", path, err)
|
||||
}
|
||||
|
||||
switch statfs.Type {
|
||||
case nfsSuperMagic, cifsMagic, smbSuperMagic, smb2MagicNumber, fuseSuperMagic:
|
||||
return true, nil
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func entryInfo(path string) (*rpc.EntryInfo, error) {
|
||||
info, err := filesystem.GetEntryFromPath(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("file not found: %w", err))
|
||||
}
|
||||
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error getting file info: %w", err))
|
||||
}
|
||||
|
||||
owner, group := getFileOwnership(info)
|
||||
|
||||
return &rpc.EntryInfo{
|
||||
Name: info.Name,
|
||||
Type: getEntryType(info.Type),
|
||||
Path: info.Path,
|
||||
Size: info.Size,
|
||||
Mode: uint32(info.Mode),
|
||||
Permissions: info.Permissions,
|
||||
Owner: owner,
|
||||
Group: group,
|
||||
ModifiedTime: toTimestamp(info.ModifiedTime),
|
||||
SymlinkTarget: info.SymlinkTarget,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func toTimestamp(time time.Time) *timestamppb.Timestamp {
|
||||
if time.IsZero() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return timestamppb.New(time)
|
||||
}
|
||||
|
||||
// getFileOwnership returns the owner and group names for a file.
|
||||
// If the lookup fails, it returns the numeric UID and GID as strings.
|
||||
func getFileOwnership(fileInfo filesystem.EntryInfo) (owner, group string) {
|
||||
// Look up username
|
||||
owner = fmt.Sprintf("%d", fileInfo.UID)
|
||||
if u, err := user.LookupId(owner); err == nil {
|
||||
owner = u.Username
|
||||
}
|
||||
|
||||
// Look up group name
|
||||
group = fmt.Sprintf("%d", fileInfo.GID)
|
||||
if g, err := user.LookupGroupId(group); err == nil {
|
||||
group = g.Name
|
||||
}
|
||||
|
||||
return owner, group
|
||||
}
|
||||
|
||||
// getEntryType determines the type of file entry based on its mode and path.
|
||||
// If the file is a symlink, it follows the symlink to determine the actual type.
|
||||
func getEntryType(fileType filesystem.FileType) rpc.FileType {
|
||||
switch fileType {
|
||||
case filesystem.FileFileType:
|
||||
return rpc.FileType_FILE_TYPE_FILE
|
||||
case filesystem.DirectoryFileType:
|
||||
return rpc.FileType_FILE_TYPE_DIRECTORY
|
||||
case filesystem.SymlinkFileType:
|
||||
return rpc.FileType_FILE_TYPE_SYMLINK
|
||||
default:
|
||||
return rpc.FileType_FILE_TYPE_UNSPECIFIED
|
||||
}
|
||||
}
|
||||
149
envd/internal/services/filesystem/utils_test.go
Normal file
149
envd/internal/services/filesystem/utils_test.go
Normal file
@ -0,0 +1,149 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/exec"
|
||||
osuser "os/user"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fsmodel "git.omukk.dev/wrenn/sandbox/envd/internal/shared/filesystem"
|
||||
)
|
||||
|
||||
func TestIsPathOnNetworkMount(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Test with a regular directory (should not be on network mount)
|
||||
tempDir := t.TempDir()
|
||||
isNetwork, err := IsPathOnNetworkMount(tempDir)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, isNetwork, "temp directory should not be on a network mount")
|
||||
}
|
||||
|
||||
func TestIsPathOnNetworkMount_FuseMount(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Require bindfs to be available
|
||||
_, err := exec.LookPath("bindfs")
|
||||
require.NoError(t, err, "bindfs must be installed for this test")
|
||||
|
||||
// Require fusermount to be available (needed for unmounting)
|
||||
_, err = exec.LookPath("fusermount")
|
||||
require.NoError(t, err, "fusermount must be installed for this test")
|
||||
|
||||
// Create source and mount directories
|
||||
sourceDir := t.TempDir()
|
||||
mountDir := t.TempDir()
|
||||
|
||||
// Mount sourceDir onto mountDir using bindfs (FUSE)
|
||||
ctx := context.Background()
|
||||
cmd := exec.CommandContext(ctx, "bindfs", sourceDir, mountDir)
|
||||
require.NoError(t, cmd.Run(), "failed to mount bindfs")
|
||||
|
||||
// Ensure we unmount on cleanup
|
||||
t.Cleanup(func() {
|
||||
_ = exec.CommandContext(context.Background(), "fusermount", "-u", mountDir).Run()
|
||||
})
|
||||
|
||||
// Test that the FUSE mount is detected
|
||||
isNetwork, err := IsPathOnNetworkMount(mountDir)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, isNetwork, "FUSE mount should be detected as network filesystem")
|
||||
|
||||
// Test that the source directory is NOT detected as network mount
|
||||
isNetworkSource, err := IsPathOnNetworkMount(sourceDir)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, isNetworkSource, "source directory should not be detected as network filesystem")
|
||||
}
|
||||
|
||||
func TestGetFileOwnership_CurrentUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("current user", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Get current user running the tests
|
||||
cur, err := osuser.Current()
|
||||
if err != nil {
|
||||
t.Skipf("unable to determine current user: %v", err)
|
||||
}
|
||||
|
||||
// Determine expected owner/group using the same lookup logic
|
||||
expectedOwner := cur.Uid
|
||||
if u, err := osuser.LookupId(cur.Uid); err == nil {
|
||||
expectedOwner = u.Username
|
||||
}
|
||||
|
||||
expectedGroup := cur.Gid
|
||||
if g, err := osuser.LookupGroupId(cur.Gid); err == nil {
|
||||
expectedGroup = g.Name
|
||||
}
|
||||
|
||||
// Parse UID/GID strings to uint32 for EntryInfo
|
||||
uid64, err := strconv.ParseUint(cur.Uid, 10, 32)
|
||||
require.NoError(t, err)
|
||||
gid64, err := strconv.ParseUint(cur.Gid, 10, 32)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Build a minimal EntryInfo with current UID/GID
|
||||
info := fsmodel.EntryInfo{ // from shared pkg
|
||||
UID: uint32(uid64),
|
||||
GID: uint32(gid64),
|
||||
}
|
||||
|
||||
owner, group := getFileOwnership(info)
|
||||
assert.Equal(t, expectedOwner, owner)
|
||||
assert.Equal(t, expectedGroup, group)
|
||||
})
|
||||
|
||||
t.Run("no user", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Find a UID that does not exist on this system
|
||||
var unknownUIDStr string
|
||||
for i := 60001; i < 70000; i++ { // search a high range typically unused
|
||||
idStr := strconv.Itoa(i)
|
||||
if _, err := osuser.LookupId(idStr); err != nil {
|
||||
unknownUIDStr = idStr
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
if unknownUIDStr == "" {
|
||||
t.Skip("could not find a non-existent UID in the probed range")
|
||||
}
|
||||
|
||||
// Find a GID that does not exist on this system
|
||||
var unknownGIDStr string
|
||||
for i := 60001; i < 70000; i++ { // search a high range typically unused
|
||||
idStr := strconv.Itoa(i)
|
||||
if _, err := osuser.LookupGroupId(idStr); err != nil {
|
||||
unknownGIDStr = idStr
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
if unknownGIDStr == "" {
|
||||
t.Skip("could not find a non-existent GID in the probed range")
|
||||
}
|
||||
|
||||
// Parse to uint32 for EntryInfo construction
|
||||
uid64, err := strconv.ParseUint(unknownUIDStr, 10, 32)
|
||||
require.NoError(t, err)
|
||||
gid64, err := strconv.ParseUint(unknownGIDStr, 10, 32)
|
||||
require.NoError(t, err)
|
||||
|
||||
info := fsmodel.EntryInfo{
|
||||
UID: uint32(uid64),
|
||||
GID: uint32(gid64),
|
||||
}
|
||||
|
||||
owner, group := getFileOwnership(info)
|
||||
// Expect numeric fallbacks because lookups should fail for unknown IDs
|
||||
assert.Equal(t, unknownUIDStr, owner)
|
||||
assert.Equal(t, unknownGIDStr, group)
|
||||
})
|
||||
}
|
||||
159
envd/internal/services/filesystem/watch.go
Normal file
159
envd/internal/services/filesystem/watch.go
Normal file
@ -0,0 +1,159 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/e2b-dev/fsnotify"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
func (s Service) WatchDir(ctx context.Context, req *connect.Request[rpc.WatchDirRequest], stream *connect.ServerStream[rpc.WatchDirResponse]) error {
|
||||
return logs.LogServerStreamWithoutEvents(ctx, s.logger, req, stream, s.watchHandler)
|
||||
}
|
||||
|
||||
func (s Service) watchHandler(ctx context.Context, req *connect.Request[rpc.WatchDirRequest], stream *connect.ServerStream[rpc.WatchDirResponse]) error {
|
||||
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
watchPath, err := permissions.ExpandAndResolve(req.Msg.GetPath(), u, s.defaults.Workdir)
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(watchPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return connect.NewError(connect.CodeNotFound, fmt.Errorf("path %s not found: %w", watchPath, err))
|
||||
}
|
||||
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("error statting path %s: %w", watchPath, err))
|
||||
}
|
||||
|
||||
if !info.IsDir() {
|
||||
return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("path %s not a directory: %w", watchPath, err))
|
||||
}
|
||||
|
||||
// Check if path is on a network filesystem mount
|
||||
isNetworkMount, err := IsPathOnNetworkMount(watchPath)
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("error checking mount status: %w", err))
|
||||
}
|
||||
if isNetworkMount {
|
||||
return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("cannot watch path on network filesystem: %s", watchPath))
|
||||
}
|
||||
|
||||
w, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("error creating watcher: %w", err))
|
||||
}
|
||||
defer w.Close()
|
||||
|
||||
err = w.Add(utils.FsnotifyPath(watchPath, req.Msg.GetRecursive()))
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("error adding path %s to watcher: %w", watchPath, err))
|
||||
}
|
||||
|
||||
err = stream.Send(&rpc.WatchDirResponse{
|
||||
Event: &rpc.WatchDirResponse_Start{
|
||||
Start: &rpc.WatchDirResponse_StartEvent{},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending start event: %w", err))
|
||||
}
|
||||
|
||||
keepaliveTicker, resetKeepalive := permissions.GetKeepAliveTicker(req)
|
||||
defer keepaliveTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-keepaliveTicker.C:
|
||||
streamErr := stream.Send(&rpc.WatchDirResponse{
|
||||
Event: &rpc.WatchDirResponse_Keepalive{
|
||||
Keepalive: &rpc.WatchDirResponse_KeepAlive{},
|
||||
},
|
||||
})
|
||||
if streamErr != nil {
|
||||
return connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending keepalive: %w", streamErr))
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case chErr, ok := <-w.Errors:
|
||||
if !ok {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("watcher error channel closed"))
|
||||
}
|
||||
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("watcher error: %w", chErr))
|
||||
case e, ok := <-w.Events:
|
||||
if !ok {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("watcher event channel closed"))
|
||||
}
|
||||
|
||||
// One event can have multiple operations.
|
||||
ops := []rpc.EventType{}
|
||||
|
||||
if fsnotify.Create.Has(e.Op) {
|
||||
ops = append(ops, rpc.EventType_EVENT_TYPE_CREATE)
|
||||
}
|
||||
|
||||
if fsnotify.Rename.Has(e.Op) {
|
||||
ops = append(ops, rpc.EventType_EVENT_TYPE_RENAME)
|
||||
}
|
||||
|
||||
if fsnotify.Chmod.Has(e.Op) {
|
||||
ops = append(ops, rpc.EventType_EVENT_TYPE_CHMOD)
|
||||
}
|
||||
|
||||
if fsnotify.Write.Has(e.Op) {
|
||||
ops = append(ops, rpc.EventType_EVENT_TYPE_WRITE)
|
||||
}
|
||||
|
||||
if fsnotify.Remove.Has(e.Op) {
|
||||
ops = append(ops, rpc.EventType_EVENT_TYPE_REMOVE)
|
||||
}
|
||||
|
||||
for _, op := range ops {
|
||||
name, nameErr := filepath.Rel(watchPath, e.Name)
|
||||
if nameErr != nil {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("error getting relative path: %w", nameErr))
|
||||
}
|
||||
|
||||
filesystemEvent := &rpc.WatchDirResponse_Filesystem{
|
||||
Filesystem: &rpc.FilesystemEvent{
|
||||
Name: name,
|
||||
Type: op,
|
||||
},
|
||||
}
|
||||
|
||||
event := &rpc.WatchDirResponse{
|
||||
Event: filesystemEvent,
|
||||
}
|
||||
|
||||
streamErr := stream.Send(event)
|
||||
|
||||
s.logger.
|
||||
Debug().
|
||||
Str("event_type", "filesystem_event").
|
||||
Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).
|
||||
Interface("filesystem_event", event).
|
||||
Msg("Streaming filesystem event")
|
||||
|
||||
if streamErr != nil {
|
||||
return connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending filesystem event: %w", streamErr))
|
||||
}
|
||||
|
||||
resetKeepalive()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
224
envd/internal/services/filesystem/watch_sync.go
Normal file
224
envd/internal/services/filesystem/watch_sync.go
Normal file
@ -0,0 +1,224 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/e2b-dev/fsnotify"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/id"
|
||||
)
|
||||
|
||||
type FileWatcher struct {
|
||||
watcher *fsnotify.Watcher
|
||||
Events []*rpc.FilesystemEvent
|
||||
cancel func()
|
||||
Error error
|
||||
|
||||
Lock sync.Mutex
|
||||
}
|
||||
|
||||
func CreateFileWatcher(ctx context.Context, watchPath string, recursive bool, operationID string, logger *zerolog.Logger) (*FileWatcher, error) {
|
||||
w, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating watcher: %w", err))
|
||||
}
|
||||
|
||||
// We don't want to cancel the context when the request is finished
|
||||
ctx, cancel := context.WithCancel(context.WithoutCancel(ctx))
|
||||
|
||||
err = w.Add(utils.FsnotifyPath(watchPath, recursive))
|
||||
if err != nil {
|
||||
_ = w.Close()
|
||||
cancel()
|
||||
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error adding path %s to watcher: %w", watchPath, err))
|
||||
}
|
||||
fw := &FileWatcher{
|
||||
watcher: w,
|
||||
cancel: cancel,
|
||||
Events: []*rpc.FilesystemEvent{},
|
||||
Error: nil,
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case chErr, ok := <-w.Errors:
|
||||
if !ok {
|
||||
fw.Error = connect.NewError(connect.CodeInternal, fmt.Errorf("watcher error channel closed"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
fw.Error = connect.NewError(connect.CodeInternal, fmt.Errorf("watcher error: %w", chErr))
|
||||
|
||||
return
|
||||
case e, ok := <-w.Events:
|
||||
if !ok {
|
||||
fw.Error = connect.NewError(connect.CodeInternal, fmt.Errorf("watcher event channel closed"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// One event can have multiple operations.
|
||||
ops := []rpc.EventType{}
|
||||
|
||||
if fsnotify.Create.Has(e.Op) {
|
||||
ops = append(ops, rpc.EventType_EVENT_TYPE_CREATE)
|
||||
}
|
||||
|
||||
if fsnotify.Rename.Has(e.Op) {
|
||||
ops = append(ops, rpc.EventType_EVENT_TYPE_RENAME)
|
||||
}
|
||||
|
||||
if fsnotify.Chmod.Has(e.Op) {
|
||||
ops = append(ops, rpc.EventType_EVENT_TYPE_CHMOD)
|
||||
}
|
||||
|
||||
if fsnotify.Write.Has(e.Op) {
|
||||
ops = append(ops, rpc.EventType_EVENT_TYPE_WRITE)
|
||||
}
|
||||
|
||||
if fsnotify.Remove.Has(e.Op) {
|
||||
ops = append(ops, rpc.EventType_EVENT_TYPE_REMOVE)
|
||||
}
|
||||
|
||||
for _, op := range ops {
|
||||
name, nameErr := filepath.Rel(watchPath, e.Name)
|
||||
if nameErr != nil {
|
||||
fw.Error = connect.NewError(connect.CodeInternal, fmt.Errorf("error getting relative path: %w", nameErr))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
fw.Lock.Lock()
|
||||
fw.Events = append(fw.Events, &rpc.FilesystemEvent{
|
||||
Name: name,
|
||||
Type: op,
|
||||
})
|
||||
fw.Lock.Unlock()
|
||||
|
||||
// these are only used for logging
|
||||
filesystemEvent := &rpc.WatchDirResponse_Filesystem{
|
||||
Filesystem: &rpc.FilesystemEvent{
|
||||
Name: name,
|
||||
Type: op,
|
||||
},
|
||||
}
|
||||
event := &rpc.WatchDirResponse{
|
||||
Event: filesystemEvent,
|
||||
}
|
||||
|
||||
logger.
|
||||
Debug().
|
||||
Str("event_type", "filesystem_event").
|
||||
Str(string(logs.OperationIDKey), operationID).
|
||||
Interface("filesystem_event", event).
|
||||
Msg("Streaming filesystem event")
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return fw, nil
|
||||
}
|
||||
|
||||
func (fw *FileWatcher) Close() {
|
||||
_ = fw.watcher.Close()
|
||||
fw.cancel()
|
||||
}
|
||||
|
||||
func (s Service) CreateWatcher(ctx context.Context, req *connect.Request[rpc.CreateWatcherRequest]) (*connect.Response[rpc.CreateWatcherResponse], error) {
|
||||
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
watchPath, err := permissions.ExpandAndResolve(req.Msg.GetPath(), u, s.defaults.Workdir)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(watchPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("path %s not found: %w", watchPath, err))
|
||||
}
|
||||
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error statting path %s: %w", watchPath, err))
|
||||
}
|
||||
|
||||
if !info.IsDir() {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("path %s not a directory: %w", watchPath, err))
|
||||
}
|
||||
|
||||
// Check if path is on a network filesystem mount
|
||||
isNetworkMount, err := IsPathOnNetworkMount(watchPath)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error checking mount status: %w", err))
|
||||
}
|
||||
if isNetworkMount {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("cannot watch path on network filesystem: %s", watchPath))
|
||||
}
|
||||
|
||||
watcherId := "w" + id.Generate()
|
||||
|
||||
w, err := CreateFileWatcher(ctx, watchPath, req.Msg.GetRecursive(), watcherId, s.logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.watchers.Store(watcherId, w)
|
||||
|
||||
return connect.NewResponse(&rpc.CreateWatcherResponse{
|
||||
WatcherId: watcherId,
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s Service) GetWatcherEvents(_ context.Context, req *connect.Request[rpc.GetWatcherEventsRequest]) (*connect.Response[rpc.GetWatcherEventsResponse], error) {
|
||||
watcherId := req.Msg.GetWatcherId()
|
||||
|
||||
w, ok := s.watchers.Load(watcherId)
|
||||
if !ok {
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("watcher with id %s not found", watcherId))
|
||||
}
|
||||
|
||||
if w.Error != nil {
|
||||
return nil, w.Error
|
||||
}
|
||||
|
||||
w.Lock.Lock()
|
||||
defer w.Lock.Unlock()
|
||||
events := w.Events
|
||||
w.Events = []*rpc.FilesystemEvent{}
|
||||
|
||||
return connect.NewResponse(&rpc.GetWatcherEventsResponse{
|
||||
Events: events,
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s Service) RemoveWatcher(_ context.Context, req *connect.Request[rpc.RemoveWatcherRequest]) (*connect.Response[rpc.RemoveWatcherResponse], error) {
|
||||
watcherId := req.Msg.GetWatcherId()
|
||||
|
||||
w, ok := s.watchers.Load(watcherId)
|
||||
if !ok {
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("watcher with id %s not found", watcherId))
|
||||
}
|
||||
|
||||
w.Close()
|
||||
s.watchers.Delete(watcherId)
|
||||
|
||||
return connect.NewResponse(&rpc.RemoveWatcherResponse{}), nil
|
||||
}
|
||||
126
envd/internal/services/process/connect.go
Normal file
126
envd/internal/services/process/connect.go
Normal file
@ -0,0 +1,126 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||
)
|
||||
|
||||
func (s *Service) Connect(ctx context.Context, req *connect.Request[rpc.ConnectRequest], stream *connect.ServerStream[rpc.ConnectResponse]) error {
|
||||
return logs.LogServerStreamWithoutEvents(ctx, s.logger, req, stream, s.handleConnect)
|
||||
}
|
||||
|
||||
func (s *Service) handleConnect(ctx context.Context, req *connect.Request[rpc.ConnectRequest], stream *connect.ServerStream[rpc.ConnectResponse]) error {
|
||||
ctx, cancel := context.WithCancelCause(ctx)
|
||||
defer cancel(nil)
|
||||
|
||||
proc, err := s.getProcess(req.Msg.GetProcess())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
exitChan := make(chan struct{})
|
||||
|
||||
data, dataCancel := proc.DataEvent.Fork()
|
||||
defer dataCancel()
|
||||
|
||||
end, endCancel := proc.EndEvent.Fork()
|
||||
defer endCancel()
|
||||
|
||||
streamErr := stream.Send(&rpc.ConnectResponse{
|
||||
Event: &rpc.ProcessEvent{
|
||||
Event: &rpc.ProcessEvent_Start{
|
||||
Start: &rpc.ProcessEvent_StartEvent{
|
||||
Pid: proc.Pid(),
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
if streamErr != nil {
|
||||
return connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending start event: %w", streamErr))
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer close(exitChan)
|
||||
|
||||
keepaliveTicker, resetKeepalive := permissions.GetKeepAliveTicker(req)
|
||||
defer keepaliveTicker.Stop()
|
||||
|
||||
dataLoop:
|
||||
for {
|
||||
select {
|
||||
case <-keepaliveTicker.C:
|
||||
streamErr := stream.Send(&rpc.ConnectResponse{
|
||||
Event: &rpc.ProcessEvent{
|
||||
Event: &rpc.ProcessEvent_Keepalive{
|
||||
Keepalive: &rpc.ProcessEvent_KeepAlive{},
|
||||
},
|
||||
},
|
||||
})
|
||||
if streamErr != nil {
|
||||
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending keepalive: %w", streamErr)))
|
||||
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
cancel(ctx.Err())
|
||||
|
||||
return
|
||||
case event, ok := <-data:
|
||||
if !ok {
|
||||
break dataLoop
|
||||
}
|
||||
|
||||
streamErr := stream.Send(&rpc.ConnectResponse{
|
||||
Event: &rpc.ProcessEvent{
|
||||
Event: &event,
|
||||
},
|
||||
})
|
||||
if streamErr != nil {
|
||||
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending data event: %w", streamErr)))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resetKeepalive()
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
cancel(ctx.Err())
|
||||
|
||||
return
|
||||
case event, ok := <-end:
|
||||
if !ok {
|
||||
cancel(connect.NewError(connect.CodeUnknown, errors.New("end event channel closed before sending end event")))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
streamErr := stream.Send(&rpc.ConnectResponse{
|
||||
Event: &rpc.ProcessEvent{
|
||||
Event: &event,
|
||||
},
|
||||
})
|
||||
if streamErr != nil {
|
||||
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending end event: %w", streamErr)))
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-exitChan:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
478
envd/internal/services/process/handler/handler.go
Normal file
478
envd/internal/services/process/handler/handler.go
Normal file
@ -0,0 +1,478 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/creack/pty"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/services/cgroups"
|
||||
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultNice = 0
|
||||
defaultOomScore = 100
|
||||
outputBufferSize = 64
|
||||
stdChunkSize = 2 << 14
|
||||
ptyChunkSize = 2 << 13
|
||||
)
|
||||
|
||||
type ProcessExit struct {
|
||||
Error *string
|
||||
Status string
|
||||
Exited bool
|
||||
Code int32
|
||||
}
|
||||
|
||||
type Handler struct {
|
||||
Config *rpc.ProcessConfig
|
||||
|
||||
logger *zerolog.Logger
|
||||
|
||||
Tag *string
|
||||
cmd *exec.Cmd
|
||||
tty *os.File
|
||||
|
||||
cancel context.CancelFunc
|
||||
|
||||
outCtx context.Context //nolint:containedctx // todo: refactor so this can be removed
|
||||
outCancel context.CancelFunc
|
||||
|
||||
stdinMu sync.Mutex
|
||||
stdin io.WriteCloser
|
||||
|
||||
DataEvent *MultiplexedChannel[rpc.ProcessEvent_Data]
|
||||
EndEvent *MultiplexedChannel[rpc.ProcessEvent_End]
|
||||
}
|
||||
|
||||
// This method must be called only after the process has been started
|
||||
func (p *Handler) Pid() uint32 {
|
||||
return uint32(p.cmd.Process.Pid)
|
||||
}
|
||||
|
||||
// userCommand returns a human-readable representation of the user's original command,
|
||||
// without the internal OOM/nice wrapper that is prepended to the actual exec.
|
||||
func (p *Handler) userCommand() string {
|
||||
return strings.Join(append([]string{p.Config.GetCmd()}, p.Config.GetArgs()...), " ")
|
||||
}
|
||||
|
||||
// currentNice returns the nice value of the current process.
|
||||
func currentNice() int {
|
||||
prio, err := syscall.Getpriority(syscall.PRIO_PROCESS, 0)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Getpriority returns 20 - nice on Linux.
|
||||
return 20 - prio
|
||||
}
|
||||
|
||||
func New(
|
||||
ctx context.Context,
|
||||
user *user.User,
|
||||
req *rpc.StartRequest,
|
||||
logger *zerolog.Logger,
|
||||
defaults *execcontext.Defaults,
|
||||
cgroupManager cgroups.Manager,
|
||||
cancel context.CancelFunc,
|
||||
) (*Handler, error) {
|
||||
// User command string for logging (without the internal wrapper details).
|
||||
userCmd := strings.Join(append([]string{req.GetProcess().GetCmd()}, req.GetProcess().GetArgs()...), " ")
|
||||
|
||||
// Wrap the command in a shell that sets the OOM score and nice value before exec-ing the actual command.
|
||||
// This eliminates the race window where grandchildren could inherit the parent's protected OOM score (-1000)
|
||||
// or high CPU priority (nice -20) before the post-start calls had a chance to correct them.
|
||||
// nice(1) applies a relative adjustment, so we compute the delta from the current (inherited) nice to the target.
|
||||
niceDelta := defaultNice - currentNice()
|
||||
oomWrapperScript := fmt.Sprintf(`echo %d > /proc/$$/oom_score_adj && exec /usr/bin/nice -n %d "${@}"`, defaultOomScore, niceDelta)
|
||||
wrapperArgs := append([]string{"-c", oomWrapperScript, "--", req.GetProcess().GetCmd()}, req.GetProcess().GetArgs()...)
|
||||
cmd := exec.CommandContext(ctx, "/bin/sh", wrapperArgs...)
|
||||
|
||||
uid, gid, err := permissions.GetUserIdUints(user)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, err)
|
||||
}
|
||||
|
||||
groups := []uint32{gid}
|
||||
if gids, err := user.GroupIds(); err != nil {
|
||||
logger.Warn().Err(err).Str("user", user.Username).Msg("failed to get supplementary groups")
|
||||
} else {
|
||||
for _, g := range gids {
|
||||
if parsed, err := strconv.ParseUint(g, 10, 32); err == nil {
|
||||
groups = append(groups, uint32(parsed))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cgroupFD, ok := cgroupManager.GetFileDescriptor(getProcType(req))
|
||||
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
UseCgroupFD: ok,
|
||||
CgroupFD: cgroupFD,
|
||||
Credential: &syscall.Credential{
|
||||
Uid: uid,
|
||||
Gid: gid,
|
||||
Groups: groups,
|
||||
},
|
||||
}
|
||||
|
||||
resolvedPath, err := permissions.ExpandAndResolve(req.GetProcess().GetCwd(), user, defaults.Workdir)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
// Check if the cwd resolved path exists
|
||||
if _, err := os.Stat(resolvedPath); errors.Is(err, os.ErrNotExist) {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("cwd '%s' does not exist", resolvedPath))
|
||||
}
|
||||
|
||||
cmd.Dir = resolvedPath
|
||||
|
||||
var formattedVars []string
|
||||
|
||||
// Take only 'PATH' variable from the current environment
|
||||
// The 'PATH' should ideally be set in the environment
|
||||
formattedVars = append(formattedVars, "PATH="+os.Getenv("PATH"))
|
||||
formattedVars = append(formattedVars, "HOME="+user.HomeDir)
|
||||
formattedVars = append(formattedVars, "USER="+user.Username)
|
||||
formattedVars = append(formattedVars, "LOGNAME="+user.Username)
|
||||
|
||||
// Add the environment variables from the global environment
|
||||
if defaults.EnvVars != nil {
|
||||
defaults.EnvVars.Range(func(key string, value string) bool {
|
||||
formattedVars = append(formattedVars, key+"="+value)
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Only the last values of the env vars are used - this allows for overwriting defaults
|
||||
for key, value := range req.GetProcess().GetEnvs() {
|
||||
formattedVars = append(formattedVars, key+"="+value)
|
||||
}
|
||||
|
||||
cmd.Env = formattedVars
|
||||
|
||||
outMultiplex := NewMultiplexedChannel[rpc.ProcessEvent_Data](outputBufferSize)
|
||||
|
||||
var outWg sync.WaitGroup
|
||||
|
||||
// Create a context for waiting for and cancelling output pipes.
|
||||
// Cancellation of the process via timeout will propagate and cancel this context too.
|
||||
outCtx, outCancel := context.WithCancel(ctx)
|
||||
|
||||
h := &Handler{
|
||||
Config: req.GetProcess(),
|
||||
cmd: cmd,
|
||||
Tag: req.Tag,
|
||||
DataEvent: outMultiplex,
|
||||
cancel: cancel,
|
||||
outCtx: outCtx,
|
||||
outCancel: outCancel,
|
||||
EndEvent: NewMultiplexedChannel[rpc.ProcessEvent_End](0),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
if req.GetPty() != nil {
|
||||
// The pty should ideally start only in the Start method, but the package does not support that and we would have to code it manually.
|
||||
// The output of the pty should correctly be passed though.
|
||||
tty, err := pty.StartWithSize(cmd, &pty.Winsize{
|
||||
Cols: uint16(req.GetPty().GetSize().GetCols()),
|
||||
Rows: uint16(req.GetPty().GetSize().GetRows()),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("error starting pty with command '%s' in dir '%s' with '%d' cols and '%d' rows: %w", userCmd, cmd.Dir, req.GetPty().GetSize().GetCols(), req.GetPty().GetSize().GetRows(), err))
|
||||
}
|
||||
|
||||
outWg.Go(func() {
|
||||
for {
|
||||
buf := make([]byte, ptyChunkSize)
|
||||
|
||||
n, readErr := tty.Read(buf)
|
||||
|
||||
if n > 0 {
|
||||
outMultiplex.Source <- rpc.ProcessEvent_Data{
|
||||
Data: &rpc.ProcessEvent_DataEvent{
|
||||
Output: &rpc.ProcessEvent_DataEvent_Pty{
|
||||
Pty: buf[:n],
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if errors.Is(readErr, io.EOF) {
|
||||
break
|
||||
}
|
||||
|
||||
if readErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "error reading from pty: %s\n", readErr)
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
h.tty = tty
|
||||
} else {
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating stdout pipe for command '%s': %w", userCmd, err))
|
||||
}
|
||||
|
||||
outWg.Go(func() {
|
||||
stdoutLogs := make(chan []byte, outputBufferSize)
|
||||
defer close(stdoutLogs)
|
||||
|
||||
stdoutLogger := logger.With().Str("event_type", "stdout").Logger()
|
||||
|
||||
go logs.LogBufferedDataEvents(stdoutLogs, &stdoutLogger, "data")
|
||||
|
||||
for {
|
||||
buf := make([]byte, stdChunkSize)
|
||||
|
||||
n, readErr := stdout.Read(buf)
|
||||
|
||||
if n > 0 {
|
||||
outMultiplex.Source <- rpc.ProcessEvent_Data{
|
||||
Data: &rpc.ProcessEvent_DataEvent{
|
||||
Output: &rpc.ProcessEvent_DataEvent_Stdout{
|
||||
Stdout: buf[:n],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
stdoutLogs <- buf[:n]
|
||||
}
|
||||
|
||||
if errors.Is(readErr, io.EOF) {
|
||||
break
|
||||
}
|
||||
|
||||
if readErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "error reading from stdout: %s\n", readErr)
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating stderr pipe for command '%s': %w", userCmd, err))
|
||||
}
|
||||
|
||||
outWg.Go(func() {
|
||||
stderrLogs := make(chan []byte, outputBufferSize)
|
||||
defer close(stderrLogs)
|
||||
|
||||
stderrLogger := logger.With().Str("event_type", "stderr").Logger()
|
||||
|
||||
go logs.LogBufferedDataEvents(stderrLogs, &stderrLogger, "data")
|
||||
|
||||
for {
|
||||
buf := make([]byte, stdChunkSize)
|
||||
|
||||
n, readErr := stderr.Read(buf)
|
||||
|
||||
if n > 0 {
|
||||
outMultiplex.Source <- rpc.ProcessEvent_Data{
|
||||
Data: &rpc.ProcessEvent_DataEvent{
|
||||
Output: &rpc.ProcessEvent_DataEvent_Stderr{
|
||||
Stderr: buf[:n],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
stderrLogs <- buf[:n]
|
||||
}
|
||||
|
||||
if errors.Is(readErr, io.EOF) {
|
||||
break
|
||||
}
|
||||
|
||||
if readErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "error reading from stderr: %s\n", readErr)
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// For backwards compatibility we still set the stdin if not explicitly disabled
|
||||
// If stdin is disabled, the process will use /dev/null as stdin
|
||||
if req.Stdin == nil || req.GetStdin() == true {
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating stdin pipe for command '%s': %w", userCmd, err))
|
||||
}
|
||||
|
||||
h.stdin = stdin
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
outWg.Wait()
|
||||
|
||||
close(outMultiplex.Source)
|
||||
|
||||
outCancel()
|
||||
}()
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func getProcType(req *rpc.StartRequest) cgroups.ProcessType {
|
||||
if req != nil && req.GetPty() != nil {
|
||||
return cgroups.ProcessTypePTY
|
||||
}
|
||||
|
||||
return cgroups.ProcessTypeUser
|
||||
}
|
||||
|
||||
func (p *Handler) SendSignal(signal syscall.Signal) error {
|
||||
if p.cmd.Process == nil {
|
||||
return fmt.Errorf("process not started")
|
||||
}
|
||||
|
||||
if signal == syscall.SIGKILL || signal == syscall.SIGTERM {
|
||||
p.outCancel()
|
||||
}
|
||||
|
||||
return p.cmd.Process.Signal(signal)
|
||||
}
|
||||
|
||||
func (p *Handler) ResizeTty(size *pty.Winsize) error {
|
||||
if p.tty == nil {
|
||||
return fmt.Errorf("tty not assigned to process")
|
||||
}
|
||||
|
||||
return pty.Setsize(p.tty, size)
|
||||
}
|
||||
|
||||
func (p *Handler) WriteStdin(data []byte) error {
|
||||
if p.tty != nil {
|
||||
return fmt.Errorf("tty assigned to process — input should be written to the pty, not the stdin")
|
||||
}
|
||||
|
||||
p.stdinMu.Lock()
|
||||
defer p.stdinMu.Unlock()
|
||||
|
||||
if p.stdin == nil {
|
||||
return fmt.Errorf("stdin not enabled or closed")
|
||||
}
|
||||
|
||||
_, err := p.stdin.Write(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing to stdin of process '%d': %w", p.cmd.Process.Pid, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseStdin closes the stdin pipe to signal EOF to the process.
|
||||
// Only works for non-PTY processes.
|
||||
func (p *Handler) CloseStdin() error {
|
||||
if p.tty != nil {
|
||||
return fmt.Errorf("cannot close stdin for PTY process — send Ctrl+D (0x04) instead")
|
||||
}
|
||||
|
||||
p.stdinMu.Lock()
|
||||
defer p.stdinMu.Unlock()
|
||||
|
||||
if p.stdin == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := p.stdin.Close()
|
||||
// We still set the stdin to nil even on error as there are no errors,
|
||||
// for which it is really safe to retry close across all distributions.
|
||||
p.stdin = nil
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *Handler) WriteTty(data []byte) error {
|
||||
if p.tty == nil {
|
||||
return fmt.Errorf("tty not assigned to process — input should be written to the stdin, not the tty")
|
||||
}
|
||||
|
||||
_, err := p.tty.Write(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing to tty of process '%d': %w", p.cmd.Process.Pid, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Handler) Start() (uint32, error) {
|
||||
// Pty is already started in the New method
|
||||
if p.tty == nil {
|
||||
err := p.cmd.Start()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error starting process '%s': %w", p.userCommand(), err)
|
||||
}
|
||||
}
|
||||
|
||||
p.logger.
|
||||
Info().
|
||||
Str("event_type", "process_start").
|
||||
Int("pid", p.cmd.Process.Pid).
|
||||
Str("command", p.userCommand()).
|
||||
Msg(fmt.Sprintf("Process with pid %d started", p.cmd.Process.Pid))
|
||||
|
||||
return uint32(p.cmd.Process.Pid), nil
|
||||
}
|
||||
|
||||
func (p *Handler) Wait() {
|
||||
// Wait for the output pipes to be closed or cancelled.
|
||||
<-p.outCtx.Done()
|
||||
|
||||
err := p.cmd.Wait()
|
||||
|
||||
p.tty.Close()
|
||||
|
||||
var errMsg *string
|
||||
|
||||
if err != nil {
|
||||
msg := err.Error()
|
||||
errMsg = &msg
|
||||
}
|
||||
|
||||
endEvent := &rpc.ProcessEvent_EndEvent{
|
||||
Error: errMsg,
|
||||
ExitCode: int32(p.cmd.ProcessState.ExitCode()),
|
||||
Exited: p.cmd.ProcessState.Exited(),
|
||||
Status: p.cmd.ProcessState.String(),
|
||||
}
|
||||
|
||||
event := rpc.ProcessEvent_End{
|
||||
End: endEvent,
|
||||
}
|
||||
|
||||
p.EndEvent.Source <- event
|
||||
|
||||
p.logger.
|
||||
Info().
|
||||
Str("event_type", "process_end").
|
||||
Interface("process_result", endEvent).
|
||||
Msg(fmt.Sprintf("Process with pid %d ended", p.cmd.Process.Pid))
|
||||
|
||||
// Ensure the process cancel is called to cleanup resources.
|
||||
// As it is called after end event and Wait, it should not affect command execution or returned events.
|
||||
p.cancel()
|
||||
}
|
||||
73
envd/internal/services/process/handler/multiplex.go
Normal file
73
envd/internal/services/process/handler/multiplex.go
Normal file
@ -0,0 +1,73 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type MultiplexedChannel[T any] struct {
|
||||
Source chan T
|
||||
channels []chan T
|
||||
mu sync.RWMutex
|
||||
exited atomic.Bool
|
||||
}
|
||||
|
||||
func NewMultiplexedChannel[T any](buffer int) *MultiplexedChannel[T] {
|
||||
c := &MultiplexedChannel[T]{
|
||||
channels: nil,
|
||||
Source: make(chan T, buffer),
|
||||
}
|
||||
|
||||
go func() {
|
||||
for v := range c.Source {
|
||||
c.mu.RLock()
|
||||
|
||||
for _, cons := range c.channels {
|
||||
cons <- v
|
||||
}
|
||||
|
||||
c.mu.RUnlock()
|
||||
}
|
||||
|
||||
c.exited.Store(true)
|
||||
|
||||
for _, cons := range c.channels {
|
||||
close(cons)
|
||||
}
|
||||
}()
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (m *MultiplexedChannel[T]) Fork() (chan T, func()) {
|
||||
if m.exited.Load() {
|
||||
ch := make(chan T)
|
||||
close(ch)
|
||||
|
||||
return ch, func() {}
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
consumer := make(chan T)
|
||||
|
||||
m.channels = append(m.channels, consumer)
|
||||
|
||||
return consumer, func() {
|
||||
m.remove(consumer)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MultiplexedChannel[T]) remove(consumer chan T) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for i, ch := range m.channels {
|
||||
if ch == consumer {
|
||||
m.channels = append(m.channels[:i], m.channels[i+1:]...)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
107
envd/internal/services/process/input.go
Normal file
107
envd/internal/services/process/input.go
Normal file
@ -0,0 +1,107 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/services/process/handler"
|
||||
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||
)
|
||||
|
||||
func handleInput(ctx context.Context, process *handler.Handler, in *rpc.ProcessInput, logger *zerolog.Logger) error {
|
||||
switch in.GetInput().(type) {
|
||||
case *rpc.ProcessInput_Pty:
|
||||
err := process.WriteTty(in.GetPty())
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("error writing to tty: %w", err))
|
||||
}
|
||||
|
||||
case *rpc.ProcessInput_Stdin:
|
||||
err := process.WriteStdin(in.GetStdin())
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("error writing to stdin: %w", err))
|
||||
}
|
||||
|
||||
logger.Debug().
|
||||
Str("event_type", "stdin").
|
||||
Interface("stdin", in.GetStdin()).
|
||||
Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).
|
||||
Msg("Streaming input to process")
|
||||
|
||||
default:
|
||||
return connect.NewError(connect.CodeUnimplemented, fmt.Errorf("invalid input type %T", in.GetInput()))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) SendInput(ctx context.Context, req *connect.Request[rpc.SendInputRequest]) (*connect.Response[rpc.SendInputResponse], error) {
|
||||
proc, err := s.getProcess(req.Msg.GetProcess())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = handleInput(ctx, proc, req.Msg.GetInput(), s.logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return connect.NewResponse(&rpc.SendInputResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Service) StreamInput(ctx context.Context, stream *connect.ClientStream[rpc.StreamInputRequest]) (*connect.Response[rpc.StreamInputResponse], error) {
|
||||
return logs.LogClientStreamWithoutEvents(ctx, s.logger, stream, s.streamInputHandler)
|
||||
}
|
||||
|
||||
func (s *Service) streamInputHandler(ctx context.Context, stream *connect.ClientStream[rpc.StreamInputRequest]) (*connect.Response[rpc.StreamInputResponse], error) {
|
||||
var proc *handler.Handler
|
||||
|
||||
for stream.Receive() {
|
||||
req := stream.Msg()
|
||||
|
||||
switch req.GetEvent().(type) {
|
||||
case *rpc.StreamInputRequest_Start:
|
||||
p, err := s.getProcess(req.GetStart().GetProcess())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
proc = p
|
||||
case *rpc.StreamInputRequest_Data:
|
||||
err := handleInput(ctx, proc, req.GetData().GetInput(), s.logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case *rpc.StreamInputRequest_Keepalive:
|
||||
default:
|
||||
return nil, connect.NewError(connect.CodeUnimplemented, fmt.Errorf("invalid event type %T", req.GetEvent()))
|
||||
}
|
||||
}
|
||||
|
||||
err := stream.Err()
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeUnknown, fmt.Errorf("error streaming input: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&rpc.StreamInputResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Service) CloseStdin(
|
||||
_ context.Context,
|
||||
req *connect.Request[rpc.CloseStdinRequest],
|
||||
) (*connect.Response[rpc.CloseStdinResponse], error) {
|
||||
handler, err := s.getProcess(req.Msg.GetProcess())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := handler.CloseStdin(); err != nil {
|
||||
return nil, connect.NewError(connect.CodeUnknown, fmt.Errorf("error closing stdin: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&rpc.CloseStdinResponse{}), nil
|
||||
}
|
||||
28
envd/internal/services/process/list.go
Normal file
28
envd/internal/services/process/list.go
Normal file
@ -0,0 +1,28 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/services/process/handler"
|
||||
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||
)
|
||||
|
||||
func (s *Service) List(context.Context, *connect.Request[rpc.ListRequest]) (*connect.Response[rpc.ListResponse], error) {
|
||||
processes := make([]*rpc.ProcessInfo, 0)
|
||||
|
||||
s.processes.Range(func(pid uint32, value *handler.Handler) bool {
|
||||
processes = append(processes, &rpc.ProcessInfo{
|
||||
Pid: pid,
|
||||
Tag: value.Tag,
|
||||
Config: value.Config,
|
||||
})
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return connect.NewResponse(&rpc.ListResponse{
|
||||
Processes: processes,
|
||||
}), nil
|
||||
}
|
||||
84
envd/internal/services/process/service.go
Normal file
84
envd/internal/services/process/service.go
Normal file
@ -0,0 +1,84 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/services/cgroups"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/services/process/handler"
|
||||
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||
spec "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process/processconnect"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
processes *utils.Map[uint32, *handler.Handler]
|
||||
logger *zerolog.Logger
|
||||
defaults *execcontext.Defaults
|
||||
cgroupManager cgroups.Manager
|
||||
}
|
||||
|
||||
func newService(l *zerolog.Logger, defaults *execcontext.Defaults, cgroupManager cgroups.Manager) *Service {
|
||||
return &Service{
|
||||
logger: l,
|
||||
processes: utils.NewMap[uint32, *handler.Handler](),
|
||||
defaults: defaults,
|
||||
cgroupManager: cgroupManager,
|
||||
}
|
||||
}
|
||||
|
||||
func Handle(server *chi.Mux, l *zerolog.Logger, defaults *execcontext.Defaults, cgroupManager cgroups.Manager) *Service {
|
||||
service := newService(l, defaults, cgroupManager)
|
||||
|
||||
interceptors := connect.WithInterceptors(logs.NewUnaryLogInterceptor(l))
|
||||
|
||||
path, h := spec.NewProcessHandler(service, interceptors)
|
||||
|
||||
server.Mount(path, h)
|
||||
|
||||
return service
|
||||
}
|
||||
|
||||
func (s *Service) getProcess(selector *rpc.ProcessSelector) (*handler.Handler, error) {
|
||||
var proc *handler.Handler
|
||||
|
||||
switch selector.GetSelector().(type) {
|
||||
case *rpc.ProcessSelector_Pid:
|
||||
p, ok := s.processes.Load(selector.GetPid())
|
||||
if !ok {
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("process with pid %d not found", selector.GetPid()))
|
||||
}
|
||||
|
||||
proc = p
|
||||
case *rpc.ProcessSelector_Tag:
|
||||
tag := selector.GetTag()
|
||||
|
||||
s.processes.Range(func(_ uint32, value *handler.Handler) bool {
|
||||
if value.Tag == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if *value.Tag == tag {
|
||||
proc = value
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
})
|
||||
|
||||
if proc == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("process with tag %s not found", tag))
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, connect.NewError(connect.CodeUnimplemented, fmt.Errorf("invalid input type %T", selector))
|
||||
}
|
||||
|
||||
return proc, nil
|
||||
}
|
||||
38
envd/internal/services/process/signal.go
Normal file
38
envd/internal/services/process/signal.go
Normal file
@ -0,0 +1,38 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"syscall"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||
)
|
||||
|
||||
func (s *Service) SendSignal(
|
||||
_ context.Context,
|
||||
req *connect.Request[rpc.SendSignalRequest],
|
||||
) (*connect.Response[rpc.SendSignalResponse], error) {
|
||||
handler, err := s.getProcess(req.Msg.GetProcess())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var signal syscall.Signal
|
||||
switch req.Msg.GetSignal() {
|
||||
case rpc.Signal_SIGNAL_SIGKILL:
|
||||
signal = syscall.SIGKILL
|
||||
case rpc.Signal_SIGNAL_SIGTERM:
|
||||
signal = syscall.SIGTERM
|
||||
default:
|
||||
return nil, connect.NewError(connect.CodeUnimplemented, fmt.Errorf("invalid signal: %s", req.Msg.GetSignal()))
|
||||
}
|
||||
|
||||
err = handler.SendSignal(signal)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error sending signal: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&rpc.SendSignalResponse{}), nil
|
||||
}
|
||||
247
envd/internal/services/process/start.go
Normal file
247
envd/internal/services/process/start.go
Normal file
@ -0,0 +1,247 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os/user"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/services/process/handler"
|
||||
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||
)
|
||||
|
||||
func (s *Service) InitializeStartProcess(ctx context.Context, user *user.User, req *rpc.StartRequest) error {
|
||||
var err error
|
||||
|
||||
ctx = logs.AddRequestIDToContext(ctx)
|
||||
|
||||
defer s.logger.
|
||||
Err(err).
|
||||
Interface("request", req).
|
||||
Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).
|
||||
Msg("Initialized startCmd")
|
||||
|
||||
handlerL := s.logger.With().Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).Logger()
|
||||
|
||||
startProcCtx, startProcCancel := context.WithCancel(ctx)
|
||||
proc, err := handler.New(startProcCtx, user, req, &handlerL, s.defaults, s.cgroupManager, startProcCancel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pid, err := proc.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.processes.Store(pid, proc)
|
||||
|
||||
go func() {
|
||||
defer s.processes.Delete(pid)
|
||||
|
||||
proc.Wait()
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) Start(ctx context.Context, req *connect.Request[rpc.StartRequest], stream *connect.ServerStream[rpc.StartResponse]) error {
|
||||
return logs.LogServerStreamWithoutEvents(ctx, s.logger, req, stream, s.handleStart)
|
||||
}
|
||||
|
||||
func (s *Service) handleStart(ctx context.Context, req *connect.Request[rpc.StartRequest], stream *connect.ServerStream[rpc.StartResponse]) error {
|
||||
ctx, cancel := context.WithCancelCause(ctx)
|
||||
defer cancel(nil)
|
||||
|
||||
handlerL := s.logger.With().Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).Logger()
|
||||
|
||||
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
timeout, err := determineTimeoutFromHeader(stream.Conn().RequestHeader())
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
// Create a new context with a timeout if provided.
|
||||
// We do not want the command to be killed if the request context is cancelled
|
||||
procCtx, cancelProc := context.Background(), func() {}
|
||||
if timeout > 0 { // zero timeout means no timeout
|
||||
procCtx, cancelProc = context.WithTimeout(procCtx, timeout)
|
||||
}
|
||||
|
||||
proc, err := handler.New( //nolint:contextcheck // TODO: fix this later
|
||||
procCtx,
|
||||
u,
|
||||
req.Msg,
|
||||
&handlerL,
|
||||
s.defaults,
|
||||
s.cgroupManager,
|
||||
cancelProc,
|
||||
)
|
||||
if err != nil {
|
||||
// Ensure the process cancel is called to cleanup resources.
|
||||
cancelProc()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
exitChan := make(chan struct{})
|
||||
|
||||
startMultiplexer := handler.NewMultiplexedChannel[rpc.ProcessEvent_Start](0)
|
||||
defer close(startMultiplexer.Source)
|
||||
|
||||
start, startCancel := startMultiplexer.Fork()
|
||||
defer startCancel()
|
||||
|
||||
data, dataCancel := proc.DataEvent.Fork()
|
||||
defer dataCancel()
|
||||
|
||||
end, endCancel := proc.EndEvent.Fork()
|
||||
defer endCancel()
|
||||
|
||||
go func() {
|
||||
defer close(exitChan)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
cancel(ctx.Err())
|
||||
|
||||
return
|
||||
case event, ok := <-start:
|
||||
if !ok {
|
||||
cancel(connect.NewError(connect.CodeUnknown, errors.New("start event channel closed before sending start event")))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
streamErr := stream.Send(&rpc.StartResponse{
|
||||
Event: &rpc.ProcessEvent{
|
||||
Event: &event,
|
||||
},
|
||||
})
|
||||
if streamErr != nil {
|
||||
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending start event: %w", streamErr)))
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
keepaliveTicker, resetKeepalive := permissions.GetKeepAliveTicker(req)
|
||||
defer keepaliveTicker.Stop()
|
||||
|
||||
dataLoop:
|
||||
for {
|
||||
select {
|
||||
case <-keepaliveTicker.C:
|
||||
streamErr := stream.Send(&rpc.StartResponse{
|
||||
Event: &rpc.ProcessEvent{
|
||||
Event: &rpc.ProcessEvent_Keepalive{
|
||||
Keepalive: &rpc.ProcessEvent_KeepAlive{},
|
||||
},
|
||||
},
|
||||
})
|
||||
if streamErr != nil {
|
||||
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending keepalive: %w", streamErr)))
|
||||
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
cancel(ctx.Err())
|
||||
|
||||
return
|
||||
case event, ok := <-data:
|
||||
if !ok {
|
||||
break dataLoop
|
||||
}
|
||||
|
||||
streamErr := stream.Send(&rpc.StartResponse{
|
||||
Event: &rpc.ProcessEvent{
|
||||
Event: &event,
|
||||
},
|
||||
})
|
||||
if streamErr != nil {
|
||||
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending data event: %w", streamErr)))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resetKeepalive()
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
cancel(ctx.Err())
|
||||
|
||||
return
|
||||
case event, ok := <-end:
|
||||
if !ok {
|
||||
cancel(connect.NewError(connect.CodeUnknown, errors.New("end event channel closed before sending end event")))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
streamErr := stream.Send(&rpc.StartResponse{
|
||||
Event: &rpc.ProcessEvent{
|
||||
Event: &event,
|
||||
},
|
||||
})
|
||||
if streamErr != nil {
|
||||
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending end event: %w", streamErr)))
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
pid, err := proc.Start()
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
s.processes.Store(pid, proc)
|
||||
|
||||
start <- rpc.ProcessEvent_Start{
|
||||
Start: &rpc.ProcessEvent_StartEvent{
|
||||
Pid: pid,
|
||||
},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer s.processes.Delete(pid)
|
||||
|
||||
proc.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-exitChan:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func determineTimeoutFromHeader(header http.Header) (time.Duration, error) {
|
||||
timeoutHeader := header.Get("Connect-Timeout-Ms")
|
||||
|
||||
if timeoutHeader == "" {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
timeout, err := strconv.Atoi(timeoutHeader)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return time.Duration(timeout) * time.Millisecond, nil
|
||||
}
|
||||
30
envd/internal/services/process/update.go
Normal file
30
envd/internal/services/process/update.go
Normal file
@ -0,0 +1,30 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/creack/pty"
|
||||
|
||||
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||
)
|
||||
|
||||
func (s *Service) Update(_ context.Context, req *connect.Request[rpc.UpdateRequest]) (*connect.Response[rpc.UpdateResponse], error) {
|
||||
proc, err := s.getProcess(req.Msg.GetProcess())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if req.Msg.GetPty() != nil {
|
||||
err := proc.ResizeTty(&pty.Winsize{
|
||||
Rows: uint16(req.Msg.GetPty().GetSize().GetRows()),
|
||||
Cols: uint16(req.Msg.GetPty().GetSize().GetCols()),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error resizing tty: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return connect.NewResponse(&rpc.UpdateResponse{}), nil
|
||||
}
|
||||
1444
envd/internal/services/spec/filesystem/filesystem.pb.go
Normal file
1444
envd/internal/services/spec/filesystem/filesystem.pb.go
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,337 @@
|
||||
// Code generated by protoc-gen-connect-go. DO NOT EDIT.
|
||||
//
|
||||
// Source: filesystem/filesystem.proto
|
||||
|
||||
package filesystemconnect
|
||||
|
||||
import (
|
||||
connect "connectrpc.com/connect"
|
||||
context "context"
|
||||
errors "errors"
|
||||
filesystem "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
|
||||
http "net/http"
|
||||
strings "strings"
|
||||
)
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file and the connect package are
|
||||
// compatible. If you get a compiler error that this constant is not defined, this code was
|
||||
// generated with a version of connect newer than the one compiled into your binary. You can fix the
|
||||
// problem by either regenerating this code with an older version of connect or updating the connect
|
||||
// version compiled into your binary.
|
||||
const _ = connect.IsAtLeastVersion1_13_0
|
||||
|
||||
const (
|
||||
// FilesystemName is the fully-qualified name of the Filesystem service.
|
||||
FilesystemName = "filesystem.Filesystem"
|
||||
)
|
||||
|
||||
// These constants are the fully-qualified names of the RPCs defined in this package. They're
|
||||
// exposed at runtime as Spec.Procedure and as the final two segments of the HTTP route.
|
||||
//
|
||||
// Note that these are different from the fully-qualified method names used by
|
||||
// google.golang.org/protobuf/reflect/protoreflect. To convert from these constants to
|
||||
// reflection-formatted method names, remove the leading slash and convert the remaining slash to a
|
||||
// period.
|
||||
const (
|
||||
// FilesystemStatProcedure is the fully-qualified name of the Filesystem's Stat RPC.
|
||||
FilesystemStatProcedure = "/filesystem.Filesystem/Stat"
|
||||
// FilesystemMakeDirProcedure is the fully-qualified name of the Filesystem's MakeDir RPC.
|
||||
FilesystemMakeDirProcedure = "/filesystem.Filesystem/MakeDir"
|
||||
// FilesystemMoveProcedure is the fully-qualified name of the Filesystem's Move RPC.
|
||||
FilesystemMoveProcedure = "/filesystem.Filesystem/Move"
|
||||
// FilesystemListDirProcedure is the fully-qualified name of the Filesystem's ListDir RPC.
|
||||
FilesystemListDirProcedure = "/filesystem.Filesystem/ListDir"
|
||||
// FilesystemRemoveProcedure is the fully-qualified name of the Filesystem's Remove RPC.
|
||||
FilesystemRemoveProcedure = "/filesystem.Filesystem/Remove"
|
||||
// FilesystemWatchDirProcedure is the fully-qualified name of the Filesystem's WatchDir RPC.
|
||||
FilesystemWatchDirProcedure = "/filesystem.Filesystem/WatchDir"
|
||||
// FilesystemCreateWatcherProcedure is the fully-qualified name of the Filesystem's CreateWatcher
|
||||
// RPC.
|
||||
FilesystemCreateWatcherProcedure = "/filesystem.Filesystem/CreateWatcher"
|
||||
// FilesystemGetWatcherEventsProcedure is the fully-qualified name of the Filesystem's
|
||||
// GetWatcherEvents RPC.
|
||||
FilesystemGetWatcherEventsProcedure = "/filesystem.Filesystem/GetWatcherEvents"
|
||||
// FilesystemRemoveWatcherProcedure is the fully-qualified name of the Filesystem's RemoveWatcher
|
||||
// RPC.
|
||||
FilesystemRemoveWatcherProcedure = "/filesystem.Filesystem/RemoveWatcher"
|
||||
)
|
||||
|
||||
// FilesystemClient is a client for the filesystem.Filesystem service.
|
||||
type FilesystemClient interface {
|
||||
Stat(context.Context, *connect.Request[filesystem.StatRequest]) (*connect.Response[filesystem.StatResponse], error)
|
||||
MakeDir(context.Context, *connect.Request[filesystem.MakeDirRequest]) (*connect.Response[filesystem.MakeDirResponse], error)
|
||||
Move(context.Context, *connect.Request[filesystem.MoveRequest]) (*connect.Response[filesystem.MoveResponse], error)
|
||||
ListDir(context.Context, *connect.Request[filesystem.ListDirRequest]) (*connect.Response[filesystem.ListDirResponse], error)
|
||||
Remove(context.Context, *connect.Request[filesystem.RemoveRequest]) (*connect.Response[filesystem.RemoveResponse], error)
|
||||
WatchDir(context.Context, *connect.Request[filesystem.WatchDirRequest]) (*connect.ServerStreamForClient[filesystem.WatchDirResponse], error)
|
||||
// Non-streaming versions of WatchDir
|
||||
CreateWatcher(context.Context, *connect.Request[filesystem.CreateWatcherRequest]) (*connect.Response[filesystem.CreateWatcherResponse], error)
|
||||
GetWatcherEvents(context.Context, *connect.Request[filesystem.GetWatcherEventsRequest]) (*connect.Response[filesystem.GetWatcherEventsResponse], error)
|
||||
RemoveWatcher(context.Context, *connect.Request[filesystem.RemoveWatcherRequest]) (*connect.Response[filesystem.RemoveWatcherResponse], error)
|
||||
}
|
||||
|
||||
// NewFilesystemClient constructs a client for the filesystem.Filesystem service. By default, it
|
||||
// uses the Connect protocol with the binary Protobuf Codec, asks for gzipped responses, and sends
|
||||
// uncompressed requests. To use the gRPC or gRPC-Web protocols, supply the connect.WithGRPC() or
|
||||
// connect.WithGRPCWeb() options.
|
||||
//
|
||||
// The URL supplied here should be the base URL for the Connect or gRPC server (for example,
|
||||
// http://api.acme.com or https://acme.com/grpc).
|
||||
func NewFilesystemClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) FilesystemClient {
|
||||
baseURL = strings.TrimRight(baseURL, "/")
|
||||
filesystemMethods := filesystem.File_filesystem_filesystem_proto.Services().ByName("Filesystem").Methods()
|
||||
return &filesystemClient{
|
||||
stat: connect.NewClient[filesystem.StatRequest, filesystem.StatResponse](
|
||||
httpClient,
|
||||
baseURL+FilesystemStatProcedure,
|
||||
connect.WithSchema(filesystemMethods.ByName("Stat")),
|
||||
connect.WithClientOptions(opts...),
|
||||
),
|
||||
makeDir: connect.NewClient[filesystem.MakeDirRequest, filesystem.MakeDirResponse](
|
||||
httpClient,
|
||||
baseURL+FilesystemMakeDirProcedure,
|
||||
connect.WithSchema(filesystemMethods.ByName("MakeDir")),
|
||||
connect.WithClientOptions(opts...),
|
||||
),
|
||||
move: connect.NewClient[filesystem.MoveRequest, filesystem.MoveResponse](
|
||||
httpClient,
|
||||
baseURL+FilesystemMoveProcedure,
|
||||
connect.WithSchema(filesystemMethods.ByName("Move")),
|
||||
connect.WithClientOptions(opts...),
|
||||
),
|
||||
listDir: connect.NewClient[filesystem.ListDirRequest, filesystem.ListDirResponse](
|
||||
httpClient,
|
||||
baseURL+FilesystemListDirProcedure,
|
||||
connect.WithSchema(filesystemMethods.ByName("ListDir")),
|
||||
connect.WithClientOptions(opts...),
|
||||
),
|
||||
remove: connect.NewClient[filesystem.RemoveRequest, filesystem.RemoveResponse](
|
||||
httpClient,
|
||||
baseURL+FilesystemRemoveProcedure,
|
||||
connect.WithSchema(filesystemMethods.ByName("Remove")),
|
||||
connect.WithClientOptions(opts...),
|
||||
),
|
||||
watchDir: connect.NewClient[filesystem.WatchDirRequest, filesystem.WatchDirResponse](
|
||||
httpClient,
|
||||
baseURL+FilesystemWatchDirProcedure,
|
||||
connect.WithSchema(filesystemMethods.ByName("WatchDir")),
|
||||
connect.WithClientOptions(opts...),
|
||||
),
|
||||
createWatcher: connect.NewClient[filesystem.CreateWatcherRequest, filesystem.CreateWatcherResponse](
|
||||
httpClient,
|
||||
baseURL+FilesystemCreateWatcherProcedure,
|
||||
connect.WithSchema(filesystemMethods.ByName("CreateWatcher")),
|
||||
connect.WithClientOptions(opts...),
|
||||
),
|
||||
getWatcherEvents: connect.NewClient[filesystem.GetWatcherEventsRequest, filesystem.GetWatcherEventsResponse](
|
||||
httpClient,
|
||||
baseURL+FilesystemGetWatcherEventsProcedure,
|
||||
connect.WithSchema(filesystemMethods.ByName("GetWatcherEvents")),
|
||||
connect.WithClientOptions(opts...),
|
||||
),
|
||||
removeWatcher: connect.NewClient[filesystem.RemoveWatcherRequest, filesystem.RemoveWatcherResponse](
|
||||
httpClient,
|
||||
baseURL+FilesystemRemoveWatcherProcedure,
|
||||
connect.WithSchema(filesystemMethods.ByName("RemoveWatcher")),
|
||||
connect.WithClientOptions(opts...),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// filesystemClient implements FilesystemClient.
|
||||
type filesystemClient struct {
|
||||
stat *connect.Client[filesystem.StatRequest, filesystem.StatResponse]
|
||||
makeDir *connect.Client[filesystem.MakeDirRequest, filesystem.MakeDirResponse]
|
||||
move *connect.Client[filesystem.MoveRequest, filesystem.MoveResponse]
|
||||
listDir *connect.Client[filesystem.ListDirRequest, filesystem.ListDirResponse]
|
||||
remove *connect.Client[filesystem.RemoveRequest, filesystem.RemoveResponse]
|
||||
watchDir *connect.Client[filesystem.WatchDirRequest, filesystem.WatchDirResponse]
|
||||
createWatcher *connect.Client[filesystem.CreateWatcherRequest, filesystem.CreateWatcherResponse]
|
||||
getWatcherEvents *connect.Client[filesystem.GetWatcherEventsRequest, filesystem.GetWatcherEventsResponse]
|
||||
removeWatcher *connect.Client[filesystem.RemoveWatcherRequest, filesystem.RemoveWatcherResponse]
|
||||
}
|
||||
|
||||
// Stat calls filesystem.Filesystem.Stat.
|
||||
func (c *filesystemClient) Stat(ctx context.Context, req *connect.Request[filesystem.StatRequest]) (*connect.Response[filesystem.StatResponse], error) {
|
||||
return c.stat.CallUnary(ctx, req)
|
||||
}
|
||||
|
||||
// MakeDir calls filesystem.Filesystem.MakeDir.
|
||||
func (c *filesystemClient) MakeDir(ctx context.Context, req *connect.Request[filesystem.MakeDirRequest]) (*connect.Response[filesystem.MakeDirResponse], error) {
|
||||
return c.makeDir.CallUnary(ctx, req)
|
||||
}
|
||||
|
||||
// Move calls filesystem.Filesystem.Move.
|
||||
func (c *filesystemClient) Move(ctx context.Context, req *connect.Request[filesystem.MoveRequest]) (*connect.Response[filesystem.MoveResponse], error) {
|
||||
return c.move.CallUnary(ctx, req)
|
||||
}
|
||||
|
||||
// ListDir calls filesystem.Filesystem.ListDir.
|
||||
func (c *filesystemClient) ListDir(ctx context.Context, req *connect.Request[filesystem.ListDirRequest]) (*connect.Response[filesystem.ListDirResponse], error) {
|
||||
return c.listDir.CallUnary(ctx, req)
|
||||
}
|
||||
|
||||
// Remove calls filesystem.Filesystem.Remove.
|
||||
func (c *filesystemClient) Remove(ctx context.Context, req *connect.Request[filesystem.RemoveRequest]) (*connect.Response[filesystem.RemoveResponse], error) {
|
||||
return c.remove.CallUnary(ctx, req)
|
||||
}
|
||||
|
||||
// WatchDir calls filesystem.Filesystem.WatchDir.
|
||||
func (c *filesystemClient) WatchDir(ctx context.Context, req *connect.Request[filesystem.WatchDirRequest]) (*connect.ServerStreamForClient[filesystem.WatchDirResponse], error) {
|
||||
return c.watchDir.CallServerStream(ctx, req)
|
||||
}
|
||||
|
||||
// CreateWatcher calls filesystem.Filesystem.CreateWatcher.
|
||||
func (c *filesystemClient) CreateWatcher(ctx context.Context, req *connect.Request[filesystem.CreateWatcherRequest]) (*connect.Response[filesystem.CreateWatcherResponse], error) {
|
||||
return c.createWatcher.CallUnary(ctx, req)
|
||||
}
|
||||
|
||||
// GetWatcherEvents calls filesystem.Filesystem.GetWatcherEvents.
|
||||
func (c *filesystemClient) GetWatcherEvents(ctx context.Context, req *connect.Request[filesystem.GetWatcherEventsRequest]) (*connect.Response[filesystem.GetWatcherEventsResponse], error) {
|
||||
return c.getWatcherEvents.CallUnary(ctx, req)
|
||||
}
|
||||
|
||||
// RemoveWatcher calls filesystem.Filesystem.RemoveWatcher.
|
||||
func (c *filesystemClient) RemoveWatcher(ctx context.Context, req *connect.Request[filesystem.RemoveWatcherRequest]) (*connect.Response[filesystem.RemoveWatcherResponse], error) {
|
||||
return c.removeWatcher.CallUnary(ctx, req)
|
||||
}
|
||||
|
||||
// FilesystemHandler is an implementation of the filesystem.Filesystem service.
|
||||
type FilesystemHandler interface {
|
||||
Stat(context.Context, *connect.Request[filesystem.StatRequest]) (*connect.Response[filesystem.StatResponse], error)
|
||||
MakeDir(context.Context, *connect.Request[filesystem.MakeDirRequest]) (*connect.Response[filesystem.MakeDirResponse], error)
|
||||
Move(context.Context, *connect.Request[filesystem.MoveRequest]) (*connect.Response[filesystem.MoveResponse], error)
|
||||
ListDir(context.Context, *connect.Request[filesystem.ListDirRequest]) (*connect.Response[filesystem.ListDirResponse], error)
|
||||
Remove(context.Context, *connect.Request[filesystem.RemoveRequest]) (*connect.Response[filesystem.RemoveResponse], error)
|
||||
WatchDir(context.Context, *connect.Request[filesystem.WatchDirRequest], *connect.ServerStream[filesystem.WatchDirResponse]) error
|
||||
// Non-streaming versions of WatchDir
|
||||
CreateWatcher(context.Context, *connect.Request[filesystem.CreateWatcherRequest]) (*connect.Response[filesystem.CreateWatcherResponse], error)
|
||||
GetWatcherEvents(context.Context, *connect.Request[filesystem.GetWatcherEventsRequest]) (*connect.Response[filesystem.GetWatcherEventsResponse], error)
|
||||
RemoveWatcher(context.Context, *connect.Request[filesystem.RemoveWatcherRequest]) (*connect.Response[filesystem.RemoveWatcherResponse], error)
|
||||
}
|
||||
|
||||
// NewFilesystemHandler builds an HTTP handler from the service implementation. It returns the path
|
||||
// on which to mount the handler and the handler itself.
|
||||
//
|
||||
// By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf
|
||||
// and JSON codecs. They also support gzip compression.
|
||||
func NewFilesystemHandler(svc FilesystemHandler, opts ...connect.HandlerOption) (string, http.Handler) {
|
||||
filesystemMethods := filesystem.File_filesystem_filesystem_proto.Services().ByName("Filesystem").Methods()
|
||||
filesystemStatHandler := connect.NewUnaryHandler(
|
||||
FilesystemStatProcedure,
|
||||
svc.Stat,
|
||||
connect.WithSchema(filesystemMethods.ByName("Stat")),
|
||||
connect.WithHandlerOptions(opts...),
|
||||
)
|
||||
filesystemMakeDirHandler := connect.NewUnaryHandler(
|
||||
FilesystemMakeDirProcedure,
|
||||
svc.MakeDir,
|
||||
connect.WithSchema(filesystemMethods.ByName("MakeDir")),
|
||||
connect.WithHandlerOptions(opts...),
|
||||
)
|
||||
filesystemMoveHandler := connect.NewUnaryHandler(
|
||||
FilesystemMoveProcedure,
|
||||
svc.Move,
|
||||
connect.WithSchema(filesystemMethods.ByName("Move")),
|
||||
connect.WithHandlerOptions(opts...),
|
||||
)
|
||||
filesystemListDirHandler := connect.NewUnaryHandler(
|
||||
FilesystemListDirProcedure,
|
||||
svc.ListDir,
|
||||
connect.WithSchema(filesystemMethods.ByName("ListDir")),
|
||||
connect.WithHandlerOptions(opts...),
|
||||
)
|
||||
filesystemRemoveHandler := connect.NewUnaryHandler(
|
||||
FilesystemRemoveProcedure,
|
||||
svc.Remove,
|
||||
connect.WithSchema(filesystemMethods.ByName("Remove")),
|
||||
connect.WithHandlerOptions(opts...),
|
||||
)
|
||||
filesystemWatchDirHandler := connect.NewServerStreamHandler(
|
||||
FilesystemWatchDirProcedure,
|
||||
svc.WatchDir,
|
||||
connect.WithSchema(filesystemMethods.ByName("WatchDir")),
|
||||
connect.WithHandlerOptions(opts...),
|
||||
)
|
||||
filesystemCreateWatcherHandler := connect.NewUnaryHandler(
|
||||
FilesystemCreateWatcherProcedure,
|
||||
svc.CreateWatcher,
|
||||
connect.WithSchema(filesystemMethods.ByName("CreateWatcher")),
|
||||
connect.WithHandlerOptions(opts...),
|
||||
)
|
||||
filesystemGetWatcherEventsHandler := connect.NewUnaryHandler(
|
||||
FilesystemGetWatcherEventsProcedure,
|
||||
svc.GetWatcherEvents,
|
||||
connect.WithSchema(filesystemMethods.ByName("GetWatcherEvents")),
|
||||
connect.WithHandlerOptions(opts...),
|
||||
)
|
||||
filesystemRemoveWatcherHandler := connect.NewUnaryHandler(
|
||||
FilesystemRemoveWatcherProcedure,
|
||||
svc.RemoveWatcher,
|
||||
connect.WithSchema(filesystemMethods.ByName("RemoveWatcher")),
|
||||
connect.WithHandlerOptions(opts...),
|
||||
)
|
||||
return "/filesystem.Filesystem/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case FilesystemStatProcedure:
|
||||
filesystemStatHandler.ServeHTTP(w, r)
|
||||
case FilesystemMakeDirProcedure:
|
||||
filesystemMakeDirHandler.ServeHTTP(w, r)
|
||||
case FilesystemMoveProcedure:
|
||||
filesystemMoveHandler.ServeHTTP(w, r)
|
||||
case FilesystemListDirProcedure:
|
||||
filesystemListDirHandler.ServeHTTP(w, r)
|
||||
case FilesystemRemoveProcedure:
|
||||
filesystemRemoveHandler.ServeHTTP(w, r)
|
||||
case FilesystemWatchDirProcedure:
|
||||
filesystemWatchDirHandler.ServeHTTP(w, r)
|
||||
case FilesystemCreateWatcherProcedure:
|
||||
filesystemCreateWatcherHandler.ServeHTTP(w, r)
|
||||
case FilesystemGetWatcherEventsProcedure:
|
||||
filesystemGetWatcherEventsHandler.ServeHTTP(w, r)
|
||||
case FilesystemRemoveWatcherProcedure:
|
||||
filesystemRemoveWatcherHandler.ServeHTTP(w, r)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// UnimplementedFilesystemHandler returns CodeUnimplemented from all methods.
|
||||
type UnimplementedFilesystemHandler struct{}
|
||||
|
||||
func (UnimplementedFilesystemHandler) Stat(context.Context, *connect.Request[filesystem.StatRequest]) (*connect.Response[filesystem.StatResponse], error) {
|
||||
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.Stat is not implemented"))
|
||||
}
|
||||
|
||||
func (UnimplementedFilesystemHandler) MakeDir(context.Context, *connect.Request[filesystem.MakeDirRequest]) (*connect.Response[filesystem.MakeDirResponse], error) {
|
||||
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.MakeDir is not implemented"))
|
||||
}
|
||||
|
||||
func (UnimplementedFilesystemHandler) Move(context.Context, *connect.Request[filesystem.MoveRequest]) (*connect.Response[filesystem.MoveResponse], error) {
|
||||
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.Move is not implemented"))
|
||||
}
|
||||
|
||||
func (UnimplementedFilesystemHandler) ListDir(context.Context, *connect.Request[filesystem.ListDirRequest]) (*connect.Response[filesystem.ListDirResponse], error) {
|
||||
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.ListDir is not implemented"))
|
||||
}
|
||||
|
||||
func (UnimplementedFilesystemHandler) Remove(context.Context, *connect.Request[filesystem.RemoveRequest]) (*connect.Response[filesystem.RemoveResponse], error) {
|
||||
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.Remove is not implemented"))
|
||||
}
|
||||
|
||||
func (UnimplementedFilesystemHandler) WatchDir(context.Context, *connect.Request[filesystem.WatchDirRequest], *connect.ServerStream[filesystem.WatchDirResponse]) error {
|
||||
return connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.WatchDir is not implemented"))
|
||||
}
|
||||
|
||||
func (UnimplementedFilesystemHandler) CreateWatcher(context.Context, *connect.Request[filesystem.CreateWatcherRequest]) (*connect.Response[filesystem.CreateWatcherResponse], error) {
|
||||
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.CreateWatcher is not implemented"))
|
||||
}
|
||||
|
||||
func (UnimplementedFilesystemHandler) GetWatcherEvents(context.Context, *connect.Request[filesystem.GetWatcherEventsRequest]) (*connect.Response[filesystem.GetWatcherEventsResponse], error) {
|
||||
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.GetWatcherEvents is not implemented"))
|
||||
}
|
||||
|
||||
func (UnimplementedFilesystemHandler) RemoveWatcher(context.Context, *connect.Request[filesystem.RemoveWatcherRequest]) (*connect.Response[filesystem.RemoveWatcherResponse], error) {
|
||||
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.RemoveWatcher is not implemented"))
|
||||
}
|
||||
1970
envd/internal/services/spec/process/process.pb.go
Normal file
1970
envd/internal/services/spec/process/process.pb.go
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,310 @@
|
||||
// Code generated by protoc-gen-connect-go. DO NOT EDIT.
|
||||
//
|
||||
// Source: process/process.proto
|
||||
|
||||
package processconnect
|
||||
|
||||
import (
|
||||
connect "connectrpc.com/connect"
|
||||
context "context"
|
||||
errors "errors"
|
||||
process "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||
http "net/http"
|
||||
strings "strings"
|
||||
)
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file and the connect package are
|
||||
// compatible. If you get a compiler error that this constant is not defined, this code was
|
||||
// generated with a version of connect newer than the one compiled into your binary. You can fix the
|
||||
// problem by either regenerating this code with an older version of connect or updating the connect
|
||||
// version compiled into your binary.
|
||||
const _ = connect.IsAtLeastVersion1_13_0
|
||||
|
||||
const (
|
||||
// ProcessName is the fully-qualified name of the Process service.
|
||||
ProcessName = "process.Process"
|
||||
)
|
||||
|
||||
// These constants are the fully-qualified names of the RPCs defined in this package. They're
|
||||
// exposed at runtime as Spec.Procedure and as the final two segments of the HTTP route.
|
||||
//
|
||||
// Note that these are different from the fully-qualified method names used by
|
||||
// google.golang.org/protobuf/reflect/protoreflect. To convert from these constants to
|
||||
// reflection-formatted method names, remove the leading slash and convert the remaining slash to a
|
||||
// period.
|
||||
const (
|
||||
// ProcessListProcedure is the fully-qualified name of the Process's List RPC.
|
||||
ProcessListProcedure = "/process.Process/List"
|
||||
// ProcessConnectProcedure is the fully-qualified name of the Process's Connect RPC.
|
||||
ProcessConnectProcedure = "/process.Process/Connect"
|
||||
// ProcessStartProcedure is the fully-qualified name of the Process's Start RPC.
|
||||
ProcessStartProcedure = "/process.Process/Start"
|
||||
// ProcessUpdateProcedure is the fully-qualified name of the Process's Update RPC.
|
||||
ProcessUpdateProcedure = "/process.Process/Update"
|
||||
// ProcessStreamInputProcedure is the fully-qualified name of the Process's StreamInput RPC.
|
||||
ProcessStreamInputProcedure = "/process.Process/StreamInput"
|
||||
// ProcessSendInputProcedure is the fully-qualified name of the Process's SendInput RPC.
|
||||
ProcessSendInputProcedure = "/process.Process/SendInput"
|
||||
// ProcessSendSignalProcedure is the fully-qualified name of the Process's SendSignal RPC.
|
||||
ProcessSendSignalProcedure = "/process.Process/SendSignal"
|
||||
// ProcessCloseStdinProcedure is the fully-qualified name of the Process's CloseStdin RPC.
|
||||
ProcessCloseStdinProcedure = "/process.Process/CloseStdin"
|
||||
)
|
||||
|
||||
// ProcessClient is a client for the process.Process service.
|
||||
type ProcessClient interface {
|
||||
List(context.Context, *connect.Request[process.ListRequest]) (*connect.Response[process.ListResponse], error)
|
||||
Connect(context.Context, *connect.Request[process.ConnectRequest]) (*connect.ServerStreamForClient[process.ConnectResponse], error)
|
||||
Start(context.Context, *connect.Request[process.StartRequest]) (*connect.ServerStreamForClient[process.StartResponse], error)
|
||||
Update(context.Context, *connect.Request[process.UpdateRequest]) (*connect.Response[process.UpdateResponse], error)
|
||||
// Client input stream ensures ordering of messages
|
||||
StreamInput(context.Context) *connect.ClientStreamForClient[process.StreamInputRequest, process.StreamInputResponse]
|
||||
SendInput(context.Context, *connect.Request[process.SendInputRequest]) (*connect.Response[process.SendInputResponse], error)
|
||||
SendSignal(context.Context, *connect.Request[process.SendSignalRequest]) (*connect.Response[process.SendSignalResponse], error)
|
||||
// Close stdin to signal EOF to the process.
|
||||
// Only works for non-PTY processes. For PTY, send Ctrl+D (0x04) instead.
|
||||
CloseStdin(context.Context, *connect.Request[process.CloseStdinRequest]) (*connect.Response[process.CloseStdinResponse], error)
|
||||
}
|
||||
|
||||
// NewProcessClient constructs a client for the process.Process service. By default, it uses the
|
||||
// Connect protocol with the binary Protobuf Codec, asks for gzipped responses, and sends
|
||||
// uncompressed requests. To use the gRPC or gRPC-Web protocols, supply the connect.WithGRPC() or
|
||||
// connect.WithGRPCWeb() options.
|
||||
//
|
||||
// The URL supplied here should be the base URL for the Connect or gRPC server (for example,
|
||||
// http://api.acme.com or https://acme.com/grpc).
|
||||
func NewProcessClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) ProcessClient {
|
||||
baseURL = strings.TrimRight(baseURL, "/")
|
||||
processMethods := process.File_process_process_proto.Services().ByName("Process").Methods()
|
||||
return &processClient{
|
||||
list: connect.NewClient[process.ListRequest, process.ListResponse](
|
||||
httpClient,
|
||||
baseURL+ProcessListProcedure,
|
||||
connect.WithSchema(processMethods.ByName("List")),
|
||||
connect.WithClientOptions(opts...),
|
||||
),
|
||||
connect: connect.NewClient[process.ConnectRequest, process.ConnectResponse](
|
||||
httpClient,
|
||||
baseURL+ProcessConnectProcedure,
|
||||
connect.WithSchema(processMethods.ByName("Connect")),
|
||||
connect.WithClientOptions(opts...),
|
||||
),
|
||||
start: connect.NewClient[process.StartRequest, process.StartResponse](
|
||||
httpClient,
|
||||
baseURL+ProcessStartProcedure,
|
||||
connect.WithSchema(processMethods.ByName("Start")),
|
||||
connect.WithClientOptions(opts...),
|
||||
),
|
||||
update: connect.NewClient[process.UpdateRequest, process.UpdateResponse](
|
||||
httpClient,
|
||||
baseURL+ProcessUpdateProcedure,
|
||||
connect.WithSchema(processMethods.ByName("Update")),
|
||||
connect.WithClientOptions(opts...),
|
||||
),
|
||||
streamInput: connect.NewClient[process.StreamInputRequest, process.StreamInputResponse](
|
||||
httpClient,
|
||||
baseURL+ProcessStreamInputProcedure,
|
||||
connect.WithSchema(processMethods.ByName("StreamInput")),
|
||||
connect.WithClientOptions(opts...),
|
||||
),
|
||||
sendInput: connect.NewClient[process.SendInputRequest, process.SendInputResponse](
|
||||
httpClient,
|
||||
baseURL+ProcessSendInputProcedure,
|
||||
connect.WithSchema(processMethods.ByName("SendInput")),
|
||||
connect.WithClientOptions(opts...),
|
||||
),
|
||||
sendSignal: connect.NewClient[process.SendSignalRequest, process.SendSignalResponse](
|
||||
httpClient,
|
||||
baseURL+ProcessSendSignalProcedure,
|
||||
connect.WithSchema(processMethods.ByName("SendSignal")),
|
||||
connect.WithClientOptions(opts...),
|
||||
),
|
||||
closeStdin: connect.NewClient[process.CloseStdinRequest, process.CloseStdinResponse](
|
||||
httpClient,
|
||||
baseURL+ProcessCloseStdinProcedure,
|
||||
connect.WithSchema(processMethods.ByName("CloseStdin")),
|
||||
connect.WithClientOptions(opts...),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// processClient implements ProcessClient.
|
||||
type processClient struct {
|
||||
list *connect.Client[process.ListRequest, process.ListResponse]
|
||||
connect *connect.Client[process.ConnectRequest, process.ConnectResponse]
|
||||
start *connect.Client[process.StartRequest, process.StartResponse]
|
||||
update *connect.Client[process.UpdateRequest, process.UpdateResponse]
|
||||
streamInput *connect.Client[process.StreamInputRequest, process.StreamInputResponse]
|
||||
sendInput *connect.Client[process.SendInputRequest, process.SendInputResponse]
|
||||
sendSignal *connect.Client[process.SendSignalRequest, process.SendSignalResponse]
|
||||
closeStdin *connect.Client[process.CloseStdinRequest, process.CloseStdinResponse]
|
||||
}
|
||||
|
||||
// List calls process.Process.List.
|
||||
func (c *processClient) List(ctx context.Context, req *connect.Request[process.ListRequest]) (*connect.Response[process.ListResponse], error) {
|
||||
return c.list.CallUnary(ctx, req)
|
||||
}
|
||||
|
||||
// Connect calls process.Process.Connect.
|
||||
func (c *processClient) Connect(ctx context.Context, req *connect.Request[process.ConnectRequest]) (*connect.ServerStreamForClient[process.ConnectResponse], error) {
|
||||
return c.connect.CallServerStream(ctx, req)
|
||||
}
|
||||
|
||||
// Start calls process.Process.Start.
|
||||
func (c *processClient) Start(ctx context.Context, req *connect.Request[process.StartRequest]) (*connect.ServerStreamForClient[process.StartResponse], error) {
|
||||
return c.start.CallServerStream(ctx, req)
|
||||
}
|
||||
|
||||
// Update calls process.Process.Update.
|
||||
func (c *processClient) Update(ctx context.Context, req *connect.Request[process.UpdateRequest]) (*connect.Response[process.UpdateResponse], error) {
|
||||
return c.update.CallUnary(ctx, req)
|
||||
}
|
||||
|
||||
// StreamInput calls process.Process.StreamInput.
|
||||
func (c *processClient) StreamInput(ctx context.Context) *connect.ClientStreamForClient[process.StreamInputRequest, process.StreamInputResponse] {
|
||||
return c.streamInput.CallClientStream(ctx)
|
||||
}
|
||||
|
||||
// SendInput calls process.Process.SendInput.
|
||||
func (c *processClient) SendInput(ctx context.Context, req *connect.Request[process.SendInputRequest]) (*connect.Response[process.SendInputResponse], error) {
|
||||
return c.sendInput.CallUnary(ctx, req)
|
||||
}
|
||||
|
||||
// SendSignal calls process.Process.SendSignal.
|
||||
func (c *processClient) SendSignal(ctx context.Context, req *connect.Request[process.SendSignalRequest]) (*connect.Response[process.SendSignalResponse], error) {
|
||||
return c.sendSignal.CallUnary(ctx, req)
|
||||
}
|
||||
|
||||
// CloseStdin calls process.Process.CloseStdin.
|
||||
func (c *processClient) CloseStdin(ctx context.Context, req *connect.Request[process.CloseStdinRequest]) (*connect.Response[process.CloseStdinResponse], error) {
|
||||
return c.closeStdin.CallUnary(ctx, req)
|
||||
}
|
||||
|
||||
// ProcessHandler is an implementation of the process.Process service.
|
||||
type ProcessHandler interface {
|
||||
List(context.Context, *connect.Request[process.ListRequest]) (*connect.Response[process.ListResponse], error)
|
||||
Connect(context.Context, *connect.Request[process.ConnectRequest], *connect.ServerStream[process.ConnectResponse]) error
|
||||
Start(context.Context, *connect.Request[process.StartRequest], *connect.ServerStream[process.StartResponse]) error
|
||||
Update(context.Context, *connect.Request[process.UpdateRequest]) (*connect.Response[process.UpdateResponse], error)
|
||||
// Client input stream ensures ordering of messages
|
||||
StreamInput(context.Context, *connect.ClientStream[process.StreamInputRequest]) (*connect.Response[process.StreamInputResponse], error)
|
||||
SendInput(context.Context, *connect.Request[process.SendInputRequest]) (*connect.Response[process.SendInputResponse], error)
|
||||
SendSignal(context.Context, *connect.Request[process.SendSignalRequest]) (*connect.Response[process.SendSignalResponse], error)
|
||||
// Close stdin to signal EOF to the process.
|
||||
// Only works for non-PTY processes. For PTY, send Ctrl+D (0x04) instead.
|
||||
CloseStdin(context.Context, *connect.Request[process.CloseStdinRequest]) (*connect.Response[process.CloseStdinResponse], error)
|
||||
}
|
||||
|
||||
// NewProcessHandler builds an HTTP handler from the service implementation. It returns the path on
|
||||
// which to mount the handler and the handler itself.
|
||||
//
|
||||
// By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf
|
||||
// and JSON codecs. They also support gzip compression.
|
||||
func NewProcessHandler(svc ProcessHandler, opts ...connect.HandlerOption) (string, http.Handler) {
|
||||
processMethods := process.File_process_process_proto.Services().ByName("Process").Methods()
|
||||
processListHandler := connect.NewUnaryHandler(
|
||||
ProcessListProcedure,
|
||||
svc.List,
|
||||
connect.WithSchema(processMethods.ByName("List")),
|
||||
connect.WithHandlerOptions(opts...),
|
||||
)
|
||||
processConnectHandler := connect.NewServerStreamHandler(
|
||||
ProcessConnectProcedure,
|
||||
svc.Connect,
|
||||
connect.WithSchema(processMethods.ByName("Connect")),
|
||||
connect.WithHandlerOptions(opts...),
|
||||
)
|
||||
processStartHandler := connect.NewServerStreamHandler(
|
||||
ProcessStartProcedure,
|
||||
svc.Start,
|
||||
connect.WithSchema(processMethods.ByName("Start")),
|
||||
connect.WithHandlerOptions(opts...),
|
||||
)
|
||||
processUpdateHandler := connect.NewUnaryHandler(
|
||||
ProcessUpdateProcedure,
|
||||
svc.Update,
|
||||
connect.WithSchema(processMethods.ByName("Update")),
|
||||
connect.WithHandlerOptions(opts...),
|
||||
)
|
||||
processStreamInputHandler := connect.NewClientStreamHandler(
|
||||
ProcessStreamInputProcedure,
|
||||
svc.StreamInput,
|
||||
connect.WithSchema(processMethods.ByName("StreamInput")),
|
||||
connect.WithHandlerOptions(opts...),
|
||||
)
|
||||
processSendInputHandler := connect.NewUnaryHandler(
|
||||
ProcessSendInputProcedure,
|
||||
svc.SendInput,
|
||||
connect.WithSchema(processMethods.ByName("SendInput")),
|
||||
connect.WithHandlerOptions(opts...),
|
||||
)
|
||||
processSendSignalHandler := connect.NewUnaryHandler(
|
||||
ProcessSendSignalProcedure,
|
||||
svc.SendSignal,
|
||||
connect.WithSchema(processMethods.ByName("SendSignal")),
|
||||
connect.WithHandlerOptions(opts...),
|
||||
)
|
||||
processCloseStdinHandler := connect.NewUnaryHandler(
|
||||
ProcessCloseStdinProcedure,
|
||||
svc.CloseStdin,
|
||||
connect.WithSchema(processMethods.ByName("CloseStdin")),
|
||||
connect.WithHandlerOptions(opts...),
|
||||
)
|
||||
return "/process.Process/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case ProcessListProcedure:
|
||||
processListHandler.ServeHTTP(w, r)
|
||||
case ProcessConnectProcedure:
|
||||
processConnectHandler.ServeHTTP(w, r)
|
||||
case ProcessStartProcedure:
|
||||
processStartHandler.ServeHTTP(w, r)
|
||||
case ProcessUpdateProcedure:
|
||||
processUpdateHandler.ServeHTTP(w, r)
|
||||
case ProcessStreamInputProcedure:
|
||||
processStreamInputHandler.ServeHTTP(w, r)
|
||||
case ProcessSendInputProcedure:
|
||||
processSendInputHandler.ServeHTTP(w, r)
|
||||
case ProcessSendSignalProcedure:
|
||||
processSendSignalHandler.ServeHTTP(w, r)
|
||||
case ProcessCloseStdinProcedure:
|
||||
processCloseStdinHandler.ServeHTTP(w, r)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// UnimplementedProcessHandler returns CodeUnimplemented from all methods.
|
||||
type UnimplementedProcessHandler struct{}
|
||||
|
||||
func (UnimplementedProcessHandler) List(context.Context, *connect.Request[process.ListRequest]) (*connect.Response[process.ListResponse], error) {
|
||||
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.List is not implemented"))
|
||||
}
|
||||
|
||||
func (UnimplementedProcessHandler) Connect(context.Context, *connect.Request[process.ConnectRequest], *connect.ServerStream[process.ConnectResponse]) error {
|
||||
return connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.Connect is not implemented"))
|
||||
}
|
||||
|
||||
func (UnimplementedProcessHandler) Start(context.Context, *connect.Request[process.StartRequest], *connect.ServerStream[process.StartResponse]) error {
|
||||
return connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.Start is not implemented"))
|
||||
}
|
||||
|
||||
func (UnimplementedProcessHandler) Update(context.Context, *connect.Request[process.UpdateRequest]) (*connect.Response[process.UpdateResponse], error) {
|
||||
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.Update is not implemented"))
|
||||
}
|
||||
|
||||
func (UnimplementedProcessHandler) StreamInput(context.Context, *connect.ClientStream[process.StreamInputRequest]) (*connect.Response[process.StreamInputResponse], error) {
|
||||
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.StreamInput is not implemented"))
|
||||
}
|
||||
|
||||
func (UnimplementedProcessHandler) SendInput(context.Context, *connect.Request[process.SendInputRequest]) (*connect.Response[process.SendInputResponse], error) {
|
||||
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.SendInput is not implemented"))
|
||||
}
|
||||
|
||||
func (UnimplementedProcessHandler) SendSignal(context.Context, *connect.Request[process.SendSignalRequest]) (*connect.Response[process.SendSignalResponse], error) {
|
||||
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.SendSignal is not implemented"))
|
||||
}
|
||||
|
||||
func (UnimplementedProcessHandler) CloseStdin(context.Context, *connect.Request[process.CloseStdinRequest]) (*connect.Response[process.CloseStdinResponse], error) {
|
||||
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.CloseStdin is not implemented"))
|
||||
}
|
||||
108
envd/internal/shared/filesystem/entry.go
Normal file
108
envd/internal/shared/filesystem/entry.go
Normal file
@ -0,0 +1,108 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
func GetEntryFromPath(path string) (EntryInfo, error) {
|
||||
fileInfo, err := os.Lstat(path)
|
||||
if err != nil {
|
||||
return EntryInfo{}, err
|
||||
}
|
||||
|
||||
return GetEntryInfo(path, fileInfo), nil
|
||||
}
|
||||
|
||||
func GetEntryInfo(path string, fileInfo os.FileInfo) EntryInfo {
|
||||
fileMode := fileInfo.Mode()
|
||||
|
||||
var symlinkTarget *string
|
||||
if fileMode&os.ModeSymlink != 0 {
|
||||
// If we can't resolve the symlink target, we won't set the target
|
||||
target := followSymlink(path)
|
||||
symlinkTarget = &target
|
||||
}
|
||||
|
||||
var entryType FileType
|
||||
var mode os.FileMode
|
||||
|
||||
if symlinkTarget == nil {
|
||||
entryType = getEntryType(fileMode)
|
||||
mode = fileMode.Perm()
|
||||
} else {
|
||||
// If it's a symlink, we need to determine the type of the target
|
||||
targetInfo, err := os.Stat(*symlinkTarget)
|
||||
if err != nil {
|
||||
entryType = UnknownFileType
|
||||
} else {
|
||||
entryType = getEntryType(targetInfo.Mode())
|
||||
mode = targetInfo.Mode().Perm()
|
||||
}
|
||||
}
|
||||
|
||||
entry := EntryInfo{
|
||||
Name: fileInfo.Name(),
|
||||
Path: path,
|
||||
Type: entryType,
|
||||
Size: fileInfo.Size(),
|
||||
Mode: mode,
|
||||
Permissions: fileMode.String(),
|
||||
ModifiedTime: fileInfo.ModTime(),
|
||||
SymlinkTarget: symlinkTarget,
|
||||
}
|
||||
|
||||
if base := getBase(fileInfo.Sys()); base != nil {
|
||||
entry.AccessedTime = toTimestamp(base.Atim)
|
||||
entry.CreatedTime = toTimestamp(base.Ctim)
|
||||
entry.ModifiedTime = toTimestamp(base.Mtim)
|
||||
entry.UID = base.Uid
|
||||
entry.GID = base.Gid
|
||||
} else if !fileInfo.ModTime().IsZero() {
|
||||
entry.ModifiedTime = fileInfo.ModTime()
|
||||
}
|
||||
|
||||
return entry
|
||||
}
|
||||
|
||||
// getEntryType determines the type of file entry based on its mode and path.
|
||||
// If the file is a symlink, it follows the symlink to determine the actual type.
|
||||
func getEntryType(mode os.FileMode) FileType {
|
||||
switch {
|
||||
case mode.IsRegular():
|
||||
return FileFileType
|
||||
case mode.IsDir():
|
||||
return DirectoryFileType
|
||||
case mode&os.ModeSymlink == os.ModeSymlink:
|
||||
return SymlinkFileType
|
||||
default:
|
||||
return UnknownFileType
|
||||
}
|
||||
}
|
||||
|
||||
// followSymlink resolves a symbolic link to its target path.
|
||||
func followSymlink(path string) string {
|
||||
// Resolve symlinks
|
||||
resolvedPath, err := filepath.EvalSymlinks(path)
|
||||
if err != nil {
|
||||
return path
|
||||
}
|
||||
|
||||
return resolvedPath
|
||||
}
|
||||
|
||||
func toTimestamp(spec syscall.Timespec) time.Time {
|
||||
if spec.Sec == 0 && spec.Nsec == 0 {
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
return time.Unix(spec.Sec, spec.Nsec)
|
||||
}
|
||||
|
||||
func getBase(sys any) *syscall.Stat_t {
|
||||
st, _ := sys.(*syscall.Stat_t)
|
||||
|
||||
return st
|
||||
}
|
||||
264
envd/internal/shared/filesystem/entry_test.go
Normal file
264
envd/internal/shared/filesystem/entry_test.go
Normal file
@ -0,0 +1,264 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetEntryType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create test files
|
||||
regularFile := filepath.Join(tempDir, "regular.txt")
|
||||
require.NoError(t, os.WriteFile(regularFile, []byte("test content"), 0o644))
|
||||
|
||||
testDir := filepath.Join(tempDir, "testdir")
|
||||
require.NoError(t, os.MkdirAll(testDir, 0o755))
|
||||
|
||||
symlink := filepath.Join(tempDir, "symlink")
|
||||
require.NoError(t, os.Symlink(regularFile, symlink))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expected FileType
|
||||
}{
|
||||
{
|
||||
name: "regular file",
|
||||
path: regularFile,
|
||||
expected: FileFileType,
|
||||
},
|
||||
{
|
||||
name: "directory",
|
||||
path: testDir,
|
||||
expected: DirectoryFileType,
|
||||
},
|
||||
{
|
||||
name: "symlink to file",
|
||||
path: symlink,
|
||||
expected: SymlinkFileType,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
info, err := os.Lstat(tt.path)
|
||||
require.NoError(t, err)
|
||||
|
||||
result := getEntryType(info.Mode())
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEntryInfoFromFileInfo_SymlinkChain(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Base temporary directory. On macOS this lives under /var/folders/…
|
||||
// which itself is a symlink to /private/var/folders/….
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create final target
|
||||
target := filepath.Join(tempDir, "target")
|
||||
require.NoError(t, os.MkdirAll(target, 0o755))
|
||||
|
||||
// Create a chain: link1 → link2 → target
|
||||
link2 := filepath.Join(tempDir, "link2")
|
||||
require.NoError(t, os.Symlink(target, link2))
|
||||
|
||||
link1 := filepath.Join(tempDir, "link1")
|
||||
require.NoError(t, os.Symlink(link2, link1))
|
||||
|
||||
// run the test
|
||||
result, err := GetEntryFromPath(link1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// verify the results
|
||||
assert.Equal(t, "link1", result.Name)
|
||||
assert.Equal(t, link1, result.Path)
|
||||
assert.Equal(t, DirectoryFileType, result.Type) // Should resolve to final target type
|
||||
assert.Contains(t, result.Permissions, "L")
|
||||
|
||||
// Canonicalize the expected target path to handle macOS symlink indirections
|
||||
expectedTarget, err := filepath.EvalSymlinks(link1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedTarget, *result.SymlinkTarget)
|
||||
}
|
||||
|
||||
func TestEntryInfoFromFileInfo_DifferentPermissions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
permissions os.FileMode
|
||||
expectedMode os.FileMode
|
||||
expectedString string
|
||||
}{
|
||||
{"read-only", 0o444, 0o444, "-r--r--r--"},
|
||||
{"executable", 0o755, 0o755, "-rwxr-xr-x"},
|
||||
{"write-only", 0o200, 0o200, "--w-------"},
|
||||
{"no permissions", 0o000, 0o000, "----------"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testFile := filepath.Join(tempDir, tc.name+".txt")
|
||||
require.NoError(t, os.WriteFile(testFile, []byte("test"), tc.permissions))
|
||||
|
||||
result, err := GetEntryFromPath(testFile)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.expectedMode, result.Mode)
|
||||
assert.Equal(t, tc.expectedString, result.Permissions)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEntryInfoFromFileInfo_EmptyFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
emptyFile := filepath.Join(tempDir, "empty.txt")
|
||||
require.NoError(t, os.WriteFile(emptyFile, []byte{}, 0o600))
|
||||
|
||||
result, err := GetEntryFromPath(emptyFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "empty.txt", result.Name)
|
||||
assert.Equal(t, int64(0), result.Size)
|
||||
assert.Equal(t, os.FileMode(0o600), result.Mode)
|
||||
assert.Equal(t, FileFileType, result.Type)
|
||||
}
|
||||
|
||||
func TestEntryInfoFromFileInfo_CyclicSymlink(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create cyclic symlink
|
||||
cyclicSymlink := filepath.Join(tempDir, "cyclic")
|
||||
require.NoError(t, os.Symlink(cyclicSymlink, cyclicSymlink))
|
||||
|
||||
result, err := GetEntryFromPath(cyclicSymlink)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "cyclic", result.Name)
|
||||
assert.Equal(t, cyclicSymlink, result.Path)
|
||||
assert.Equal(t, UnknownFileType, result.Type)
|
||||
assert.Contains(t, result.Permissions, "L")
|
||||
}
|
||||
|
||||
func TestEntryInfoFromFileInfo_BrokenSymlink(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create broken symlink
|
||||
brokenSymlink := filepath.Join(tempDir, "broken")
|
||||
require.NoError(t, os.Symlink("/nonexistent", brokenSymlink))
|
||||
|
||||
result, err := GetEntryFromPath(brokenSymlink)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "broken", result.Name)
|
||||
assert.Equal(t, brokenSymlink, result.Path)
|
||||
assert.Equal(t, UnknownFileType, result.Type)
|
||||
assert.Contains(t, result.Permissions, "L")
|
||||
// SymlinkTarget might be empty if followSymlink fails
|
||||
}
|
||||
|
||||
func TestEntryInfoFromFileInfo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create a regular file with known content and permissions
|
||||
testFile := filepath.Join(tempDir, "test.txt")
|
||||
testContent := []byte("Hello, World!")
|
||||
require.NoError(t, os.WriteFile(testFile, testContent, 0o644))
|
||||
|
||||
// Get current user for ownership comparison
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := GetEntryFromPath(testFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Basic assertions
|
||||
assert.Equal(t, "test.txt", result.Name)
|
||||
assert.Equal(t, testFile, result.Path)
|
||||
assert.Equal(t, int64(len(testContent)), result.Size)
|
||||
assert.Equal(t, FileFileType, result.Type)
|
||||
assert.Equal(t, os.FileMode(0o644), result.Mode)
|
||||
assert.Contains(t, result.Permissions, "-rw-r--r--")
|
||||
assert.Equal(t, currentUser.Uid, strconv.Itoa(int(result.UID)))
|
||||
assert.Equal(t, currentUser.Gid, strconv.Itoa(int(result.GID)))
|
||||
assert.NotNil(t, result.ModifiedTime)
|
||||
assert.Empty(t, result.SymlinkTarget)
|
||||
|
||||
// Check that modified time is reasonable (within last minute)
|
||||
modTime := result.ModifiedTime
|
||||
assert.WithinDuration(t, time.Now(), modTime, time.Minute)
|
||||
}
|
||||
|
||||
func TestEntryInfoFromFileInfo_Directory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
testDir := filepath.Join(tempDir, "testdir")
|
||||
require.NoError(t, os.MkdirAll(testDir, 0o755))
|
||||
|
||||
result, err := GetEntryFromPath(testDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "testdir", result.Name)
|
||||
assert.Equal(t, testDir, result.Path)
|
||||
assert.Equal(t, DirectoryFileType, result.Type)
|
||||
assert.Equal(t, os.FileMode(0o755), result.Mode)
|
||||
assert.Equal(t, "drwxr-xr-x", result.Permissions)
|
||||
assert.Empty(t, result.SymlinkTarget)
|
||||
}
|
||||
|
||||
func TestEntryInfoFromFileInfo_Symlink(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Base temporary directory. On macOS this lives under /var/folders/…
|
||||
// which itself is a symlink to /private/var/folders/….
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create target file
|
||||
targetFile := filepath.Join(tempDir, "target.txt")
|
||||
require.NoError(t, os.WriteFile(targetFile, []byte("target content"), 0o644))
|
||||
|
||||
// Create symlink
|
||||
symlinkPath := filepath.Join(tempDir, "symlink")
|
||||
require.NoError(t, os.Symlink(targetFile, symlinkPath))
|
||||
|
||||
// Use Lstat to get symlink info (not the target)
|
||||
result, err := GetEntryFromPath(symlinkPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "symlink", result.Name)
|
||||
assert.Equal(t, symlinkPath, result.Path)
|
||||
assert.Equal(t, FileFileType, result.Type) // Should resolve to target type
|
||||
assert.Contains(t, result.Permissions, "L") // Should show as symlink in permissions
|
||||
|
||||
// Canonicalize the expected target path to handle macOS /var → /private/var symlink
|
||||
expectedTarget, err := filepath.EvalSymlinks(symlinkPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedTarget, *result.SymlinkTarget)
|
||||
}
|
||||
30
envd/internal/shared/filesystem/model.go
Normal file
30
envd/internal/shared/filesystem/model.go
Normal file
@ -0,0 +1,30 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
type EntryInfo struct {
|
||||
Name string
|
||||
Type FileType
|
||||
Path string
|
||||
Size int64
|
||||
Mode os.FileMode
|
||||
Permissions string
|
||||
UID uint32
|
||||
GID uint32
|
||||
AccessedTime time.Time
|
||||
CreatedTime time.Time
|
||||
ModifiedTime time.Time
|
||||
SymlinkTarget *string
|
||||
}
|
||||
|
||||
type FileType int32
|
||||
|
||||
const (
|
||||
UnknownFileType FileType = 0
|
||||
FileFileType FileType = 1
|
||||
DirectoryFileType FileType = 2
|
||||
SymlinkFileType FileType = 3
|
||||
)
|
||||
164
envd/internal/shared/id/id.go
Normal file
164
envd/internal/shared/id/id.go
Normal file
@ -0,0 +1,164 @@
|
||||
package id
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/dchest/uniuri"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var (
|
||||
caseInsensitiveAlphabet = []byte("abcdefghijklmnopqrstuvwxyz1234567890")
|
||||
identifierRegex = regexp.MustCompile(`^[a-z0-9-_]+$`)
|
||||
tagRegex = regexp.MustCompile(`^[a-z0-9-_.]+$`)
|
||||
sandboxIDRegex = regexp.MustCompile(`^[a-z0-9]+$`)
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultTag = "default"
|
||||
TagSeparator = ":"
|
||||
NamespaceSeparator = "/"
|
||||
)
|
||||
|
||||
func Generate() string {
|
||||
return uniuri.NewLenChars(uniuri.UUIDLen, caseInsensitiveAlphabet)
|
||||
}
|
||||
|
||||
// ValidateSandboxID checks that a sandbox ID contains only lowercase alphanumeric characters.
|
||||
func ValidateSandboxID(sandboxID string) error {
|
||||
if !sandboxIDRegex.MatchString(sandboxID) {
|
||||
return fmt.Errorf("invalid sandbox ID: %q", sandboxID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanAndValidate(value, name string, re *regexp.Regexp) (string, error) {
|
||||
cleaned := strings.ToLower(strings.TrimSpace(value))
|
||||
if !re.MatchString(cleaned) {
|
||||
return "", fmt.Errorf("invalid %s: %s", name, value)
|
||||
}
|
||||
|
||||
return cleaned, nil
|
||||
}
|
||||
|
||||
func validateTag(tag string) (string, error) {
|
||||
cleanedTag, err := cleanAndValidate(tag, "tag", tagRegex)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Prevent tags from being a UUID
|
||||
_, err = uuid.Parse(cleanedTag)
|
||||
if err == nil {
|
||||
return "", errors.New("tag cannot be a UUID")
|
||||
}
|
||||
|
||||
return cleanedTag, nil
|
||||
}
|
||||
|
||||
func ValidateAndDeduplicateTags(tags []string) ([]string, error) {
|
||||
seen := make(map[string]struct{})
|
||||
|
||||
for _, tag := range tags {
|
||||
cleanedTag, err := validateTag(tag)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid tag '%s': %w", tag, err)
|
||||
}
|
||||
|
||||
seen[cleanedTag] = struct{}{}
|
||||
}
|
||||
|
||||
return slices.Collect(maps.Keys(seen)), nil
|
||||
}
|
||||
|
||||
// SplitIdentifier splits "namespace/alias" into its parts.
|
||||
// Returns nil namespace for bare aliases, pointer for explicit namespace.
|
||||
func SplitIdentifier(identifier string) (namespace *string, alias string) {
|
||||
before, after, found := strings.Cut(identifier, NamespaceSeparator)
|
||||
if !found {
|
||||
return nil, before
|
||||
}
|
||||
|
||||
return &before, after
|
||||
}
|
||||
|
||||
// ParseName parses and validates "namespace/alias:tag" or "alias:tag".
|
||||
// Returns the cleaned identifier (namespace/alias or alias) and optional tag.
|
||||
// All components are validated and normalized (lowercase, trimmed).
|
||||
func ParseName(input string) (identifier string, tag *string, err error) {
|
||||
input = strings.TrimSpace(input)
|
||||
|
||||
// Extract raw parts
|
||||
identifierPart, tagPart, hasTag := strings.Cut(input, TagSeparator)
|
||||
namespacePart, aliasPart := SplitIdentifier(identifierPart)
|
||||
|
||||
// Validate tag
|
||||
if hasTag {
|
||||
validated, err := cleanAndValidate(tagPart, "tag", tagRegex)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if !strings.EqualFold(validated, DefaultTag) {
|
||||
tag = &validated
|
||||
}
|
||||
}
|
||||
|
||||
// Validate namespace
|
||||
if namespacePart != nil {
|
||||
validated, err := cleanAndValidate(*namespacePart, "namespace", identifierRegex)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
namespacePart = &validated
|
||||
}
|
||||
|
||||
// Validate alias
|
||||
aliasPart, err = cleanAndValidate(aliasPart, "template ID", identifierRegex)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// Build identifier
|
||||
if namespacePart != nil {
|
||||
identifier = WithNamespace(*namespacePart, aliasPart)
|
||||
} else {
|
||||
identifier = aliasPart
|
||||
}
|
||||
|
||||
return identifier, tag, nil
|
||||
}
|
||||
|
||||
// WithTag returns the identifier with the given tag appended (e.g. "templateID:tag").
|
||||
func WithTag(identifier, tag string) string {
|
||||
return identifier + TagSeparator + tag
|
||||
}
|
||||
|
||||
// WithNamespace returns identifier with the given namespace prefix.
|
||||
func WithNamespace(namespace, alias string) string {
|
||||
return namespace + NamespaceSeparator + alias
|
||||
}
|
||||
|
||||
// ExtractAlias returns just the alias portion from an identifier (namespace/alias or alias).
|
||||
func ExtractAlias(identifier string) string {
|
||||
_, alias := SplitIdentifier(identifier)
|
||||
|
||||
return alias
|
||||
}
|
||||
|
||||
// ValidateNamespaceMatchesTeam checks if an explicit namespace in the identifier matches the team's slug.
|
||||
// Returns an error if the namespace doesn't match.
|
||||
// If the identifier has no explicit namespace, returns nil (valid).
|
||||
func ValidateNamespaceMatchesTeam(identifier, teamSlug string) error {
|
||||
namespace, _ := SplitIdentifier(identifier)
|
||||
if namespace != nil && *namespace != teamSlug {
|
||||
return fmt.Errorf("namespace '%s' must match your team '%s'", *namespace, teamSlug)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
380
envd/internal/shared/id/id_test.go
Normal file
380
envd/internal/shared/id/id_test.go
Normal file
@ -0,0 +1,380 @@
|
||||
package id
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/utils"
|
||||
)
|
||||
|
||||
func TestParseName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantIdentifier string
|
||||
wantTag *string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "bare alias only",
|
||||
input: "my-template",
|
||||
wantIdentifier: "my-template",
|
||||
wantTag: nil,
|
||||
},
|
||||
{
|
||||
name: "alias with tag",
|
||||
input: "my-template:v1",
|
||||
wantIdentifier: "my-template",
|
||||
wantTag: utils.ToPtr("v1"),
|
||||
},
|
||||
{
|
||||
name: "namespace and alias",
|
||||
input: "acme/my-template",
|
||||
wantIdentifier: "acme/my-template",
|
||||
wantTag: nil,
|
||||
},
|
||||
{
|
||||
name: "namespace, alias and tag",
|
||||
input: "acme/my-template:v1",
|
||||
wantIdentifier: "acme/my-template",
|
||||
wantTag: utils.ToPtr("v1"),
|
||||
},
|
||||
{
|
||||
name: "namespace with hyphens",
|
||||
input: "my-team/my-template:prod",
|
||||
wantIdentifier: "my-team/my-template",
|
||||
wantTag: utils.ToPtr("prod"),
|
||||
},
|
||||
{
|
||||
name: "default tag normalized to nil",
|
||||
input: "my-template:default",
|
||||
wantIdentifier: "my-template",
|
||||
wantTag: nil,
|
||||
},
|
||||
{
|
||||
name: "uppercase converted to lowercase",
|
||||
input: "MyTemplate:Prod",
|
||||
wantIdentifier: "mytemplate",
|
||||
wantTag: utils.ToPtr("prod"),
|
||||
},
|
||||
{
|
||||
name: "whitespace trimmed",
|
||||
input: " my-template : v1 ",
|
||||
wantIdentifier: "my-template",
|
||||
wantTag: utils.ToPtr("v1"),
|
||||
},
|
||||
{
|
||||
name: "invalid - empty namespace",
|
||||
input: "/my-template",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - empty tag after colon",
|
||||
input: "my-template:",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - special characters in alias",
|
||||
input: "my template!",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - special characters in namespace",
|
||||
input: "my team!/my-template",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gotIdentifier, gotTag, err := ParseName(tt.input)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err, "Expected ParseName() to return error, got")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "Expected ParseName() not to return error, got: %v", err)
|
||||
assert.Equal(t, tt.wantIdentifier, gotIdentifier, "ParseName() identifier = %v, want %v", gotIdentifier, tt.wantIdentifier)
|
||||
assert.Equal(t, tt.wantTag, gotTag, "ParseName() tag = %v, want %v", utils.Sprintp(gotTag), utils.Sprintp(tt.wantTag))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithNamespace(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := WithNamespace("acme", "my-template")
|
||||
want := "acme/my-template"
|
||||
assert.Equal(t, want, got, "WithNamespace() = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
func TestSplitIdentifier(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
identifier string
|
||||
wantNamespace *string
|
||||
wantAlias string
|
||||
}{
|
||||
{
|
||||
name: "bare alias",
|
||||
identifier: "my-template",
|
||||
wantNamespace: nil,
|
||||
wantAlias: "my-template",
|
||||
},
|
||||
{
|
||||
name: "with namespace",
|
||||
identifier: "acme/my-template",
|
||||
wantNamespace: ptrStr("acme"),
|
||||
wantAlias: "my-template",
|
||||
},
|
||||
{
|
||||
name: "empty namespace prefix",
|
||||
identifier: "/my-template",
|
||||
wantNamespace: ptrStr(""),
|
||||
wantAlias: "my-template",
|
||||
},
|
||||
{
|
||||
name: "multiple slashes - only first split",
|
||||
identifier: "a/b/c",
|
||||
wantNamespace: ptrStr("a"),
|
||||
wantAlias: "b/c",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gotNamespace, gotAlias := SplitIdentifier(tt.identifier)
|
||||
|
||||
if tt.wantNamespace == nil {
|
||||
assert.Nil(t, gotNamespace)
|
||||
} else {
|
||||
require.NotNil(t, gotNamespace)
|
||||
assert.Equal(t, *tt.wantNamespace, *gotNamespace)
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.wantAlias, gotAlias)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ptrStr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func TestValidateAndDeduplicateTags(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tags []string
|
||||
want []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "single valid tag",
|
||||
tags: []string{"v1"},
|
||||
want: []string{"v1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "multiple unique tags",
|
||||
tags: []string{"v1", "prod", "latest"},
|
||||
want: []string{"v1", "prod", "latest"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "duplicate tags deduplicated",
|
||||
tags: []string{"v1", "V1", "v1"},
|
||||
want: []string{"v1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "tags with dots and underscores",
|
||||
tags: []string{"v1.0", "v1_1"},
|
||||
want: []string{"v1.0", "v1_1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid - UUID tag rejected",
|
||||
tags: []string{"550e8400-e29b-41d4-a716-446655440000"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - special characters",
|
||||
tags: []string{"v1!", "v2@"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty list returns empty",
|
||||
tags: []string{},
|
||||
want: []string{},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := ValidateAndDeduplicateTags(tt.tags)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.ElementsMatch(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSandboxID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "canonical sandbox ID",
|
||||
input: "i1a2b3c4d5e6f7g8h9j0k",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "short alphanumeric",
|
||||
input: "abc123",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "all digits",
|
||||
input: "1234567890",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "all lowercase letters",
|
||||
input: "abcdefghijklmnopqrst",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid - empty",
|
||||
input: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - contains colon (Redis separator)",
|
||||
input: "abc:def",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - contains open brace (Redis hash slot)",
|
||||
input: "abc{def",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - contains close brace (Redis hash slot)",
|
||||
input: "abc}def",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - contains newline",
|
||||
input: "abc\ndef",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - contains space",
|
||||
input: "abc def",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - contains hyphen",
|
||||
input: "abc-def",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - contains uppercase",
|
||||
input: "abcDEF",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - contains slash",
|
||||
input: "abc/def",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - contains null byte",
|
||||
input: "abc\x00def",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := ValidateSandboxID(tt.input)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateNamespaceMatchesTeam(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
identifier string
|
||||
teamSlug string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "bare alias - no namespace",
|
||||
identifier: "my-template",
|
||||
teamSlug: "acme",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "matching namespace",
|
||||
identifier: "acme/my-template",
|
||||
teamSlug: "acme",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "mismatched namespace",
|
||||
identifier: "other-team/my-template",
|
||||
teamSlug: "acme",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := ValidateNamespaceMatchesTeam(tt.identifier, tt.teamSlug)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
6
envd/internal/shared/keys/constants.go
Normal file
6
envd/internal/shared/keys/constants.go
Normal file
@ -0,0 +1,6 @@
|
||||
package keys
|
||||
|
||||
const (
|
||||
ApiKeyPrefix = "wrn_"
|
||||
AccessTokenPrefix = "sk_wrn_"
|
||||
)
|
||||
5
envd/internal/shared/keys/hashing.go
Normal file
5
envd/internal/shared/keys/hashing.go
Normal file
@ -0,0 +1,5 @@
|
||||
package keys
|
||||
|
||||
type Hasher interface {
|
||||
Hash(key []byte) string
|
||||
}
|
||||
25
envd/internal/shared/keys/hmac_sha256.go
Normal file
25
envd/internal/shared/keys/hmac_sha256.go
Normal file
@ -0,0 +1,25 @@
|
||||
package keys
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
type HMACSha256Hashing struct {
|
||||
key []byte
|
||||
}
|
||||
|
||||
func NewHMACSHA256Hashing(key []byte) *HMACSha256Hashing {
|
||||
return &HMACSha256Hashing{key: key}
|
||||
}
|
||||
|
||||
func (h *HMACSha256Hashing) Hash(content []byte) (string, error) {
|
||||
mac := hmac.New(sha256.New, h.key)
|
||||
_, err := mac.Write(content)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return hex.EncodeToString(mac.Sum(nil)), nil
|
||||
}
|
||||
74
envd/internal/shared/keys/hmac_sha256_test.go
Normal file
74
envd/internal/shared/keys/hmac_sha256_test.go
Normal file
@ -0,0 +1,74 @@
|
||||
package keys
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHMACSha256Hashing_ValidHash(t *testing.T) {
|
||||
t.Parallel()
|
||||
key := []byte("test-key")
|
||||
hasher := NewHMACSHA256Hashing(key)
|
||||
content := []byte("hello world")
|
||||
expectedHash := "18c4b268f0bbf8471eda56af3e70b1d4613d734dc538b4940b59931c412a1591"
|
||||
actualHash, err := hasher.Hash(content)
|
||||
require.NoError(t, err)
|
||||
|
||||
if actualHash != expectedHash {
|
||||
t.Errorf("expected %s, got %s", expectedHash, actualHash)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHMACSha256Hashing_EmptyContent(t *testing.T) {
|
||||
t.Parallel()
|
||||
key := []byte("test-key")
|
||||
hasher := NewHMACSHA256Hashing(key)
|
||||
content := []byte("")
|
||||
expectedHash := "2711cc23e9ab1b8a9bc0fe991238da92671624a9ebdaf1c1abec06e7e9a14f9b"
|
||||
actualHash, err := hasher.Hash(content)
|
||||
require.NoError(t, err)
|
||||
|
||||
if actualHash != expectedHash {
|
||||
t.Errorf("expected %s, got %s", expectedHash, actualHash)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHMACSha256Hashing_DifferentKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
key := []byte("test-key")
|
||||
hasher := NewHMACSHA256Hashing(key)
|
||||
differentKeyHasher := NewHMACSHA256Hashing([]byte("different-key"))
|
||||
content := []byte("hello world")
|
||||
|
||||
hashWithOriginalKey, err := hasher.Hash(content)
|
||||
require.NoError(t, err)
|
||||
|
||||
hashWithDifferentKey, err := differentKeyHasher.Hash(content)
|
||||
require.NoError(t, err)
|
||||
|
||||
if hashWithOriginalKey == hashWithDifferentKey {
|
||||
t.Errorf("hashes with different keys should not match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHMACSha256Hashing_IdenticalResult(t *testing.T) {
|
||||
t.Parallel()
|
||||
key := []byte("placeholder-hashing-key")
|
||||
content := []byte("test content for hashing")
|
||||
|
||||
mac := hmac.New(sha256.New, key)
|
||||
mac.Write(content)
|
||||
expectedResult := hex.EncodeToString(mac.Sum(nil))
|
||||
|
||||
hasher := NewHMACSHA256Hashing(key)
|
||||
actualResult, err := hasher.Hash(content)
|
||||
require.NoError(t, err)
|
||||
|
||||
if actualResult != expectedResult {
|
||||
t.Errorf("expected %s, got %s", expectedResult, actualResult)
|
||||
}
|
||||
}
|
||||
99
envd/internal/shared/keys/key.go
Normal file
99
envd/internal/shared/keys/key.go
Normal file
@ -0,0 +1,99 @@
|
||||
package keys
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
identifierValueSuffixLength = 4
|
||||
identifierValuePrefixLength = 2
|
||||
|
||||
keyLength = 20
|
||||
)
|
||||
|
||||
var hasher Hasher = NewSHA256Hashing()
|
||||
|
||||
type Key struct {
|
||||
PrefixedRawValue string
|
||||
HashedValue string
|
||||
Masked MaskedIdentifier
|
||||
}
|
||||
|
||||
type MaskedIdentifier struct {
|
||||
Prefix string
|
||||
ValueLength int
|
||||
MaskedValuePrefix string
|
||||
MaskedValueSuffix string
|
||||
}
|
||||
|
||||
// MaskKey returns identifier masking properties in accordance to the OpenAPI response spec
|
||||
func MaskKey(prefix, value string) (MaskedIdentifier, error) {
|
||||
valueLength := len(value)
|
||||
|
||||
suffixOffset := valueLength - identifierValueSuffixLength
|
||||
prefixOffset := identifierValuePrefixLength
|
||||
|
||||
if suffixOffset < 0 {
|
||||
return MaskedIdentifier{}, fmt.Errorf("mask value length is less than identifier suffix length (%d)", identifierValueSuffixLength)
|
||||
}
|
||||
|
||||
if suffixOffset == 0 {
|
||||
return MaskedIdentifier{}, fmt.Errorf("mask value length is equal to identifier suffix length (%d), which would expose the entire identifier in the mask", identifierValueSuffixLength)
|
||||
}
|
||||
|
||||
// cap prefixOffset by suffixOffset to prevent overlap with the suffix.
|
||||
if prefixOffset > suffixOffset {
|
||||
prefixOffset = suffixOffset
|
||||
}
|
||||
|
||||
maskPrefix := value[:prefixOffset]
|
||||
maskSuffix := value[suffixOffset:]
|
||||
|
||||
maskedIdentifierProperties := MaskedIdentifier{
|
||||
Prefix: prefix,
|
||||
ValueLength: valueLength,
|
||||
MaskedValuePrefix: maskPrefix,
|
||||
MaskedValueSuffix: maskSuffix,
|
||||
}
|
||||
|
||||
return maskedIdentifierProperties, nil
|
||||
}
|
||||
|
||||
func GenerateKey(prefix string) (Key, error) {
|
||||
keyBytes := make([]byte, keyLength)
|
||||
|
||||
_, err := rand.Read(keyBytes)
|
||||
if err != nil {
|
||||
return Key{}, err
|
||||
}
|
||||
|
||||
generatedIdentifier := hex.EncodeToString(keyBytes)
|
||||
|
||||
mask, err := MaskKey(prefix, generatedIdentifier)
|
||||
if err != nil {
|
||||
return Key{}, err
|
||||
}
|
||||
|
||||
return Key{
|
||||
PrefixedRawValue: prefix + generatedIdentifier,
|
||||
HashedValue: hasher.Hash(keyBytes),
|
||||
Masked: mask,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func VerifyKey(prefix string, key string) (string, error) {
|
||||
if !strings.HasPrefix(key, prefix) {
|
||||
return "", fmt.Errorf("invalid key prefix")
|
||||
}
|
||||
|
||||
keyValue := key[len(prefix):]
|
||||
keyBytes, err := hex.DecodeString(keyValue)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid key")
|
||||
}
|
||||
|
||||
return hasher.Hash(keyBytes), nil
|
||||
}
|
||||
160
envd/internal/shared/keys/key_test.go
Normal file
160
envd/internal/shared/keys/key_test.go
Normal file
@ -0,0 +1,160 @@
|
||||
package keys
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMaskKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("succeeds: value longer than suffix length", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
masked, err := MaskKey("test_", "1234567890")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test_", masked.Prefix)
|
||||
assert.Equal(t, "12", masked.MaskedValuePrefix)
|
||||
assert.Equal(t, "7890", masked.MaskedValueSuffix)
|
||||
})
|
||||
|
||||
t.Run("succeeds: empty prefix, value longer than suffix length", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
masked, err := MaskKey("", "1234567890")
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, masked.Prefix)
|
||||
assert.Equal(t, "12", masked.MaskedValuePrefix)
|
||||
assert.Equal(t, "7890", masked.MaskedValueSuffix)
|
||||
})
|
||||
|
||||
t.Run("error: value length less than suffix length", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := MaskKey("test", "123")
|
||||
require.Error(t, err)
|
||||
assert.EqualError(t, err, fmt.Sprintf("mask value length is less than identifier suffix length (%d)", identifierValueSuffixLength))
|
||||
})
|
||||
|
||||
t.Run("error: value length equals suffix length", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := MaskKey("test", "1234")
|
||||
require.Error(t, err)
|
||||
assert.EqualError(t, err, fmt.Sprintf("mask value length is equal to identifier suffix length (%d), which would expose the entire identifier in the mask", identifierValueSuffixLength))
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
keyLength := 40
|
||||
|
||||
t.Run("succeeds", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := GenerateKey("test_")
|
||||
require.NoError(t, err)
|
||||
assert.Regexp(t, "^test_.*", key.PrefixedRawValue)
|
||||
assert.Equal(t, "test_", key.Masked.Prefix)
|
||||
assert.Equal(t, keyLength, key.Masked.ValueLength)
|
||||
assert.Regexp(t, "^[0-9a-f]{"+strconv.Itoa(identifierValuePrefixLength)+"}$", key.Masked.MaskedValuePrefix)
|
||||
assert.Regexp(t, "^[0-9a-f]{"+strconv.Itoa(identifierValueSuffixLength)+"}$", key.Masked.MaskedValueSuffix)
|
||||
assert.Regexp(t, "^\\$sha256\\$.*", key.HashedValue)
|
||||
})
|
||||
|
||||
t.Run("no prefix", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := GenerateKey("")
|
||||
require.NoError(t, err)
|
||||
assert.Regexp(t, "^[0-9a-f]{"+strconv.Itoa(keyLength)+"}$", key.PrefixedRawValue)
|
||||
assert.Empty(t, key.Masked.Prefix)
|
||||
assert.Equal(t, keyLength, key.Masked.ValueLength)
|
||||
assert.Regexp(t, "^[0-9a-f]{"+strconv.Itoa(identifierValuePrefixLength)+"}$", key.Masked.MaskedValuePrefix)
|
||||
assert.Regexp(t, "^[0-9a-f]{"+strconv.Itoa(identifierValueSuffixLength)+"}$", key.Masked.MaskedValueSuffix)
|
||||
assert.Regexp(t, "^\\$sha256\\$.*", key.HashedValue)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetMaskedIdentifierProperties(t *testing.T) {
|
||||
t.Parallel()
|
||||
type testCase struct {
|
||||
name string
|
||||
prefix string
|
||||
value string
|
||||
expectedResult MaskedIdentifier
|
||||
expectedErrString string
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
// --- ERROR CASES (value's length <= identifierValueSuffixLength) ---
|
||||
{
|
||||
name: "error: value length < suffix length (3 vs 4)",
|
||||
prefix: "pk_",
|
||||
value: "abc",
|
||||
expectedResult: MaskedIdentifier{},
|
||||
expectedErrString: fmt.Sprintf("mask value length is less than identifier suffix length (%d)", identifierValueSuffixLength),
|
||||
},
|
||||
{
|
||||
name: "error: value length == suffix length (4 vs 4)",
|
||||
prefix: "sk_",
|
||||
value: "abcd",
|
||||
expectedResult: MaskedIdentifier{},
|
||||
expectedErrString: fmt.Sprintf("mask value length is equal to identifier suffix length (%d), which would expose the entire identifier in the mask", identifierValueSuffixLength),
|
||||
},
|
||||
{
|
||||
name: "error: value length < suffix length (0 vs 4, empty value)",
|
||||
prefix: "err_",
|
||||
value: "",
|
||||
expectedResult: MaskedIdentifier{},
|
||||
expectedErrString: fmt.Sprintf("mask value length is less than identifier suffix length (%d)", identifierValueSuffixLength),
|
||||
},
|
||||
|
||||
// --- SUCCESS CASES (value's length > identifierValueSuffixLength) ---
|
||||
{
|
||||
name: "success: value long (10), prefix val len fully used",
|
||||
prefix: "pk_",
|
||||
value: "abcdefghij",
|
||||
expectedResult: MaskedIdentifier{
|
||||
Prefix: "pk_",
|
||||
ValueLength: 10,
|
||||
MaskedValuePrefix: "ab",
|
||||
MaskedValueSuffix: "ghij",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success: value medium (5), prefix val len truncated by overlap",
|
||||
prefix: "",
|
||||
value: "abcde",
|
||||
expectedResult: MaskedIdentifier{
|
||||
Prefix: "",
|
||||
ValueLength: 5,
|
||||
MaskedValuePrefix: "a",
|
||||
MaskedValueSuffix: "bcde",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success: value medium (6), prefix val len fits exactly",
|
||||
prefix: "pk_",
|
||||
value: "abcdef",
|
||||
expectedResult: MaskedIdentifier{
|
||||
Prefix: "pk_",
|
||||
ValueLength: 6,
|
||||
MaskedValuePrefix: "ab",
|
||||
MaskedValueSuffix: "cdef",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result, err := MaskKey(tc.prefix, tc.value)
|
||||
|
||||
if tc.expectedErrString != "" {
|
||||
require.EqualError(t, err, tc.expectedErrString)
|
||||
assert.Equal(t, tc.expectedResult, result)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.expectedResult, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
30
envd/internal/shared/keys/sha256.go
Normal file
30
envd/internal/shared/keys/sha256.go
Normal file
@ -0,0 +1,30 @@
|
||||
package keys
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Sha256Hashing struct{}
|
||||
|
||||
func NewSHA256Hashing() *Sha256Hashing {
|
||||
return &Sha256Hashing{}
|
||||
}
|
||||
|
||||
func (h *Sha256Hashing) Hash(key []byte) string {
|
||||
hashBytes := sha256.Sum256(key)
|
||||
|
||||
hash64 := base64.RawStdEncoding.EncodeToString(hashBytes[:])
|
||||
|
||||
return fmt.Sprintf(
|
||||
"$sha256$%s",
|
||||
hash64,
|
||||
)
|
||||
}
|
||||
|
||||
func (h *Sha256Hashing) HashWithoutPrefix(key []byte) string {
|
||||
hashBytes := sha256.Sum256(key)
|
||||
|
||||
return base64.RawStdEncoding.EncodeToString(hashBytes[:])
|
||||
}
|
||||
15
envd/internal/shared/keys/sha256_test.go
Normal file
15
envd/internal/shared/keys/sha256_test.go
Normal file
@ -0,0 +1,15 @@
|
||||
package keys
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSHA256Hashing(t *testing.T) {
|
||||
t.Parallel()
|
||||
hasher := NewSHA256Hashing()
|
||||
|
||||
hashed := hasher.Hash([]byte("test"))
|
||||
assert.Regexp(t, "^\\$sha256\\$.*", hashed)
|
||||
}
|
||||
20
envd/internal/shared/keys/sha512.go
Normal file
20
envd/internal/shared/keys/sha512.go
Normal file
@ -0,0 +1,20 @@
|
||||
package keys
|
||||
|
||||
import (
|
||||
"crypto/sha512"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
// HashAccessToken computes the SHA-512 hash of an access token.
|
||||
func HashAccessToken(token string) string {
|
||||
h := sha512.Sum512([]byte(token))
|
||||
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
// HashAccessTokenBytes computes the SHA-512 hash of an access token from bytes.
|
||||
func HashAccessTokenBytes(token []byte) string {
|
||||
h := sha512.Sum512(token)
|
||||
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
47
envd/internal/shared/smap/smap.go
Normal file
47
envd/internal/shared/smap/smap.go
Normal file
@ -0,0 +1,47 @@
|
||||
package smap
|
||||
|
||||
import (
|
||||
cmap "github.com/orcaman/concurrent-map/v2"
|
||||
)
|
||||
|
||||
type Map[V any] struct {
|
||||
m cmap.ConcurrentMap[string, V]
|
||||
}
|
||||
|
||||
func New[V any]() *Map[V] {
|
||||
return &Map[V]{
|
||||
m: cmap.New[V](),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Map[V]) Remove(key string) {
|
||||
m.m.Remove(key)
|
||||
}
|
||||
|
||||
func (m *Map[V]) Get(key string) (V, bool) {
|
||||
return m.m.Get(key)
|
||||
}
|
||||
|
||||
func (m *Map[V]) Insert(key string, value V) {
|
||||
m.m.Set(key, value)
|
||||
}
|
||||
|
||||
func (m *Map[V]) Upsert(key string, value V, cb cmap.UpsertCb[V]) V {
|
||||
return m.m.Upsert(key, value, cb)
|
||||
}
|
||||
|
||||
func (m *Map[V]) InsertIfAbsent(key string, value V) bool {
|
||||
return m.m.SetIfAbsent(key, value)
|
||||
}
|
||||
|
||||
func (m *Map[V]) Items() map[string]V {
|
||||
return m.m.Items()
|
||||
}
|
||||
|
||||
func (m *Map[V]) RemoveCb(key string, cb func(key string, v V, exists bool) bool) bool {
|
||||
return m.m.RemoveCb(key, cb)
|
||||
}
|
||||
|
||||
func (m *Map[V]) Count() int {
|
||||
return m.m.Count()
|
||||
}
|
||||
43
envd/internal/shared/utils/ptr.go
Normal file
43
envd/internal/shared/utils/ptr.go
Normal file
@ -0,0 +1,43 @@
|
||||
package utils
|
||||
|
||||
import "fmt"
|
||||
|
||||
func ToPtr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
func FromPtr[T any](s *T) T {
|
||||
if s == nil {
|
||||
var zero T
|
||||
|
||||
return zero
|
||||
}
|
||||
|
||||
return *s
|
||||
}
|
||||
|
||||
func Sprintp[T any](s *T) string {
|
||||
if s == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%v", *s)
|
||||
}
|
||||
|
||||
func DerefOrDefault[T any](s *T, defaultValue T) T {
|
||||
if s == nil {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return *s
|
||||
}
|
||||
|
||||
func CastPtr[S any, T any](s *S, castFunc func(S) T) *T {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
t := castFunc(*s)
|
||||
|
||||
return &t
|
||||
}
|
||||
27
envd/internal/utils/atomic.go
Normal file
27
envd/internal/utils/atomic.go
Normal file
@ -0,0 +1,27 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type AtomicMax struct {
|
||||
val int64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewAtomicMax() *AtomicMax {
|
||||
return &AtomicMax{}
|
||||
}
|
||||
|
||||
func (a *AtomicMax) SetToGreater(newValue int64) bool {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
if a.val > newValue {
|
||||
return false
|
||||
}
|
||||
|
||||
a.val = newValue
|
||||
|
||||
return true
|
||||
}
|
||||
76
envd/internal/utils/atomic_test.go
Normal file
76
envd/internal/utils/atomic_test.go
Normal file
@ -0,0 +1,76 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAtomicMax_NewAtomicMax(t *testing.T) {
|
||||
t.Parallel()
|
||||
am := NewAtomicMax()
|
||||
require.NotNil(t, am)
|
||||
require.Equal(t, int64(0), am.val)
|
||||
}
|
||||
|
||||
func TestAtomicMax_SetToGreater_InitialValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
am := NewAtomicMax()
|
||||
|
||||
// Should succeed when newValue > current
|
||||
assert.True(t, am.SetToGreater(10))
|
||||
assert.Equal(t, int64(10), am.val)
|
||||
}
|
||||
|
||||
func TestAtomicMax_SetToGreater_EqualValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
am := NewAtomicMax()
|
||||
am.val = 10
|
||||
|
||||
// Should succeed when newValue > current
|
||||
assert.True(t, am.SetToGreater(20))
|
||||
assert.Equal(t, int64(20), am.val)
|
||||
}
|
||||
|
||||
func TestAtomicMax_SetToGreater_GreaterValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
am := NewAtomicMax()
|
||||
am.val = 10
|
||||
|
||||
// Should fail when newValue < current, keeping the max value
|
||||
assert.False(t, am.SetToGreater(5))
|
||||
assert.Equal(t, int64(10), am.val)
|
||||
}
|
||||
|
||||
func TestAtomicMax_SetToGreater_NegativeValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
am := NewAtomicMax()
|
||||
am.val = -5
|
||||
|
||||
assert.True(t, am.SetToGreater(-2))
|
||||
assert.Equal(t, int64(-2), am.val)
|
||||
}
|
||||
|
||||
func TestAtomicMax_SetToGreater_Concurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
am := NewAtomicMax()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Run 100 goroutines trying to update the value concurrently
|
||||
numGoroutines := 100
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := range numGoroutines {
|
||||
go func(val int64) {
|
||||
defer wg.Done()
|
||||
am.SetToGreater(val)
|
||||
}(int64(i))
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// The final value should be 99 (the maximum value)
|
||||
assert.Equal(t, int64(99), am.val)
|
||||
}
|
||||
51
envd/internal/utils/map.go
Normal file
51
envd/internal/utils/map.go
Normal file
@ -0,0 +1,51 @@
|
||||
package utils
|
||||
|
||||
import "sync"
|
||||
|
||||
type Map[K comparable, V any] struct {
|
||||
m sync.Map
|
||||
}
|
||||
|
||||
func NewMap[K comparable, V any]() *Map[K, V] {
|
||||
return &Map[K, V]{
|
||||
m: sync.Map{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Delete(key K) {
|
||||
m.m.Delete(key)
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Load(key K) (value V, ok bool) {
|
||||
v, ok := m.m.Load(key)
|
||||
if !ok {
|
||||
return value, ok
|
||||
}
|
||||
|
||||
return v.(V), ok
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
|
||||
v, loaded := m.m.LoadAndDelete(key)
|
||||
if !loaded {
|
||||
return value, loaded
|
||||
}
|
||||
|
||||
return v.(V), loaded
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
|
||||
a, loaded := m.m.LoadOrStore(key, value)
|
||||
|
||||
return a.(V), loaded
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Range(f func(key K, value V) bool) {
|
||||
m.m.Range(func(key, value any) bool {
|
||||
return f(key.(K), value.(V))
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Store(key K, value V) {
|
||||
m.m.Store(key, value)
|
||||
}
|
||||
43
envd/internal/utils/multipart.go
Normal file
43
envd/internal/utils/multipart.go
Normal file
@ -0,0 +1,43 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
)
|
||||
|
||||
// CustomPart is a wrapper around multipart.Part that overloads the FileName method
|
||||
type CustomPart struct {
|
||||
*multipart.Part
|
||||
}
|
||||
|
||||
// FileNameWithPath returns the filename parameter of the Part's Content-Disposition header.
|
||||
// This method borrows from the original FileName method implementation but returns the full
|
||||
// filename without using `filepath.Base`.
|
||||
func (p *CustomPart) FileNameWithPath() (string, error) {
|
||||
dispositionParams, err := p.parseContentDisposition()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
filename, ok := dispositionParams["filename"]
|
||||
if !ok {
|
||||
return "", errors.New("filename not found in Content-Disposition header")
|
||||
}
|
||||
|
||||
return filename, nil
|
||||
}
|
||||
|
||||
func (p *CustomPart) parseContentDisposition() (map[string]string, error) {
|
||||
v := p.Header.Get("Content-Disposition")
|
||||
_, dispositionParams, err := mime.ParseMediaType(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return dispositionParams, nil
|
||||
}
|
||||
|
||||
// NewCustomPart creates a new CustomPart from a multipart.Part
|
||||
func NewCustomPart(part *multipart.Part) *CustomPart {
|
||||
return &CustomPart{Part: part}
|
||||
}
|
||||
12
envd/internal/utils/rfsnotify.go
Normal file
12
envd/internal/utils/rfsnotify.go
Normal file
@ -0,0 +1,12 @@
|
||||
package utils
|
||||
|
||||
import "path/filepath"
|
||||
|
||||
// FsnotifyPath creates an optionally recursive path for fsnotify/fsnotify internal implementation
|
||||
func FsnotifyPath(path string, recursive bool) string {
|
||||
if recursive {
|
||||
return filepath.Join(path, "...")
|
||||
}
|
||||
|
||||
return path
|
||||
}
|
||||
290
envd/main.go
290
envd/main.go
@ -0,0 +1,290 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/authn"
|
||||
connectcors "connectrpc.com/cors"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/rs/cors"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/api"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||
publicport "git.omukk.dev/wrenn/sandbox/envd/internal/port"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/services/cgroups"
|
||||
filesystemRpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/filesystem"
|
||||
processRpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/process"
|
||||
processSpec "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
// Downstream timeout should be greater than upstream (in orchestrator proxy).
|
||||
idleTimeout = 640 * time.Second
|
||||
maxAge = 2 * time.Hour
|
||||
|
||||
defaultPort = 49983
|
||||
|
||||
portScannerInterval = 1000 * time.Millisecond
|
||||
|
||||
// This is the default user used in the container if not specified otherwise.
|
||||
// It should be always overridden by the user in /init when building the template.
|
||||
defaultUser = "root"
|
||||
|
||||
kilobyte = 1024
|
||||
megabyte = 1024 * kilobyte
|
||||
)
|
||||
|
||||
var (
|
||||
Version = "0.5.4"
|
||||
|
||||
commitSHA string
|
||||
|
||||
isNotFC bool
|
||||
port int64
|
||||
|
||||
versionFlag bool
|
||||
commitFlag bool
|
||||
startCmdFlag string
|
||||
cgroupRoot string
|
||||
)
|
||||
|
||||
func parseFlags() {
|
||||
flag.BoolVar(
|
||||
&isNotFC,
|
||||
"isnotfc",
|
||||
false,
|
||||
"isNotFCmode prints all logs to stdout",
|
||||
)
|
||||
|
||||
flag.BoolVar(
|
||||
&versionFlag,
|
||||
"version",
|
||||
false,
|
||||
"print envd version",
|
||||
)
|
||||
|
||||
flag.BoolVar(
|
||||
&commitFlag,
|
||||
"commit",
|
||||
false,
|
||||
"print envd source commit",
|
||||
)
|
||||
|
||||
flag.Int64Var(
|
||||
&port,
|
||||
"port",
|
||||
defaultPort,
|
||||
"a port on which the daemon should run",
|
||||
)
|
||||
|
||||
flag.StringVar(
|
||||
&startCmdFlag,
|
||||
"cmd",
|
||||
"",
|
||||
"a command to run on the daemon start",
|
||||
)
|
||||
|
||||
flag.StringVar(
|
||||
&cgroupRoot,
|
||||
"cgroup-root",
|
||||
"/sys/fs/cgroup",
|
||||
"cgroup root directory",
|
||||
)
|
||||
|
||||
flag.Parse()
|
||||
}
|
||||
|
||||
func withCORS(h http.Handler) http.Handler {
|
||||
middleware := cors.New(cors.Options{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowedMethods: []string{
|
||||
http.MethodHead,
|
||||
http.MethodGet,
|
||||
http.MethodPost,
|
||||
http.MethodPut,
|
||||
http.MethodPatch,
|
||||
http.MethodDelete,
|
||||
},
|
||||
AllowedHeaders: []string{"*"},
|
||||
ExposedHeaders: append(
|
||||
connectcors.ExposedHeaders(),
|
||||
"Location",
|
||||
"Cache-Control",
|
||||
"X-Content-Type-Options",
|
||||
),
|
||||
MaxAge: int(maxAge.Seconds()),
|
||||
})
|
||||
|
||||
return middleware.Handler(h)
|
||||
}
|
||||
|
||||
func main() {
|
||||
parseFlags()
|
||||
|
||||
if versionFlag {
|
||||
fmt.Printf("%s\n", Version)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if commitFlag {
|
||||
fmt.Printf("%s\n", commitSHA)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if err := os.MkdirAll(host.WrennRunDir, 0o755); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error creating wrenn run directory: %v\n", err)
|
||||
}
|
||||
|
||||
defaults := &execcontext.Defaults{
|
||||
User: defaultUser,
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
}
|
||||
isFCBoolStr := strconv.FormatBool(!isNotFC)
|
||||
defaults.EnvVars.Store("WRENN_SANDBOX", isFCBoolStr)
|
||||
if err := os.WriteFile(filepath.Join(host.WrennRunDir, ".WRENN_SANDBOX"), []byte(isFCBoolStr), 0o444); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error writing sandbox file: %v\n", err)
|
||||
}
|
||||
|
||||
mmdsChan := make(chan *host.MMDSOpts, 1)
|
||||
defer close(mmdsChan)
|
||||
if !isNotFC {
|
||||
go host.PollForMMDSOpts(ctx, mmdsChan, defaults.EnvVars)
|
||||
}
|
||||
|
||||
l := logs.NewLogger(ctx, isNotFC, mmdsChan)
|
||||
|
||||
m := chi.NewRouter()
|
||||
|
||||
envLogger := l.With().Str("logger", "envd").Logger()
|
||||
fsLogger := l.With().Str("logger", "filesystem").Logger()
|
||||
filesystemRpc.Handle(m, &fsLogger, defaults)
|
||||
|
||||
cgroupManager := createCgroupManager()
|
||||
defer func() {
|
||||
err := cgroupManager.Close()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "failed to close cgroup manager: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
processLogger := l.With().Str("logger", "process").Logger()
|
||||
processService := processRpc.Handle(m, &processLogger, defaults, cgroupManager)
|
||||
|
||||
service := api.New(&envLogger, defaults, mmdsChan, isNotFC)
|
||||
handler := api.HandlerFromMux(service, m)
|
||||
middleware := authn.NewMiddleware(permissions.AuthenticateUsername)
|
||||
|
||||
s := &http.Server{
|
||||
Handler: withCORS(
|
||||
service.WithAuthorization(
|
||||
middleware.Wrap(handler),
|
||||
),
|
||||
),
|
||||
Addr: fmt.Sprintf("0.0.0.0:%d", port),
|
||||
// We remove the timeouts as the connection is terminated by closing of the sandbox and keepalive close.
|
||||
ReadTimeout: 0,
|
||||
WriteTimeout: 0,
|
||||
IdleTimeout: idleTimeout,
|
||||
}
|
||||
|
||||
// TODO: Not used anymore in template build, replaced by direct envd command call.
|
||||
if startCmdFlag != "" {
|
||||
tag := "startCmd"
|
||||
cwd := "/home/user"
|
||||
user, err := permissions.GetUser("root")
|
||||
if err != nil {
|
||||
log.Fatalf("error getting user: %v", err) //nolint:gocritic // probably fine to bail if we're done?
|
||||
}
|
||||
|
||||
if err = processService.InitializeStartProcess(ctx, user, &processSpec.StartRequest{
|
||||
Tag: &tag,
|
||||
Process: &processSpec.ProcessConfig{
|
||||
Envs: make(map[string]string),
|
||||
Cmd: "/bin/bash",
|
||||
Args: []string{"-l", "-c", startCmdFlag},
|
||||
Cwd: &cwd,
|
||||
},
|
||||
}); err != nil {
|
||||
log.Fatalf("error starting process: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Bind all open ports on 127.0.0.1 and localhost to the eth0 interface
|
||||
portScanner := publicport.NewScanner(portScannerInterval)
|
||||
defer portScanner.Destroy()
|
||||
|
||||
portLogger := l.With().Str("logger", "port-forwarder").Logger()
|
||||
portForwarder := publicport.NewForwarder(&portLogger, portScanner, cgroupManager)
|
||||
go portForwarder.StartForwarding(ctx)
|
||||
|
||||
go portScanner.ScanAndBroadcast()
|
||||
|
||||
err := s.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Fatalf("error starting server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func createCgroupManager() (m cgroups.Manager) {
|
||||
defer func() {
|
||||
if m == nil {
|
||||
fmt.Fprintf(os.Stderr, "falling back to no-op cgroup manager\n")
|
||||
m = cgroups.NewNoopManager()
|
||||
}
|
||||
}()
|
||||
|
||||
metrics, err := host.GetMetrics()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "failed to calculate host metrics: %v\n", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// try to keep 1/8 of the memory free, but no more than 128 MB
|
||||
maxMemoryReserved := uint64(float64(metrics.MemTotal) * .125)
|
||||
maxMemoryReserved = min(maxMemoryReserved, uint64(128)*megabyte)
|
||||
|
||||
opts := []cgroups.Cgroup2ManagerOption{
|
||||
cgroups.WithCgroup2ProcessType(cgroups.ProcessTypePTY, "ptys", map[string]string{
|
||||
"cpu.weight": "200", // gets much preferred cpu access, to help keep these real time
|
||||
}),
|
||||
cgroups.WithCgroup2ProcessType(cgroups.ProcessTypeSocat, "socats", map[string]string{
|
||||
"cpu.weight": "150", // gets slightly preferred cpu access
|
||||
"memory.min": fmt.Sprintf("%d", 5*megabyte),
|
||||
"memory.low": fmt.Sprintf("%d", 8*megabyte),
|
||||
}),
|
||||
cgroups.WithCgroup2ProcessType(cgroups.ProcessTypeUser, "user", map[string]string{
|
||||
"memory.high": fmt.Sprintf("%d", metrics.MemTotal-maxMemoryReserved),
|
||||
"cpu.weight": "50", // less than envd, and less than core processes that default to 100
|
||||
}),
|
||||
}
|
||||
if cgroupRoot != "" {
|
||||
opts = append(opts, cgroups.WithCgroup2RootSysFSPath(cgroupRoot))
|
||||
}
|
||||
|
||||
mgr, err := cgroups.NewCgroup2Manager(opts...)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "failed to create cgroup2 manager: %v\n", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return mgr
|
||||
}
|
||||
|
||||
14
envd/spec/buf.gen.yaml
Normal file
14
envd/spec/buf.gen.yaml
Normal file
@ -0,0 +1,14 @@
|
||||
version: v1
|
||||
plugins:
|
||||
- plugin: go
|
||||
out: ../internal/services/spec
|
||||
opt: paths=source_relative
|
||||
- plugin: connect-go
|
||||
out: ../internal/services/spec
|
||||
opt: paths=source_relative
|
||||
|
||||
managed:
|
||||
enabled: true
|
||||
optimize_for: SPEED
|
||||
go_package_prefix:
|
||||
default: git.omukk.dev/wrenn/sandbox/envd/internal/services/spec
|
||||
303
envd/spec/envd.yaml
Normal file
303
envd/spec/envd.yaml
Normal file
@ -0,0 +1,303 @@
|
||||
openapi: 3.0.0
|
||||
info:
|
||||
title: envd
|
||||
version: 0.1.1
|
||||
description: API for managing files' content and controlling envd
|
||||
|
||||
tags:
|
||||
- name: files
|
||||
|
||||
paths:
|
||||
/health:
|
||||
get:
|
||||
summary: Check the health of the service
|
||||
responses:
|
||||
"204":
|
||||
description: The service is healthy
|
||||
|
||||
/metrics:
|
||||
get:
|
||||
summary: Get the stats of the service
|
||||
security:
|
||||
- AccessTokenAuth: []
|
||||
- {}
|
||||
responses:
|
||||
"200":
|
||||
description: The resource usage metrics of the service
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Metrics"
|
||||
|
||||
/init:
|
||||
post:
|
||||
summary: Set initial vars, ensure the time and metadata is synced with the host
|
||||
security:
|
||||
- AccessTokenAuth: []
|
||||
- {}
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
volumeMounts:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/VolumeMount"
|
||||
hyperloopIP:
|
||||
type: string
|
||||
description: IP address of the hyperloop server to connect to
|
||||
envVars:
|
||||
$ref: "#/components/schemas/EnvVars"
|
||||
accessToken:
|
||||
type: string
|
||||
description: Access token for secure access to envd service
|
||||
x-go-type: SecureToken
|
||||
timestamp:
|
||||
type: string
|
||||
format: date-time
|
||||
description: The current timestamp in RFC3339 format
|
||||
defaultUser:
|
||||
type: string
|
||||
description: The default user to use for operations
|
||||
defaultWorkdir:
|
||||
type: string
|
||||
description: The default working directory to use for operations
|
||||
responses:
|
||||
"204":
|
||||
description: Env vars set, the time and metadata is synced with the host
|
||||
|
||||
/envs:
|
||||
get:
|
||||
summary: Get the environment variables
|
||||
security:
|
||||
- AccessTokenAuth: []
|
||||
- {}
|
||||
responses:
|
||||
"200":
|
||||
description: Environment variables
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/EnvVars"
|
||||
|
||||
/files:
|
||||
get:
|
||||
summary: Download a file
|
||||
tags: [files]
|
||||
security:
|
||||
- AccessTokenAuth: []
|
||||
- {}
|
||||
parameters:
|
||||
- $ref: "#/components/parameters/FilePath"
|
||||
- $ref: "#/components/parameters/User"
|
||||
- $ref: "#/components/parameters/Signature"
|
||||
- $ref: "#/components/parameters/SignatureExpiration"
|
||||
responses:
|
||||
"200":
|
||||
$ref: "#/components/responses/DownloadSuccess"
|
||||
"401":
|
||||
$ref: "#/components/responses/InvalidUser"
|
||||
"400":
|
||||
$ref: "#/components/responses/InvalidPath"
|
||||
"404":
|
||||
$ref: "#/components/responses/FileNotFound"
|
||||
"500":
|
||||
$ref: "#/components/responses/InternalServerError"
|
||||
post:
|
||||
summary: Upload a file and ensure the parent directories exist. If the file exists, it will be overwritten.
|
||||
tags: [files]
|
||||
security:
|
||||
- AccessTokenAuth: []
|
||||
- {}
|
||||
parameters:
|
||||
- $ref: "#/components/parameters/FilePath"
|
||||
- $ref: "#/components/parameters/User"
|
||||
- $ref: "#/components/parameters/Signature"
|
||||
- $ref: "#/components/parameters/SignatureExpiration"
|
||||
requestBody:
|
||||
$ref: "#/components/requestBodies/File"
|
||||
responses:
|
||||
"200":
|
||||
$ref: "#/components/responses/UploadSuccess"
|
||||
"400":
|
||||
$ref: "#/components/responses/InvalidPath"
|
||||
"401":
|
||||
$ref: "#/components/responses/InvalidUser"
|
||||
"500":
|
||||
$ref: "#/components/responses/InternalServerError"
|
||||
"507":
|
||||
$ref: "#/components/responses/NotEnoughDiskSpace"
|
||||
|
||||
components:
|
||||
securitySchemes:
|
||||
AccessTokenAuth:
|
||||
type: apiKey
|
||||
in: header
|
||||
name: X-Access-Token
|
||||
|
||||
parameters:
|
||||
FilePath:
|
||||
name: path
|
||||
in: query
|
||||
required: false
|
||||
description: Path to the file, URL encoded. Can be relative to user's home directory.
|
||||
schema:
|
||||
type: string
|
||||
User:
|
||||
name: username
|
||||
in: query
|
||||
required: false
|
||||
description: User used for setting the owner, or resolving relative paths.
|
||||
schema:
|
||||
type: string
|
||||
Signature:
|
||||
name: signature
|
||||
in: query
|
||||
required: false
|
||||
description: Signature used for file access permission verification.
|
||||
schema:
|
||||
type: string
|
||||
SignatureExpiration:
|
||||
name: signature_expiration
|
||||
in: query
|
||||
required: false
|
||||
description: Signature expiration used for defining the expiration time of the signature.
|
||||
schema:
|
||||
type: integer
|
||||
|
||||
requestBodies:
|
||||
File:
|
||||
required: true
|
||||
content:
|
||||
multipart/form-data:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
file:
|
||||
type: string
|
||||
format: binary
|
||||
|
||||
responses:
|
||||
UploadSuccess:
|
||||
description: The file was uploaded successfully.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/EntryInfo"
|
||||
|
||||
DownloadSuccess:
|
||||
description: Entire file downloaded successfully.
|
||||
content:
|
||||
application/octet-stream:
|
||||
schema:
|
||||
type: string
|
||||
format: binary
|
||||
description: The file content
|
||||
InvalidPath:
|
||||
description: Invalid path
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
InternalServerError:
|
||||
description: Internal server error
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
FileNotFound:
|
||||
description: File not found
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
InvalidUser:
|
||||
description: Invalid user
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
NotEnoughDiskSpace:
|
||||
description: Not enough disk space
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
schemas:
|
||||
Error:
|
||||
required:
|
||||
- message
|
||||
- code
|
||||
properties:
|
||||
message:
|
||||
type: string
|
||||
description: Error message
|
||||
code:
|
||||
type: integer
|
||||
description: Error code
|
||||
EntryInfo:
|
||||
required:
|
||||
- path
|
||||
- name
|
||||
- type
|
||||
properties:
|
||||
path:
|
||||
type: string
|
||||
description: Path to the file
|
||||
name:
|
||||
type: string
|
||||
description: Name of the file
|
||||
type:
|
||||
type: string
|
||||
description: Type of the file
|
||||
enum:
|
||||
- file
|
||||
EnvVars:
|
||||
type: object
|
||||
description: Environment variables to set
|
||||
additionalProperties:
|
||||
type: string
|
||||
Metrics:
|
||||
type: object
|
||||
description: Resource usage metrics
|
||||
properties:
|
||||
ts:
|
||||
type: integer
|
||||
format: int64
|
||||
description: Unix timestamp in UTC for current sandbox time
|
||||
cpu_count:
|
||||
type: integer
|
||||
description: Number of CPU cores
|
||||
cpu_used_pct:
|
||||
type: number
|
||||
format: float
|
||||
description: CPU usage percentage
|
||||
mem_total:
|
||||
type: integer
|
||||
description: Total virtual memory in bytes
|
||||
mem_used:
|
||||
type: integer
|
||||
description: Used virtual memory in bytes
|
||||
disk_used:
|
||||
type: integer
|
||||
description: Used disk space in bytes
|
||||
disk_total:
|
||||
type: integer
|
||||
description: Total disk space in bytes
|
||||
VolumeMount:
|
||||
type: object
|
||||
description: Volume
|
||||
additionalProperties: false
|
||||
properties:
|
||||
nfs_target:
|
||||
type: string
|
||||
path:
|
||||
type: string
|
||||
required:
|
||||
- nfs_target
|
||||
- path
|
||||
135
envd/spec/filesystem/filesystem.proto
Normal file
135
envd/spec/filesystem/filesystem.proto
Normal file
@ -0,0 +1,135 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package filesystem;
|
||||
|
||||
import "google/protobuf/timestamp.proto";
|
||||
|
||||
service Filesystem {
|
||||
rpc Stat(StatRequest) returns (StatResponse);
|
||||
rpc MakeDir(MakeDirRequest) returns (MakeDirResponse);
|
||||
rpc Move(MoveRequest) returns (MoveResponse);
|
||||
rpc ListDir(ListDirRequest) returns (ListDirResponse);
|
||||
rpc Remove(RemoveRequest) returns (RemoveResponse);
|
||||
|
||||
rpc WatchDir(WatchDirRequest) returns (stream WatchDirResponse);
|
||||
|
||||
// Non-streaming versions of WatchDir
|
||||
rpc CreateWatcher(CreateWatcherRequest) returns (CreateWatcherResponse);
|
||||
rpc GetWatcherEvents(GetWatcherEventsRequest) returns (GetWatcherEventsResponse);
|
||||
rpc RemoveWatcher(RemoveWatcherRequest) returns (RemoveWatcherResponse);
|
||||
}
|
||||
|
||||
message MoveRequest {
|
||||
string source = 1;
|
||||
string destination = 2;
|
||||
}
|
||||
|
||||
message MoveResponse {
|
||||
EntryInfo entry = 1;
|
||||
}
|
||||
|
||||
message MakeDirRequest {
|
||||
string path = 1;
|
||||
}
|
||||
|
||||
message MakeDirResponse {
|
||||
EntryInfo entry = 1;
|
||||
}
|
||||
|
||||
message RemoveRequest {
|
||||
string path = 1;
|
||||
}
|
||||
|
||||
message RemoveResponse {}
|
||||
|
||||
message StatRequest {
|
||||
string path = 1;
|
||||
}
|
||||
|
||||
message StatResponse {
|
||||
EntryInfo entry = 1;
|
||||
}
|
||||
|
||||
message EntryInfo {
|
||||
string name = 1;
|
||||
FileType type = 2;
|
||||
string path = 3;
|
||||
int64 size = 4;
|
||||
uint32 mode = 5;
|
||||
string permissions = 6;
|
||||
string owner = 7;
|
||||
string group = 8;
|
||||
google.protobuf.Timestamp modified_time = 9;
|
||||
// If the entry is a symlink, this field contains the target of the symlink.
|
||||
optional string symlink_target = 10;
|
||||
}
|
||||
|
||||
enum FileType {
|
||||
FILE_TYPE_UNSPECIFIED = 0;
|
||||
FILE_TYPE_FILE = 1;
|
||||
FILE_TYPE_DIRECTORY = 2;
|
||||
FILE_TYPE_SYMLINK = 3;
|
||||
}
|
||||
|
||||
message ListDirRequest {
|
||||
string path = 1;
|
||||
uint32 depth = 2;
|
||||
}
|
||||
|
||||
message ListDirResponse {
|
||||
repeated EntryInfo entries = 1;
|
||||
}
|
||||
|
||||
message WatchDirRequest {
|
||||
string path = 1;
|
||||
bool recursive = 2;
|
||||
}
|
||||
|
||||
message FilesystemEvent {
|
||||
string name = 1;
|
||||
EventType type = 2;
|
||||
}
|
||||
|
||||
message WatchDirResponse {
|
||||
oneof event {
|
||||
StartEvent start = 1;
|
||||
FilesystemEvent filesystem = 2;
|
||||
KeepAlive keepalive = 3;
|
||||
}
|
||||
|
||||
message StartEvent {}
|
||||
|
||||
message KeepAlive {}
|
||||
}
|
||||
|
||||
message CreateWatcherRequest {
|
||||
string path = 1;
|
||||
bool recursive = 2;
|
||||
}
|
||||
|
||||
message CreateWatcherResponse {
|
||||
string watcher_id = 1;
|
||||
}
|
||||
|
||||
message GetWatcherEventsRequest {
|
||||
string watcher_id = 1;
|
||||
}
|
||||
|
||||
message GetWatcherEventsResponse {
|
||||
repeated FilesystemEvent events = 1;
|
||||
}
|
||||
|
||||
message RemoveWatcherRequest {
|
||||
string watcher_id = 1;
|
||||
}
|
||||
|
||||
message RemoveWatcherResponse {}
|
||||
|
||||
enum EventType {
|
||||
EVENT_TYPE_UNSPECIFIED = 0;
|
||||
EVENT_TYPE_CREATE = 1;
|
||||
EVENT_TYPE_WRITE = 2;
|
||||
EVENT_TYPE_REMOVE = 3;
|
||||
EVENT_TYPE_RENAME = 4;
|
||||
EVENT_TYPE_CHMOD = 5;
|
||||
}
|
||||
3
envd/spec/generate.go
Normal file
3
envd/spec/generate.go
Normal file
@ -0,0 +1,3 @@
|
||||
package spec
|
||||
|
||||
//go:generate buf generate --template buf.gen.yaml
|
||||
171
envd/spec/process/process.proto
Normal file
171
envd/spec/process/process.proto
Normal file
@ -0,0 +1,171 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package process;
|
||||
|
||||
service Process {
|
||||
rpc List(ListRequest) returns (ListResponse);
|
||||
|
||||
rpc Connect(ConnectRequest) returns (stream ConnectResponse);
|
||||
rpc Start(StartRequest) returns (stream StartResponse);
|
||||
|
||||
rpc Update(UpdateRequest) returns (UpdateResponse);
|
||||
|
||||
// Client input stream ensures ordering of messages
|
||||
rpc StreamInput(stream StreamInputRequest) returns (StreamInputResponse);
|
||||
rpc SendInput(SendInputRequest) returns (SendInputResponse);
|
||||
rpc SendSignal(SendSignalRequest) returns (SendSignalResponse);
|
||||
|
||||
// Close stdin to signal EOF to the process.
|
||||
// Only works for non-PTY processes. For PTY, send Ctrl+D (0x04) instead.
|
||||
rpc CloseStdin(CloseStdinRequest) returns (CloseStdinResponse);
|
||||
}
|
||||
|
||||
message PTY {
|
||||
Size size = 1;
|
||||
|
||||
message Size {
|
||||
uint32 cols = 1;
|
||||
uint32 rows = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message ProcessConfig {
|
||||
string cmd = 1;
|
||||
repeated string args = 2;
|
||||
|
||||
map<string, string> envs = 3;
|
||||
optional string cwd = 4;
|
||||
}
|
||||
|
||||
message ListRequest {}
|
||||
|
||||
message ProcessInfo {
|
||||
ProcessConfig config = 1;
|
||||
uint32 pid = 2;
|
||||
optional string tag = 3;
|
||||
}
|
||||
|
||||
message ListResponse {
|
||||
repeated ProcessInfo processes = 1;
|
||||
}
|
||||
|
||||
message StartRequest {
|
||||
ProcessConfig process = 1;
|
||||
optional PTY pty = 2;
|
||||
optional string tag = 3;
|
||||
// This is optional for backwards compatibility.
|
||||
// We default to true. New SDK versions will set this to false by default.
|
||||
optional bool stdin = 4;
|
||||
}
|
||||
|
||||
message UpdateRequest {
|
||||
ProcessSelector process = 1;
|
||||
|
||||
optional PTY pty = 2;
|
||||
}
|
||||
|
||||
message UpdateResponse {}
|
||||
|
||||
message ProcessEvent {
|
||||
oneof event {
|
||||
StartEvent start = 1;
|
||||
DataEvent data = 2;
|
||||
EndEvent end = 3;
|
||||
KeepAlive keepalive = 4;
|
||||
}
|
||||
|
||||
message StartEvent {
|
||||
uint32 pid = 1;
|
||||
}
|
||||
|
||||
message DataEvent {
|
||||
oneof output {
|
||||
bytes stdout = 1;
|
||||
bytes stderr = 2;
|
||||
bytes pty = 3;
|
||||
}
|
||||
}
|
||||
|
||||
message EndEvent {
|
||||
sint32 exit_code = 1;
|
||||
bool exited = 2;
|
||||
string status = 3;
|
||||
optional string error = 4;
|
||||
}
|
||||
|
||||
message KeepAlive {}
|
||||
}
|
||||
|
||||
message StartResponse {
|
||||
ProcessEvent event = 1;
|
||||
}
|
||||
|
||||
message ConnectResponse {
|
||||
ProcessEvent event = 1;
|
||||
}
|
||||
|
||||
message SendInputRequest {
|
||||
ProcessSelector process = 1;
|
||||
|
||||
ProcessInput input = 2;
|
||||
}
|
||||
|
||||
message SendInputResponse {}
|
||||
|
||||
message ProcessInput {
|
||||
oneof input {
|
||||
bytes stdin = 1;
|
||||
bytes pty = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message StreamInputRequest {
|
||||
oneof event {
|
||||
StartEvent start = 1;
|
||||
DataEvent data = 2;
|
||||
KeepAlive keepalive = 3;
|
||||
}
|
||||
|
||||
message StartEvent {
|
||||
ProcessSelector process = 1;
|
||||
}
|
||||
|
||||
message DataEvent {
|
||||
ProcessInput input = 2;
|
||||
}
|
||||
|
||||
message KeepAlive {}
|
||||
}
|
||||
|
||||
message StreamInputResponse {}
|
||||
|
||||
enum Signal {
|
||||
SIGNAL_UNSPECIFIED = 0;
|
||||
SIGNAL_SIGTERM = 15;
|
||||
SIGNAL_SIGKILL = 9;
|
||||
}
|
||||
|
||||
message SendSignalRequest {
|
||||
ProcessSelector process = 1;
|
||||
|
||||
Signal signal = 2;
|
||||
}
|
||||
|
||||
message SendSignalResponse {}
|
||||
|
||||
message CloseStdinRequest {
|
||||
ProcessSelector process = 1;
|
||||
}
|
||||
|
||||
message CloseStdinResponse {}
|
||||
|
||||
message ConnectRequest {
|
||||
ProcessSelector process = 1;
|
||||
}
|
||||
|
||||
message ProcessSelector {
|
||||
oneof selector {
|
||||
uint32 pid = 1;
|
||||
string tag = 2;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user