1
0
forked from wrenn/wrenn

Prototype with single host server and no admin panel (#2)

Reviewed-on: wrenn/sandbox#2
Co-authored-by: pptx704 <rafeed@omukk.dev>
Co-committed-by: pptx704 <rafeed@omukk.dev>
This commit is contained in:
2026-03-22 21:01:23 +00:00
committed by Rafeed M. Bhuiyan
parent bd78cc068c
commit 32e5a5a715
293 changed files with 46885 additions and 1033 deletions

201
envd/LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2023 FoundryLabs, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -1,17 +1,62 @@
LDFLAGS := -s -w
BUILD := $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
LDFLAGS := -s -w -X=main.commitSHA=$(BUILD)
BUILDS := ../builds
.PHONY: build clean fmt vet
# ═══════════════════════════════════════════════════
# Build
# ═══════════════════════════════════════════════════
.PHONY: build build-debug
build:
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="$(LDFLAGS)" -o envd .
@file envd | grep -q "statically linked" || \
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="$(LDFLAGS)" -o $(BUILDS)/envd .
@file $(BUILDS)/envd | grep -q "statically linked" || \
(echo "ERROR: envd is not statically linked!" && exit 1)
clean:
rm -f envd
build-debug:
CGO_ENABLED=1 go build -race -gcflags=all="-N -l" -ldflags="-X=main.commitSHA=$(BUILD)" -o $(BUILDS)/debug/envd .
# ═══════════════════════════════════════════════════
# Run (debug mode, not inside a VM)
# ═══════════════════════════════════════════════════
.PHONY: run-debug
run-debug: build-debug
$(BUILDS)/debug/envd -isnotfc -port 49983
# ═══════════════════════════════════════════════════
# Code Generation
# ═══════════════════════════════════════════════════
.PHONY: generate proto openapi
generate: proto openapi
proto:
cd spec && buf generate --template buf.gen.yaml
openapi:
go generate ./internal/api/...
# ═══════════════════════════════════════════════════
# Quality
# ═══════════════════════════════════════════════════
.PHONY: fmt vet test tidy
fmt:
gofmt -w .
vet:
go vet ./...
test:
go test -race -v ./...
tidy:
go mod tidy
# ═══════════════════════════════════════════════════
# Clean
# ═══════════════════════════════════════════════════
.PHONY: clean
clean:
rm -f $(BUILDS)/envd $(BUILDS)/debug/envd

View File

@ -1,9 +1,42 @@
module github.com/wrenn-dev/envd
module git.omukk.dev/wrenn/sandbox/envd
go 1.23.0
go 1.25.5
require (
github.com/mdlayher/vsock v1.2.1
google.golang.org/grpc v1.71.0
google.golang.org/protobuf v1.36.5
connectrpc.com/authn v0.1.0
connectrpc.com/connect v1.19.1
connectrpc.com/cors v0.1.0
github.com/awnumar/memguard v0.23.0
github.com/creack/pty v1.1.24
github.com/dchest/uniuri v1.2.0
github.com/e2b-dev/fsnotify v0.0.1
github.com/go-chi/chi/v5 v5.2.5
github.com/google/uuid v1.6.0
github.com/oapi-codegen/runtime v1.2.0
github.com/orcaman/concurrent-map/v2 v2.0.1
github.com/rs/cors v1.11.1
github.com/rs/zerolog v1.34.0
github.com/shirou/gopsutil/v4 v4.26.2
github.com/stretchr/testify v1.11.1
github.com/txn2/txeh v1.8.0
golang.org/x/sys v0.42.0
google.golang.org/protobuf v1.36.11
)
require (
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
github.com/awnumar/memcall v0.4.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/ebitengine/purego v0.10.0 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
github.com/tklauser/go-sysconf v0.3.16 // indirect
github.com/tklauser/numcpus v0.11.0 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
golang.org/x/crypto v0.41.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

92
envd/go.sum Normal file
View File

@ -0,0 +1,92 @@
connectrpc.com/authn v0.1.0 h1:m5weACjLWwgwcjttvUDyTPICJKw74+p2obBVrf8hT9E=
connectrpc.com/authn v0.1.0/go.mod h1:AwNZK/KYbqaJzRYadTuAaoz6sYQSPdORPqh1TOPIkgY=
connectrpc.com/connect v1.19.1 h1:R5M57z05+90EfEvCY1b7hBxDVOUl45PrtXtAV2fOC14=
connectrpc.com/connect v1.19.1/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w=
connectrpc.com/cors v0.1.0 h1:f3gTXJyDZPrDIZCQ567jxfD9PAIpopHiRDnJRt3QuOQ=
connectrpc.com/cors v0.1.0/go.mod h1:v8SJZCPfHtGH1zsm+Ttajpozd4cYIUryl4dFB6QEpfg=
github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk=
github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ=
github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk=
github.com/awnumar/memcall v0.4.0 h1:B7hgZYdfH6Ot1Goaz8jGne/7i8xD4taZie/PNSFZ29g=
github.com/awnumar/memcall v0.4.0/go.mod h1:8xOx1YbfyuCg3Fy6TO8DK0kZUua3V42/goA5Ru47E8w=
github.com/awnumar/memguard v0.23.0 h1:sJ3a1/SWlcuKIQ7MV+R9p0Pvo9CWsMbGZvcZQtmc68A=
github.com/awnumar/memguard v0.23.0/go.mod h1:olVofBrsPdITtJ2HgxQKrEYEMyIBAIciVG4wNnZhW9M=
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dchest/uniuri v1.2.0 h1:koIcOUdrTIivZgSLhHQvKgqdWZq5d7KdMEWF1Ud6+5g=
github.com/dchest/uniuri v1.2.0/go.mod h1:fSzm4SLHzNZvWLvWJew423PhAzkpNQYq+uNLq4kxhkY=
github.com/e2b-dev/fsnotify v0.0.1 h1:7j0I98HD6VehAuK/bcslvW4QDynAULtOuMZtImihjVk=
github.com/e2b-dev/fsnotify v0.0.1/go.mod h1:jAuDjregRrUixKneTRQwPI847nNuPFg3+n5QM/ku/JM=
github.com/ebitengine/purego v0.10.0 h1:QIw4xfpWT6GWTzaW5XEKy3HXoqrJGx1ijYHzTF0/ISU=
github.com/ebitengine/purego v0.10.0/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/oapi-codegen/runtime v1.2.0 h1:RvKc1CVS1QeKSNzO97FBQbSMZyQ8s6rZd+LpmzwHMP4=
github.com/oapi-codegen/runtime v1.2.0/go.mod h1:Y7ZhmmlE8ikZOmuHRRndiIm7nf3xcVv+YMweKgG1DT0=
github.com/orcaman/concurrent-map/v2 v2.0.1 h1:jOJ5Pg2w1oeB6PeDurIYf6k9PQ+aTITr/6lP/L/zp6c=
github.com/orcaman/concurrent-map/v2 v2.0.1/go.mod h1:9Eq3TG2oBe5FirmYWQfYO5iH1q0Jv47PLaNK++uCdOM=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA=
github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/shirou/gopsutil/v4 v4.26.2 h1:X8i6sicvUFih4BmYIGT1m2wwgw2VG9YgrDTi7cIRGUI=
github.com/shirou/gopsutil/v4 v4.26.2/go.mod h1:LZ6ewCSkBqUpvSOf+LsTGnRinC6iaNUNMGBtDkJBaLQ=
github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYICU0nA=
github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI=
github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw=
github.com/tklauser/numcpus v0.11.0/go.mod h1:z+LwcLq54uWZTX0u/bGobaV34u6V7KNlTZejzM6/3MQ=
github.com/txn2/txeh v1.8.0 h1:G1vZgom6+P/xWwU53AMOpcZgC5ni382ukcPP1TDVYHk=
github.com/txn2/txeh v1.8.0/go.mod h1:rRI3Egi3+AFmEXQjft051YdYbxeCT3nFmBLsNCZZaxM=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
pgregory.net/rapid v1.2.0 h1:keKAYRcjm+e1F0oAuU5F5+YPAWcyxNNRK2wud503Gnk=
pgregory.net/rapid v1.2.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04=

View File

@ -0,0 +1,568 @@
// Package api provides primitives to interact with the openapi HTTP API.
//
// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.5.1 DO NOT EDIT.
package api
import (
"context"
"fmt"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/oapi-codegen/runtime"
openapi_types "github.com/oapi-codegen/runtime/types"
)
const (
AccessTokenAuthScopes = "AccessTokenAuth.Scopes"
)
// Defines values for EntryInfoType.
const (
File EntryInfoType = "file"
)
// EntryInfo defines model for EntryInfo.
type EntryInfo struct {
// Name Name of the file
Name string `json:"name"`
// Path Path to the file
Path string `json:"path"`
// Type Type of the file
Type EntryInfoType `json:"type"`
}
// EntryInfoType Type of the file
type EntryInfoType string
// EnvVars Environment variables to set
type EnvVars map[string]string
// Error defines model for Error.
type Error struct {
// Code Error code
Code int `json:"code"`
// Message Error message
Message string `json:"message"`
}
// Metrics Resource usage metrics
type Metrics struct {
// CpuCount Number of CPU cores
CpuCount *int `json:"cpu_count,omitempty"`
// CpuUsedPct CPU usage percentage
CpuUsedPct *float32 `json:"cpu_used_pct,omitempty"`
// DiskTotal Total disk space in bytes
DiskTotal *int `json:"disk_total,omitempty"`
// DiskUsed Used disk space in bytes
DiskUsed *int `json:"disk_used,omitempty"`
// MemTotal Total virtual memory in bytes
MemTotal *int `json:"mem_total,omitempty"`
// MemUsed Used virtual memory in bytes
MemUsed *int `json:"mem_used,omitempty"`
// Ts Unix timestamp in UTC for current sandbox time
Ts *int64 `json:"ts,omitempty"`
}
// VolumeMount Volume
type VolumeMount struct {
NfsTarget string `json:"nfs_target"`
Path string `json:"path"`
}
// FilePath defines model for FilePath.
type FilePath = string
// Signature defines model for Signature.
type Signature = string
// SignatureExpiration defines model for SignatureExpiration.
type SignatureExpiration = int
// User defines model for User.
type User = string
// FileNotFound defines model for FileNotFound.
type FileNotFound = Error
// InternalServerError defines model for InternalServerError.
type InternalServerError = Error
// InvalidPath defines model for InvalidPath.
type InvalidPath = Error
// InvalidUser defines model for InvalidUser.
type InvalidUser = Error
// NotEnoughDiskSpace defines model for NotEnoughDiskSpace.
type NotEnoughDiskSpace = Error
// UploadSuccess defines model for UploadSuccess.
type UploadSuccess = []EntryInfo
// GetFilesParams defines parameters for GetFiles.
type GetFilesParams struct {
// Path Path to the file, URL encoded. Can be relative to user's home directory.
Path *FilePath `form:"path,omitempty" json:"path,omitempty"`
// Username User used for setting the owner, or resolving relative paths.
Username *User `form:"username,omitempty" json:"username,omitempty"`
// Signature Signature used for file access permission verification.
Signature *Signature `form:"signature,omitempty" json:"signature,omitempty"`
// SignatureExpiration Signature expiration used for defining the expiration time of the signature.
SignatureExpiration *SignatureExpiration `form:"signature_expiration,omitempty" json:"signature_expiration,omitempty"`
}
// PostFilesMultipartBody defines parameters for PostFiles.
type PostFilesMultipartBody struct {
File *openapi_types.File `json:"file,omitempty"`
}
// PostFilesParams defines parameters for PostFiles.
type PostFilesParams struct {
// Path Path to the file, URL encoded. Can be relative to user's home directory.
Path *FilePath `form:"path,omitempty" json:"path,omitempty"`
// Username User used for setting the owner, or resolving relative paths.
Username *User `form:"username,omitempty" json:"username,omitempty"`
// Signature Signature used for file access permission verification.
Signature *Signature `form:"signature,omitempty" json:"signature,omitempty"`
// SignatureExpiration Signature expiration used for defining the expiration time of the signature.
SignatureExpiration *SignatureExpiration `form:"signature_expiration,omitempty" json:"signature_expiration,omitempty"`
}
// PostInitJSONBody defines parameters for PostInit.
type PostInitJSONBody struct {
// AccessToken Access token for secure access to envd service
AccessToken *SecureToken `json:"accessToken,omitempty"`
// DefaultUser The default user to use for operations
DefaultUser *string `json:"defaultUser,omitempty"`
// DefaultWorkdir The default working directory to use for operations
DefaultWorkdir *string `json:"defaultWorkdir,omitempty"`
// EnvVars Environment variables to set
EnvVars *EnvVars `json:"envVars,omitempty"`
// HyperloopIP IP address of the hyperloop server to connect to
HyperloopIP *string `json:"hyperloopIP,omitempty"`
// Timestamp The current timestamp in RFC3339 format
Timestamp *time.Time `json:"timestamp,omitempty"`
VolumeMounts *[]VolumeMount `json:"volumeMounts,omitempty"`
}
// PostFilesMultipartRequestBody defines body for PostFiles for multipart/form-data ContentType.
type PostFilesMultipartRequestBody PostFilesMultipartBody
// PostInitJSONRequestBody defines body for PostInit for application/json ContentType.
type PostInitJSONRequestBody PostInitJSONBody
// ServerInterface represents all server handlers.
type ServerInterface interface {
// Get the environment variables
// (GET /envs)
GetEnvs(w http.ResponseWriter, r *http.Request)
// Download a file
// (GET /files)
GetFiles(w http.ResponseWriter, r *http.Request, params GetFilesParams)
// Upload a file and ensure the parent directories exist. If the file exists, it will be overwritten.
// (POST /files)
PostFiles(w http.ResponseWriter, r *http.Request, params PostFilesParams)
// Check the health of the service
// (GET /health)
GetHealth(w http.ResponseWriter, r *http.Request)
// Set initial vars, ensure the time and metadata is synced with the host
// (POST /init)
PostInit(w http.ResponseWriter, r *http.Request)
// Get the stats of the service
// (GET /metrics)
GetMetrics(w http.ResponseWriter, r *http.Request)
}
// Unimplemented server implementation that returns http.StatusNotImplemented for each endpoint.
type Unimplemented struct{}
// Get the environment variables
// (GET /envs)
func (_ Unimplemented) GetEnvs(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotImplemented)
}
// Download a file
// (GET /files)
func (_ Unimplemented) GetFiles(w http.ResponseWriter, r *http.Request, params GetFilesParams) {
w.WriteHeader(http.StatusNotImplemented)
}
// Upload a file and ensure the parent directories exist. If the file exists, it will be overwritten.
// (POST /files)
func (_ Unimplemented) PostFiles(w http.ResponseWriter, r *http.Request, params PostFilesParams) {
w.WriteHeader(http.StatusNotImplemented)
}
// Check the health of the service
// (GET /health)
func (_ Unimplemented) GetHealth(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotImplemented)
}
// Set initial vars, ensure the time and metadata is synced with the host
// (POST /init)
func (_ Unimplemented) PostInit(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotImplemented)
}
// Get the stats of the service
// (GET /metrics)
func (_ Unimplemented) GetMetrics(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotImplemented)
}
// ServerInterfaceWrapper converts contexts to parameters.
type ServerInterfaceWrapper struct {
Handler ServerInterface
HandlerMiddlewares []MiddlewareFunc
ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error)
}
type MiddlewareFunc func(http.Handler) http.Handler
// GetEnvs operation middleware
func (siw *ServerInterfaceWrapper) GetEnvs(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
r = r.WithContext(ctx)
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
siw.Handler.GetEnvs(w, r)
}))
for _, middleware := range siw.HandlerMiddlewares {
handler = middleware(handler)
}
handler.ServeHTTP(w, r)
}
// GetFiles operation middleware
func (siw *ServerInterfaceWrapper) GetFiles(w http.ResponseWriter, r *http.Request) {
var err error
ctx := r.Context()
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
r = r.WithContext(ctx)
// Parameter object where we will unmarshal all parameters from the context
var params GetFilesParams
// ------------- Optional query parameter "path" -------------
err = runtime.BindQueryParameter("form", true, false, "path", r.URL.Query(), &params.Path)
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "path", Err: err})
return
}
// ------------- Optional query parameter "username" -------------
err = runtime.BindQueryParameter("form", true, false, "username", r.URL.Query(), &params.Username)
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "username", Err: err})
return
}
// ------------- Optional query parameter "signature" -------------
err = runtime.BindQueryParameter("form", true, false, "signature", r.URL.Query(), &params.Signature)
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature", Err: err})
return
}
// ------------- Optional query parameter "signature_expiration" -------------
err = runtime.BindQueryParameter("form", true, false, "signature_expiration", r.URL.Query(), &params.SignatureExpiration)
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature_expiration", Err: err})
return
}
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
siw.Handler.GetFiles(w, r, params)
}))
for _, middleware := range siw.HandlerMiddlewares {
handler = middleware(handler)
}
handler.ServeHTTP(w, r)
}
// PostFiles operation middleware
func (siw *ServerInterfaceWrapper) PostFiles(w http.ResponseWriter, r *http.Request) {
var err error
ctx := r.Context()
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
r = r.WithContext(ctx)
// Parameter object where we will unmarshal all parameters from the context
var params PostFilesParams
// ------------- Optional query parameter "path" -------------
err = runtime.BindQueryParameter("form", true, false, "path", r.URL.Query(), &params.Path)
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "path", Err: err})
return
}
// ------------- Optional query parameter "username" -------------
err = runtime.BindQueryParameter("form", true, false, "username", r.URL.Query(), &params.Username)
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "username", Err: err})
return
}
// ------------- Optional query parameter "signature" -------------
err = runtime.BindQueryParameter("form", true, false, "signature", r.URL.Query(), &params.Signature)
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature", Err: err})
return
}
// ------------- Optional query parameter "signature_expiration" -------------
err = runtime.BindQueryParameter("form", true, false, "signature_expiration", r.URL.Query(), &params.SignatureExpiration)
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature_expiration", Err: err})
return
}
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
siw.Handler.PostFiles(w, r, params)
}))
for _, middleware := range siw.HandlerMiddlewares {
handler = middleware(handler)
}
handler.ServeHTTP(w, r)
}
// GetHealth operation middleware
func (siw *ServerInterfaceWrapper) GetHealth(w http.ResponseWriter, r *http.Request) {
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
siw.Handler.GetHealth(w, r)
}))
for _, middleware := range siw.HandlerMiddlewares {
handler = middleware(handler)
}
handler.ServeHTTP(w, r)
}
// PostInit operation middleware
func (siw *ServerInterfaceWrapper) PostInit(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
r = r.WithContext(ctx)
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
siw.Handler.PostInit(w, r)
}))
for _, middleware := range siw.HandlerMiddlewares {
handler = middleware(handler)
}
handler.ServeHTTP(w, r)
}
// GetMetrics operation middleware
func (siw *ServerInterfaceWrapper) GetMetrics(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
r = r.WithContext(ctx)
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
siw.Handler.GetMetrics(w, r)
}))
for _, middleware := range siw.HandlerMiddlewares {
handler = middleware(handler)
}
handler.ServeHTTP(w, r)
}
type UnescapedCookieParamError struct {
ParamName string
Err error
}
func (e *UnescapedCookieParamError) Error() string {
return fmt.Sprintf("error unescaping cookie parameter '%s'", e.ParamName)
}
func (e *UnescapedCookieParamError) Unwrap() error {
return e.Err
}
type UnmarshalingParamError struct {
ParamName string
Err error
}
func (e *UnmarshalingParamError) Error() string {
return fmt.Sprintf("Error unmarshaling parameter %s as JSON: %s", e.ParamName, e.Err.Error())
}
func (e *UnmarshalingParamError) Unwrap() error {
return e.Err
}
type RequiredParamError struct {
ParamName string
}
func (e *RequiredParamError) Error() string {
return fmt.Sprintf("Query argument %s is required, but not found", e.ParamName)
}
type RequiredHeaderError struct {
ParamName string
Err error
}
func (e *RequiredHeaderError) Error() string {
return fmt.Sprintf("Header parameter %s is required, but not found", e.ParamName)
}
func (e *RequiredHeaderError) Unwrap() error {
return e.Err
}
type InvalidParamFormatError struct {
ParamName string
Err error
}
func (e *InvalidParamFormatError) Error() string {
return fmt.Sprintf("Invalid format for parameter %s: %s", e.ParamName, e.Err.Error())
}
func (e *InvalidParamFormatError) Unwrap() error {
return e.Err
}
type TooManyValuesForParamError struct {
ParamName string
Count int
}
func (e *TooManyValuesForParamError) Error() string {
return fmt.Sprintf("Expected one value for %s, got %d", e.ParamName, e.Count)
}
// Handler creates http.Handler with routing matching OpenAPI spec.
func Handler(si ServerInterface) http.Handler {
return HandlerWithOptions(si, ChiServerOptions{})
}
type ChiServerOptions struct {
BaseURL string
BaseRouter chi.Router
Middlewares []MiddlewareFunc
ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error)
}
// HandlerFromMux creates http.Handler with routing matching OpenAPI spec based on the provided mux.
func HandlerFromMux(si ServerInterface, r chi.Router) http.Handler {
return HandlerWithOptions(si, ChiServerOptions{
BaseRouter: r,
})
}
func HandlerFromMuxWithBaseURL(si ServerInterface, r chi.Router, baseURL string) http.Handler {
return HandlerWithOptions(si, ChiServerOptions{
BaseURL: baseURL,
BaseRouter: r,
})
}
// HandlerWithOptions creates http.Handler with additional options
func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handler {
r := options.BaseRouter
if r == nil {
r = chi.NewRouter()
}
if options.ErrorHandlerFunc == nil {
options.ErrorHandlerFunc = func(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, err.Error(), http.StatusBadRequest)
}
}
wrapper := ServerInterfaceWrapper{
Handler: si,
HandlerMiddlewares: options.Middlewares,
ErrorHandlerFunc: options.ErrorHandlerFunc,
}
r.Group(func(r chi.Router) {
r.Get(options.BaseURL+"/envs", wrapper.GetEnvs)
})
r.Group(func(r chi.Router) {
r.Get(options.BaseURL+"/files", wrapper.GetFiles)
})
r.Group(func(r chi.Router) {
r.Post(options.BaseURL+"/files", wrapper.PostFiles)
})
r.Group(func(r chi.Router) {
r.Get(options.BaseURL+"/health", wrapper.GetHealth)
})
r.Group(func(r chi.Router) {
r.Post(options.BaseURL+"/init", wrapper.PostInit)
})
r.Group(func(r chi.Router) {
r.Get(options.BaseURL+"/metrics", wrapper.GetMetrics)
})
return r
}

131
envd/internal/api/auth.go Normal file
View File

@ -0,0 +1,131 @@
// SPDX-License-Identifier: Apache-2.0
package api
import (
"errors"
"fmt"
"net/http"
"slices"
"strconv"
"strings"
"time"
"github.com/awnumar/memguard"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
)
const (
SigningReadOperation = "read"
SigningWriteOperation = "write"
accessTokenHeader = "X-Access-Token"
)
// paths that are always allowed without general authentication
// POST/init is secured via MMDS hash validation instead
var authExcludedPaths = []string{
"GET/health",
"GET/files",
"POST/files",
"POST/init",
}
func (a *API) WithAuthorization(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if a.accessToken.IsSet() {
authHeader := req.Header.Get(accessTokenHeader)
// check if this path is allowed without authentication (e.g., health check, endpoints supporting signing)
allowedPath := slices.Contains(authExcludedPaths, req.Method+req.URL.Path)
if !a.accessToken.Equals(authHeader) && !allowedPath {
a.logger.Error().Msg("Trying to access secured envd without correct access token")
err := fmt.Errorf("unauthorized access, please provide a valid access token or method signing if supported")
jsonError(w, http.StatusUnauthorized, err)
return
}
}
handler.ServeHTTP(w, req)
})
}
func (a *API) generateSignature(path string, username string, operation string, signatureExpiration *int64) (string, error) {
tokenBytes, err := a.accessToken.Bytes()
if err != nil {
return "", fmt.Errorf("access token is not set: %w", err)
}
defer memguard.WipeBytes(tokenBytes)
var signature string
hasher := keys.NewSHA256Hashing()
if signatureExpiration == nil {
signature = strings.Join([]string{path, operation, username, string(tokenBytes)}, ":")
} else {
signature = strings.Join([]string{path, operation, username, string(tokenBytes), strconv.FormatInt(*signatureExpiration, 10)}, ":")
}
return fmt.Sprintf("v1_%s", hasher.HashWithoutPrefix([]byte(signature))), nil
}
func (a *API) validateSigning(r *http.Request, signature *string, signatureExpiration *int, username *string, path string, operation string) (err error) {
var expectedSignature string
// no need to validate signing key if access token is not set
if !a.accessToken.IsSet() {
return nil
}
// check if access token is sent in the header
tokenFromHeader := r.Header.Get(accessTokenHeader)
if tokenFromHeader != "" {
if !a.accessToken.Equals(tokenFromHeader) {
return fmt.Errorf("access token present in header but does not match")
}
return nil
}
if signature == nil {
return fmt.Errorf("missing signature query parameter")
}
// Empty string is used when no username is provided and the default user should be used
signatureUsername := ""
if username != nil {
signatureUsername = *username
}
if signatureExpiration == nil {
expectedSignature, err = a.generateSignature(path, signatureUsername, operation, nil)
} else {
exp := int64(*signatureExpiration)
expectedSignature, err = a.generateSignature(path, signatureUsername, operation, &exp)
}
if err != nil {
a.logger.Error().Err(err).Msg("error generating signing key")
return errors.New("invalid signature")
}
// signature validation
if expectedSignature != *signature {
return fmt.Errorf("invalid signature")
}
// signature expiration
if signatureExpiration != nil {
exp := int64(*signatureExpiration)
if exp < time.Now().Unix() {
return fmt.Errorf("signature is already expired")
}
}
return nil
}

View File

@ -0,0 +1,64 @@
// SPDX-License-Identifier: Apache-2.0
package api
import (
"fmt"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
)
func TestKeyGenerationAlgorithmIsStable(t *testing.T) {
t.Parallel()
apiToken := "secret-access-token"
secureToken := &SecureToken{}
err := secureToken.Set([]byte(apiToken))
require.NoError(t, err)
api := &API{accessToken: secureToken}
path := "/path/to/demo.txt"
username := "root"
operation := "write"
timestamp := time.Now().Unix()
signature, err := api.generateSignature(path, username, operation, &timestamp)
require.NoError(t, err)
assert.NotEmpty(t, signature)
// locally generated signature
hasher := keys.NewSHA256Hashing()
localSignatureTmp := fmt.Sprintf("%s:%s:%s:%s:%s", path, operation, username, apiToken, strconv.FormatInt(timestamp, 10))
localSignature := fmt.Sprintf("v1_%s", hasher.HashWithoutPrefix([]byte(localSignatureTmp)))
assert.Equal(t, localSignature, signature)
}
func TestKeyGenerationAlgorithmWithoutExpirationIsStable(t *testing.T) {
t.Parallel()
apiToken := "secret-access-token"
secureToken := &SecureToken{}
err := secureToken.Set([]byte(apiToken))
require.NoError(t, err)
api := &API{accessToken: secureToken}
path := "/path/to/resource.txt"
username := "user"
operation := "read"
signature, err := api.generateSignature(path, username, operation, nil)
require.NoError(t, err)
assert.NotEmpty(t, signature)
// locally generated signature
hasher := keys.NewSHA256Hashing()
localSignatureTmp := fmt.Sprintf("%s:%s:%s:%s", path, operation, username, apiToken)
localSignature := fmt.Sprintf("v1_%s", hasher.HashWithoutPrefix([]byte(localSignatureTmp)))
assert.Equal(t, localSignature, signature)
}

View File

@ -0,0 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# yaml-language-server: $schema=https://raw.githubusercontent.com/deepmap/oapi-codegen/HEAD/configuration-schema.json
package: api
output: api.gen.go
generate:
models: true
chi-server: true
client: false

View File

@ -0,0 +1,175 @@
// SPDX-License-Identifier: Apache-2.0
package api
import (
"compress/gzip"
"errors"
"fmt"
"io"
"mime"
"net/http"
"os"
"os/user"
"path/filepath"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
)
func (a *API) GetFiles(w http.ResponseWriter, r *http.Request, params GetFilesParams) {
defer r.Body.Close()
var errorCode int
var errMsg error
var path string
if params.Path != nil {
path = *params.Path
}
operationID := logs.AssignOperationID()
// signing authorization if needed
err := a.validateSigning(r, params.Signature, params.SignatureExpiration, params.Username, path, SigningReadOperation)
if err != nil {
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("error during auth validation")
jsonError(w, http.StatusUnauthorized, err)
return
}
username, err := execcontext.ResolveDefaultUsername(params.Username, a.defaults.User)
if err != nil {
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("no user specified")
jsonError(w, http.StatusBadRequest, err)
return
}
defer func() {
l := a.logger.
Err(errMsg).
Str("method", r.Method+" "+r.URL.Path).
Str(string(logs.OperationIDKey), operationID).
Str("path", path).
Str("username", username)
if errMsg != nil {
l = l.Int("error_code", errorCode)
}
l.Msg("File read")
}()
u, err := user.Lookup(username)
if err != nil {
errMsg = fmt.Errorf("error looking up user '%s': %w", username, err)
errorCode = http.StatusUnauthorized
jsonError(w, errorCode, errMsg)
return
}
resolvedPath, err := permissions.ExpandAndResolve(path, u, a.defaults.Workdir)
if err != nil {
errMsg = fmt.Errorf("error expanding and resolving path '%s': %w", path, err)
errorCode = http.StatusBadRequest
jsonError(w, errorCode, errMsg)
return
}
stat, err := os.Stat(resolvedPath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
errMsg = fmt.Errorf("path '%s' does not exist", resolvedPath)
errorCode = http.StatusNotFound
jsonError(w, errorCode, errMsg)
return
}
errMsg = fmt.Errorf("error checking if path exists '%s': %w", resolvedPath, err)
errorCode = http.StatusInternalServerError
jsonError(w, errorCode, errMsg)
return
}
if stat.IsDir() {
errMsg = fmt.Errorf("path '%s' is a directory", resolvedPath)
errorCode = http.StatusBadRequest
jsonError(w, errorCode, errMsg)
return
}
// Validate Accept-Encoding header
encoding, err := parseAcceptEncoding(r)
if err != nil {
errMsg = fmt.Errorf("error parsing Accept-Encoding: %w", err)
errorCode = http.StatusNotAcceptable
jsonError(w, errorCode, errMsg)
return
}
// Tell caches to store separate variants for different Accept-Encoding values
w.Header().Set("Vary", "Accept-Encoding")
// Fall back to identity for Range or conditional requests to preserve http.ServeContent
// behavior (206 Partial Content, 304 Not Modified). However, we must check if identity
// is acceptable per the Accept-Encoding header.
hasRangeOrConditional := r.Header.Get("Range") != "" ||
r.Header.Get("If-Modified-Since") != "" ||
r.Header.Get("If-None-Match") != "" ||
r.Header.Get("If-Range") != ""
if hasRangeOrConditional {
if !isIdentityAcceptable(r) {
errMsg = fmt.Errorf("identity encoding not acceptable for Range or conditional request")
errorCode = http.StatusNotAcceptable
jsonError(w, errorCode, errMsg)
return
}
encoding = EncodingIdentity
}
file, err := os.Open(resolvedPath)
if err != nil {
errMsg = fmt.Errorf("error opening file '%s': %w", resolvedPath, err)
errorCode = http.StatusInternalServerError
jsonError(w, errorCode, errMsg)
return
}
defer file.Close()
w.Header().Set("Content-Disposition", mime.FormatMediaType("inline", map[string]string{"filename": filepath.Base(resolvedPath)}))
// Serve with gzip encoding if requested.
if encoding == EncodingGzip {
w.Header().Set("Content-Encoding", EncodingGzip)
// Set Content-Type based on file extension, preserving the original type
contentType := mime.TypeByExtension(filepath.Ext(path))
if contentType == "" {
contentType = "application/octet-stream"
}
w.Header().Set("Content-Type", contentType)
gw := gzip.NewWriter(w)
defer gw.Close()
_, err = io.Copy(gw, file)
if err != nil {
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("error writing gzip response")
}
return
}
http.ServeContent(w, r, path, stat.ModTime(), file)
}

View File

@ -0,0 +1,403 @@
// SPDX-License-Identifier: Apache-2.0
package api
import (
"bytes"
"compress/gzip"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/url"
"os"
"os/user"
"path/filepath"
"testing"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
func TestGetFilesContentDisposition(t *testing.T) {
t.Parallel()
currentUser, err := user.Current()
require.NoError(t, err)
tests := []struct {
name string
filename string
expectedHeader string
}{
{
name: "simple filename",
filename: "test.txt",
expectedHeader: `inline; filename=test.txt`,
},
{
name: "filename with extension",
filename: "presentation.pptx",
expectedHeader: `inline; filename=presentation.pptx`,
},
{
name: "filename with multiple dots",
filename: "archive.tar.gz",
expectedHeader: `inline; filename=archive.tar.gz`,
},
{
name: "filename with spaces",
filename: "my document.pdf",
expectedHeader: `inline; filename="my document.pdf"`,
},
{
name: "filename with quotes",
filename: `file"name.txt`,
expectedHeader: `inline; filename="file\"name.txt"`,
},
{
name: "filename with backslash",
filename: `file\name.txt`,
expectedHeader: `inline; filename="file\\name.txt"`,
},
{
name: "unicode filename",
filename: "\u6587\u6863.pdf", // 文档.pdf in Chinese
expectedHeader: "inline; filename*=utf-8''%E6%96%87%E6%A1%A3.pdf",
},
{
name: "dotfile preserved",
filename: ".env",
expectedHeader: `inline; filename=.env`,
},
{
name: "dotfile with extension preserved",
filename: ".gitignore",
expectedHeader: `inline; filename=.gitignore`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Create a temp directory and file
tempDir := t.TempDir()
tempFile := filepath.Join(tempDir, tt.filename)
err := os.WriteFile(tempFile, []byte("test content"), 0o644)
require.NoError(t, err)
// Create test API
logger := zerolog.Nop()
defaults := &execcontext.Defaults{
EnvVars: utils.NewMap[string, string](),
User: currentUser.Username,
}
api := New(&logger, defaults, nil, false)
// Create request and response recorder
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
w := httptest.NewRecorder()
// Call the handler
params := GetFilesParams{
Path: &tempFile,
Username: &currentUser.Username,
}
api.GetFiles(w, req, params)
// Check response
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
// Verify Content-Disposition header
contentDisposition := resp.Header.Get("Content-Disposition")
assert.Equal(t, tt.expectedHeader, contentDisposition, "Content-Disposition header should be set with correct filename")
})
}
}
func TestGetFilesContentDispositionWithNestedPath(t *testing.T) {
t.Parallel()
currentUser, err := user.Current()
require.NoError(t, err)
// Create a temp directory with nested structure
tempDir := t.TempDir()
nestedDir := filepath.Join(tempDir, "subdir", "another")
err = os.MkdirAll(nestedDir, 0o755)
require.NoError(t, err)
filename := "document.pdf"
tempFile := filepath.Join(nestedDir, filename)
err = os.WriteFile(tempFile, []byte("test content"), 0o644)
require.NoError(t, err)
// Create test API
logger := zerolog.Nop()
defaults := &execcontext.Defaults{
EnvVars: utils.NewMap[string, string](),
User: currentUser.Username,
}
api := New(&logger, defaults, nil, false)
// Create request and response recorder
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
w := httptest.NewRecorder()
// Call the handler
params := GetFilesParams{
Path: &tempFile,
Username: &currentUser.Username,
}
api.GetFiles(w, req, params)
// Check response
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
// Verify Content-Disposition header uses only the base filename, not the full path
contentDisposition := resp.Header.Get("Content-Disposition")
assert.Equal(t, `inline; filename=document.pdf`, contentDisposition, "Content-Disposition should contain only the filename, not the path")
}
func TestGetFiles_GzipEncoding_ExplicitIdentityOffWithRange(t *testing.T) {
t.Parallel()
currentUser, err := user.Current()
require.NoError(t, err)
// Create a temp directory with a test file
tempDir := t.TempDir()
filename := "document.pdf"
tempFile := filepath.Join(tempDir, filename)
err = os.WriteFile(tempFile, []byte("test content"), 0o644)
require.NoError(t, err)
// Create test API
logger := zerolog.Nop()
defaults := &execcontext.Defaults{
EnvVars: utils.NewMap[string, string](),
User: currentUser.Username,
}
api := New(&logger, defaults, nil, false)
// Create request and response recorder
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
req.Header.Set("Accept-Encoding", "gzip; q=1,*; q=0")
req.Header.Set("Range", "bytes=0-4") // Request first 5 bytes
w := httptest.NewRecorder()
// Call the handler
params := GetFilesParams{
Path: &tempFile,
Username: &currentUser.Username,
}
api.GetFiles(w, req, params)
// Check response
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusNotAcceptable, resp.StatusCode)
}
func TestGetFiles_GzipDownload(t *testing.T) {
t.Parallel()
currentUser, err := user.Current()
require.NoError(t, err)
originalContent := []byte("hello world, this is a test file for gzip compression")
// Create a temp file with known content
tempDir := t.TempDir()
tempFile := filepath.Join(tempDir, "test.txt")
err = os.WriteFile(tempFile, originalContent, 0o644)
require.NoError(t, err)
logger := zerolog.Nop()
defaults := &execcontext.Defaults{
EnvVars: utils.NewMap[string, string](),
User: currentUser.Username,
}
api := New(&logger, defaults, nil, false)
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
req.Header.Set("Accept-Encoding", "gzip")
w := httptest.NewRecorder()
params := GetFilesParams{
Path: &tempFile,
Username: &currentUser.Username,
}
api.GetFiles(w, req, params)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "gzip", resp.Header.Get("Content-Encoding"))
assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type"))
// Decompress the gzip response body
gzReader, err := gzip.NewReader(resp.Body)
require.NoError(t, err)
defer gzReader.Close()
decompressed, err := io.ReadAll(gzReader)
require.NoError(t, err)
assert.Equal(t, originalContent, decompressed)
}
func TestPostFiles_GzipUpload(t *testing.T) {
t.Parallel()
currentUser, err := user.Current()
require.NoError(t, err)
originalContent := []byte("hello world, this is a test file uploaded with gzip")
// Build a multipart body
var multipartBuf bytes.Buffer
mpWriter := multipart.NewWriter(&multipartBuf)
part, err := mpWriter.CreateFormFile("file", "uploaded.txt")
require.NoError(t, err)
_, err = part.Write(originalContent)
require.NoError(t, err)
err = mpWriter.Close()
require.NoError(t, err)
// Gzip-compress the entire multipart body
var gzBuf bytes.Buffer
gzWriter := gzip.NewWriter(&gzBuf)
_, err = gzWriter.Write(multipartBuf.Bytes())
require.NoError(t, err)
err = gzWriter.Close()
require.NoError(t, err)
// Create test API
tempDir := t.TempDir()
destPath := filepath.Join(tempDir, "uploaded.txt")
logger := zerolog.Nop()
defaults := &execcontext.Defaults{
EnvVars: utils.NewMap[string, string](),
User: currentUser.Username,
}
api := New(&logger, defaults, nil, false)
req := httptest.NewRequest(http.MethodPost, "/files?path="+url.QueryEscape(destPath), &gzBuf)
req.Header.Set("Content-Type", mpWriter.FormDataContentType())
req.Header.Set("Content-Encoding", "gzip")
w := httptest.NewRecorder()
params := PostFilesParams{
Path: &destPath,
Username: &currentUser.Username,
}
api.PostFiles(w, req, params)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
// Verify the file was written with the original (decompressed) content
data, err := os.ReadFile(destPath)
require.NoError(t, err)
assert.Equal(t, originalContent, data)
}
func TestGzipUploadThenGzipDownload(t *testing.T) {
t.Parallel()
currentUser, err := user.Current()
require.NoError(t, err)
originalContent := []byte("round-trip gzip test: upload compressed, download compressed, verify match")
// --- Upload with gzip ---
// Build a multipart body
var multipartBuf bytes.Buffer
mpWriter := multipart.NewWriter(&multipartBuf)
part, err := mpWriter.CreateFormFile("file", "roundtrip.txt")
require.NoError(t, err)
_, err = part.Write(originalContent)
require.NoError(t, err)
err = mpWriter.Close()
require.NoError(t, err)
// Gzip-compress the entire multipart body
var gzBuf bytes.Buffer
gzWriter := gzip.NewWriter(&gzBuf)
_, err = gzWriter.Write(multipartBuf.Bytes())
require.NoError(t, err)
err = gzWriter.Close()
require.NoError(t, err)
tempDir := t.TempDir()
destPath := filepath.Join(tempDir, "roundtrip.txt")
logger := zerolog.Nop()
defaults := &execcontext.Defaults{
EnvVars: utils.NewMap[string, string](),
User: currentUser.Username,
}
api := New(&logger, defaults, nil, false)
uploadReq := httptest.NewRequest(http.MethodPost, "/files?path="+url.QueryEscape(destPath), &gzBuf)
uploadReq.Header.Set("Content-Type", mpWriter.FormDataContentType())
uploadReq.Header.Set("Content-Encoding", "gzip")
uploadW := httptest.NewRecorder()
uploadParams := PostFilesParams{
Path: &destPath,
Username: &currentUser.Username,
}
api.PostFiles(uploadW, uploadReq, uploadParams)
uploadResp := uploadW.Result()
defer uploadResp.Body.Close()
require.Equal(t, http.StatusOK, uploadResp.StatusCode)
// --- Download with gzip ---
downloadReq := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(destPath), nil)
downloadReq.Header.Set("Accept-Encoding", "gzip")
downloadW := httptest.NewRecorder()
downloadParams := GetFilesParams{
Path: &destPath,
Username: &currentUser.Username,
}
api.GetFiles(downloadW, downloadReq, downloadParams)
downloadResp := downloadW.Result()
defer downloadResp.Body.Close()
require.Equal(t, http.StatusOK, downloadResp.StatusCode)
assert.Equal(t, "gzip", downloadResp.Header.Get("Content-Encoding"))
// Decompress and verify content matches original
gzReader, err := gzip.NewReader(downloadResp.Body)
require.NoError(t, err)
defer gzReader.Close()
decompressed, err := io.ReadAll(gzReader)
require.NoError(t, err)
assert.Equal(t, originalContent, decompressed)
}

View File

@ -0,0 +1,229 @@
// SPDX-License-Identifier: Apache-2.0
package api
import (
"compress/gzip"
"fmt"
"io"
"net/http"
"slices"
"sort"
"strconv"
"strings"
)
const (
// EncodingGzip is the gzip content encoding.
EncodingGzip = "gzip"
// EncodingIdentity means no encoding (passthrough).
EncodingIdentity = "identity"
// EncodingWildcard means any encoding is acceptable.
EncodingWildcard = "*"
)
// SupportedEncodings lists the content encodings supported for file transfer.
// The order matters - encodings are checked in order of preference.
var SupportedEncodings = []string{
EncodingGzip,
}
// encodingWithQuality holds an encoding name and its quality value.
type encodingWithQuality struct {
encoding string
quality float64
}
// isSupportedEncoding checks if the given encoding is in the supported list.
// Per RFC 7231, content-coding values are case-insensitive.
func isSupportedEncoding(encoding string) bool {
return slices.Contains(SupportedEncodings, strings.ToLower(encoding))
}
// parseEncodingWithQuality parses an encoding value and extracts the quality.
// Returns the encoding name (lowercased) and quality value (default 1.0 if not specified).
// Per RFC 7231, content-coding values are case-insensitive.
func parseEncodingWithQuality(value string) encodingWithQuality {
value = strings.TrimSpace(value)
quality := 1.0
if idx := strings.Index(value, ";"); idx != -1 {
params := value[idx+1:]
value = strings.TrimSpace(value[:idx])
// Parse q=X.X parameter
for param := range strings.SplitSeq(params, ";") {
param = strings.TrimSpace(param)
if strings.HasPrefix(strings.ToLower(param), "q=") {
if q, err := strconv.ParseFloat(param[2:], 64); err == nil {
quality = q
}
}
}
}
// Normalize encoding to lowercase per RFC 7231
return encodingWithQuality{encoding: strings.ToLower(value), quality: quality}
}
// parseEncoding extracts the encoding name from a header value, stripping quality.
func parseEncoding(value string) string {
return parseEncodingWithQuality(value).encoding
}
// parseContentEncoding parses the Content-Encoding header and returns the encoding.
// Returns an error if an unsupported encoding is specified.
// If no Content-Encoding header is present, returns empty string.
func parseContentEncoding(r *http.Request) (string, error) {
header := r.Header.Get("Content-Encoding")
if header == "" {
return EncodingIdentity, nil
}
encoding := parseEncoding(header)
if encoding == EncodingIdentity {
return EncodingIdentity, nil
}
if !isSupportedEncoding(encoding) {
return "", fmt.Errorf("unsupported Content-Encoding: %s, supported: %v", header, SupportedEncodings)
}
return encoding, nil
}
// parseAcceptEncodingHeader parses the Accept-Encoding header and returns
// the parsed encodings along with the identity rejection state.
// Per RFC 7231 Section 5.3.4, identity is acceptable unless excluded by
// "identity;q=0" or "*;q=0" without a more specific entry for identity with q>0.
func parseAcceptEncodingHeader(header string) ([]encodingWithQuality, bool) {
if header == "" {
return nil, false // identity not rejected when header is empty
}
// Parse all encodings with their quality values
var encodings []encodingWithQuality
for value := range strings.SplitSeq(header, ",") {
eq := parseEncodingWithQuality(value)
encodings = append(encodings, eq)
}
// Check if identity is rejected per RFC 7231 Section 5.3.4:
// identity is acceptable unless excluded by "identity;q=0" or "*;q=0"
// without a more specific entry for identity with q>0.
identityRejected := false
identityExplicitlyAccepted := false
wildcardRejected := false
for _, eq := range encodings {
switch eq.encoding {
case EncodingIdentity:
if eq.quality == 0 {
identityRejected = true
} else {
identityExplicitlyAccepted = true
}
case EncodingWildcard:
if eq.quality == 0 {
wildcardRejected = true
}
}
}
if wildcardRejected && !identityExplicitlyAccepted {
identityRejected = true
}
return encodings, identityRejected
}
// isIdentityAcceptable checks if identity encoding is acceptable based on the
// Accept-Encoding header. Per RFC 7231 section 5.3.4, identity is always
// implicitly acceptable unless explicitly rejected with q=0.
func isIdentityAcceptable(r *http.Request) bool {
header := r.Header.Get("Accept-Encoding")
_, identityRejected := parseAcceptEncodingHeader(header)
return !identityRejected
}
// parseAcceptEncoding parses the Accept-Encoding header and returns the best
// supported encoding based on quality values. Per RFC 7231 section 5.3.4,
// identity is always implicitly acceptable unless explicitly rejected with q=0.
// If no Accept-Encoding header is present, returns empty string (identity).
func parseAcceptEncoding(r *http.Request) (string, error) {
header := r.Header.Get("Accept-Encoding")
if header == "" {
return EncodingIdentity, nil
}
encodings, identityRejected := parseAcceptEncodingHeader(header)
// Sort by quality value (highest first)
sort.Slice(encodings, func(i, j int) bool {
return encodings[i].quality > encodings[j].quality
})
// Find the best supported encoding
for _, eq := range encodings {
// Skip encodings with q=0 (explicitly rejected)
if eq.quality == 0 {
continue
}
if eq.encoding == EncodingIdentity {
return EncodingIdentity, nil
}
// Wildcard means any encoding is acceptable - return a supported encoding if identity is rejected
if eq.encoding == EncodingWildcard {
if identityRejected && len(SupportedEncodings) > 0 {
return SupportedEncodings[0], nil
}
return EncodingIdentity, nil
}
if isSupportedEncoding(eq.encoding) {
return eq.encoding, nil
}
}
// Per RFC 7231, identity is implicitly acceptable unless rejected
if !identityRejected {
return EncodingIdentity, nil
}
// Identity rejected and no supported encodings found
return "", fmt.Errorf("no acceptable encoding found, supported: %v", SupportedEncodings)
}
// getDecompressedBody returns a reader that decompresses the request body based on
// Content-Encoding header. Returns the original body if no encoding is specified.
// Returns an error if an unsupported encoding is specified.
// The caller is responsible for closing both the returned ReadCloser and the
// original request body (r.Body) separately.
func getDecompressedBody(r *http.Request) (io.ReadCloser, error) {
encoding, err := parseContentEncoding(r)
if err != nil {
return nil, err
}
if encoding == EncodingIdentity {
return r.Body, nil
}
switch encoding {
case EncodingGzip:
gzReader, err := gzip.NewReader(r.Body)
if err != nil {
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
}
return gzReader, nil
default:
// This shouldn't happen if isSupportedEncoding is correct
return nil, fmt.Errorf("encoding %s is supported but not implemented", encoding)
}
}

View File

@ -0,0 +1,496 @@
// SPDX-License-Identifier: Apache-2.0
package api
import (
"bytes"
"compress/gzip"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIsSupportedEncoding(t *testing.T) {
t.Parallel()
t.Run("gzip is supported", func(t *testing.T) {
t.Parallel()
assert.True(t, isSupportedEncoding("gzip"))
})
t.Run("GZIP is supported (case-insensitive)", func(t *testing.T) {
t.Parallel()
assert.True(t, isSupportedEncoding("GZIP"))
})
t.Run("Gzip is supported (case-insensitive)", func(t *testing.T) {
t.Parallel()
assert.True(t, isSupportedEncoding("Gzip"))
})
t.Run("br is not supported", func(t *testing.T) {
t.Parallel()
assert.False(t, isSupportedEncoding("br"))
})
t.Run("deflate is not supported", func(t *testing.T) {
t.Parallel()
assert.False(t, isSupportedEncoding("deflate"))
})
}
func TestParseEncodingWithQuality(t *testing.T) {
t.Parallel()
t.Run("returns encoding with default quality 1.0", func(t *testing.T) {
t.Parallel()
eq := parseEncodingWithQuality("gzip")
assert.Equal(t, "gzip", eq.encoding)
assert.InDelta(t, 1.0, eq.quality, 0.001)
})
t.Run("parses quality value", func(t *testing.T) {
t.Parallel()
eq := parseEncodingWithQuality("gzip;q=0.5")
assert.Equal(t, "gzip", eq.encoding)
assert.InDelta(t, 0.5, eq.quality, 0.001)
})
t.Run("parses quality value with whitespace", func(t *testing.T) {
t.Parallel()
eq := parseEncodingWithQuality("gzip ; q=0.8")
assert.Equal(t, "gzip", eq.encoding)
assert.InDelta(t, 0.8, eq.quality, 0.001)
})
t.Run("handles q=0", func(t *testing.T) {
t.Parallel()
eq := parseEncodingWithQuality("gzip;q=0")
assert.Equal(t, "gzip", eq.encoding)
assert.InDelta(t, 0.0, eq.quality, 0.001)
})
t.Run("handles invalid quality value", func(t *testing.T) {
t.Parallel()
eq := parseEncodingWithQuality("gzip;q=invalid")
assert.Equal(t, "gzip", eq.encoding)
assert.InDelta(t, 1.0, eq.quality, 0.001) // defaults to 1.0 on parse error
})
t.Run("trims whitespace from encoding", func(t *testing.T) {
t.Parallel()
eq := parseEncodingWithQuality(" gzip ")
assert.Equal(t, "gzip", eq.encoding)
assert.InDelta(t, 1.0, eq.quality, 0.001)
})
t.Run("normalizes encoding to lowercase", func(t *testing.T) {
t.Parallel()
eq := parseEncodingWithQuality("GZIP")
assert.Equal(t, "gzip", eq.encoding)
})
t.Run("normalizes mixed case encoding", func(t *testing.T) {
t.Parallel()
eq := parseEncodingWithQuality("Gzip;q=0.5")
assert.Equal(t, "gzip", eq.encoding)
assert.InDelta(t, 0.5, eq.quality, 0.001)
})
}
func TestParseEncoding(t *testing.T) {
t.Parallel()
t.Run("returns encoding as-is", func(t *testing.T) {
t.Parallel()
assert.Equal(t, "gzip", parseEncoding("gzip"))
})
t.Run("trims whitespace", func(t *testing.T) {
t.Parallel()
assert.Equal(t, "gzip", parseEncoding(" gzip "))
})
t.Run("strips quality value", func(t *testing.T) {
t.Parallel()
assert.Equal(t, "gzip", parseEncoding("gzip;q=1.0"))
})
t.Run("strips quality value with whitespace", func(t *testing.T) {
t.Parallel()
assert.Equal(t, "gzip", parseEncoding("gzip ; q=0.5"))
})
}
func TestParseContentEncoding(t *testing.T) {
t.Parallel()
t.Run("returns identity when no header", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
encoding, err := parseContentEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("returns gzip when Content-Encoding is gzip", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
req.Header.Set("Content-Encoding", "gzip")
encoding, err := parseContentEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("returns gzip when Content-Encoding is GZIP (case-insensitive)", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
req.Header.Set("Content-Encoding", "GZIP")
encoding, err := parseContentEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("returns gzip when Content-Encoding is Gzip (case-insensitive)", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
req.Header.Set("Content-Encoding", "Gzip")
encoding, err := parseContentEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("returns identity for identity encoding", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
req.Header.Set("Content-Encoding", "identity")
encoding, err := parseContentEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("returns error for unsupported encoding", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
req.Header.Set("Content-Encoding", "br")
_, err := parseContentEncoding(req)
require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported Content-Encoding")
assert.Contains(t, err.Error(), "supported: [gzip]")
})
t.Run("handles gzip with quality value", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
req.Header.Set("Content-Encoding", "gzip;q=1.0")
encoding, err := parseContentEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
}
func TestParseAcceptEncoding(t *testing.T) {
t.Parallel()
t.Run("returns identity when no header", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("returns gzip when Accept-Encoding is gzip", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "gzip")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("returns gzip when Accept-Encoding is GZIP (case-insensitive)", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "GZIP")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("returns gzip when gzip is among multiple encodings", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "deflate, gzip, br")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("returns gzip with quality value", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "gzip;q=1.0")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("returns identity for identity encoding", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "identity")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("returns identity for wildcard encoding", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "*")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("falls back to identity for unsupported encoding only", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "br")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("falls back to identity when only unsupported encodings", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "deflate, br")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("selects gzip when it has highest quality", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "br;q=0.5, gzip;q=1.0, deflate;q=0.8")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("selects gzip even with lower quality when others unsupported", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "br;q=1.0, gzip;q=0.5")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("returns identity when it has higher quality than gzip", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "gzip;q=0.5, identity;q=1.0")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("skips encoding with q=0", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "gzip;q=0, identity")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("falls back to identity when gzip rejected and no other supported", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "gzip;q=0, br")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("returns error when identity explicitly rejected and no supported encoding", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "br, identity;q=0")
_, err := parseAcceptEncoding(req)
require.Error(t, err)
assert.Contains(t, err.Error(), "no acceptable encoding found")
})
t.Run("returns gzip for wildcard when identity rejected", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "*, identity;q=0")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding) // wildcard with identity rejected returns supported encoding
})
t.Run("returns error when wildcard rejected and no explicit identity", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "*;q=0")
_, err := parseAcceptEncoding(req)
require.Error(t, err)
assert.Contains(t, err.Error(), "no acceptable encoding found")
})
t.Run("returns identity when wildcard rejected but identity explicitly accepted", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "*;q=0, identity")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("returns gzip when wildcard rejected but gzip explicitly accepted", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "*;q=0, gzip")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingGzip, encoding)
})
}
func TestGetDecompressedBody(t *testing.T) {
t.Parallel()
t.Run("returns original body when no Content-Encoding header", func(t *testing.T) {
t.Parallel()
content := []byte("test content")
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(content))
body, err := getDecompressedBody(req)
require.NoError(t, err)
assert.Equal(t, req.Body, body, "should return original body")
data, err := io.ReadAll(body)
require.NoError(t, err)
assert.Equal(t, content, data)
})
t.Run("decompresses gzip body when Content-Encoding is gzip", func(t *testing.T) {
t.Parallel()
originalContent := []byte("test content to compress")
var compressed bytes.Buffer
gw := gzip.NewWriter(&compressed)
_, err := gw.Write(originalContent)
require.NoError(t, err)
err = gw.Close()
require.NoError(t, err)
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(compressed.Bytes()))
req.Header.Set("Content-Encoding", "gzip")
body, err := getDecompressedBody(req)
require.NoError(t, err)
defer body.Close()
assert.NotEqual(t, req.Body, body, "should return a new gzip reader")
data, err := io.ReadAll(body)
require.NoError(t, err)
assert.Equal(t, originalContent, data)
})
t.Run("returns error for invalid gzip data", func(t *testing.T) {
t.Parallel()
invalidGzip := []byte("this is not gzip data")
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(invalidGzip))
req.Header.Set("Content-Encoding", "gzip")
_, err := getDecompressedBody(req)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to create gzip reader")
})
t.Run("returns original body for identity encoding", func(t *testing.T) {
t.Parallel()
content := []byte("test content")
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(content))
req.Header.Set("Content-Encoding", "identity")
body, err := getDecompressedBody(req)
require.NoError(t, err)
assert.Equal(t, req.Body, body, "should return original body")
data, err := io.ReadAll(body)
require.NoError(t, err)
assert.Equal(t, content, data)
})
t.Run("returns error for unsupported encoding", func(t *testing.T) {
t.Parallel()
content := []byte("test content")
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(content))
req.Header.Set("Content-Encoding", "br")
_, err := getDecompressedBody(req)
require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported Content-Encoding")
})
t.Run("handles gzip with quality value", func(t *testing.T) {
t.Parallel()
originalContent := []byte("test content to compress")
var compressed bytes.Buffer
gw := gzip.NewWriter(&compressed)
_, err := gw.Write(originalContent)
require.NoError(t, err)
err = gw.Close()
require.NoError(t, err)
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(compressed.Bytes()))
req.Header.Set("Content-Encoding", "gzip;q=1.0")
body, err := getDecompressedBody(req)
require.NoError(t, err)
defer body.Close()
data, err := io.ReadAll(body)
require.NoError(t, err)
assert.Equal(t, originalContent, data)
})
}

31
envd/internal/api/envs.go Normal file
View File

@ -0,0 +1,31 @@
// SPDX-License-Identifier: Apache-2.0
package api
import (
"encoding/json"
"net/http"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
)
func (a *API) GetEnvs(w http.ResponseWriter, _ *http.Request) {
operationID := logs.AssignOperationID()
a.logger.Debug().Str(string(logs.OperationIDKey), operationID).Msg("Getting env vars")
envs := make(EnvVars)
a.defaults.EnvVars.Range(func(key, value string) bool {
envs[key] = value
return true
})
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(envs); err != nil {
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("Failed to encode env vars")
}
}

View File

@ -0,0 +1,23 @@
// SPDX-License-Identifier: Apache-2.0
package api
import (
"encoding/json"
"errors"
"net/http"
)
func jsonError(w http.ResponseWriter, code int, err error) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.WriteHeader(code)
encodeErr := json.NewEncoder(w).Encode(Error{
Code: code,
Message: err.Error(),
})
if encodeErr != nil {
http.Error(w, errors.Join(encodeErr, err).Error(), http.StatusInternalServerError)
}
}

View File

@ -0,0 +1,5 @@
// SPDX-License-Identifier: Apache-2.0
package api
//go:generate go run github.com/oapi-codegen/oapi-codegen/v2/cmd/oapi-codegen -config cfg.yaml ../../spec/envd.yaml

317
envd/internal/api/init.go Normal file
View File

@ -0,0 +1,317 @@
// SPDX-License-Identifier: Apache-2.0
// Modifications by M/S Omukk
package api
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/netip"
"os/exec"
"time"
"github.com/awnumar/memguard"
"github.com/rs/zerolog"
"github.com/txn2/txeh"
"golang.org/x/sys/unix"
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
)
var (
ErrAccessTokenMismatch = errors.New("access token validation failed")
ErrAccessTokenResetNotAuthorized = errors.New("access token reset not authorized")
)
const (
maxTimeInPast = 50 * time.Millisecond
maxTimeInFuture = 5 * time.Second
)
// validateInitAccessToken validates the access token for /init requests.
// Token is valid if it matches the existing token OR the MMDS hash.
// If neither exists, first-time setup is allowed.
func (a *API) validateInitAccessToken(ctx context.Context, requestToken *SecureToken) error {
requestTokenSet := requestToken.IsSet()
// Fast path: token matches existing
if a.accessToken.IsSet() && requestTokenSet && a.accessToken.EqualsSecure(requestToken) {
return nil
}
// Check MMDS only if token didn't match existing
matchesMMDS, mmdsExists := a.checkMMDSHash(ctx, requestToken)
switch {
case matchesMMDS:
return nil
case !a.accessToken.IsSet() && !mmdsExists:
return nil // first-time setup
case !requestTokenSet:
return ErrAccessTokenResetNotAuthorized
default:
return ErrAccessTokenMismatch
}
}
// checkMMDSHash checks if the request token matches the MMDS hash.
// Returns (matches, mmdsExists).
//
// The MMDS hash is set by the orchestrator during Resume:
// - hash(token): requires this specific token
// - hash(""): explicitly allows nil token (token reset authorized)
// - "": MMDS not properly configured, no authorization granted
func (a *API) checkMMDSHash(ctx context.Context, requestToken *SecureToken) (bool, bool) {
if a.isNotFC {
return false, false
}
mmdsHash, err := a.mmdsClient.GetAccessTokenHash(ctx)
if err != nil {
return false, false
}
if mmdsHash == "" {
return false, false
}
if !requestToken.IsSet() {
return mmdsHash == keys.HashAccessToken(""), true
}
tokenBytes, err := requestToken.Bytes()
if err != nil {
return false, true
}
defer memguard.WipeBytes(tokenBytes)
return keys.HashAccessTokenBytes(tokenBytes) == mmdsHash, true
}
func (a *API) PostInit(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
ctx := r.Context()
operationID := logs.AssignOperationID()
logger := a.logger.With().Str(string(logs.OperationIDKey), operationID).Logger()
if r.Body != nil {
// Read raw body so we can wipe it after parsing
body, err := io.ReadAll(r.Body)
// Ensure body is wiped after we're done
defer memguard.WipeBytes(body)
if err != nil {
logger.Error().Msgf("Failed to read request body: %v", err)
w.WriteHeader(http.StatusBadRequest)
return
}
var initRequest PostInitJSONBody
if len(body) > 0 {
err = json.Unmarshal(body, &initRequest)
if err != nil {
logger.Error().Msgf("Failed to decode request: %v", err)
w.WriteHeader(http.StatusBadRequest)
return
}
}
// Ensure request token is destroyed if not transferred via TakeFrom.
// This handles: validation failures, timestamp-based skips, and any early returns.
// Safe because Destroy() is nil-safe and TakeFrom clears the source.
defer initRequest.AccessToken.Destroy()
a.initLock.Lock()
defer a.initLock.Unlock()
// Update data only if the request is newer or if there's no timestamp at all
if initRequest.Timestamp == nil || a.lastSetTime.SetToGreater(initRequest.Timestamp.UnixNano()) {
err = a.SetData(ctx, logger, initRequest)
if err != nil {
switch {
case errors.Is(err, ErrAccessTokenMismatch), errors.Is(err, ErrAccessTokenResetNotAuthorized):
w.WriteHeader(http.StatusUnauthorized)
default:
logger.Error().Msgf("Failed to set data: %v", err)
w.WriteHeader(http.StatusBadRequest)
}
w.Write([]byte(err.Error()))
return
}
}
}
go func() { //nolint:contextcheck // TODO: fix this later
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
host.PollForMMDSOpts(ctx, a.mmdsChan, a.defaults.EnvVars)
}()
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Content-Type", "")
w.WriteHeader(http.StatusNoContent)
}
func (a *API) SetData(ctx context.Context, logger zerolog.Logger, data PostInitJSONBody) error {
// Validate access token before proceeding with any action
// The request must provide a token that is either:
// 1. Matches the existing access token (if set), OR
// 2. Matches the MMDS hash (for token change during resume)
if err := a.validateInitAccessToken(ctx, data.AccessToken); err != nil {
return err
}
if data.Timestamp != nil {
// Check if current time differs significantly from the received timestamp
if shouldSetSystemTime(time.Now(), *data.Timestamp) {
logger.Debug().Msgf("Setting sandbox start time to: %v", *data.Timestamp)
ts := unix.NsecToTimespec(data.Timestamp.UnixNano())
err := unix.ClockSettime(unix.CLOCK_REALTIME, &ts)
if err != nil {
logger.Error().Msgf("Failed to set system time: %v", err)
}
} else {
logger.Debug().Msgf("Current time is within acceptable range of timestamp %v, not setting system time", *data.Timestamp)
}
}
if data.EnvVars != nil {
logger.Debug().Msg(fmt.Sprintf("Setting %d env vars", len(*data.EnvVars)))
for key, value := range *data.EnvVars {
logger.Debug().Msgf("Setting env var for %s", key)
a.defaults.EnvVars.Store(key, value)
}
}
if data.AccessToken.IsSet() {
logger.Debug().Msg("Setting access token")
a.accessToken.TakeFrom(data.AccessToken)
} else if a.accessToken.IsSet() {
logger.Debug().Msg("Clearing access token")
a.accessToken.Destroy()
}
if data.HyperloopIP != nil {
go a.SetupHyperloop(*data.HyperloopIP)
}
if data.DefaultUser != nil && *data.DefaultUser != "" {
logger.Debug().Msgf("Setting default user to: %s", *data.DefaultUser)
a.defaults.User = *data.DefaultUser
}
if data.DefaultWorkdir != nil && *data.DefaultWorkdir != "" {
logger.Debug().Msgf("Setting default workdir to: %s", *data.DefaultWorkdir)
a.defaults.Workdir = data.DefaultWorkdir
}
if data.VolumeMounts != nil {
for _, volume := range *data.VolumeMounts {
logger.Debug().Msgf("Mounting %s at %q", volume.NfsTarget, volume.Path)
go a.setupNfs(context.WithoutCancel(ctx), volume.NfsTarget, volume.Path)
}
}
return nil
}
func (a *API) setupNfs(ctx context.Context, nfsTarget, path string) {
commands := [][]string{
{"mkdir", "-p", path},
{"mount", "-v", "-t", "nfs", "-o", "mountproto=tcp,mountport=2049,proto=tcp,port=2049,nfsvers=3,noacl", nfsTarget, path},
}
for _, command := range commands {
data, err := exec.CommandContext(ctx, command[0], command[1:]...).CombinedOutput()
logger := a.getLogger(err)
logger.
Strs("command", command).
Str("output", string(data)).
Msg("Mount NFS")
if err != nil {
return
}
}
}
func (a *API) SetupHyperloop(address string) {
a.hyperloopLock.Lock()
defer a.hyperloopLock.Unlock()
if err := rewriteHostsFile(address, "/etc/hosts"); err != nil {
a.logger.Error().Err(err).Msg("failed to modify hosts file")
} else {
a.defaults.EnvVars.Store("WRENN_EVENTS_ADDRESS", fmt.Sprintf("http://%s", address))
}
}
const eventsHost = "events.wrenn.local"
func rewriteHostsFile(address, path string) error {
hosts, err := txeh.NewHosts(&txeh.HostsConfig{
ReadFilePath: path,
WriteFilePath: path,
})
if err != nil {
return fmt.Errorf("failed to create hosts: %w", err)
}
// Update /etc/hosts to point events.wrenn.local to the hyperloop IP
// This will remove any existing entries for events.wrenn.local first
ipFamily, err := getIPFamily(address)
if err != nil {
return fmt.Errorf("failed to get ip family: %w", err)
}
if ok, current, _ := hosts.HostAddressLookup(eventsHost, ipFamily); ok && current == address {
return nil // nothing to be done
}
hosts.AddHost(address, eventsHost)
return hosts.Save()
}
var (
ErrInvalidAddress = errors.New("invalid IP address")
ErrUnknownAddressFormat = errors.New("unknown IP address format")
)
func getIPFamily(address string) (txeh.IPFamily, error) {
addressIP, err := netip.ParseAddr(address)
if err != nil {
return txeh.IPFamilyV4, fmt.Errorf("failed to parse IP address: %w", err)
}
switch {
case addressIP.Is4():
return txeh.IPFamilyV4, nil
case addressIP.Is6():
return txeh.IPFamilyV6, nil
default:
return txeh.IPFamilyV4, fmt.Errorf("%w: %s", ErrUnknownAddressFormat, address)
}
}
// shouldSetSystemTime returns true if the current time differs significantly from the received timestamp,
// indicating the system clock should be adjusted. Returns true when the sandboxTime is more than
// maxTimeInPast before the hostTime or more than maxTimeInFuture after the hostTime.
func shouldSetSystemTime(sandboxTime, hostTime time.Time) bool {
return sandboxTime.Before(hostTime.Add(-maxTimeInPast)) || sandboxTime.After(hostTime.Add(maxTimeInFuture))
}

View File

@ -0,0 +1,590 @@
// SPDX-License-Identifier: Apache-2.0
// Modifications by M/S Omukk
package api
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
utilsShared "git.omukk.dev/wrenn/sandbox/envd/internal/shared/utils"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
func TestSimpleCases(t *testing.T) {
t.Parallel()
testCases := map[string]func(string) string{
"both newlines": func(s string) string { return s },
"no newline prefix": func(s string) string { return strings.TrimPrefix(s, "\n") },
"no newline suffix": func(s string) string { return strings.TrimSuffix(s, "\n") },
"no newline prefix or suffix": strings.TrimSpace,
}
for name, preprocessor := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
value := `
# comment
127.0.0.1 one.host
127.0.0.2 two.host
`
value = preprocessor(value)
inputPath := filepath.Join(tempDir, "hosts")
err := os.WriteFile(inputPath, []byte(value), 0o644)
require.NoError(t, err)
err = rewriteHostsFile("127.0.0.3", inputPath)
require.NoError(t, err)
data, err := os.ReadFile(inputPath)
require.NoError(t, err)
assert.Equal(t, `# comment
127.0.0.1 one.host
127.0.0.2 two.host
127.0.0.3 events.wrenn.local`, strings.TrimSpace(string(data)))
})
}
}
func TestShouldSetSystemTime(t *testing.T) {
t.Parallel()
sandboxTime := time.Now()
tests := []struct {
name string
hostTime time.Time
want bool
}{
{
name: "sandbox time far ahead of host time (should set)",
hostTime: sandboxTime.Add(-10 * time.Second),
want: true,
},
{
name: "sandbox time at maxTimeInPast boundary ahead of host time (should not set)",
hostTime: sandboxTime.Add(-50 * time.Millisecond),
want: false,
},
{
name: "sandbox time just within maxTimeInPast ahead of host time (should not set)",
hostTime: sandboxTime.Add(-40 * time.Millisecond),
want: false,
},
{
name: "sandbox time slightly ahead of host time (should not set)",
hostTime: sandboxTime.Add(-10 * time.Millisecond),
want: false,
},
{
name: "sandbox time equals host time (should not set)",
hostTime: sandboxTime,
want: false,
},
{
name: "sandbox time slightly behind host time (should not set)",
hostTime: sandboxTime.Add(1 * time.Second),
want: false,
},
{
name: "sandbox time just within maxTimeInFuture behind host time (should not set)",
hostTime: sandboxTime.Add(4 * time.Second),
want: false,
},
{
name: "sandbox time at maxTimeInFuture boundary behind host time (should not set)",
hostTime: sandboxTime.Add(5 * time.Second),
want: false,
},
{
name: "sandbox time far behind host time (should set)",
hostTime: sandboxTime.Add(1 * time.Minute),
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := shouldSetSystemTime(tt.hostTime, sandboxTime)
assert.Equal(t, tt.want, got)
})
}
}
func secureTokenPtr(s string) *SecureToken {
token := &SecureToken{}
_ = token.Set([]byte(s))
return token
}
type mockMMDSClient struct {
hash string
err error
}
func (m *mockMMDSClient) GetAccessTokenHash(_ context.Context) (string, error) {
return m.hash, m.err
}
func newTestAPI(accessToken *SecureToken, mmdsClient MMDSClient) *API {
logger := zerolog.Nop()
defaults := &execcontext.Defaults{
EnvVars: utils.NewMap[string, string](),
}
api := New(&logger, defaults, nil, false)
if accessToken != nil {
api.accessToken.TakeFrom(accessToken)
}
api.mmdsClient = mmdsClient
return api
}
func TestValidateInitAccessToken(t *testing.T) {
t.Parallel()
ctx := t.Context()
tests := []struct {
name string
accessToken *SecureToken
requestToken *SecureToken
mmdsHash string
mmdsErr error
wantErr error
}{
{
name: "fast path: token matches existing",
accessToken: secureTokenPtr("secret-token"),
requestToken: secureTokenPtr("secret-token"),
mmdsHash: "",
mmdsErr: nil,
wantErr: nil,
},
{
name: "MMDS match: token hash matches MMDS hash",
accessToken: secureTokenPtr("old-token"),
requestToken: secureTokenPtr("new-token"),
mmdsHash: keys.HashAccessToken("new-token"),
mmdsErr: nil,
wantErr: nil,
},
{
name: "first-time setup: no existing token, MMDS error",
accessToken: nil,
requestToken: secureTokenPtr("new-token"),
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: nil,
},
{
name: "first-time setup: no existing token, empty MMDS hash",
accessToken: nil,
requestToken: secureTokenPtr("new-token"),
mmdsHash: "",
mmdsErr: nil,
wantErr: nil,
},
{
name: "first-time setup: both tokens nil, no MMDS",
accessToken: nil,
requestToken: nil,
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: nil,
},
{
name: "mismatch: existing token differs from request, no MMDS",
accessToken: secureTokenPtr("existing-token"),
requestToken: secureTokenPtr("wrong-token"),
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: ErrAccessTokenMismatch,
},
{
name: "mismatch: existing token differs from request, MMDS hash mismatch",
accessToken: secureTokenPtr("existing-token"),
requestToken: secureTokenPtr("wrong-token"),
mmdsHash: keys.HashAccessToken("different-token"),
mmdsErr: nil,
wantErr: ErrAccessTokenMismatch,
},
{
name: "conflict: existing token, nil request, MMDS exists",
accessToken: secureTokenPtr("existing-token"),
requestToken: nil,
mmdsHash: keys.HashAccessToken("some-token"),
mmdsErr: nil,
wantErr: ErrAccessTokenResetNotAuthorized,
},
{
name: "conflict: existing token, nil request, no MMDS",
accessToken: secureTokenPtr("existing-token"),
requestToken: nil,
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: ErrAccessTokenResetNotAuthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: tt.mmdsHash, err: tt.mmdsErr}
api := newTestAPI(tt.accessToken, mmdsClient)
err := api.validateInitAccessToken(ctx, tt.requestToken)
if tt.wantErr != nil {
require.Error(t, err)
assert.ErrorIs(t, err, tt.wantErr)
} else {
require.NoError(t, err)
}
})
}
}
func TestCheckMMDSHash(t *testing.T) {
t.Parallel()
ctx := t.Context()
t.Run("returns match when token hash equals MMDS hash", func(t *testing.T) {
t.Parallel()
token := "my-secret-token"
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken(token), err: nil}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr(token))
assert.True(t, matches)
assert.True(t, exists)
})
t.Run("returns no match when token hash differs from MMDS hash", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken("different-token"), err: nil}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("my-token"))
assert.False(t, matches)
assert.True(t, exists)
})
t.Run("returns exists but no match when request token is nil", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken("some-token"), err: nil}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, nil)
assert.False(t, matches)
assert.True(t, exists)
})
t.Run("returns false, false when MMDS returns error", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("any-token"))
assert.False(t, matches)
assert.False(t, exists)
})
t.Run("returns false, false when MMDS returns empty hash with non-nil request", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: nil}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("any-token"))
assert.False(t, matches)
assert.False(t, exists)
})
t.Run("returns false, false when MMDS returns empty hash with nil request", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: nil}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, nil)
assert.False(t, matches)
assert.False(t, exists)
})
t.Run("returns true, true when MMDS returns hash of empty string with nil request (explicit reset)", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken(""), err: nil}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, nil)
assert.True(t, matches)
assert.True(t, exists)
})
}
func TestSetData(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := zerolog.Nop()
t.Run("access token updates", func(t *testing.T) {
t.Parallel()
tests := []struct {
name string
existingToken *SecureToken
requestToken *SecureToken
mmdsHash string
mmdsErr error
wantErr error
wantFinalToken *SecureToken
}{
{
name: "first-time setup: sets initial token",
existingToken: nil,
requestToken: secureTokenPtr("initial-token"),
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: nil,
wantFinalToken: secureTokenPtr("initial-token"),
},
{
name: "first-time setup: nil request token leaves token unset",
existingToken: nil,
requestToken: nil,
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: nil,
wantFinalToken: nil,
},
{
name: "re-init with same token: token unchanged",
existingToken: secureTokenPtr("same-token"),
requestToken: secureTokenPtr("same-token"),
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: nil,
wantFinalToken: secureTokenPtr("same-token"),
},
{
name: "resume with MMDS: updates token when hash matches",
existingToken: secureTokenPtr("old-token"),
requestToken: secureTokenPtr("new-token"),
mmdsHash: keys.HashAccessToken("new-token"),
mmdsErr: nil,
wantErr: nil,
wantFinalToken: secureTokenPtr("new-token"),
},
{
name: "resume with MMDS: fails when hash doesn't match",
existingToken: secureTokenPtr("old-token"),
requestToken: secureTokenPtr("new-token"),
mmdsHash: keys.HashAccessToken("different-token"),
mmdsErr: nil,
wantErr: ErrAccessTokenMismatch,
wantFinalToken: secureTokenPtr("old-token"),
},
{
name: "fails when existing token and request token mismatch without MMDS",
existingToken: secureTokenPtr("existing-token"),
requestToken: secureTokenPtr("wrong-token"),
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: ErrAccessTokenMismatch,
wantFinalToken: secureTokenPtr("existing-token"),
},
{
name: "conflict when existing token but nil request token",
existingToken: secureTokenPtr("existing-token"),
requestToken: nil,
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: ErrAccessTokenResetNotAuthorized,
wantFinalToken: secureTokenPtr("existing-token"),
},
{
name: "conflict when existing token but nil request with MMDS present",
existingToken: secureTokenPtr("existing-token"),
requestToken: nil,
mmdsHash: keys.HashAccessToken("some-token"),
mmdsErr: nil,
wantErr: ErrAccessTokenResetNotAuthorized,
wantFinalToken: secureTokenPtr("existing-token"),
},
{
name: "conflict when MMDS returns empty hash and request is nil (prevents unauthorized reset)",
existingToken: secureTokenPtr("existing-token"),
requestToken: nil,
mmdsHash: "",
mmdsErr: nil,
wantErr: ErrAccessTokenResetNotAuthorized,
wantFinalToken: secureTokenPtr("existing-token"),
},
{
name: "resets token when MMDS returns hash of empty string and request is nil (explicit reset)",
existingToken: secureTokenPtr("existing-token"),
requestToken: nil,
mmdsHash: keys.HashAccessToken(""),
mmdsErr: nil,
wantErr: nil,
wantFinalToken: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: tt.mmdsHash, err: tt.mmdsErr}
api := newTestAPI(tt.existingToken, mmdsClient)
data := PostInitJSONBody{
AccessToken: tt.requestToken,
}
err := api.SetData(ctx, logger, data)
if tt.wantErr != nil {
require.ErrorIs(t, err, tt.wantErr)
} else {
require.NoError(t, err)
}
if tt.wantFinalToken == nil {
assert.False(t, api.accessToken.IsSet(), "expected token to not be set")
} else {
require.True(t, api.accessToken.IsSet(), "expected token to be set")
assert.True(t, api.accessToken.EqualsSecure(tt.wantFinalToken), "expected token to match")
}
})
}
})
t.Run("sets environment variables", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
envVars := EnvVars{"FOO": "bar", "BAZ": "qux"}
data := PostInitJSONBody{
EnvVars: &envVars,
}
err := api.SetData(ctx, logger, data)
require.NoError(t, err)
val, ok := api.defaults.EnvVars.Load("FOO")
assert.True(t, ok)
assert.Equal(t, "bar", val)
val, ok = api.defaults.EnvVars.Load("BAZ")
assert.True(t, ok)
assert.Equal(t, "qux", val)
})
t.Run("sets default user", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
data := PostInitJSONBody{
DefaultUser: utilsShared.ToPtr("testuser"),
}
err := api.SetData(ctx, logger, data)
require.NoError(t, err)
assert.Equal(t, "testuser", api.defaults.User)
})
t.Run("does not set default user when empty", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
api.defaults.User = "original"
data := PostInitJSONBody{
DefaultUser: utilsShared.ToPtr(""),
}
err := api.SetData(ctx, logger, data)
require.NoError(t, err)
assert.Equal(t, "original", api.defaults.User)
})
t.Run("sets default workdir", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
data := PostInitJSONBody{
DefaultWorkdir: utilsShared.ToPtr("/home/user"),
}
err := api.SetData(ctx, logger, data)
require.NoError(t, err)
require.NotNil(t, api.defaults.Workdir)
assert.Equal(t, "/home/user", *api.defaults.Workdir)
})
t.Run("does not set default workdir when empty", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
originalWorkdir := "/original"
api.defaults.Workdir = &originalWorkdir
data := PostInitJSONBody{
DefaultWorkdir: utilsShared.ToPtr(""),
}
err := api.SetData(ctx, logger, data)
require.NoError(t, err)
require.NotNil(t, api.defaults.Workdir)
assert.Equal(t, "/original", *api.defaults.Workdir)
})
t.Run("sets multiple fields at once", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
envVars := EnvVars{"KEY": "value"}
data := PostInitJSONBody{
AccessToken: secureTokenPtr("token"),
DefaultUser: utilsShared.ToPtr("user"),
DefaultWorkdir: utilsShared.ToPtr("/workdir"),
EnvVars: &envVars,
}
err := api.SetData(ctx, logger, data)
require.NoError(t, err)
assert.True(t, api.accessToken.Equals("token"), "expected token to match")
assert.Equal(t, "user", api.defaults.User)
assert.Equal(t, "/workdir", *api.defaults.Workdir)
val, ok := api.defaults.EnvVars.Load("KEY")
assert.True(t, ok)
assert.Equal(t, "value", val)
})
}

View File

@ -0,0 +1,214 @@
// SPDX-License-Identifier: Apache-2.0
package api
import (
"bytes"
"errors"
"sync"
"github.com/awnumar/memguard"
)
var (
ErrTokenNotSet = errors.New("access token not set")
ErrTokenEmpty = errors.New("empty token not allowed")
)
// SecureToken wraps memguard for secure token storage.
// It uses LockedBuffer which provides memory locking, guard pages,
// and secure zeroing on destroy.
type SecureToken struct {
mu sync.RWMutex
buffer *memguard.LockedBuffer
}
// Set securely replaces the token, destroying the old one first.
// The old token memory is zeroed before the new token is stored.
// The input byte slice is wiped after copying to secure memory.
// Returns ErrTokenEmpty if token is empty - use Destroy() to clear the token instead.
func (s *SecureToken) Set(token []byte) error {
if len(token) == 0 {
return ErrTokenEmpty
}
s.mu.Lock()
defer s.mu.Unlock()
// Destroy old token first (zeros memory)
if s.buffer != nil {
s.buffer.Destroy()
s.buffer = nil
}
// Create new LockedBuffer from bytes (source slice is wiped by memguard)
s.buffer = memguard.NewBufferFromBytes(token)
return nil
}
// UnmarshalJSON implements json.Unmarshaler to securely parse a JSON string
// directly into memguard, wiping the input bytes after copying.
//
// Access tokens are hex-encoded HMAC-SHA256 hashes (64 chars of [0-9a-f]),
// so they never contain JSON escape sequences.
func (s *SecureToken) UnmarshalJSON(data []byte) error {
// JSON strings are quoted, so minimum valid is `""` (2 bytes).
if len(data) < 2 || data[0] != '"' || data[len(data)-1] != '"' {
memguard.WipeBytes(data)
return errors.New("invalid secure token JSON string")
}
content := data[1 : len(data)-1]
// Access tokens are hex strings - reject if contains backslash
if bytes.ContainsRune(content, '\\') {
memguard.WipeBytes(data)
return errors.New("invalid secure token: unexpected escape sequence")
}
if len(content) == 0 {
memguard.WipeBytes(data)
return ErrTokenEmpty
}
s.mu.Lock()
defer s.mu.Unlock()
if s.buffer != nil {
s.buffer.Destroy()
s.buffer = nil
}
// Allocate secure buffer and copy directly into it
s.buffer = memguard.NewBuffer(len(content))
copy(s.buffer.Bytes(), content)
// Wipe the input data
memguard.WipeBytes(data)
return nil
}
// TakeFrom transfers the token from src to this SecureToken, destroying any
// existing token. The source token is cleared after transfer.
// This avoids copying the underlying bytes.
func (s *SecureToken) TakeFrom(src *SecureToken) {
if src == nil || s == src {
return
}
// Extract buffer from source
src.mu.Lock()
buffer := src.buffer
src.buffer = nil
src.mu.Unlock()
// Install buffer in destination
s.mu.Lock()
if s.buffer != nil {
s.buffer.Destroy()
}
s.buffer = buffer
s.mu.Unlock()
}
// Equals checks if token matches using constant-time comparison.
// Returns false if the receiver is nil.
func (s *SecureToken) Equals(token string) bool {
if s == nil {
return false
}
s.mu.RLock()
defer s.mu.RUnlock()
if s.buffer == nil || !s.buffer.IsAlive() {
return false
}
return s.buffer.EqualTo([]byte(token))
}
// EqualsSecure compares this token with another SecureToken using constant-time comparison.
// Returns false if either receiver or other is nil.
func (s *SecureToken) EqualsSecure(other *SecureToken) bool {
if s == nil || other == nil {
return false
}
if s == other {
return s.IsSet()
}
// Get a copy of other's bytes (avoids holding two locks simultaneously)
otherBytes, err := other.Bytes()
if err != nil {
return false
}
defer memguard.WipeBytes(otherBytes)
s.mu.RLock()
defer s.mu.RUnlock()
if s.buffer == nil || !s.buffer.IsAlive() {
return false
}
return s.buffer.EqualTo(otherBytes)
}
// IsSet returns true if a token is stored.
// Returns false if the receiver is nil.
func (s *SecureToken) IsSet() bool {
if s == nil {
return false
}
s.mu.RLock()
defer s.mu.RUnlock()
return s.buffer != nil && s.buffer.IsAlive()
}
// Bytes returns a copy of the token bytes (for signature generation).
// The caller should zero the returned slice after use.
// Returns ErrTokenNotSet if the receiver is nil.
func (s *SecureToken) Bytes() ([]byte, error) {
if s == nil {
return nil, ErrTokenNotSet
}
s.mu.RLock()
defer s.mu.RUnlock()
if s.buffer == nil || !s.buffer.IsAlive() {
return nil, ErrTokenNotSet
}
// Return a copy (unavoidable for signature generation)
src := s.buffer.Bytes()
result := make([]byte, len(src))
copy(result, src)
return result, nil
}
// Destroy securely wipes the token from memory.
// No-op if the receiver is nil.
func (s *SecureToken) Destroy() {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
if s.buffer != nil {
s.buffer.Destroy()
s.buffer = nil
}
}

View File

@ -0,0 +1,463 @@
// SPDX-License-Identifier: Apache-2.0
package api
import (
"sync"
"testing"
"github.com/awnumar/memguard"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSecureTokenSetAndEquals(t *testing.T) {
t.Parallel()
st := &SecureToken{}
// Initially not set
assert.False(t, st.IsSet(), "token should not be set initially")
assert.False(t, st.Equals("any-token"), "equals should return false when not set")
// Set token
err := st.Set([]byte("test-token"))
require.NoError(t, err)
assert.True(t, st.IsSet(), "token should be set after Set()")
assert.True(t, st.Equals("test-token"), "equals should return true for correct token")
assert.False(t, st.Equals("wrong-token"), "equals should return false for wrong token")
assert.False(t, st.Equals(""), "equals should return false for empty token")
}
func TestSecureTokenReplace(t *testing.T) {
t.Parallel()
st := &SecureToken{}
// Set initial token
err := st.Set([]byte("first-token"))
require.NoError(t, err)
assert.True(t, st.Equals("first-token"))
// Replace with new token (old one should be destroyed)
err = st.Set([]byte("second-token"))
require.NoError(t, err)
assert.True(t, st.Equals("second-token"), "should match new token")
assert.False(t, st.Equals("first-token"), "should not match old token")
}
func TestSecureTokenDestroy(t *testing.T) {
t.Parallel()
st := &SecureToken{}
// Set and then destroy
err := st.Set([]byte("test-token"))
require.NoError(t, err)
assert.True(t, st.IsSet())
st.Destroy()
assert.False(t, st.IsSet(), "token should not be set after Destroy()")
assert.False(t, st.Equals("test-token"), "equals should return false after Destroy()")
// Destroy on already destroyed should be safe
st.Destroy()
assert.False(t, st.IsSet())
// Nil receiver should be safe
var nilToken *SecureToken
assert.False(t, nilToken.IsSet(), "nil receiver should return false for IsSet()")
assert.False(t, nilToken.Equals("anything"), "nil receiver should return false for Equals()")
assert.False(t, nilToken.EqualsSecure(st), "nil receiver should return false for EqualsSecure()")
nilToken.Destroy() // should not panic
_, err = nilToken.Bytes()
require.ErrorIs(t, err, ErrTokenNotSet, "nil receiver should return ErrTokenNotSet for Bytes()")
}
func TestSecureTokenBytes(t *testing.T) {
t.Parallel()
st := &SecureToken{}
// Bytes should return error when not set
_, err := st.Bytes()
require.ErrorIs(t, err, ErrTokenNotSet)
// Set token and get bytes
err = st.Set([]byte("test-token"))
require.NoError(t, err)
bytes, err := st.Bytes()
require.NoError(t, err)
assert.Equal(t, []byte("test-token"), bytes)
// Zero out the bytes (as caller should do)
memguard.WipeBytes(bytes)
// Original should still be intact
assert.True(t, st.Equals("test-token"), "original token should still work after zeroing copy")
// After destroy, bytes should fail
st.Destroy()
_, err = st.Bytes()
assert.ErrorIs(t, err, ErrTokenNotSet)
}
func TestSecureTokenConcurrentAccess(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.Set([]byte("initial-token"))
require.NoError(t, err)
var wg sync.WaitGroup
const numGoroutines = 100
// Concurrent reads
for range numGoroutines {
wg.Go(func() {
st.IsSet()
st.Equals("initial-token")
})
}
// Concurrent writes
for i := range 10 {
wg.Add(1)
go func(idx int) {
defer wg.Done()
st.Set([]byte("token-" + string(rune('a'+idx))))
}(i)
}
wg.Wait()
// Should still be in a valid state
assert.True(t, st.IsSet())
}
func TestSecureTokenEmptyToken(t *testing.T) {
t.Parallel()
st := &SecureToken{}
// Setting empty token should return an error
err := st.Set([]byte{})
require.ErrorIs(t, err, ErrTokenEmpty)
assert.False(t, st.IsSet(), "token should not be set after empty token error")
// Setting nil should also return an error
err = st.Set(nil)
require.ErrorIs(t, err, ErrTokenEmpty)
assert.False(t, st.IsSet(), "token should not be set after nil token error")
}
func TestSecureTokenEmptyTokenDoesNotClearExisting(t *testing.T) {
t.Parallel()
st := &SecureToken{}
// Set a valid token first
err := st.Set([]byte("valid-token"))
require.NoError(t, err)
assert.True(t, st.IsSet())
// Attempting to set empty token should fail and preserve existing token
err = st.Set([]byte{})
require.ErrorIs(t, err, ErrTokenEmpty)
assert.True(t, st.IsSet(), "existing token should be preserved after empty token error")
assert.True(t, st.Equals("valid-token"), "existing token value should be unchanged")
}
func TestSecureTokenUnmarshalJSON(t *testing.T) {
t.Parallel()
t.Run("unmarshals valid JSON string", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.UnmarshalJSON([]byte(`"my-secret-token"`))
require.NoError(t, err)
assert.True(t, st.IsSet())
assert.True(t, st.Equals("my-secret-token"))
})
t.Run("returns error for empty string", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.UnmarshalJSON([]byte(`""`))
require.ErrorIs(t, err, ErrTokenEmpty)
assert.False(t, st.IsSet())
})
t.Run("returns error for invalid JSON", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.UnmarshalJSON([]byte(`not-valid-json`))
require.Error(t, err)
assert.False(t, st.IsSet())
})
t.Run("replaces existing token", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.Set([]byte("old-token"))
require.NoError(t, err)
err = st.UnmarshalJSON([]byte(`"new-token"`))
require.NoError(t, err)
assert.True(t, st.Equals("new-token"))
assert.False(t, st.Equals("old-token"))
})
t.Run("wipes input buffer after parsing", func(t *testing.T) {
t.Parallel()
// Create a buffer with a known token
input := []byte(`"secret-token-12345"`)
original := make([]byte, len(input))
copy(original, input)
st := &SecureToken{}
err := st.UnmarshalJSON(input)
require.NoError(t, err)
// Verify the token was stored correctly
assert.True(t, st.Equals("secret-token-12345"))
// Verify the input buffer was wiped (all zeros)
for i, b := range input {
assert.Equal(t, byte(0), b, "byte at position %d should be zero, got %d", i, b)
}
})
t.Run("wipes input buffer on error", func(t *testing.T) {
t.Parallel()
// Create a buffer with an empty token (will error)
input := []byte(`""`)
st := &SecureToken{}
err := st.UnmarshalJSON(input)
require.Error(t, err)
// Verify the input buffer was still wiped
for i, b := range input {
assert.Equal(t, byte(0), b, "byte at position %d should be zero, got %d", i, b)
}
})
t.Run("rejects escape sequences", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.UnmarshalJSON([]byte(`"token\nwith\nnewlines"`))
require.Error(t, err)
assert.Contains(t, err.Error(), "escape sequence")
assert.False(t, st.IsSet())
})
}
func TestSecureTokenSetWipesInput(t *testing.T) {
t.Parallel()
t.Run("wipes input buffer after storing", func(t *testing.T) {
t.Parallel()
// Create a buffer with a known token
input := []byte("my-secret-token")
original := make([]byte, len(input))
copy(original, input)
st := &SecureToken{}
err := st.Set(input)
require.NoError(t, err)
// Verify the token was stored correctly
assert.True(t, st.Equals("my-secret-token"))
// Verify the input buffer was wiped (all zeros)
for i, b := range input {
assert.Equal(t, byte(0), b, "byte at position %d should be zero, got %d", i, b)
}
})
}
func TestSecureTokenTakeFrom(t *testing.T) {
t.Parallel()
t.Run("transfers token from source to destination", func(t *testing.T) {
t.Parallel()
src := &SecureToken{}
err := src.Set([]byte("source-token"))
require.NoError(t, err)
dst := &SecureToken{}
dst.TakeFrom(src)
assert.True(t, dst.IsSet())
assert.True(t, dst.Equals("source-token"))
assert.False(t, src.IsSet(), "source should be empty after transfer")
})
t.Run("replaces existing destination token", func(t *testing.T) {
t.Parallel()
src := &SecureToken{}
err := src.Set([]byte("new-token"))
require.NoError(t, err)
dst := &SecureToken{}
err = dst.Set([]byte("old-token"))
require.NoError(t, err)
dst.TakeFrom(src)
assert.True(t, dst.Equals("new-token"))
assert.False(t, dst.Equals("old-token"))
assert.False(t, src.IsSet())
})
t.Run("handles nil source", func(t *testing.T) {
t.Parallel()
dst := &SecureToken{}
err := dst.Set([]byte("existing-token"))
require.NoError(t, err)
dst.TakeFrom(nil)
assert.True(t, dst.IsSet(), "destination should be unchanged with nil source")
assert.True(t, dst.Equals("existing-token"))
})
t.Run("handles empty source", func(t *testing.T) {
t.Parallel()
src := &SecureToken{}
dst := &SecureToken{}
err := dst.Set([]byte("existing-token"))
require.NoError(t, err)
dst.TakeFrom(src)
assert.False(t, dst.IsSet(), "destination should be cleared when source is empty")
})
t.Run("self-transfer is no-op and does not deadlock", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.Set([]byte("token"))
require.NoError(t, err)
st.TakeFrom(st)
assert.True(t, st.IsSet(), "token should remain set after self-transfer")
assert.True(t, st.Equals("token"), "token value should be unchanged")
})
}
func TestSecureTokenEqualsSecure(t *testing.T) {
t.Parallel()
t.Run("returns true for matching tokens", func(t *testing.T) {
t.Parallel()
st1 := &SecureToken{}
err := st1.Set([]byte("same-token"))
require.NoError(t, err)
st2 := &SecureToken{}
err = st2.Set([]byte("same-token"))
require.NoError(t, err)
assert.True(t, st1.EqualsSecure(st2))
assert.True(t, st2.EqualsSecure(st1))
})
t.Run("concurrent TakeFrom and EqualsSecure do not deadlock", func(t *testing.T) {
t.Parallel()
// This test verifies the fix for the lock ordering deadlock bug.
const iterations = 100
for range iterations {
a := &SecureToken{}
err := a.Set([]byte("token-a"))
require.NoError(t, err)
b := &SecureToken{}
err = b.Set([]byte("token-b"))
require.NoError(t, err)
var wg sync.WaitGroup
wg.Add(2)
// Goroutine 1: a.TakeFrom(b)
go func() {
defer wg.Done()
a.TakeFrom(b)
}()
// Goroutine 2: b.EqualsSecure(a)
go func() {
defer wg.Done()
b.EqualsSecure(a)
}()
wg.Wait()
}
})
t.Run("returns false for different tokens", func(t *testing.T) {
t.Parallel()
st1 := &SecureToken{}
err := st1.Set([]byte("token-a"))
require.NoError(t, err)
st2 := &SecureToken{}
err = st2.Set([]byte("token-b"))
require.NoError(t, err)
assert.False(t, st1.EqualsSecure(st2))
})
t.Run("returns false when comparing with nil", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.Set([]byte("token"))
require.NoError(t, err)
assert.False(t, st.EqualsSecure(nil))
})
t.Run("returns false when other is not set", func(t *testing.T) {
t.Parallel()
st1 := &SecureToken{}
err := st1.Set([]byte("token"))
require.NoError(t, err)
st2 := &SecureToken{}
assert.False(t, st1.EqualsSecure(st2))
})
t.Run("returns false when self is not set", func(t *testing.T) {
t.Parallel()
st1 := &SecureToken{}
st2 := &SecureToken{}
err := st2.Set([]byte("token"))
require.NoError(t, err)
assert.False(t, st1.EqualsSecure(st2))
})
t.Run("self-comparison returns true when set", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.Set([]byte("token"))
require.NoError(t, err)
assert.True(t, st.EqualsSecure(st), "self-comparison should return true and not deadlock")
})
t.Run("self-comparison returns false when not set", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
assert.False(t, st.EqualsSecure(st), "self-comparison on unset token should return false")
})
}

View File

@ -0,0 +1,95 @@
// SPDX-License-Identifier: Apache-2.0
package api
import (
"context"
"encoding/json"
"net/http"
"sync"
"github.com/rs/zerolog"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
// MMDSClient provides access to MMDS metadata.
type MMDSClient interface {
GetAccessTokenHash(ctx context.Context) (string, error)
}
// DefaultMMDSClient is the production implementation that calls the real MMDS endpoint.
type DefaultMMDSClient struct{}
func (c *DefaultMMDSClient) GetAccessTokenHash(ctx context.Context) (string, error) {
return host.GetAccessTokenHashFromMMDS(ctx)
}
type API struct {
isNotFC bool
logger *zerolog.Logger
accessToken *SecureToken
defaults *execcontext.Defaults
mmdsChan chan *host.MMDSOpts
hyperloopLock sync.Mutex
mmdsClient MMDSClient
lastSetTime *utils.AtomicMax
initLock sync.Mutex
}
func New(l *zerolog.Logger, defaults *execcontext.Defaults, mmdsChan chan *host.MMDSOpts, isNotFC bool) *API {
return &API{
logger: l,
defaults: defaults,
mmdsChan: mmdsChan,
isNotFC: isNotFC,
mmdsClient: &DefaultMMDSClient{},
lastSetTime: utils.NewAtomicMax(),
accessToken: &SecureToken{},
}
}
func (a *API) GetHealth(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
a.logger.Trace().Msg("Health check")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Content-Type", "")
w.WriteHeader(http.StatusNoContent)
}
func (a *API) GetMetrics(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
a.logger.Trace().Msg("Get metrics")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Content-Type", "application/json")
metrics, err := host.GetMetrics()
if err != nil {
a.logger.Error().Err(err).Msg("Failed to get metrics")
w.WriteHeader(http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(metrics); err != nil {
a.logger.Error().Err(err).Msg("Failed to encode metrics")
}
}
func (a *API) getLogger(err error) *zerolog.Event {
if err != nil {
return a.logger.Error().Err(err) //nolint:zerologlint // this is only prep
}
return a.logger.Info() //nolint:zerologlint // this is only prep
}

311
envd/internal/api/upload.go Normal file
View File

@ -0,0 +1,311 @@
// SPDX-License-Identifier: Apache-2.0
package api
import (
"encoding/json"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"os/user"
"path/filepath"
"strings"
"syscall"
"github.com/rs/zerolog"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
var ErrNoDiskSpace = fmt.Errorf("not enough disk space available")
func processFile(r *http.Request, path string, part io.Reader, uid, gid int, logger zerolog.Logger) (int, error) {
logger.Debug().
Str("path", path).
Msg("File processing")
err := permissions.EnsureDirs(filepath.Dir(path), uid, gid)
if err != nil {
err := fmt.Errorf("error ensuring directories: %w", err)
return http.StatusInternalServerError, err
}
canBePreChowned := false
stat, err := os.Stat(path)
if err != nil && !os.IsNotExist(err) {
errMsg := fmt.Errorf("error getting file info: %w", err)
return http.StatusInternalServerError, errMsg
} else if err == nil {
if stat.IsDir() {
err := fmt.Errorf("path is a directory: %s", path)
return http.StatusBadRequest, err
}
canBePreChowned = true
}
hasBeenChowned := false
if canBePreChowned {
err = os.Chown(path, uid, gid)
if err != nil {
if !os.IsNotExist(err) {
err = fmt.Errorf("error changing file ownership: %w", err)
return http.StatusInternalServerError, err
}
} else {
hasBeenChowned = true
}
}
file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o666)
if err != nil {
if errors.Is(err, syscall.ENOSPC) {
err = fmt.Errorf("not enough inodes available: %w", err)
return http.StatusInsufficientStorage, err
}
err := fmt.Errorf("error opening file: %w", err)
return http.StatusInternalServerError, err
}
defer file.Close()
if !hasBeenChowned {
err = os.Chown(path, uid, gid)
if err != nil {
err := fmt.Errorf("error changing file ownership: %w", err)
return http.StatusInternalServerError, err
}
}
_, err = file.ReadFrom(part)
if err != nil {
if errors.Is(err, syscall.ENOSPC) {
err = ErrNoDiskSpace
if r.ContentLength > 0 {
err = fmt.Errorf("attempted to write %d bytes: %w", r.ContentLength, err)
}
return http.StatusInsufficientStorage, err
}
err = fmt.Errorf("error writing file: %w", err)
return http.StatusInternalServerError, err
}
return http.StatusNoContent, nil
}
func resolvePath(part *multipart.Part, paths *UploadSuccess, u *user.User, defaultPath *string, params PostFilesParams) (string, error) {
var pathToResolve string
if params.Path != nil {
pathToResolve = *params.Path
} else {
var err error
customPart := utils.NewCustomPart(part)
pathToResolve, err = customPart.FileNameWithPath()
if err != nil {
return "", fmt.Errorf("error getting multipart custom part file name: %w", err)
}
}
filePath, err := permissions.ExpandAndResolve(pathToResolve, u, defaultPath)
if err != nil {
return "", fmt.Errorf("error resolving path: %w", err)
}
for _, entry := range *paths {
if entry.Path == filePath {
var alreadyUploaded []string
for _, uploadedFile := range *paths {
if uploadedFile.Path != filePath {
alreadyUploaded = append(alreadyUploaded, uploadedFile.Path)
}
}
errMsg := fmt.Errorf("you cannot upload multiple files to the same path '%s' in one upload request, only the first specified file was uploaded", filePath)
if len(alreadyUploaded) > 1 {
errMsg = fmt.Errorf("%w, also the following files were uploaded: %v", errMsg, strings.Join(alreadyUploaded, ", "))
}
return "", errMsg
}
}
return filePath, nil
}
func (a *API) handlePart(r *http.Request, part *multipart.Part, paths UploadSuccess, u *user.User, uid, gid int, operationID string, params PostFilesParams) (*EntryInfo, int, error) {
defer part.Close()
if part.FormName() != "file" {
return nil, http.StatusOK, nil
}
filePath, err := resolvePath(part, &paths, u, a.defaults.Workdir, params)
if err != nil {
return nil, http.StatusBadRequest, err
}
logger := a.logger.
With().
Str(string(logs.OperationIDKey), operationID).
Str("event_type", "file_processing").
Logger()
status, err := processFile(r, filePath, part, uid, gid, logger)
if err != nil {
return nil, status, err
}
return &EntryInfo{
Path: filePath,
Name: filepath.Base(filePath),
Type: File,
}, http.StatusOK, nil
}
func (a *API) PostFiles(w http.ResponseWriter, r *http.Request, params PostFilesParams) {
// Capture original body to ensure it's always closed
originalBody := r.Body
defer originalBody.Close()
var errorCode int
var errMsg error
var path string
if params.Path != nil {
path = *params.Path
}
operationID := logs.AssignOperationID()
// signing authorization if needed
err := a.validateSigning(r, params.Signature, params.SignatureExpiration, params.Username, path, SigningWriteOperation)
if err != nil {
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("error during auth validation")
jsonError(w, http.StatusUnauthorized, err)
return
}
username, err := execcontext.ResolveDefaultUsername(params.Username, a.defaults.User)
if err != nil {
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("no user specified")
jsonError(w, http.StatusBadRequest, err)
return
}
defer func() {
l := a.logger.
Err(errMsg).
Str("method", r.Method+" "+r.URL.Path).
Str(string(logs.OperationIDKey), operationID).
Str("path", path).
Str("username", username)
if errMsg != nil {
l = l.Int("error_code", errorCode)
}
l.Msg("File write")
}()
// Handle gzip-encoded request body
body, err := getDecompressedBody(r)
if err != nil {
errMsg = fmt.Errorf("error decompressing request body: %w", err)
errorCode = http.StatusBadRequest
jsonError(w, errorCode, errMsg)
return
}
defer body.Close()
r.Body = body
f, err := r.MultipartReader()
if err != nil {
errMsg = fmt.Errorf("error parsing multipart form: %w", err)
errorCode = http.StatusInternalServerError
jsonError(w, errorCode, errMsg)
return
}
u, err := user.Lookup(username)
if err != nil {
errMsg = fmt.Errorf("error looking up user '%s': %w", username, err)
errorCode = http.StatusUnauthorized
jsonError(w, errorCode, errMsg)
return
}
uid, gid, err := permissions.GetUserIdInts(u)
if err != nil {
errMsg = fmt.Errorf("error getting user ids: %w", err)
jsonError(w, http.StatusInternalServerError, errMsg)
return
}
paths := UploadSuccess{}
for {
part, partErr := f.NextPart()
if partErr == io.EOF {
// We're done reading the parts.
break
} else if partErr != nil {
errMsg = fmt.Errorf("error reading form: %w", partErr)
errorCode = http.StatusInternalServerError
jsonError(w, errorCode, errMsg)
break
}
entry, status, err := a.handlePart(r, part, paths, u, uid, gid, operationID, params)
if err != nil {
errorCode = status
errMsg = err
jsonError(w, errorCode, errMsg)
return
}
if entry != nil {
paths = append(paths, *entry)
}
}
data, err := json.Marshal(paths)
if err != nil {
errMsg = fmt.Errorf("error marshaling response: %w", err)
errorCode = http.StatusInternalServerError
jsonError(w, errorCode, errMsg)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write(data)
}

View File

@ -0,0 +1,251 @@
// SPDX-License-Identifier: Apache-2.0
package api
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
"testing"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestProcessFile(t *testing.T) {
t.Parallel()
uid := os.Getuid()
gid := os.Getgid()
newRequest := func(content []byte) (*http.Request, io.Reader) {
request := &http.Request{
ContentLength: int64(len(content)),
}
buffer := bytes.NewBuffer(content)
return request, buffer
}
var emptyReq http.Request
var emptyPart *bytes.Buffer
var emptyLogger zerolog.Logger
t.Run("failed to ensure directories", func(t *testing.T) {
t.Parallel()
httpStatus, err := processFile(&emptyReq, "/proc/invalid/not-real", emptyPart, uid, gid, emptyLogger)
require.Error(t, err)
assert.Equal(t, http.StatusInternalServerError, httpStatus)
assert.ErrorContains(t, err, "error ensuring directories: ")
})
t.Run("attempt to replace directory with a file", func(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
httpStatus, err := processFile(&emptyReq, tempDir, emptyPart, uid, gid, emptyLogger)
require.Error(t, err)
assert.Equal(t, http.StatusBadRequest, httpStatus, err.Error())
assert.ErrorContains(t, err, "path is a directory: ")
})
t.Run("fail to create file", func(t *testing.T) {
t.Parallel()
httpStatus, err := processFile(&emptyReq, "/proc/invalid-filename", emptyPart, uid, gid, emptyLogger)
require.Error(t, err)
assert.Equal(t, http.StatusInternalServerError, httpStatus)
assert.ErrorContains(t, err, "error opening file: ")
})
t.Run("out of disk space", func(t *testing.T) {
t.Parallel()
// make a tiny tmpfs mount
mountSize := 1024
tempDir := createTmpfsMount(t, mountSize)
// create test file
firstFileSize := mountSize / 2
tempFile1 := filepath.Join(tempDir, "test-file-1")
// fill it up
cmd := exec.CommandContext(t.Context(),
"dd", "if=/dev/zero", "of="+tempFile1, fmt.Sprintf("bs=%d", firstFileSize), "count=1")
err := cmd.Run()
require.NoError(t, err)
// create a new file that would fill up the
secondFileContents := make([]byte, mountSize*2)
for index := range secondFileContents {
secondFileContents[index] = 'a'
}
// try to replace it
request, buffer := newRequest(secondFileContents)
tempFile2 := filepath.Join(tempDir, "test-file-2")
httpStatus, err := processFile(request, tempFile2, buffer, uid, gid, emptyLogger)
require.Error(t, err)
assert.Equal(t, http.StatusInsufficientStorage, httpStatus)
assert.ErrorContains(t, err, "attempted to write 2048 bytes: not enough disk space")
})
t.Run("happy path", func(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
tempFile := filepath.Join(tempDir, "test-file")
content := []byte("test-file-contents")
request, buffer := newRequest(content)
httpStatus, err := processFile(request, tempFile, buffer, uid, gid, emptyLogger)
require.NoError(t, err)
assert.Equal(t, http.StatusNoContent, httpStatus)
data, err := os.ReadFile(tempFile)
require.NoError(t, err)
assert.Equal(t, content, data)
})
t.Run("overwrite file on full disk", func(t *testing.T) {
t.Parallel()
// make a tiny tmpfs mount
sizeInBytes := 1024
tempDir := createTmpfsMount(t, 1024)
// create test file
tempFile := filepath.Join(tempDir, "test-file")
// fill it up
cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+tempFile, fmt.Sprintf("bs=%d", sizeInBytes), "count=1")
err := cmd.Run()
require.NoError(t, err)
// try to replace it
content := []byte("test-file-contents")
request, buffer := newRequest(content)
httpStatus, err := processFile(request, tempFile, buffer, uid, gid, emptyLogger)
require.NoError(t, err)
assert.Equal(t, http.StatusNoContent, httpStatus)
})
t.Run("write new file on full disk", func(t *testing.T) {
t.Parallel()
// make a tiny tmpfs mount
sizeInBytes := 1024
tempDir := createTmpfsMount(t, 1024)
// create test file
tempFile1 := filepath.Join(tempDir, "test-file")
// fill it up
cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+tempFile1, fmt.Sprintf("bs=%d", sizeInBytes), "count=1")
err := cmd.Run()
require.NoError(t, err)
// try to write a new file
tempFile2 := filepath.Join(tempDir, "test-file-2")
content := []byte("test-file-contents")
request, buffer := newRequest(content)
httpStatus, err := processFile(request, tempFile2, buffer, uid, gid, emptyLogger)
require.ErrorContains(t, err, "not enough disk space available")
assert.Equal(t, http.StatusInsufficientStorage, httpStatus)
})
t.Run("write new file with no inodes available", func(t *testing.T) {
t.Parallel()
// make a tiny tmpfs mount
tempDir := createTmpfsMountWithInodes(t, 1024, 2)
// create test file
tempFile1 := filepath.Join(tempDir, "test-file")
// fill it up
cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+tempFile1, fmt.Sprintf("bs=%d", 100), "count=1")
err := cmd.Run()
require.NoError(t, err)
// try to write a new file
tempFile2 := filepath.Join(tempDir, "test-file-2")
content := []byte("test-file-contents")
request, buffer := newRequest(content)
httpStatus, err := processFile(request, tempFile2, buffer, uid, gid, emptyLogger)
require.ErrorContains(t, err, "not enough inodes available")
assert.Equal(t, http.StatusInsufficientStorage, httpStatus)
})
t.Run("update sysfs or other virtual fs", func(t *testing.T) {
t.Parallel()
if os.Geteuid() != 0 {
t.Skip("skipping sysfs updates: Operation not permitted with non-root user")
}
filePath := "/sys/fs/cgroup/user.slice/cpu.weight"
newContent := []byte("102\n")
request, buffer := newRequest(newContent)
httpStatus, err := processFile(request, filePath, buffer, uid, gid, emptyLogger)
require.NoError(t, err)
assert.Equal(t, http.StatusNoContent, httpStatus)
data, err := os.ReadFile(filePath)
require.NoError(t, err)
assert.Equal(t, newContent, data)
})
t.Run("replace file", func(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
tempFile := filepath.Join(tempDir, "test-file")
err := os.WriteFile(tempFile, []byte("old-contents"), 0o644)
require.NoError(t, err)
newContent := []byte("new-file-contents")
request, buffer := newRequest(newContent)
httpStatus, err := processFile(request, tempFile, buffer, uid, gid, emptyLogger)
require.NoError(t, err)
assert.Equal(t, http.StatusNoContent, httpStatus)
data, err := os.ReadFile(tempFile)
require.NoError(t, err)
assert.Equal(t, newContent, data)
})
}
func createTmpfsMount(t *testing.T, sizeInBytes int) string {
t.Helper()
return createTmpfsMountWithInodes(t, sizeInBytes, 5)
}
func createTmpfsMountWithInodes(t *testing.T, sizeInBytes, inodesCount int) string {
t.Helper()
if os.Geteuid() != 0 {
t.Skip("skipping sysfs updates: Operation not permitted with non-root user")
}
tempDir := t.TempDir()
cmd := exec.CommandContext(t.Context(),
"mount",
"tmpfs",
tempDir,
"-t", "tmpfs",
"-o", fmt.Sprintf("size=%d,nr_inodes=%d", sizeInBytes, inodesCount))
err := cmd.Run()
require.NoError(t, err)
t.Cleanup(func() {
ctx := context.WithoutCancel(t.Context())
cmd := exec.CommandContext(ctx, "umount", tempDir)
err := cmd.Run()
require.NoError(t, err)
})
return tempDir
}

View File

@ -0,0 +1,39 @@
// SPDX-License-Identifier: Apache-2.0
package execcontext
import (
"errors"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
type Defaults struct {
EnvVars *utils.Map[string, string]
User string
Workdir *string
}
func ResolveDefaultWorkdir(workdir string, defaultWorkdir *string) string {
if workdir != "" {
return workdir
}
if defaultWorkdir != nil {
return *defaultWorkdir
}
return ""
}
func ResolveDefaultUsername(username *string, defaultUsername string) (string, error) {
if username != nil {
return *username, nil
}
if defaultUsername != "" {
return defaultUsername, nil
}
return "", errors.New("username not provided")
}

View File

@ -0,0 +1,96 @@
// SPDX-License-Identifier: Apache-2.0
// Modifications by M/S Omukk
package host
import (
"math"
"time"
"github.com/shirou/gopsutil/v4/cpu"
"github.com/shirou/gopsutil/v4/mem"
"golang.org/x/sys/unix"
)
type Metrics struct {
Timestamp int64 `json:"ts"` // Unix Timestamp in UTC
CPUCount uint32 `json:"cpu_count"` // Total CPU cores
CPUUsedPercent float32 `json:"cpu_used_pct"` // Percent rounded to 2 decimal places
// Deprecated: kept for backwards compatibility with older orchestrators.
MemTotalMiB uint64 `json:"mem_total_mib"` // Total virtual memory in MiB
// Deprecated: kept for backwards compatibility with older orchestrators.
MemUsedMiB uint64 `json:"mem_used_mib"` // Used virtual memory in MiB
MemTotal uint64 `json:"mem_total"` // Total virtual memory in bytes
MemUsed uint64 `json:"mem_used"` // Used virtual memory in bytes
DiskUsed uint64 `json:"disk_used"` // Used disk space in bytes
DiskTotal uint64 `json:"disk_total"` // Total disk space in bytes
}
func GetMetrics() (*Metrics, error) {
v, err := mem.VirtualMemory()
if err != nil {
return nil, err
}
memUsedMiB := v.Used / 1024 / 1024
memTotalMiB := v.Total / 1024 / 1024
cpuTotal, err := cpu.Counts(true)
if err != nil {
return nil, err
}
cpuUsedPcts, err := cpu.Percent(0, false)
if err != nil {
return nil, err
}
cpuUsedPct := cpuUsedPcts[0]
cpuUsedPctRounded := float32(cpuUsedPct)
if cpuUsedPct > 0 {
cpuUsedPctRounded = float32(math.Round(cpuUsedPct*100) / 100)
}
diskMetrics, err := diskStats("/")
if err != nil {
return nil, err
}
return &Metrics{
Timestamp: time.Now().UTC().Unix(),
CPUCount: uint32(cpuTotal),
CPUUsedPercent: cpuUsedPctRounded,
MemUsedMiB: memUsedMiB,
MemTotalMiB: memTotalMiB,
MemTotal: v.Total,
MemUsed: v.Used,
DiskUsed: diskMetrics.Total - diskMetrics.Available,
DiskTotal: diskMetrics.Total,
}, nil
}
type diskSpace struct {
Total uint64
Available uint64
}
func diskStats(path string) (diskSpace, error) {
var st unix.Statfs_t
if err := unix.Statfs(path, &st); err != nil {
return diskSpace{}, err
}
block := uint64(st.Bsize)
// all data blocks
total := st.Blocks * block
// blocks available
available := st.Bavail * block
return diskSpace{Total: total, Available: available}, nil
}

185
envd/internal/host/mmds.go Normal file
View File

@ -0,0 +1,185 @@
// SPDX-License-Identifier: Apache-2.0
// Modifications by M/S Omukk
package host
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"time"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
const (
WrennRunDir = "/run/wrenn" // store sandbox metadata files here
mmdsDefaultAddress = "169.254.169.254"
mmdsTokenExpiration = 60 * time.Second
mmdsAccessTokenRequestClientTimeout = 10 * time.Second
)
var mmdsAccessTokenClient = &http.Client{
Timeout: mmdsAccessTokenRequestClientTimeout,
Transport: &http.Transport{
DisableKeepAlives: true,
},
}
type MMDSOpts struct {
SandboxID string `json:"instanceID"`
TemplateID string `json:"envID"`
LogsCollectorAddress string `json:"address"`
AccessTokenHash string `json:"accessTokenHash"`
}
func (opts *MMDSOpts) Update(sandboxID, templateID, collectorAddress string) {
opts.SandboxID = sandboxID
opts.TemplateID = templateID
opts.LogsCollectorAddress = collectorAddress
}
func (opts *MMDSOpts) AddOptsToJSON(jsonLogs []byte) ([]byte, error) {
parsed := make(map[string]any)
err := json.Unmarshal(jsonLogs, &parsed)
if err != nil {
return nil, err
}
parsed["instanceID"] = opts.SandboxID
parsed["envID"] = opts.TemplateID
data, err := json.Marshal(parsed)
return data, err
}
func getMMDSToken(ctx context.Context, client *http.Client) (string, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://"+mmdsDefaultAddress+"/latest/api/token", &bytes.Buffer{})
if err != nil {
return "", err
}
request.Header["X-metadata-token-ttl-seconds"] = []string{fmt.Sprint(mmdsTokenExpiration.Seconds())}
response, err := client.Do(request)
if err != nil {
return "", err
}
defer response.Body.Close()
body, err := io.ReadAll(response.Body)
if err != nil {
return "", err
}
token := string(body)
if len(token) == 0 {
return "", fmt.Errorf("mmds token is an empty string")
}
return token, nil
}
func getMMDSOpts(ctx context.Context, client *http.Client, token string) (*MMDSOpts, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://"+mmdsDefaultAddress, &bytes.Buffer{})
if err != nil {
return nil, err
}
request.Header["X-metadata-token"] = []string{token}
request.Header["Accept"] = []string{"application/json"}
response, err := client.Do(request)
if err != nil {
return nil, err
}
defer response.Body.Close()
body, err := io.ReadAll(response.Body)
if err != nil {
return nil, err
}
var opts MMDSOpts
err = json.Unmarshal(body, &opts)
if err != nil {
return nil, err
}
return &opts, nil
}
// GetAccessTokenHashFromMMDS reads the access token hash from MMDS.
// This is used to validate that /init requests come from the orchestrator.
func GetAccessTokenHashFromMMDS(ctx context.Context) (string, error) {
token, err := getMMDSToken(ctx, mmdsAccessTokenClient)
if err != nil {
return "", fmt.Errorf("failed to get MMDS token: %w", err)
}
opts, err := getMMDSOpts(ctx, mmdsAccessTokenClient, token)
if err != nil {
return "", fmt.Errorf("failed to get MMDS opts: %w", err)
}
return opts.AccessTokenHash, nil
}
func PollForMMDSOpts(ctx context.Context, mmdsChan chan<- *MMDSOpts, envVars *utils.Map[string, string]) {
httpClient := &http.Client{}
defer httpClient.CloseIdleConnections()
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "context cancelled while waiting for mmds opts")
return
case <-ticker.C:
token, err := getMMDSToken(ctx, httpClient)
if err != nil {
fmt.Fprintf(os.Stderr, "error getting mmds token: %v\n", err)
continue
}
mmdsOpts, err := getMMDSOpts(ctx, httpClient, token)
if err != nil {
fmt.Fprintf(os.Stderr, "error getting mmds opts: %v\n", err)
continue
}
envVars.Store("WRENN_SANDBOX_ID", mmdsOpts.SandboxID)
envVars.Store("WRENN_TEMPLATE_ID", mmdsOpts.TemplateID)
if err := os.WriteFile(filepath.Join(WrennRunDir, ".WRENN_SANDBOX_ID"), []byte(mmdsOpts.SandboxID), 0o666); err != nil {
fmt.Fprintf(os.Stderr, "error writing sandbox ID file: %v\n", err)
}
if err := os.WriteFile(filepath.Join(WrennRunDir, ".WRENN_TEMPLATE_ID"), []byte(mmdsOpts.TemplateID), 0o666); err != nil {
fmt.Fprintf(os.Stderr, "error writing template ID file: %v\n", err)
}
if mmdsOpts.LogsCollectorAddress != "" {
mmdsChan <- mmdsOpts
}
return
}
}
}

View File

@ -0,0 +1,49 @@
// SPDX-License-Identifier: Apache-2.0
package logs
import (
"time"
"github.com/rs/zerolog"
)
const (
defaultMaxBufferSize = 2 << 15
defaultTimeout = 2 * time.Second
)
func LogBufferedDataEvents(dataCh <-chan []byte, logger *zerolog.Logger, eventType string) {
timer := time.NewTicker(defaultTimeout)
defer timer.Stop()
var buffer []byte
defer func() {
if len(buffer) > 0 {
logger.Info().Str(eventType, string(buffer)).Msg("Streaming process event (flush)")
}
}()
for {
select {
case <-timer.C:
if len(buffer) > 0 {
logger.Info().Str(eventType, string(buffer)).Msg("Streaming process event")
buffer = nil
}
case data, ok := <-dataCh:
if !ok {
return
}
buffer = append(buffer, data...)
if len(buffer) >= defaultMaxBufferSize {
logger.Info().Str(eventType, string(buffer)).Msg("Streaming process event")
buffer = nil
continue
}
}
}
}

View File

@ -0,0 +1,174 @@
// SPDX-License-Identifier: Apache-2.0
package exporter
import (
"bytes"
"context"
"fmt"
"log"
"net/http"
"os"
"sync"
"time"
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
)
const ExporterTimeout = 10 * time.Second
type HTTPExporter struct {
client http.Client
logs [][]byte
isNotFC bool
mmdsOpts *host.MMDSOpts
// Concurrency coordination
triggers chan struct{}
logLock sync.RWMutex
mmdsLock sync.RWMutex
startOnce sync.Once
}
func NewHTTPLogsExporter(ctx context.Context, isNotFC bool, mmdsChan <-chan *host.MMDSOpts) *HTTPExporter {
exporter := &HTTPExporter{
client: http.Client{
Timeout: ExporterTimeout,
},
triggers: make(chan struct{}, 1),
isNotFC: isNotFC,
startOnce: sync.Once{},
mmdsOpts: &host.MMDSOpts{
SandboxID: "unknown",
TemplateID: "unknown",
LogsCollectorAddress: "",
},
}
go exporter.listenForMMDSOptsAndStart(ctx, mmdsChan)
return exporter
}
func (w *HTTPExporter) sendInstanceLogs(ctx context.Context, logs []byte, address string) error {
if address == "" {
return nil
}
request, err := http.NewRequestWithContext(ctx, http.MethodPost, address, bytes.NewBuffer(logs))
if err != nil {
return err
}
request.Header.Set("Content-Type", "application/json")
response, err := w.client.Do(request)
if err != nil {
return err
}
defer response.Body.Close()
return nil
}
func printLog(logs []byte) {
fmt.Fprintf(os.Stdout, "%v", string(logs))
}
func (w *HTTPExporter) listenForMMDSOptsAndStart(ctx context.Context, mmdsChan <-chan *host.MMDSOpts) {
for {
select {
case <-ctx.Done():
return
case mmdsOpts, ok := <-mmdsChan:
if !ok {
return
}
w.mmdsLock.Lock()
w.mmdsOpts.Update(mmdsOpts.SandboxID, mmdsOpts.TemplateID, mmdsOpts.LogsCollectorAddress)
w.mmdsLock.Unlock()
w.startOnce.Do(func() {
go w.start(ctx)
})
}
}
}
func (w *HTTPExporter) start(ctx context.Context) {
for range w.triggers {
logs := w.getAllLogs()
if len(logs) == 0 {
continue
}
if w.isNotFC {
for _, log := range logs {
fmt.Fprintf(os.Stdout, "%v", string(log))
}
continue
}
for _, logLine := range logs {
w.mmdsLock.RLock()
logLineWithOpts, err := w.mmdsOpts.AddOptsToJSON(logLine)
w.mmdsLock.RUnlock()
if err != nil {
log.Printf("error adding instance logging options (%+v) to JSON (%+v) with logs : %v\n", w.mmdsOpts, logLine, err)
printLog(logLine)
continue
}
err = w.sendInstanceLogs(ctx, logLineWithOpts, w.mmdsOpts.LogsCollectorAddress)
if err != nil {
log.Printf("error sending instance logs: %+v", err)
printLog(logLine)
continue
}
}
}
}
func (w *HTTPExporter) resumeProcessing() {
select {
case w.triggers <- struct{}{}:
default:
// Exporter processing already triggered
// This is expected behavior if the exporter is already processing logs
}
}
func (w *HTTPExporter) Write(logs []byte) (int, error) {
logsCopy := make([]byte, len(logs))
copy(logsCopy, logs)
go w.addLogs(logsCopy)
return len(logs), nil
}
func (w *HTTPExporter) getAllLogs() [][]byte {
w.logLock.Lock()
defer w.logLock.Unlock()
logs := w.logs
w.logs = nil
return logs
}
func (w *HTTPExporter) addLogs(logs []byte) {
w.logLock.Lock()
defer w.logLock.Unlock()
w.logs = append(w.logs, logs)
w.resumeProcessing()
}

View File

@ -0,0 +1,174 @@
// SPDX-License-Identifier: Apache-2.0
package logs
import (
"context"
"fmt"
"strconv"
"strings"
"sync/atomic"
"connectrpc.com/connect"
"github.com/rs/zerolog"
)
type OperationID string
const (
OperationIDKey OperationID = "operation_id"
DefaultHTTPMethod string = "POST"
)
var operationID = atomic.Int32{}
func AssignOperationID() string {
id := operationID.Add(1)
return strconv.Itoa(int(id))
}
func AddRequestIDToContext(ctx context.Context) context.Context {
return context.WithValue(ctx, OperationIDKey, AssignOperationID())
}
func formatMethod(method string) string {
parts := strings.Split(method, ".")
if len(parts) < 2 {
return method
}
split := strings.Split(parts[1], "/")
if len(split) < 2 {
return method
}
servicePart := split[0]
servicePart = strings.ToUpper(servicePart[:1]) + servicePart[1:]
methodPart := split[1]
methodPart = strings.ToLower(methodPart[:1]) + methodPart[1:]
return fmt.Sprintf("%s %s", servicePart, methodPart)
}
func NewUnaryLogInterceptor(logger *zerolog.Logger) connect.UnaryInterceptorFunc {
interceptor := func(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(
ctx context.Context,
req connect.AnyRequest,
) (connect.AnyResponse, error) {
ctx = AddRequestIDToContext(ctx)
res, err := next(ctx, req)
l := logger.
Err(err).
Str("method", DefaultHTTPMethod+" "+req.Spec().Procedure).
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
if err != nil {
l = l.Int("error_code", int(connect.CodeOf(err)))
}
if req != nil {
l = l.Interface("request", req.Any())
}
if res != nil && err == nil {
l = l.Interface("response", res.Any())
}
if res == nil && err == nil {
l = l.Interface("response", nil)
}
l.Msg(formatMethod(req.Spec().Procedure))
return res, err
})
}
return connect.UnaryInterceptorFunc(interceptor)
}
func LogServerStreamWithoutEvents[T any, R any](
ctx context.Context,
logger *zerolog.Logger,
req *connect.Request[R],
stream *connect.ServerStream[T],
handler func(ctx context.Context, req *connect.Request[R], stream *connect.ServerStream[T]) error,
) error {
ctx = AddRequestIDToContext(ctx)
l := logger.Debug().
Str("method", DefaultHTTPMethod+" "+req.Spec().Procedure).
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
if req != nil {
l = l.Interface("request", req.Any())
}
l.Msg(fmt.Sprintf("%s (server stream start)", formatMethod(req.Spec().Procedure)))
err := handler(ctx, req, stream)
logEvent := getErrDebugLogEvent(logger, err).
Str("method", DefaultHTTPMethod+" "+req.Spec().Procedure).
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
if err != nil {
logEvent = logEvent.Int("error_code", int(connect.CodeOf(err)))
} else {
logEvent = logEvent.Interface("response", nil)
}
logEvent.Msg(fmt.Sprintf("%s (server stream end)", formatMethod(req.Spec().Procedure)))
return err
}
func LogClientStreamWithoutEvents[T any, R any](
ctx context.Context,
logger *zerolog.Logger,
stream *connect.ClientStream[T],
handler func(ctx context.Context, stream *connect.ClientStream[T]) (*connect.Response[R], error),
) (*connect.Response[R], error) {
ctx = AddRequestIDToContext(ctx)
logger.Debug().
Str("method", DefaultHTTPMethod+" "+stream.Spec().Procedure).
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string)).
Msg(fmt.Sprintf("%s (client stream start)", formatMethod(stream.Spec().Procedure)))
res, err := handler(ctx, stream)
logEvent := getErrDebugLogEvent(logger, err).
Str("method", DefaultHTTPMethod+" "+stream.Spec().Procedure).
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
if err != nil {
logEvent = logEvent.Int("error_code", int(connect.CodeOf(err)))
}
if res != nil && err == nil {
logEvent = logEvent.Interface("response", res.Any())
}
if res == nil && err == nil {
logEvent = logEvent.Interface("response", nil)
}
logEvent.Msg(fmt.Sprintf("%s (client stream end)", formatMethod(stream.Spec().Procedure)))
return res, err
}
// Return logger with error level if err is not nil, otherwise return logger with debug level
func getErrDebugLogEvent(logger *zerolog.Logger, err error) *zerolog.Event {
if err != nil {
return logger.Error().Err(err) //nolint:zerologlint // this builds an event, it is not expected to return it
}
return logger.Debug() //nolint:zerologlint // this builds an event, it is not expected to return it
}

View File

@ -0,0 +1,37 @@
// SPDX-License-Identifier: Apache-2.0
package logs
import (
"context"
"io"
"os"
"time"
"github.com/rs/zerolog"
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs/exporter"
)
func NewLogger(ctx context.Context, isNotFC bool, mmdsChan <-chan *host.MMDSOpts) *zerolog.Logger {
zerolog.TimestampFieldName = "timestamp"
zerolog.TimeFieldFormat = time.RFC3339Nano
exporters := []io.Writer{}
if isNotFC {
exporters = append(exporters, os.Stdout)
} else {
exporters = append(exporters, exporter.NewHTTPLogsExporter(ctx, isNotFC, mmdsChan), os.Stdout)
}
l := zerolog.
New(io.MultiWriter(exporters...)).
With().
Timestamp().
Logger().
Level(zerolog.DebugLevel)
return &l
}

View File

@ -0,0 +1,49 @@
// SPDX-License-Identifier: Apache-2.0
package permissions
import (
"context"
"fmt"
"os/user"
"connectrpc.com/authn"
"connectrpc.com/connect"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
)
func AuthenticateUsername(_ context.Context, req authn.Request) (any, error) {
username, _, ok := req.BasicAuth()
if !ok {
// When no username is provided, ignore the authentication method (not all endpoints require it)
// Missing user is then handled in the GetAuthUser function
return nil, nil
}
u, err := GetUser(username)
if err != nil {
return nil, authn.Errorf("invalid username: '%s'", username)
}
return u, nil
}
func GetAuthUser(ctx context.Context, defaultUser string) (*user.User, error) {
u, ok := authn.GetInfo(ctx).(*user.User)
if !ok {
username, err := execcontext.ResolveDefaultUsername(nil, defaultUser)
if err != nil {
return nil, connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("no user specified"))
}
u, err := GetUser(username)
if err != nil {
return nil, authn.Errorf("invalid default user: '%s'", username)
}
return u, nil
}
return u, nil
}

View File

@ -0,0 +1,31 @@
// SPDX-License-Identifier: Apache-2.0
package permissions
import (
"strconv"
"time"
"connectrpc.com/connect"
)
const defaultKeepAliveInterval = 90 * time.Second
func GetKeepAliveTicker[T any](req *connect.Request[T]) (*time.Ticker, func()) {
keepAliveIntervalHeader := req.Header().Get("Keepalive-Ping-Interval")
var interval time.Duration
keepAliveIntervalInt, err := strconv.Atoi(keepAliveIntervalHeader)
if err != nil {
interval = defaultKeepAliveInterval
} else {
interval = time.Duration(keepAliveIntervalInt) * time.Second
}
ticker := time.NewTicker(interval)
return ticker, func() {
ticker.Reset(interval)
}
}

View File

@ -0,0 +1,98 @@
// SPDX-License-Identifier: Apache-2.0
package permissions
import (
"errors"
"fmt"
"os"
"os/user"
"path/filepath"
"slices"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
)
func expand(path, homedir string) (string, error) {
if len(path) == 0 {
return path, nil
}
if path[0] != '~' {
return path, nil
}
if len(path) > 1 && path[1] != '/' && path[1] != '\\' {
return "", errors.New("cannot expand user-specific home dir")
}
return filepath.Join(homedir, path[1:]), nil
}
func ExpandAndResolve(path string, user *user.User, defaultPath *string) (string, error) {
path = execcontext.ResolveDefaultWorkdir(path, defaultPath)
path, err := expand(path, user.HomeDir)
if err != nil {
return "", fmt.Errorf("failed to expand path '%s' for user '%s': %w", path, user.Username, err)
}
if filepath.IsAbs(path) {
return path, nil
}
// The filepath.Abs can correctly resolve paths like /home/user/../file
path = filepath.Join(user.HomeDir, path)
abs, err := filepath.Abs(path)
if err != nil {
return "", fmt.Errorf("failed to resolve path '%s' for user '%s' with home dir '%s': %w", path, user.Username, user.HomeDir, err)
}
return abs, nil
}
func getSubpaths(path string) (subpaths []string) {
for {
subpaths = append(subpaths, path)
path = filepath.Dir(path)
if path == "/" {
break
}
}
slices.Reverse(subpaths)
return subpaths
}
func EnsureDirs(path string, uid, gid int) error {
subpaths := getSubpaths(path)
for _, subpath := range subpaths {
info, err := os.Stat(subpath)
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to stat directory: %w", err)
}
if err != nil && os.IsNotExist(err) {
err = os.Mkdir(subpath, 0o755)
if err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
err = os.Chown(subpath, uid, gid)
if err != nil {
return fmt.Errorf("failed to chown directory: %w", err)
}
continue
}
if !info.IsDir() {
return fmt.Errorf("path is a file: %s", subpath)
}
}
return nil
}

View File

@ -0,0 +1,46 @@
// SPDX-License-Identifier: Apache-2.0
package permissions
import (
"fmt"
"os/user"
"strconv"
)
func GetUserIdUints(u *user.User) (uid, gid uint32, err error) {
newUID, err := strconv.ParseUint(u.Uid, 10, 32)
if err != nil {
return 0, 0, fmt.Errorf("error parsing uid '%s': %w", u.Uid, err)
}
newGID, err := strconv.ParseUint(u.Gid, 10, 32)
if err != nil {
return 0, 0, fmt.Errorf("error parsing gid '%s': %w", u.Gid, err)
}
return uint32(newUID), uint32(newGID), nil
}
func GetUserIdInts(u *user.User) (uid, gid int, err error) {
newUID, err := strconv.ParseInt(u.Uid, 10, strconv.IntSize)
if err != nil {
return 0, 0, fmt.Errorf("error parsing uid '%s': %w", u.Uid, err)
}
newGID, err := strconv.ParseInt(u.Gid, 10, strconv.IntSize)
if err != nil {
return 0, 0, fmt.Errorf("error parsing gid '%s': %w", u.Gid, err)
}
return int(newUID), int(newGID), nil
}
func GetUser(username string) (u *user.User, err error) {
u, err = user.Lookup(username)
if err != nil {
return nil, fmt.Errorf("error looking up user '%s': %w", username, err)
}
return u, nil
}

View File

@ -0,0 +1,220 @@
// SPDX-License-Identifier: Apache-2.0
// portf (port forward) periodaically scans opened TCP ports on the 127.0.0.1 (or localhost)
// and launches `socat` process for every such port in the background.
// socat forward traffic from `sourceIP`:port to the 127.0.0.1:port.
// WARNING: portf isn't thread safe!
package port
import (
"context"
"fmt"
"net"
"os/exec"
"syscall"
"github.com/rs/zerolog"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/cgroups"
)
type PortState string
const (
PortStateForward PortState = "FORWARD"
PortStateDelete PortState = "DELETE"
)
var defaultGatewayIP = net.IPv4(169, 254, 0, 21)
type PortToForward struct {
socat *exec.Cmd
// Process ID of the process that's listening on port.
pid int32
// family version of the ip.
family uint32
state PortState
port uint32
}
type Forwarder struct {
logger *zerolog.Logger
cgroupManager cgroups.Manager
// Map of ports that are being currently forwarded.
ports map[string]*PortToForward
scannerSubscriber *ScannerSubscriber
sourceIP net.IP
}
func NewForwarder(
logger *zerolog.Logger,
scanner *Scanner,
cgroupManager cgroups.Manager,
) *Forwarder {
scannerSub := scanner.AddSubscriber(
logger,
"port-forwarder",
// We only want to forward ports that are actively listening on localhost.
&ScannerFilter{
IPs: []string{"127.0.0.1", "localhost", "::1"},
State: "LISTEN",
},
)
return &Forwarder{
logger: logger,
sourceIP: defaultGatewayIP,
ports: make(map[string]*PortToForward),
scannerSubscriber: scannerSub,
cgroupManager: cgroupManager,
}
}
func (f *Forwarder) StartForwarding(ctx context.Context) {
if f.scannerSubscriber == nil {
f.logger.Error().Msg("Cannot start forwarding because scanner subscriber is nil")
return
}
for {
// procs is an array of currently opened ports.
if procs, ok := <-f.scannerSubscriber.Messages; ok {
// Now we are going to refresh all ports that are being forwarded in the `ports` map. Maybe add new ones
// and maybe remove some.
// Go through the ports that are currently being forwarded and set all of them
// to the `DELETE` state. We don't know yet if they will be there after refresh.
for _, v := range f.ports {
v.state = PortStateDelete
}
// Let's refresh our map of currently forwarded ports and mark the currently opened ones with the "FORWARD" state.
// This will make sure we won't delete them later.
for _, p := range procs {
key := fmt.Sprintf("%d-%d", p.Pid, p.Laddr.Port)
// We check if the opened port is in our map of forwarded ports.
val, portOk := f.ports[key]
if portOk {
// Just mark the port as being forwarded so we don't delete it.
// The actual socat process that handles forwarding should be running from the last iteration.
val.state = PortStateForward
} else {
f.logger.Debug().
Str("ip", p.Laddr.IP).
Uint32("port", p.Laddr.Port).
Uint32("family", familyToIPVersion(p.Family)).
Str("state", p.Status).
Msg("Detected new opened port on localhost that is not forwarded")
// The opened port wasn't in the map so we create a new PortToForward and start forwarding.
ptf := &PortToForward{
pid: p.Pid,
port: p.Laddr.Port,
state: PortStateForward,
family: familyToIPVersion(p.Family),
}
f.ports[key] = ptf
f.startPortForwarding(ctx, ptf)
}
}
// We go through the ports map one more time and stop forwarding all ports
// that stayed marked as "DELETE".
for _, v := range f.ports {
if v.state == PortStateDelete {
f.stopPortForwarding(v)
}
}
}
}
}
func (f *Forwarder) startPortForwarding(ctx context.Context, p *PortToForward) {
// https://unix.stackexchange.com/questions/311492/redirect-application-listening-on-localhost-to-listening-on-external-interface
// socat -d -d TCP4-LISTEN:4000,bind=169.254.0.21,fork TCP4:localhost:4000
// reuseaddr is used to fix the "Address already in use" error when restarting socat quickly.
cmd := exec.CommandContext(ctx,
"socat", "-d", "-d", "-d",
fmt.Sprintf("TCP4-LISTEN:%v,bind=%s,reuseaddr,fork", p.port, f.sourceIP.To4()),
fmt.Sprintf("TCP%d:localhost:%v", p.family, p.port),
)
cgroupFD, ok := f.cgroupManager.GetFileDescriptor(cgroups.ProcessTypeSocat)
cmd.SysProcAttr = &syscall.SysProcAttr{
Setpgid: true,
CgroupFD: cgroupFD,
UseCgroupFD: ok,
}
f.logger.Debug().
Str("socatCmd", cmd.String()).
Int32("pid", p.pid).
Uint32("family", p.family).
IPAddr("sourceIP", f.sourceIP.To4()).
Uint32("port", p.port).
Msg("About to start port forwarding")
if err := cmd.Start(); err != nil {
f.logger.
Error().
Str("socatCmd", cmd.String()).
Err(err).
Msg("Failed to start port forwarding - failed to start socat")
return
}
go func() {
if err := cmd.Wait(); err != nil {
f.logger.
Debug().
Str("socatCmd", cmd.String()).
Err(err).
Msg("Port forwarding socat process exited")
}
}()
p.socat = cmd
}
func (f *Forwarder) stopPortForwarding(p *PortToForward) {
if p.socat == nil {
return
}
defer func() { p.socat = nil }()
logger := f.logger.With().
Str("socatCmd", p.socat.String()).
Int32("pid", p.pid).
Uint32("family", p.family).
IPAddr("sourceIP", f.sourceIP.To4()).
Uint32("port", p.port).
Logger()
logger.Debug().Msg("Stopping port forwarding")
if err := syscall.Kill(-p.socat.Process.Pid, syscall.SIGKILL); err != nil {
logger.Error().Err(err).Msg("Failed to kill process group")
return
}
logger.Debug().Msg("Stopped port forwarding")
}
func familyToIPVersion(family uint32) uint32 {
switch family {
case syscall.AF_INET:
return 4
case syscall.AF_INET6:
return 6
default:
return 0 // Unknown or unsupported family
}
}

View File

@ -0,0 +1,61 @@
// SPDX-License-Identifier: Apache-2.0
package port
import (
"time"
"github.com/rs/zerolog"
"github.com/shirou/gopsutil/v4/net"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/smap"
)
type Scanner struct {
Processes chan net.ConnectionStat
scanExit chan struct{}
subs *smap.Map[*ScannerSubscriber]
period time.Duration
}
func (s *Scanner) Destroy() {
close(s.scanExit)
}
func NewScanner(period time.Duration) *Scanner {
return &Scanner{
period: period,
subs: smap.New[*ScannerSubscriber](),
scanExit: make(chan struct{}),
Processes: make(chan net.ConnectionStat),
}
}
func (s *Scanner) AddSubscriber(logger *zerolog.Logger, id string, filter *ScannerFilter) *ScannerSubscriber {
subscriber := NewScannerSubscriber(logger, id, filter)
s.subs.Insert(id, subscriber)
return subscriber
}
func (s *Scanner) Unsubscribe(sub *ScannerSubscriber) {
s.subs.Remove(sub.ID())
sub.Destroy()
}
// ScanAndBroadcast starts scanning open TCP ports and broadcasts every open port to all subscribers.
func (s *Scanner) ScanAndBroadcast() {
for {
// tcp monitors both ipv4 and ipv6 connections.
processes, _ := net.Connections("tcp")
for _, sub := range s.subs.Items() {
sub.Signal(processes)
}
select {
case <-s.scanExit:
return
default:
time.Sleep(s.period)
}
}
}

View File

@ -0,0 +1,52 @@
// SPDX-License-Identifier: Apache-2.0
package port
import (
"github.com/rs/zerolog"
"github.com/shirou/gopsutil/v4/net"
)
// If we want to create a listener/subscriber pattern somewhere else we should move
// from a concrete implementation to combination of generics and interfaces.
type ScannerSubscriber struct {
logger *zerolog.Logger
filter *ScannerFilter
Messages chan ([]net.ConnectionStat)
id string
}
func NewScannerSubscriber(logger *zerolog.Logger, id string, filter *ScannerFilter) *ScannerSubscriber {
return &ScannerSubscriber{
logger: logger,
id: id,
filter: filter,
Messages: make(chan []net.ConnectionStat),
}
}
func (ss *ScannerSubscriber) ID() string {
return ss.id
}
func (ss *ScannerSubscriber) Destroy() {
close(ss.Messages)
}
func (ss *ScannerSubscriber) Signal(proc []net.ConnectionStat) {
// Filter isn't specified. Accept everything.
if ss.filter == nil {
ss.Messages <- proc
} else {
filtered := []net.ConnectionStat{}
for i := range proc {
// We need to access the list directly otherwise there will be implicit memory aliasing
// If the filter matched a process, we will send it to a channel.
if ss.filter.Match(&proc[i]) {
filtered = append(filtered, proc[i])
}
}
ss.Messages <- filtered
}
}

View File

@ -0,0 +1,29 @@
// SPDX-License-Identifier: Apache-2.0
package port
import (
"slices"
"github.com/shirou/gopsutil/v4/net"
)
type ScannerFilter struct {
State string
IPs []string
}
func (sf *ScannerFilter) Match(proc *net.ConnectionStat) bool {
// Filter is an empty struct.
if sf.State == "" && len(sf.IPs) == 0 {
return false
}
ipMatch := slices.Contains(sf.IPs, proc.Laddr.IP)
if ipMatch && sf.State == proc.Status {
return true
}
return false
}

View File

@ -0,0 +1,129 @@
// SPDX-License-Identifier: Apache-2.0
package cgroups
import (
"errors"
"fmt"
"os"
"path/filepath"
"golang.org/x/sys/unix"
)
type Cgroup2Manager struct {
cgroupFDs map[ProcessType]int
}
var _ Manager = (*Cgroup2Manager)(nil)
type cgroup2Config struct {
rootPath string
processTypes map[ProcessType]Cgroup2Config
}
type Cgroup2ManagerOption func(*cgroup2Config)
func WithCgroup2RootSysFSPath(path string) Cgroup2ManagerOption {
return func(config *cgroup2Config) {
config.rootPath = path
}
}
func WithCgroup2ProcessType(processType ProcessType, path string, properties map[string]string) Cgroup2ManagerOption {
return func(config *cgroup2Config) {
if config.processTypes == nil {
config.processTypes = make(map[ProcessType]Cgroup2Config)
}
config.processTypes[processType] = Cgroup2Config{Path: path, Properties: properties}
}
}
type Cgroup2Config struct {
Path string
Properties map[string]string
}
func NewCgroup2Manager(opts ...Cgroup2ManagerOption) (*Cgroup2Manager, error) {
config := cgroup2Config{
rootPath: "/sys/fs/cgroup",
}
for _, opt := range opts {
opt(&config)
}
cgroupFDs, err := createCgroups(config)
if err != nil {
return nil, fmt.Errorf("failed to create cgroups: %w", err)
}
return &Cgroup2Manager{cgroupFDs: cgroupFDs}, nil
}
func createCgroups(configs cgroup2Config) (map[ProcessType]int, error) {
var (
results = make(map[ProcessType]int)
errs []error
)
for procType, config := range configs.processTypes {
fullPath := filepath.Join(configs.rootPath, config.Path)
fd, err := createCgroup(fullPath, config.Properties)
if err != nil {
errs = append(errs, fmt.Errorf("failed to create %s cgroup: %w", procType, err))
continue
}
results[procType] = fd
}
if len(errs) > 0 {
for procType, fd := range results {
err := unix.Close(fd)
if err != nil {
errs = append(errs, fmt.Errorf("failed to close cgroup fd for %s: %w", procType, err))
}
}
return nil, errors.Join(errs...)
}
return results, nil
}
func createCgroup(fullPath string, properties map[string]string) (int, error) {
if err := os.MkdirAll(fullPath, 0o755); err != nil {
return -1, fmt.Errorf("failed to create cgroup root: %w", err)
}
var errs []error
for name, value := range properties {
if err := os.WriteFile(filepath.Join(fullPath, name), []byte(value), 0o644); err != nil {
errs = append(errs, fmt.Errorf("failed to write cgroup property: %w", err))
}
}
if len(errs) > 0 {
return -1, errors.Join(errs...)
}
return unix.Open(fullPath, unix.O_RDONLY, 0)
}
func (c Cgroup2Manager) GetFileDescriptor(procType ProcessType) (int, bool) {
fd, ok := c.cgroupFDs[procType]
return fd, ok
}
func (c Cgroup2Manager) Close() error {
var errs []error
for procType, fd := range c.cgroupFDs {
if err := unix.Close(fd); err != nil {
errs = append(errs, fmt.Errorf("failed to close cgroup fd for %s: %w", procType, err))
}
delete(c.cgroupFDs, procType)
}
return errors.Join(errs...)
}

View File

@ -0,0 +1,187 @@
// SPDX-License-Identifier: Apache-2.0
package cgroups
import (
"context"
"fmt"
"math/rand"
"os"
"os/exec"
"strconv"
"syscall"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
oneByte = 1
kilobyte = 1024 * oneByte
megabyte = 1024 * kilobyte
)
func TestCgroupRoundTrip(t *testing.T) {
t.Parallel()
if os.Geteuid() != 0 {
t.Skip("must run as root")
return
}
maxTimeout := time.Second * 5
t.Run("process does not die without cgroups", func(t *testing.T) {
t.Parallel()
// create manager
m, err := NewCgroup2Manager()
require.NoError(t, err)
// create new child process
cmd := startProcess(t, m, "not-a-real-one")
// wait for child process to die
err = waitForProcess(t, cmd, maxTimeout)
require.ErrorIs(t, err, context.DeadlineExceeded)
})
t.Run("process dies with cgroups", func(t *testing.T) {
t.Parallel()
cgroupPath := createCgroupPath(t, "real-one")
// create manager
m, err := NewCgroup2Manager(
WithCgroup2ProcessType(ProcessTypePTY, cgroupPath, map[string]string{
"memory.max": strconv.Itoa(1 * megabyte),
}),
)
require.NoError(t, err)
t.Cleanup(func() {
err := m.Close()
assert.NoError(t, err)
})
// create new child process
cmd := startProcess(t, m, ProcessTypePTY)
// wait for child process to die
err = waitForProcess(t, cmd, maxTimeout)
// verify process exited correctly
var exitErr *exec.ExitError
require.ErrorAs(t, err, &exitErr)
assert.Equal(t, "signal: killed", exitErr.Error())
assert.False(t, exitErr.Exited())
assert.False(t, exitErr.Success())
assert.Equal(t, -1, exitErr.ExitCode())
// dig a little deeper
ws, ok := exitErr.Sys().(syscall.WaitStatus)
require.True(t, ok)
assert.Equal(t, syscall.SIGKILL, ws.Signal())
assert.True(t, ws.Signaled())
assert.False(t, ws.Stopped())
assert.False(t, ws.Continued())
assert.False(t, ws.CoreDump())
assert.False(t, ws.Exited())
assert.Equal(t, -1, ws.ExitStatus())
})
t.Run("process cannot be spawned because memory limit is too low", func(t *testing.T) {
t.Parallel()
cgroupPath := createCgroupPath(t, "real-one")
// create manager
m, err := NewCgroup2Manager(
WithCgroup2ProcessType(ProcessTypeSocat, cgroupPath, map[string]string{
"memory.max": strconv.Itoa(1 * kilobyte),
}),
)
require.NoError(t, err)
t.Cleanup(func() {
err := m.Close()
assert.NoError(t, err)
})
// create new child process
cmd := startProcess(t, m, ProcessTypeSocat)
// wait for child process to die
err = waitForProcess(t, cmd, maxTimeout)
// verify process exited correctly
var exitErr *exec.ExitError
require.ErrorAs(t, err, &exitErr)
assert.Equal(t, "exit status 253", exitErr.Error())
assert.True(t, exitErr.Exited())
assert.False(t, exitErr.Success())
assert.Equal(t, 253, exitErr.ExitCode())
// dig a little deeper
ws, ok := exitErr.Sys().(syscall.WaitStatus)
require.True(t, ok)
assert.Equal(t, syscall.Signal(-1), ws.Signal())
assert.False(t, ws.Signaled())
assert.False(t, ws.Stopped())
assert.False(t, ws.Continued())
assert.False(t, ws.CoreDump())
assert.True(t, ws.Exited())
assert.Equal(t, 253, ws.ExitStatus())
})
}
func createCgroupPath(t *testing.T, s string) string {
t.Helper()
randPart := rand.Int()
return fmt.Sprintf("envd-test-%s-%d", s, randPart)
}
func startProcess(t *testing.T, m *Cgroup2Manager, pt ProcessType) *exec.Cmd {
t.Helper()
cmdName, args := "bash", []string{"-c", `sleep 1 && tail /dev/zero`}
cmd := exec.CommandContext(t.Context(), cmdName, args...)
fd, ok := m.GetFileDescriptor(pt)
cmd.SysProcAttr = &syscall.SysProcAttr{
UseCgroupFD: ok,
CgroupFD: fd,
}
err := cmd.Start()
require.NoError(t, err)
return cmd
}
func waitForProcess(t *testing.T, cmd *exec.Cmd, timeout time.Duration) error {
t.Helper()
done := make(chan error, 1)
go func() {
defer close(done)
done <- cmd.Wait()
}()
ctx, cancel := context.WithTimeout(t.Context(), timeout)
t.Cleanup(cancel)
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
return err
}
}

View File

@ -0,0 +1,16 @@
// SPDX-License-Identifier: Apache-2.0
package cgroups
type ProcessType string
const (
ProcessTypePTY ProcessType = "pty"
ProcessTypeUser ProcessType = "user"
ProcessTypeSocat ProcessType = "socat"
)
type Manager interface {
GetFileDescriptor(procType ProcessType) (int, bool)
Close() error
}

View File

@ -0,0 +1,19 @@
// SPDX-License-Identifier: Apache-2.0
package cgroups
type NoopManager struct{}
var _ Manager = (*NoopManager)(nil)
func NewNoopManager() *NoopManager {
return &NoopManager{}
}
func (n NoopManager) GetFileDescriptor(ProcessType) (int, bool) {
return 0, false
}
func (n NoopManager) Close() error {
return nil
}

View File

@ -0,0 +1,186 @@
// SPDX-License-Identifier: Apache-2.0
package filesystem
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"connectrpc.com/connect"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
)
func (s Service) ListDir(ctx context.Context, req *connect.Request[rpc.ListDirRequest]) (*connect.Response[rpc.ListDirResponse], error) {
depth := req.Msg.GetDepth()
if depth == 0 {
depth = 1 // default depth to current directory
}
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
if err != nil {
return nil, err
}
requestedPath := req.Msg.GetPath()
// Expand the path so we can return absolute paths in the response.
requestedPath, err = permissions.ExpandAndResolve(requestedPath, u, s.defaults.Workdir)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
resolvedPath, err := followSymlink(requestedPath)
if err != nil {
return nil, err
}
err = checkIfDirectory(resolvedPath)
if err != nil {
return nil, err
}
entries, err := walkDir(requestedPath, resolvedPath, int(depth))
if err != nil {
return nil, err
}
return connect.NewResponse(&rpc.ListDirResponse{
Entries: entries,
}), nil
}
func (s Service) MakeDir(ctx context.Context, req *connect.Request[rpc.MakeDirRequest]) (*connect.Response[rpc.MakeDirResponse], error) {
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
if err != nil {
return nil, err
}
dirPath, err := permissions.ExpandAndResolve(req.Msg.GetPath(), u, s.defaults.Workdir)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
stat, err := os.Stat(dirPath)
if err != nil && !os.IsNotExist(err) {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error getting file info: %w", err))
}
if err == nil {
if stat.IsDir() {
return nil, connect.NewError(connect.CodeAlreadyExists, fmt.Errorf("directory already exists: %s", dirPath))
}
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("path already exists but it is not a directory: %s", dirPath))
}
uid, gid, userErr := permissions.GetUserIdInts(u)
if userErr != nil {
return nil, connect.NewError(connect.CodeInternal, userErr)
}
userErr = permissions.EnsureDirs(dirPath, uid, gid)
if userErr != nil {
return nil, connect.NewError(connect.CodeInternal, userErr)
}
entry, err := entryInfo(dirPath)
if err != nil {
return nil, err
}
return connect.NewResponse(&rpc.MakeDirResponse{
Entry: entry,
}), nil
}
// followSymlink resolves a symbolic link to its target path.
func followSymlink(path string) (string, error) {
// Resolve symlinks
resolvedPath, err := filepath.EvalSymlinks(path)
if err != nil {
if os.IsNotExist(err) {
return "", connect.NewError(connect.CodeNotFound, fmt.Errorf("path not found: %w", err))
}
if strings.Contains(err.Error(), "too many links") {
return "", connect.NewError(connect.CodeFailedPrecondition, fmt.Errorf("cyclic symlink or chain >255 links at %q", path))
}
return "", connect.NewError(connect.CodeInternal, fmt.Errorf("error resolving symlink: %w", err))
}
return resolvedPath, nil
}
// checkIfDirectory checks if the given path is a directory.
func checkIfDirectory(path string) error {
stat, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
return connect.NewError(connect.CodeNotFound, fmt.Errorf("directory not found: %w", err))
}
return connect.NewError(connect.CodeInternal, fmt.Errorf("error getting file info: %w", err))
}
if !stat.IsDir() {
return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("path is not a directory: %s", path))
}
return nil
}
// walkDir walks the directory tree starting from dirPath up to the specified depth (doesn't follow symlinks).
func walkDir(requestedPath string, dirPath string, depth int) (entries []*rpc.EntryInfo, err error) {
err = filepath.WalkDir(dirPath, func(path string, _ os.DirEntry, err error) error {
if err != nil {
return err
}
// Skip the root directory itself
if path == dirPath {
return nil
}
// Calculate current depth
relPath, err := filepath.Rel(dirPath, path)
if err != nil {
return err
}
currentDepth := len(strings.Split(relPath, string(os.PathSeparator)))
if currentDepth > depth {
return filepath.SkipDir
}
entryInfo, err := entryInfo(path)
if err != nil {
var connectErr *connect.Error
if errors.As(err, &connectErr) && connectErr.Code() == connect.CodeNotFound {
// Skip entries that don't exist anymore
return nil
}
return err
}
// Return the requested path as the base path instead of the symlink-resolved path
path = filepath.Join(requestedPath, relPath)
entryInfo.Path = path
entries = append(entries, entryInfo)
return nil
})
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error reading directory %s: %w", dirPath, err))
}
return entries, nil
}

View File

@ -0,0 +1,407 @@
// SPDX-License-Identifier: Apache-2.0
package filesystem
import (
"context"
"errors"
"fmt"
"os"
"os/user"
"path/filepath"
"testing"
"connectrpc.com/authn"
"connectrpc.com/connect"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
)
func TestListDir(t *testing.T) {
t.Parallel()
// Setup temp root and user
root := t.TempDir()
u, err := user.Current()
require.NoError(t, err)
// Setup directory structure
testFolder := filepath.Join(root, "test")
require.NoError(t, os.MkdirAll(filepath.Join(testFolder, "test-dir", "sub-dir-1"), 0o755))
require.NoError(t, os.MkdirAll(filepath.Join(testFolder, "test-dir", "sub-dir-2"), 0o755))
filePath := filepath.Join(testFolder, "test-dir", "sub-dir-1", "file.txt")
require.NoError(t, os.WriteFile(filePath, []byte("Hello, World!"), 0o644))
// Service instance
svc := mockService()
// Helper to inject user into context
injectUser := func(ctx context.Context, u *user.User) context.Context {
return authn.SetInfo(ctx, u)
}
tests := []struct {
name string
depth uint32
expectedPaths []string
}{
{
name: "depth 0 lists only root directory",
depth: 0,
expectedPaths: []string{
filepath.Join(testFolder, "test-dir"),
},
},
{
name: "depth 1 lists root directory",
depth: 1,
expectedPaths: []string{
filepath.Join(testFolder, "test-dir"),
},
},
{
name: "depth 2 lists first level of subdirectories (in this case the root directory)",
depth: 2,
expectedPaths: []string{
filepath.Join(testFolder, "test-dir"),
filepath.Join(testFolder, "test-dir", "sub-dir-1"),
filepath.Join(testFolder, "test-dir", "sub-dir-2"),
},
},
{
name: "depth 3 lists all directories and files",
depth: 3,
expectedPaths: []string{
filepath.Join(testFolder, "test-dir"),
filepath.Join(testFolder, "test-dir", "sub-dir-1"),
filepath.Join(testFolder, "test-dir", "sub-dir-2"),
filepath.Join(testFolder, "test-dir", "sub-dir-1", "file.txt"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := injectUser(t.Context(), u)
req := connect.NewRequest(&filesystem.ListDirRequest{
Path: testFolder,
Depth: tt.depth,
})
resp, err := svc.ListDir(ctx, req)
require.NoError(t, err)
assert.NotEmpty(t, resp.Msg)
assert.Len(t, resp.Msg.GetEntries(), len(tt.expectedPaths))
actualPaths := make([]string, len(resp.Msg.GetEntries()))
for i, entry := range resp.Msg.GetEntries() {
actualPaths[i] = entry.GetPath()
}
assert.ElementsMatch(t, tt.expectedPaths, actualPaths)
})
}
}
func TestListDirNonExistingPath(t *testing.T) {
t.Parallel()
svc := mockService()
u, err := user.Current()
require.NoError(t, err)
ctx := authn.SetInfo(t.Context(), u)
req := connect.NewRequest(&filesystem.ListDirRequest{
Path: "/non-existing-path",
Depth: 1,
})
_, err = svc.ListDir(ctx, req)
require.Error(t, err)
var connectErr *connect.Error
ok := errors.As(err, &connectErr)
assert.True(t, ok, "expected error to be of type *connect.Error")
assert.Equal(t, connect.CodeNotFound, connectErr.Code())
}
func TestListDirRelativePath(t *testing.T) {
t.Parallel()
// Setup temp root and user
u, err := user.Current()
require.NoError(t, err)
// Setup directory structure
testRelativePath := fmt.Sprintf("test-%s", uuid.New())
testFolderPath := filepath.Join(u.HomeDir, testRelativePath)
filePath := filepath.Join(testFolderPath, "file.txt")
require.NoError(t, os.MkdirAll(testFolderPath, 0o755))
require.NoError(t, os.WriteFile(filePath, []byte("Hello, World!"), 0o644))
// Service instance
svc := mockService()
ctx := authn.SetInfo(t.Context(), u)
req := connect.NewRequest(&filesystem.ListDirRequest{
Path: testRelativePath,
Depth: 1,
})
resp, err := svc.ListDir(ctx, req)
require.NoError(t, err)
assert.NotEmpty(t, resp.Msg)
expectedPaths := []string{
filepath.Join(testFolderPath, "file.txt"),
}
assert.Len(t, resp.Msg.GetEntries(), len(expectedPaths))
actualPaths := make([]string, len(resp.Msg.GetEntries()))
for i, entry := range resp.Msg.GetEntries() {
actualPaths[i] = entry.GetPath()
}
assert.ElementsMatch(t, expectedPaths, actualPaths)
}
func TestListDir_Symlinks(t *testing.T) {
t.Parallel()
root := t.TempDir()
u, err := user.Current()
require.NoError(t, err)
ctx := authn.SetInfo(t.Context(), u)
symlinkRoot := filepath.Join(root, "test-symlinks")
require.NoError(t, os.MkdirAll(symlinkRoot, 0o755))
// 1. Prepare a real directory + file that a symlink will point to
realDir := filepath.Join(symlinkRoot, "real-dir")
require.NoError(t, os.MkdirAll(realDir, 0o755))
filePath := filepath.Join(realDir, "file.txt")
require.NoError(t, os.WriteFile(filePath, []byte("hello via symlink"), 0o644))
// 2. Prepare a standalone real file (points-to-file scenario)
realFile := filepath.Join(symlinkRoot, "real-file.txt")
require.NoError(t, os.WriteFile(realFile, []byte("i am a plain file"), 0o644))
// 3. Create the three symlinks
linkToDir := filepath.Join(symlinkRoot, "link-dir") // → directory
linkToFile := filepath.Join(symlinkRoot, "link-file") // → file
cyclicLink := filepath.Join(symlinkRoot, "cyclic") // → itself
require.NoError(t, os.Symlink(realDir, linkToDir))
require.NoError(t, os.Symlink(realFile, linkToFile))
require.NoError(t, os.Symlink(cyclicLink, cyclicLink))
svc := mockService()
t.Run("symlink to directory behaves like directory and the content looks like inside the directory", func(t *testing.T) {
t.Parallel()
req := connect.NewRequest(&filesystem.ListDirRequest{
Path: linkToDir,
Depth: 1,
})
resp, err := svc.ListDir(ctx, req)
require.NoError(t, err)
expected := []string{
filepath.Join(linkToDir, "file.txt"),
}
actual := make([]string, len(resp.Msg.GetEntries()))
for i, e := range resp.Msg.GetEntries() {
actual[i] = e.GetPath()
}
assert.ElementsMatch(t, expected, actual)
})
t.Run("link to file", func(t *testing.T) {
t.Parallel()
req := connect.NewRequest(&filesystem.ListDirRequest{
Path: linkToFile,
Depth: 1,
})
_, err := svc.ListDir(ctx, req)
require.Error(t, err)
assert.Contains(t, err.Error(), "not a directory")
})
t.Run("cyclic symlink surfaces 'too many links' → invalid-argument", func(t *testing.T) {
t.Parallel()
req := connect.NewRequest(&filesystem.ListDirRequest{
Path: cyclicLink,
})
_, err := svc.ListDir(ctx, req)
require.Error(t, err)
var connectErr *connect.Error
ok := errors.As(err, &connectErr)
assert.True(t, ok, "expected error to be of type *connect.Error")
assert.Equal(t, connect.CodeFailedPrecondition, connectErr.Code())
assert.Contains(t, connectErr.Error(), "cyclic symlink")
})
t.Run("symlink not resolved if not root", func(t *testing.T) {
t.Parallel()
req := connect.NewRequest(&filesystem.ListDirRequest{
Path: symlinkRoot,
Depth: 3,
})
res, err := svc.ListDir(ctx, req)
require.NoError(t, err)
expected := []string{
filepath.Join(symlinkRoot, "cyclic"),
filepath.Join(symlinkRoot, "link-dir"),
filepath.Join(symlinkRoot, "link-file"),
filepath.Join(symlinkRoot, "real-dir"),
filepath.Join(symlinkRoot, "real-dir", "file.txt"),
filepath.Join(symlinkRoot, "real-file.txt"),
}
actual := make([]string, len(res.Msg.GetEntries()))
for i, e := range res.Msg.GetEntries() {
actual[i] = e.GetPath()
}
assert.ElementsMatch(t, expected, actual, "symlinks should not be resolved when listing the symlink root directory")
})
}
// TestFollowSymlink_Success makes sure that followSymlink resolves symlinks,
// while also being robust to the /var → /private/var indirection that exists on macOS.
func TestFollowSymlink_Success(t *testing.T) {
t.Parallel()
// Base temporary directory. On macOS this lives under /var/folders/…
// which itself is a symlink to /private/var/folders/….
base := t.TempDir()
// Create a real directory that we ultimately want to resolve to.
target := filepath.Join(base, "target")
require.NoError(t, os.MkdirAll(target, 0o755))
// Create a symlink pointing at the real directory so we can verify that
// followSymlink follows it.
link := filepath.Join(base, "link")
require.NoError(t, os.Symlink(target, link))
got, err := followSymlink(link)
require.NoError(t, err)
// Canonicalise the expected path too, so that /var → /private/var (macOS)
// or any other benign symlink indirections dont cause flaky tests.
want, err := filepath.EvalSymlinks(link)
require.NoError(t, err)
require.Equal(t, want, got, "followSymlink should resolve and canonicalise symlinks")
}
// TestFollowSymlink_MultiSymlinkChain verifies that followSymlink follows a chain
// of several symlinks (noncyclic) correctly.
func TestFollowSymlink_MultiSymlinkChain(t *testing.T) {
t.Parallel()
base := t.TempDir()
// Final destination directory.
target := filepath.Join(base, "target")
require.NoError(t, os.MkdirAll(target, 0o755))
// Build a 3link chain: link1 → link2 → link3 → target.
link3 := filepath.Join(base, "link3")
require.NoError(t, os.Symlink(target, link3))
link2 := filepath.Join(base, "link2")
require.NoError(t, os.Symlink(link3, link2))
link1 := filepath.Join(base, "link1")
require.NoError(t, os.Symlink(link2, link1))
got, err := followSymlink(link1)
require.NoError(t, err)
want, err := filepath.EvalSymlinks(link1)
require.NoError(t, err)
require.Equal(t, want, got, "followSymlink should resolve an arbitrary symlink chain")
}
func TestFollowSymlink_NotFound(t *testing.T) {
t.Parallel()
_, err := followSymlink("/definitely/does/not/exist")
require.Error(t, err)
var cerr *connect.Error
require.ErrorAs(t, err, &cerr)
require.Equal(t, connect.CodeNotFound, cerr.Code())
}
func TestFollowSymlink_CyclicSymlink(t *testing.T) {
t.Parallel()
dir := t.TempDir()
a := filepath.Join(dir, "a")
b := filepath.Join(dir, "b")
require.NoError(t, os.MkdirAll(a, 0o755))
require.NoError(t, os.MkdirAll(b, 0o755))
// Create a twonode loop: a/loop → b/loop, b/loop → a/loop.
require.NoError(t, os.Symlink(filepath.Join(b, "loop"), filepath.Join(a, "loop")))
require.NoError(t, os.Symlink(filepath.Join(a, "loop"), filepath.Join(b, "loop")))
_, err := followSymlink(filepath.Join(a, "loop"))
require.Error(t, err)
var cerr *connect.Error
require.ErrorAs(t, err, &cerr)
require.Equal(t, connect.CodeFailedPrecondition, cerr.Code())
require.Contains(t, cerr.Message(), "cyclic")
}
func TestCheckIfDirectory(t *testing.T) {
t.Parallel()
dir := t.TempDir()
require.NoError(t, checkIfDirectory(dir))
file := filepath.Join(dir, "file.txt")
require.NoError(t, os.WriteFile(file, []byte("hello"), 0o644))
err := checkIfDirectory(file)
require.Error(t, err)
var cerr *connect.Error
require.ErrorAs(t, err, &cerr)
require.Equal(t, connect.CodeInvalidArgument, cerr.Code())
}
func TestWalkDir_Depth(t *testing.T) {
t.Parallel()
root := t.TempDir()
sub := filepath.Join(root, "sub")
subsub := filepath.Join(sub, "subsub")
require.NoError(t, os.MkdirAll(subsub, 0o755))
entries, err := walkDir(root, root, 1)
require.NoError(t, err)
// Collect the names for easier assertions.
names := make([]string, 0, len(entries))
for _, e := range entries {
names = append(names, e.GetName())
}
require.Contains(t, names, "sub")
require.NotContains(t, names, "subsub", "entries beyond depth should be excluded")
}
func TestWalkDir_Error(t *testing.T) {
t.Parallel()
_, err := walkDir("/does/not/exist", "/does/not/exist", 1)
require.Error(t, err)
var cerr *connect.Error
require.ErrorAs(t, err, &cerr)
require.Equal(t, connect.CodeInternal, cerr.Code())
}

View File

@ -0,0 +1,60 @@
// SPDX-License-Identifier: Apache-2.0
package filesystem
import (
"context"
"fmt"
"os"
"path/filepath"
"connectrpc.com/connect"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
)
func (s Service) Move(ctx context.Context, req *connect.Request[rpc.MoveRequest]) (*connect.Response[rpc.MoveResponse], error) {
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
if err != nil {
return nil, err
}
source, err := permissions.ExpandAndResolve(req.Msg.GetSource(), u, s.defaults.Workdir)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
destination, err := permissions.ExpandAndResolve(req.Msg.GetDestination(), u, s.defaults.Workdir)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
uid, gid, userErr := permissions.GetUserIdInts(u)
if userErr != nil {
return nil, connect.NewError(connect.CodeInternal, userErr)
}
userErr = permissions.EnsureDirs(filepath.Dir(destination), uid, gid)
if userErr != nil {
return nil, connect.NewError(connect.CodeInternal, userErr)
}
err = os.Rename(source, destination)
if err != nil {
if os.IsNotExist(err) {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("source file not found: %w", err))
}
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error renaming: %w", err))
}
entry, err := entryInfo(destination)
if err != nil {
return nil, err
}
return connect.NewResponse(&rpc.MoveResponse{
Entry: entry,
}), nil
}

View File

@ -0,0 +1,366 @@
// SPDX-License-Identifier: Apache-2.0
package filesystem
import (
"errors"
"fmt"
"os"
"os/user"
"path/filepath"
"testing"
"connectrpc.com/authn"
"connectrpc.com/connect"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
)
func TestMove(t *testing.T) {
t.Parallel()
// Setup temp root and user
root := t.TempDir()
u, err := user.Current()
require.NoError(t, err)
// Setup source and destination directories
sourceDir := filepath.Join(root, "source")
destDir := filepath.Join(root, "destination")
require.NoError(t, os.MkdirAll(sourceDir, 0o755))
require.NoError(t, os.MkdirAll(destDir, 0o755))
// Create a test file to move
sourceFile := filepath.Join(sourceDir, "test-file.txt")
testContent := []byte("Hello, World!")
require.NoError(t, os.WriteFile(sourceFile, testContent, 0o644))
// Destination file path
destFile := filepath.Join(destDir, "test-file.txt")
// Service instance
svc := mockService()
// Call the Move function
ctx := authn.SetInfo(t.Context(), u)
req := connect.NewRequest(&filesystem.MoveRequest{
Source: sourceFile,
Destination: destFile,
})
resp, err := svc.Move(ctx, req)
// Verify the move was successful
require.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, destFile, resp.Msg.GetEntry().GetPath())
// Verify the file exists at the destination
_, err = os.Stat(destFile)
require.NoError(t, err)
// Verify the file no longer exists at the source
_, err = os.Stat(sourceFile)
assert.True(t, os.IsNotExist(err))
// Verify the content of the moved file
content, err := os.ReadFile(destFile)
require.NoError(t, err)
assert.Equal(t, testContent, content)
}
func TestMoveDirectory(t *testing.T) {
t.Parallel()
// Setup temp root and user
root := t.TempDir()
u, err := user.Current()
require.NoError(t, err)
// Setup source and destination directories
sourceParent := filepath.Join(root, "source-parent")
destParent := filepath.Join(root, "dest-parent")
require.NoError(t, os.MkdirAll(sourceParent, 0o755))
require.NoError(t, os.MkdirAll(destParent, 0o755))
// Create a test directory with files to move
sourceDir := filepath.Join(sourceParent, "test-dir")
require.NoError(t, os.MkdirAll(filepath.Join(sourceDir, "subdir"), 0o755))
// Create some files in the directory
file1 := filepath.Join(sourceDir, "file1.txt")
file2 := filepath.Join(sourceDir, "subdir", "file2.txt")
require.NoError(t, os.WriteFile(file1, []byte("File 1 content"), 0o644))
require.NoError(t, os.WriteFile(file2, []byte("File 2 content"), 0o644))
// Destination directory path
destDir := filepath.Join(destParent, "test-dir")
// Service instance
svc := mockService()
// Call the Move function
ctx := authn.SetInfo(t.Context(), u)
req := connect.NewRequest(&filesystem.MoveRequest{
Source: sourceDir,
Destination: destDir,
})
resp, err := svc.Move(ctx, req)
// Verify the move was successful
require.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, destDir, resp.Msg.GetEntry().GetPath())
// Verify the directory exists at the destination
_, err = os.Stat(destDir)
require.NoError(t, err)
// Verify the files exist at the destination
destFile1 := filepath.Join(destDir, "file1.txt")
destFile2 := filepath.Join(destDir, "subdir", "file2.txt")
_, err = os.Stat(destFile1)
require.NoError(t, err)
_, err = os.Stat(destFile2)
require.NoError(t, err)
// Verify the directory no longer exists at the source
_, err = os.Stat(sourceDir)
assert.True(t, os.IsNotExist(err))
// Verify the content of the moved files
content1, err := os.ReadFile(destFile1)
require.NoError(t, err)
assert.Equal(t, []byte("File 1 content"), content1)
content2, err := os.ReadFile(destFile2)
require.NoError(t, err)
assert.Equal(t, []byte("File 2 content"), content2)
}
func TestMoveNonExistingFile(t *testing.T) {
t.Parallel()
// Setup temp root and user
root := t.TempDir()
u, err := user.Current()
require.NoError(t, err)
// Setup destination directory
destDir := filepath.Join(root, "destination")
require.NoError(t, os.MkdirAll(destDir, 0o755))
// Non-existing source file
sourceFile := filepath.Join(root, "non-existing-file.txt")
// Destination file path
destFile := filepath.Join(destDir, "moved-file.txt")
// Service instance
svc := mockService()
// Call the Move function
ctx := authn.SetInfo(t.Context(), u)
req := connect.NewRequest(&filesystem.MoveRequest{
Source: sourceFile,
Destination: destFile,
})
_, err = svc.Move(ctx, req)
// Verify the correct error is returned
require.Error(t, err)
var connectErr *connect.Error
ok := errors.As(err, &connectErr)
assert.True(t, ok, "expected error to be of type *connect.Error")
assert.Equal(t, connect.CodeNotFound, connectErr.Code())
assert.Contains(t, connectErr.Message(), "source file not found")
}
func TestMoveRelativePath(t *testing.T) {
t.Parallel()
// Setup user
u, err := user.Current()
require.NoError(t, err)
// Setup directory structure with unique name to avoid conflicts
testRelativePath := fmt.Sprintf("test-move-%s", uuid.New())
testFolderPath := filepath.Join(u.HomeDir, testRelativePath)
require.NoError(t, os.MkdirAll(testFolderPath, 0o755))
// Create a test file to move
sourceFile := filepath.Join(testFolderPath, "source-file.txt")
testContent := []byte("Hello from relative path!")
require.NoError(t, os.WriteFile(sourceFile, testContent, 0o644))
// Destination file path (also relative)
destRelativePath := fmt.Sprintf("test-move-dest-%s", uuid.New())
destFolderPath := filepath.Join(u.HomeDir, destRelativePath)
require.NoError(t, os.MkdirAll(destFolderPath, 0o755))
destFile := filepath.Join(destFolderPath, "moved-file.txt")
// Service instance
svc := mockService()
// Call the Move function with relative paths
ctx := authn.SetInfo(t.Context(), u)
req := connect.NewRequest(&filesystem.MoveRequest{
Source: filepath.Join(testRelativePath, "source-file.txt"), // Relative path
Destination: filepath.Join(destRelativePath, "moved-file.txt"), // Relative path
})
resp, err := svc.Move(ctx, req)
// Verify the move was successful
require.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, destFile, resp.Msg.GetEntry().GetPath())
// Verify the file exists at the destination
_, err = os.Stat(destFile)
require.NoError(t, err)
// Verify the file no longer exists at the source
_, err = os.Stat(sourceFile)
assert.True(t, os.IsNotExist(err))
// Verify the content of the moved file
content, err := os.ReadFile(destFile)
require.NoError(t, err)
assert.Equal(t, testContent, content)
// Clean up
os.RemoveAll(testFolderPath)
os.RemoveAll(destFolderPath)
}
func TestMove_Symlinks(t *testing.T) { //nolint:tparallel // this test cannot be executed in parallel
root := t.TempDir()
u, err := user.Current()
require.NoError(t, err)
ctx := authn.SetInfo(t.Context(), u)
// Setup source and destination directories
sourceRoot := filepath.Join(root, "source")
destRoot := filepath.Join(root, "destination")
require.NoError(t, os.MkdirAll(sourceRoot, 0o755))
require.NoError(t, os.MkdirAll(destRoot, 0o755))
// 1. Prepare a real directory + file that a symlink will point to
realDir := filepath.Join(sourceRoot, "real-dir")
require.NoError(t, os.MkdirAll(realDir, 0o755))
filePath := filepath.Join(realDir, "file.txt")
require.NoError(t, os.WriteFile(filePath, []byte("hello via symlink"), 0o644))
// 2. Prepare a standalone real file (points-to-file scenario)
realFile := filepath.Join(sourceRoot, "real-file.txt")
require.NoError(t, os.WriteFile(realFile, []byte("i am a plain file"), 0o644))
// 3. Create symlinks
linkToDir := filepath.Join(sourceRoot, "link-dir") // → directory
linkToFile := filepath.Join(sourceRoot, "link-file") // → file
require.NoError(t, os.Symlink(realDir, linkToDir))
require.NoError(t, os.Symlink(realFile, linkToFile))
svc := mockService()
t.Run("move symlink to directory", func(t *testing.T) {
t.Parallel()
destPath := filepath.Join(destRoot, "moved-link-dir")
req := connect.NewRequest(&filesystem.MoveRequest{
Source: linkToDir,
Destination: destPath,
})
resp, err := svc.Move(ctx, req)
require.NoError(t, err)
assert.Equal(t, destPath, resp.Msg.GetEntry().GetPath())
// Verify the symlink was moved
_, err = os.Stat(destPath)
require.NoError(t, err)
// Verify it's still a symlink
info, err := os.Lstat(destPath)
require.NoError(t, err)
assert.NotEqual(t, 0, info.Mode()&os.ModeSymlink, "expected a symlink")
// Verify the symlink target is still correct
target, err := os.Readlink(destPath)
require.NoError(t, err)
assert.Equal(t, realDir, target)
// Verify the original symlink is gone
_, err = os.Stat(linkToDir)
assert.True(t, os.IsNotExist(err))
// Verify the real directory still exists
_, err = os.Stat(realDir)
assert.NoError(t, err)
})
t.Run("move symlink to file", func(t *testing.T) { //nolint:paralleltest
destPath := filepath.Join(destRoot, "moved-link-file")
req := connect.NewRequest(&filesystem.MoveRequest{
Source: linkToFile,
Destination: destPath,
})
resp, err := svc.Move(ctx, req)
require.NoError(t, err)
assert.Equal(t, destPath, resp.Msg.GetEntry().GetPath())
// Verify the symlink was moved
_, err = os.Stat(destPath)
require.NoError(t, err)
// Verify it's still a symlink
info, err := os.Lstat(destPath)
require.NoError(t, err)
assert.NotEqual(t, 0, info.Mode()&os.ModeSymlink, "expected a symlink")
// Verify the symlink target is still correct
target, err := os.Readlink(destPath)
require.NoError(t, err)
assert.Equal(t, realFile, target)
// Verify the original symlink is gone
_, err = os.Stat(linkToFile)
assert.True(t, os.IsNotExist(err))
// Verify the real file still exists
_, err = os.Stat(realFile)
assert.NoError(t, err)
})
t.Run("move real file that is target of symlink", func(t *testing.T) {
t.Parallel()
// Create a new symlink to the real file
newLinkToFile := filepath.Join(sourceRoot, "new-link-file")
require.NoError(t, os.Symlink(realFile, newLinkToFile))
destPath := filepath.Join(destRoot, "moved-real-file.txt")
req := connect.NewRequest(&filesystem.MoveRequest{
Source: realFile,
Destination: destPath,
})
resp, err := svc.Move(ctx, req)
require.NoError(t, err)
assert.Equal(t, destPath, resp.Msg.GetEntry().GetPath())
// Verify the real file was moved
_, err = os.Stat(destPath)
require.NoError(t, err)
// Verify the original file is gone
_, err = os.Stat(realFile)
assert.True(t, os.IsNotExist(err))
// Verify the symlink still exists but now points to a non-existent file
_, err = os.Stat(newLinkToFile)
require.Error(t, err, "symlink should point to non-existent file")
})
}

View File

@ -0,0 +1,33 @@
// SPDX-License-Identifier: Apache-2.0
package filesystem
import (
"context"
"fmt"
"os"
"connectrpc.com/connect"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
)
func (s Service) Remove(ctx context.Context, req *connect.Request[rpc.RemoveRequest]) (*connect.Response[rpc.RemoveResponse], error) {
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
if err != nil {
return nil, err
}
path, err := permissions.ExpandAndResolve(req.Msg.GetPath(), u, s.defaults.Workdir)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
err = os.RemoveAll(path)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error removing file or directory: %w", err))
}
return connect.NewResponse(&rpc.RemoveResponse{}), nil
}

View File

@ -0,0 +1,37 @@
// SPDX-License-Identifier: Apache-2.0
// Modifications by M/S Omukk
package filesystem
import (
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"github.com/rs/zerolog"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
spec "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem/filesystemconnect"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
type Service struct {
logger *zerolog.Logger
watchers *utils.Map[string, *FileWatcher]
defaults *execcontext.Defaults
}
func Handle(server *chi.Mux, l *zerolog.Logger, defaults *execcontext.Defaults) {
service := Service{
logger: l,
watchers: utils.NewMap[string, *FileWatcher](),
defaults: defaults,
}
interceptors := connect.WithInterceptors(
logs.NewUnaryLogInterceptor(l),
)
path, handler := spec.NewFilesystemHandler(service, interceptors)
server.Mount(path, handler)
}

View File

@ -0,0 +1,16 @@
// SPDX-License-Identifier: Apache-2.0
package filesystem
import (
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
func mockService() Service {
return Service{
defaults: &execcontext.Defaults{
EnvVars: utils.NewMap[string, string](),
},
}
}

View File

@ -0,0 +1,31 @@
// SPDX-License-Identifier: Apache-2.0
package filesystem
import (
"context"
"connectrpc.com/connect"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
)
func (s Service) Stat(ctx context.Context, req *connect.Request[rpc.StatRequest]) (*connect.Response[rpc.StatResponse], error) {
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
if err != nil {
return nil, err
}
path, err := permissions.ExpandAndResolve(req.Msg.GetPath(), u, s.defaults.Workdir)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
entry, err := entryInfo(path)
if err != nil {
return nil, err
}
return connect.NewResponse(&rpc.StatResponse{Entry: entry}), nil
}

View File

@ -0,0 +1,116 @@
// SPDX-License-Identifier: Apache-2.0
package filesystem
import (
"context"
"os"
"os/user"
"path/filepath"
"testing"
"connectrpc.com/authn"
"connectrpc.com/connect"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
)
func TestStat(t *testing.T) {
t.Parallel()
// Setup temp root and user
root := t.TempDir()
// Get the actual path to the temp directory (symlinks can cause issues)
root, err := filepath.EvalSymlinks(root)
require.NoError(t, err)
u, err := user.Current()
require.NoError(t, err)
group, err := user.LookupGroupId(u.Gid)
require.NoError(t, err)
// Setup directory structure
testFolder := filepath.Join(root, "test")
err = os.MkdirAll(testFolder, 0o755)
require.NoError(t, err)
testFile := filepath.Join(testFolder, "file.txt")
err = os.WriteFile(testFile, []byte("Hello, World!"), 0o644)
require.NoError(t, err)
linkedFile := filepath.Join(testFolder, "linked-file.txt")
err = os.Symlink(testFile, linkedFile)
require.NoError(t, err)
// Service instance
svc := mockService()
// Helper to inject user into context
injectUser := func(ctx context.Context, u *user.User) context.Context {
return authn.SetInfo(ctx, u)
}
tests := []struct {
name string
path string
}{
{
name: "Stat file directory",
path: testFile,
},
{
name: "Stat symlink to file",
path: linkedFile,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := injectUser(t.Context(), u)
req := connect.NewRequest(&filesystem.StatRequest{
Path: tt.path,
})
resp, err := svc.Stat(ctx, req)
require.NoError(t, err)
require.NotEmpty(t, resp.Msg)
require.NotNil(t, resp.Msg.GetEntry())
assert.Equal(t, tt.path, resp.Msg.GetEntry().GetPath())
assert.Equal(t, filesystem.FileType_FILE_TYPE_FILE, resp.Msg.GetEntry().GetType())
assert.Equal(t, u.Username, resp.Msg.GetEntry().GetOwner())
assert.Equal(t, group.Name, resp.Msg.GetEntry().GetGroup())
assert.Equal(t, uint32(0o644), resp.Msg.GetEntry().GetMode())
if tt.path == linkedFile {
require.NotNil(t, resp.Msg.GetEntry().GetSymlinkTarget())
assert.Equal(t, testFile, resp.Msg.GetEntry().GetSymlinkTarget())
} else {
assert.Empty(t, resp.Msg.GetEntry().GetSymlinkTarget())
}
})
}
}
func TestStatMissingPathReturnsNotFound(t *testing.T) {
t.Parallel()
u, err := user.Current()
require.NoError(t, err)
svc := mockService()
ctx := authn.SetInfo(t.Context(), u)
req := connect.NewRequest(&filesystem.StatRequest{
Path: filepath.Join(t.TempDir(), "missing.txt"),
})
_, err = svc.Stat(ctx, req)
require.Error(t, err)
var connectErr *connect.Error
require.ErrorAs(t, err, &connectErr)
assert.Equal(t, connect.CodeNotFound, connectErr.Code())
}

View File

@ -0,0 +1,109 @@
// SPDX-License-Identifier: Apache-2.0
package filesystem
import (
"fmt"
"os"
"os/user"
"syscall"
"time"
"connectrpc.com/connect"
"google.golang.org/protobuf/types/known/timestamppb"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/filesystem"
)
// Filesystem magic numbers from Linux kernel (include/uapi/linux/magic.h)
const (
nfsSuperMagic = 0x6969
cifsMagic = 0xFF534D42
smbSuperMagic = 0x517B
smb2MagicNumber = 0xFE534D42
fuseSuperMagic = 0x65735546
)
// IsPathOnNetworkMount checks if the given path is on a network filesystem mount.
// Returns true if the path is on NFS, CIFS, SMB, or FUSE filesystem.
func IsPathOnNetworkMount(path string) (bool, error) {
var statfs syscall.Statfs_t
if err := syscall.Statfs(path, &statfs); err != nil {
return false, fmt.Errorf("failed to statfs %s: %w", path, err)
}
switch statfs.Type {
case nfsSuperMagic, cifsMagic, smbSuperMagic, smb2MagicNumber, fuseSuperMagic:
return true, nil
default:
return false, nil
}
}
func entryInfo(path string) (*rpc.EntryInfo, error) {
info, err := filesystem.GetEntryFromPath(path)
if err != nil {
if os.IsNotExist(err) {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("file not found: %w", err))
}
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error getting file info: %w", err))
}
owner, group := getFileOwnership(info)
return &rpc.EntryInfo{
Name: info.Name,
Type: getEntryType(info.Type),
Path: info.Path,
Size: info.Size,
Mode: uint32(info.Mode),
Permissions: info.Permissions,
Owner: owner,
Group: group,
ModifiedTime: toTimestamp(info.ModifiedTime),
SymlinkTarget: info.SymlinkTarget,
}, nil
}
func toTimestamp(time time.Time) *timestamppb.Timestamp {
if time.IsZero() {
return nil
}
return timestamppb.New(time)
}
// getFileOwnership returns the owner and group names for a file.
// If the lookup fails, it returns the numeric UID and GID as strings.
func getFileOwnership(fileInfo filesystem.EntryInfo) (owner, group string) {
// Look up username
owner = fmt.Sprintf("%d", fileInfo.UID)
if u, err := user.LookupId(owner); err == nil {
owner = u.Username
}
// Look up group name
group = fmt.Sprintf("%d", fileInfo.GID)
if g, err := user.LookupGroupId(group); err == nil {
group = g.Name
}
return owner, group
}
// getEntryType determines the type of file entry based on its mode and path.
// If the file is a symlink, it follows the symlink to determine the actual type.
func getEntryType(fileType filesystem.FileType) rpc.FileType {
switch fileType {
case filesystem.FileFileType:
return rpc.FileType_FILE_TYPE_FILE
case filesystem.DirectoryFileType:
return rpc.FileType_FILE_TYPE_DIRECTORY
case filesystem.SymlinkFileType:
return rpc.FileType_FILE_TYPE_SYMLINK
default:
return rpc.FileType_FILE_TYPE_UNSPECIFIED
}
}

View File

@ -0,0 +1,151 @@
// SPDX-License-Identifier: Apache-2.0
package filesystem
import (
"context"
"os/exec"
osuser "os/user"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
fsmodel "git.omukk.dev/wrenn/sandbox/envd/internal/shared/filesystem"
)
func TestIsPathOnNetworkMount(t *testing.T) {
t.Parallel()
// Test with a regular directory (should not be on network mount)
tempDir := t.TempDir()
isNetwork, err := IsPathOnNetworkMount(tempDir)
require.NoError(t, err)
assert.False(t, isNetwork, "temp directory should not be on a network mount")
}
func TestIsPathOnNetworkMount_FuseMount(t *testing.T) {
t.Parallel()
// Require bindfs to be available
_, err := exec.LookPath("bindfs")
require.NoError(t, err, "bindfs must be installed for this test")
// Require fusermount to be available (needed for unmounting)
_, err = exec.LookPath("fusermount")
require.NoError(t, err, "fusermount must be installed for this test")
// Create source and mount directories
sourceDir := t.TempDir()
mountDir := t.TempDir()
// Mount sourceDir onto mountDir using bindfs (FUSE)
ctx := context.Background()
cmd := exec.CommandContext(ctx, "bindfs", sourceDir, mountDir)
require.NoError(t, cmd.Run(), "failed to mount bindfs")
// Ensure we unmount on cleanup
t.Cleanup(func() {
_ = exec.CommandContext(context.Background(), "fusermount", "-u", mountDir).Run()
})
// Test that the FUSE mount is detected
isNetwork, err := IsPathOnNetworkMount(mountDir)
require.NoError(t, err)
assert.True(t, isNetwork, "FUSE mount should be detected as network filesystem")
// Test that the source directory is NOT detected as network mount
isNetworkSource, err := IsPathOnNetworkMount(sourceDir)
require.NoError(t, err)
assert.False(t, isNetworkSource, "source directory should not be detected as network filesystem")
}
func TestGetFileOwnership_CurrentUser(t *testing.T) {
t.Parallel()
t.Run("current user", func(t *testing.T) {
t.Parallel()
// Get current user running the tests
cur, err := osuser.Current()
if err != nil {
t.Skipf("unable to determine current user: %v", err)
}
// Determine expected owner/group using the same lookup logic
expectedOwner := cur.Uid
if u, err := osuser.LookupId(cur.Uid); err == nil {
expectedOwner = u.Username
}
expectedGroup := cur.Gid
if g, err := osuser.LookupGroupId(cur.Gid); err == nil {
expectedGroup = g.Name
}
// Parse UID/GID strings to uint32 for EntryInfo
uid64, err := strconv.ParseUint(cur.Uid, 10, 32)
require.NoError(t, err)
gid64, err := strconv.ParseUint(cur.Gid, 10, 32)
require.NoError(t, err)
// Build a minimal EntryInfo with current UID/GID
info := fsmodel.EntryInfo{ // from shared pkg
UID: uint32(uid64),
GID: uint32(gid64),
}
owner, group := getFileOwnership(info)
assert.Equal(t, expectedOwner, owner)
assert.Equal(t, expectedGroup, group)
})
t.Run("no user", func(t *testing.T) {
t.Parallel()
// Find a UID that does not exist on this system
var unknownUIDStr string
for i := 60001; i < 70000; i++ { // search a high range typically unused
idStr := strconv.Itoa(i)
if _, err := osuser.LookupId(idStr); err != nil {
unknownUIDStr = idStr
break
}
}
if unknownUIDStr == "" {
t.Skip("could not find a non-existent UID in the probed range")
}
// Find a GID that does not exist on this system
var unknownGIDStr string
for i := 60001; i < 70000; i++ { // search a high range typically unused
idStr := strconv.Itoa(i)
if _, err := osuser.LookupGroupId(idStr); err != nil {
unknownGIDStr = idStr
break
}
}
if unknownGIDStr == "" {
t.Skip("could not find a non-existent GID in the probed range")
}
// Parse to uint32 for EntryInfo construction
uid64, err := strconv.ParseUint(unknownUIDStr, 10, 32)
require.NoError(t, err)
gid64, err := strconv.ParseUint(unknownGIDStr, 10, 32)
require.NoError(t, err)
info := fsmodel.EntryInfo{
UID: uint32(uid64),
GID: uint32(gid64),
}
owner, group := getFileOwnership(info)
// Expect numeric fallbacks because lookups should fail for unknown IDs
assert.Equal(t, unknownUIDStr, owner)
assert.Equal(t, unknownGIDStr, group)
})
}

View File

@ -0,0 +1,161 @@
// SPDX-License-Identifier: Apache-2.0
package filesystem
import (
"context"
"fmt"
"os"
"path/filepath"
"connectrpc.com/connect"
"github.com/e2b-dev/fsnotify"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
func (s Service) WatchDir(ctx context.Context, req *connect.Request[rpc.WatchDirRequest], stream *connect.ServerStream[rpc.WatchDirResponse]) error {
return logs.LogServerStreamWithoutEvents(ctx, s.logger, req, stream, s.watchHandler)
}
func (s Service) watchHandler(ctx context.Context, req *connect.Request[rpc.WatchDirRequest], stream *connect.ServerStream[rpc.WatchDirResponse]) error {
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
if err != nil {
return err
}
watchPath, err := permissions.ExpandAndResolve(req.Msg.GetPath(), u, s.defaults.Workdir)
if err != nil {
return connect.NewError(connect.CodeInvalidArgument, err)
}
info, err := os.Stat(watchPath)
if err != nil {
if os.IsNotExist(err) {
return connect.NewError(connect.CodeNotFound, fmt.Errorf("path %s not found: %w", watchPath, err))
}
return connect.NewError(connect.CodeInternal, fmt.Errorf("error statting path %s: %w", watchPath, err))
}
if !info.IsDir() {
return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("path %s not a directory: %w", watchPath, err))
}
// Check if path is on a network filesystem mount
isNetworkMount, err := IsPathOnNetworkMount(watchPath)
if err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("error checking mount status: %w", err))
}
if isNetworkMount {
return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("cannot watch path on network filesystem: %s", watchPath))
}
w, err := fsnotify.NewWatcher()
if err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("error creating watcher: %w", err))
}
defer w.Close()
err = w.Add(utils.FsnotifyPath(watchPath, req.Msg.GetRecursive()))
if err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("error adding path %s to watcher: %w", watchPath, err))
}
err = stream.Send(&rpc.WatchDirResponse{
Event: &rpc.WatchDirResponse_Start{
Start: &rpc.WatchDirResponse_StartEvent{},
},
})
if err != nil {
return connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending start event: %w", err))
}
keepaliveTicker, resetKeepalive := permissions.GetKeepAliveTicker(req)
defer keepaliveTicker.Stop()
for {
select {
case <-keepaliveTicker.C:
streamErr := stream.Send(&rpc.WatchDirResponse{
Event: &rpc.WatchDirResponse_Keepalive{
Keepalive: &rpc.WatchDirResponse_KeepAlive{},
},
})
if streamErr != nil {
return connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending keepalive: %w", streamErr))
}
case <-ctx.Done():
return ctx.Err()
case chErr, ok := <-w.Errors:
if !ok {
return connect.NewError(connect.CodeInternal, fmt.Errorf("watcher error channel closed"))
}
return connect.NewError(connect.CodeInternal, fmt.Errorf("watcher error: %w", chErr))
case e, ok := <-w.Events:
if !ok {
return connect.NewError(connect.CodeInternal, fmt.Errorf("watcher event channel closed"))
}
// One event can have multiple operations.
ops := []rpc.EventType{}
if fsnotify.Create.Has(e.Op) {
ops = append(ops, rpc.EventType_EVENT_TYPE_CREATE)
}
if fsnotify.Rename.Has(e.Op) {
ops = append(ops, rpc.EventType_EVENT_TYPE_RENAME)
}
if fsnotify.Chmod.Has(e.Op) {
ops = append(ops, rpc.EventType_EVENT_TYPE_CHMOD)
}
if fsnotify.Write.Has(e.Op) {
ops = append(ops, rpc.EventType_EVENT_TYPE_WRITE)
}
if fsnotify.Remove.Has(e.Op) {
ops = append(ops, rpc.EventType_EVENT_TYPE_REMOVE)
}
for _, op := range ops {
name, nameErr := filepath.Rel(watchPath, e.Name)
if nameErr != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("error getting relative path: %w", nameErr))
}
filesystemEvent := &rpc.WatchDirResponse_Filesystem{
Filesystem: &rpc.FilesystemEvent{
Name: name,
Type: op,
},
}
event := &rpc.WatchDirResponse{
Event: filesystemEvent,
}
streamErr := stream.Send(event)
s.logger.
Debug().
Str("event_type", "filesystem_event").
Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).
Interface("filesystem_event", event).
Msg("Streaming filesystem event")
if streamErr != nil {
return connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending filesystem event: %w", streamErr))
}
resetKeepalive()
}
}
}
}

View File

@ -0,0 +1,226 @@
// SPDX-License-Identifier: Apache-2.0
package filesystem
import (
"context"
"fmt"
"os"
"path/filepath"
"sync"
"connectrpc.com/connect"
"github.com/e2b-dev/fsnotify"
"github.com/rs/zerolog"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/id"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
type FileWatcher struct {
watcher *fsnotify.Watcher
Events []*rpc.FilesystemEvent
cancel func()
Error error
Lock sync.Mutex
}
func CreateFileWatcher(ctx context.Context, watchPath string, recursive bool, operationID string, logger *zerolog.Logger) (*FileWatcher, error) {
w, err := fsnotify.NewWatcher()
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating watcher: %w", err))
}
// We don't want to cancel the context when the request is finished
ctx, cancel := context.WithCancel(context.WithoutCancel(ctx))
err = w.Add(utils.FsnotifyPath(watchPath, recursive))
if err != nil {
_ = w.Close()
cancel()
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error adding path %s to watcher: %w", watchPath, err))
}
fw := &FileWatcher{
watcher: w,
cancel: cancel,
Events: []*rpc.FilesystemEvent{},
Error: nil,
}
go func() {
for {
select {
case <-ctx.Done():
return
case chErr, ok := <-w.Errors:
if !ok {
fw.Error = connect.NewError(connect.CodeInternal, fmt.Errorf("watcher error channel closed"))
return
}
fw.Error = connect.NewError(connect.CodeInternal, fmt.Errorf("watcher error: %w", chErr))
return
case e, ok := <-w.Events:
if !ok {
fw.Error = connect.NewError(connect.CodeInternal, fmt.Errorf("watcher event channel closed"))
return
}
// One event can have multiple operations.
ops := []rpc.EventType{}
if fsnotify.Create.Has(e.Op) {
ops = append(ops, rpc.EventType_EVENT_TYPE_CREATE)
}
if fsnotify.Rename.Has(e.Op) {
ops = append(ops, rpc.EventType_EVENT_TYPE_RENAME)
}
if fsnotify.Chmod.Has(e.Op) {
ops = append(ops, rpc.EventType_EVENT_TYPE_CHMOD)
}
if fsnotify.Write.Has(e.Op) {
ops = append(ops, rpc.EventType_EVENT_TYPE_WRITE)
}
if fsnotify.Remove.Has(e.Op) {
ops = append(ops, rpc.EventType_EVENT_TYPE_REMOVE)
}
for _, op := range ops {
name, nameErr := filepath.Rel(watchPath, e.Name)
if nameErr != nil {
fw.Error = connect.NewError(connect.CodeInternal, fmt.Errorf("error getting relative path: %w", nameErr))
return
}
fw.Lock.Lock()
fw.Events = append(fw.Events, &rpc.FilesystemEvent{
Name: name,
Type: op,
})
fw.Lock.Unlock()
// these are only used for logging
filesystemEvent := &rpc.WatchDirResponse_Filesystem{
Filesystem: &rpc.FilesystemEvent{
Name: name,
Type: op,
},
}
event := &rpc.WatchDirResponse{
Event: filesystemEvent,
}
logger.
Debug().
Str("event_type", "filesystem_event").
Str(string(logs.OperationIDKey), operationID).
Interface("filesystem_event", event).
Msg("Streaming filesystem event")
}
}
}
}()
return fw, nil
}
func (fw *FileWatcher) Close() {
_ = fw.watcher.Close()
fw.cancel()
}
func (s Service) CreateWatcher(ctx context.Context, req *connect.Request[rpc.CreateWatcherRequest]) (*connect.Response[rpc.CreateWatcherResponse], error) {
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
if err != nil {
return nil, err
}
watchPath, err := permissions.ExpandAndResolve(req.Msg.GetPath(), u, s.defaults.Workdir)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
info, err := os.Stat(watchPath)
if err != nil {
if os.IsNotExist(err) {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("path %s not found: %w", watchPath, err))
}
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error statting path %s: %w", watchPath, err))
}
if !info.IsDir() {
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("path %s not a directory: %w", watchPath, err))
}
// Check if path is on a network filesystem mount
isNetworkMount, err := IsPathOnNetworkMount(watchPath)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error checking mount status: %w", err))
}
if isNetworkMount {
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("cannot watch path on network filesystem: %s", watchPath))
}
watcherId := "w" + id.Generate()
w, err := CreateFileWatcher(ctx, watchPath, req.Msg.GetRecursive(), watcherId, s.logger)
if err != nil {
return nil, err
}
s.watchers.Store(watcherId, w)
return connect.NewResponse(&rpc.CreateWatcherResponse{
WatcherId: watcherId,
}), nil
}
func (s Service) GetWatcherEvents(_ context.Context, req *connect.Request[rpc.GetWatcherEventsRequest]) (*connect.Response[rpc.GetWatcherEventsResponse], error) {
watcherId := req.Msg.GetWatcherId()
w, ok := s.watchers.Load(watcherId)
if !ok {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("watcher with id %s not found", watcherId))
}
if w.Error != nil {
return nil, w.Error
}
w.Lock.Lock()
defer w.Lock.Unlock()
events := w.Events
w.Events = []*rpc.FilesystemEvent{}
return connect.NewResponse(&rpc.GetWatcherEventsResponse{
Events: events,
}), nil
}
func (s Service) RemoveWatcher(_ context.Context, req *connect.Request[rpc.RemoveWatcherRequest]) (*connect.Response[rpc.RemoveWatcherResponse], error) {
watcherId := req.Msg.GetWatcherId()
w, ok := s.watchers.Load(watcherId)
if !ok {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("watcher with id %s not found", watcherId))
}
w.Close()
s.watchers.Delete(watcherId)
return connect.NewResponse(&rpc.RemoveWatcherResponse{}), nil
}

View File

@ -0,0 +1,128 @@
// SPDX-License-Identifier: Apache-2.0
package process
import (
"context"
"errors"
"fmt"
"connectrpc.com/connect"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
)
func (s *Service) Connect(ctx context.Context, req *connect.Request[rpc.ConnectRequest], stream *connect.ServerStream[rpc.ConnectResponse]) error {
return logs.LogServerStreamWithoutEvents(ctx, s.logger, req, stream, s.handleConnect)
}
func (s *Service) handleConnect(ctx context.Context, req *connect.Request[rpc.ConnectRequest], stream *connect.ServerStream[rpc.ConnectResponse]) error {
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
proc, err := s.getProcess(req.Msg.GetProcess())
if err != nil {
return err
}
exitChan := make(chan struct{})
data, dataCancel := proc.DataEvent.Fork()
defer dataCancel()
end, endCancel := proc.EndEvent.Fork()
defer endCancel()
streamErr := stream.Send(&rpc.ConnectResponse{
Event: &rpc.ProcessEvent{
Event: &rpc.ProcessEvent_Start{
Start: &rpc.ProcessEvent_StartEvent{
Pid: proc.Pid(),
},
},
},
})
if streamErr != nil {
return connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending start event: %w", streamErr))
}
go func() {
defer close(exitChan)
keepaliveTicker, resetKeepalive := permissions.GetKeepAliveTicker(req)
defer keepaliveTicker.Stop()
dataLoop:
for {
select {
case <-keepaliveTicker.C:
streamErr := stream.Send(&rpc.ConnectResponse{
Event: &rpc.ProcessEvent{
Event: &rpc.ProcessEvent_Keepalive{
Keepalive: &rpc.ProcessEvent_KeepAlive{},
},
},
})
if streamErr != nil {
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending keepalive: %w", streamErr)))
return
}
case <-ctx.Done():
cancel(ctx.Err())
return
case event, ok := <-data:
if !ok {
break dataLoop
}
streamErr := stream.Send(&rpc.ConnectResponse{
Event: &rpc.ProcessEvent{
Event: &event,
},
})
if streamErr != nil {
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending data event: %w", streamErr)))
return
}
resetKeepalive()
}
}
select {
case <-ctx.Done():
cancel(ctx.Err())
return
case event, ok := <-end:
if !ok {
cancel(connect.NewError(connect.CodeUnknown, errors.New("end event channel closed before sending end event")))
return
}
streamErr := stream.Send(&rpc.ConnectResponse{
Event: &rpc.ProcessEvent{
Event: &event,
},
})
if streamErr != nil {
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending end event: %w", streamErr)))
return
}
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-exitChan:
return nil
}
}

View File

@ -0,0 +1,480 @@
// SPDX-License-Identifier: Apache-2.0
package handler
import (
"context"
"errors"
"fmt"
"io"
"os"
"os/exec"
"os/user"
"strconv"
"strings"
"sync"
"syscall"
"connectrpc.com/connect"
"github.com/creack/pty"
"github.com/rs/zerolog"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/cgroups"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
)
const (
defaultNice = 0
defaultOomScore = 100
outputBufferSize = 64
stdChunkSize = 2 << 14
ptyChunkSize = 2 << 13
)
type ProcessExit struct {
Error *string
Status string
Exited bool
Code int32
}
type Handler struct {
Config *rpc.ProcessConfig
logger *zerolog.Logger
Tag *string
cmd *exec.Cmd
tty *os.File
cancel context.CancelFunc
outCtx context.Context //nolint:containedctx // todo: refactor so this can be removed
outCancel context.CancelFunc
stdinMu sync.Mutex
stdin io.WriteCloser
DataEvent *MultiplexedChannel[rpc.ProcessEvent_Data]
EndEvent *MultiplexedChannel[rpc.ProcessEvent_End]
}
// This method must be called only after the process has been started
func (p *Handler) Pid() uint32 {
return uint32(p.cmd.Process.Pid)
}
// userCommand returns a human-readable representation of the user's original command,
// without the internal OOM/nice wrapper that is prepended to the actual exec.
func (p *Handler) userCommand() string {
return strings.Join(append([]string{p.Config.GetCmd()}, p.Config.GetArgs()...), " ")
}
// currentNice returns the nice value of the current process.
func currentNice() int {
prio, err := syscall.Getpriority(syscall.PRIO_PROCESS, 0)
if err != nil {
return 0
}
// Getpriority returns 20 - nice on Linux.
return 20 - prio
}
func New(
ctx context.Context,
user *user.User,
req *rpc.StartRequest,
logger *zerolog.Logger,
defaults *execcontext.Defaults,
cgroupManager cgroups.Manager,
cancel context.CancelFunc,
) (*Handler, error) {
// User command string for logging (without the internal wrapper details).
userCmd := strings.Join(append([]string{req.GetProcess().GetCmd()}, req.GetProcess().GetArgs()...), " ")
// Wrap the command in a shell that sets the OOM score and nice value before exec-ing the actual command.
// This eliminates the race window where grandchildren could inherit the parent's protected OOM score (-1000)
// or high CPU priority (nice -20) before the post-start calls had a chance to correct them.
// nice(1) applies a relative adjustment, so we compute the delta from the current (inherited) nice to the target.
niceDelta := defaultNice - currentNice()
oomWrapperScript := fmt.Sprintf(`echo %d > /proc/$$/oom_score_adj && exec /usr/bin/nice -n %d "${@}"`, defaultOomScore, niceDelta)
wrapperArgs := append([]string{"-c", oomWrapperScript, "--", req.GetProcess().GetCmd()}, req.GetProcess().GetArgs()...)
cmd := exec.CommandContext(ctx, "/bin/sh", wrapperArgs...)
uid, gid, err := permissions.GetUserIdUints(user)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, err)
}
groups := []uint32{gid}
if gids, err := user.GroupIds(); err != nil {
logger.Warn().Err(err).Str("user", user.Username).Msg("failed to get supplementary groups")
} else {
for _, g := range gids {
if parsed, err := strconv.ParseUint(g, 10, 32); err == nil {
groups = append(groups, uint32(parsed))
}
}
}
cgroupFD, ok := cgroupManager.GetFileDescriptor(getProcType(req))
cmd.SysProcAttr = &syscall.SysProcAttr{
UseCgroupFD: ok,
CgroupFD: cgroupFD,
Credential: &syscall.Credential{
Uid: uid,
Gid: gid,
Groups: groups,
},
}
resolvedPath, err := permissions.ExpandAndResolve(req.GetProcess().GetCwd(), user, defaults.Workdir)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
// Check if the cwd resolved path exists
if _, err := os.Stat(resolvedPath); errors.Is(err, os.ErrNotExist) {
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("cwd '%s' does not exist", resolvedPath))
}
cmd.Dir = resolvedPath
var formattedVars []string
// Take only 'PATH' variable from the current environment
// The 'PATH' should ideally be set in the environment
formattedVars = append(formattedVars, "PATH="+os.Getenv("PATH"))
formattedVars = append(formattedVars, "HOME="+user.HomeDir)
formattedVars = append(formattedVars, "USER="+user.Username)
formattedVars = append(formattedVars, "LOGNAME="+user.Username)
// Add the environment variables from the global environment
if defaults.EnvVars != nil {
defaults.EnvVars.Range(func(key string, value string) bool {
formattedVars = append(formattedVars, key+"="+value)
return true
})
}
// Only the last values of the env vars are used - this allows for overwriting defaults
for key, value := range req.GetProcess().GetEnvs() {
formattedVars = append(formattedVars, key+"="+value)
}
cmd.Env = formattedVars
outMultiplex := NewMultiplexedChannel[rpc.ProcessEvent_Data](outputBufferSize)
var outWg sync.WaitGroup
// Create a context for waiting for and cancelling output pipes.
// Cancellation of the process via timeout will propagate and cancel this context too.
outCtx, outCancel := context.WithCancel(ctx)
h := &Handler{
Config: req.GetProcess(),
cmd: cmd,
Tag: req.Tag,
DataEvent: outMultiplex,
cancel: cancel,
outCtx: outCtx,
outCancel: outCancel,
EndEvent: NewMultiplexedChannel[rpc.ProcessEvent_End](0),
logger: logger,
}
if req.GetPty() != nil {
// The pty should ideally start only in the Start method, but the package does not support that and we would have to code it manually.
// The output of the pty should correctly be passed though.
tty, err := pty.StartWithSize(cmd, &pty.Winsize{
Cols: uint16(req.GetPty().GetSize().GetCols()),
Rows: uint16(req.GetPty().GetSize().GetRows()),
})
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("error starting pty with command '%s' in dir '%s' with '%d' cols and '%d' rows: %w", userCmd, cmd.Dir, req.GetPty().GetSize().GetCols(), req.GetPty().GetSize().GetRows(), err))
}
outWg.Go(func() {
for {
buf := make([]byte, ptyChunkSize)
n, readErr := tty.Read(buf)
if n > 0 {
outMultiplex.Source <- rpc.ProcessEvent_Data{
Data: &rpc.ProcessEvent_DataEvent{
Output: &rpc.ProcessEvent_DataEvent_Pty{
Pty: buf[:n],
},
},
}
}
if errors.Is(readErr, io.EOF) {
break
}
if readErr != nil {
fmt.Fprintf(os.Stderr, "error reading from pty: %s\n", readErr)
break
}
}
})
h.tty = tty
} else {
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating stdout pipe for command '%s': %w", userCmd, err))
}
outWg.Go(func() {
stdoutLogs := make(chan []byte, outputBufferSize)
defer close(stdoutLogs)
stdoutLogger := logger.With().Str("event_type", "stdout").Logger()
go logs.LogBufferedDataEvents(stdoutLogs, &stdoutLogger, "data")
for {
buf := make([]byte, stdChunkSize)
n, readErr := stdout.Read(buf)
if n > 0 {
outMultiplex.Source <- rpc.ProcessEvent_Data{
Data: &rpc.ProcessEvent_DataEvent{
Output: &rpc.ProcessEvent_DataEvent_Stdout{
Stdout: buf[:n],
},
},
}
stdoutLogs <- buf[:n]
}
if errors.Is(readErr, io.EOF) {
break
}
if readErr != nil {
fmt.Fprintf(os.Stderr, "error reading from stdout: %s\n", readErr)
break
}
}
})
stderr, err := cmd.StderrPipe()
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating stderr pipe for command '%s': %w", userCmd, err))
}
outWg.Go(func() {
stderrLogs := make(chan []byte, outputBufferSize)
defer close(stderrLogs)
stderrLogger := logger.With().Str("event_type", "stderr").Logger()
go logs.LogBufferedDataEvents(stderrLogs, &stderrLogger, "data")
for {
buf := make([]byte, stdChunkSize)
n, readErr := stderr.Read(buf)
if n > 0 {
outMultiplex.Source <- rpc.ProcessEvent_Data{
Data: &rpc.ProcessEvent_DataEvent{
Output: &rpc.ProcessEvent_DataEvent_Stderr{
Stderr: buf[:n],
},
},
}
stderrLogs <- buf[:n]
}
if errors.Is(readErr, io.EOF) {
break
}
if readErr != nil {
fmt.Fprintf(os.Stderr, "error reading from stderr: %s\n", readErr)
break
}
}
})
// For backwards compatibility we still set the stdin if not explicitly disabled
// If stdin is disabled, the process will use /dev/null as stdin
if req.Stdin == nil || req.GetStdin() == true {
stdin, err := cmd.StdinPipe()
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating stdin pipe for command '%s': %w", userCmd, err))
}
h.stdin = stdin
}
}
go func() {
outWg.Wait()
close(outMultiplex.Source)
outCancel()
}()
return h, nil
}
func getProcType(req *rpc.StartRequest) cgroups.ProcessType {
if req != nil && req.GetPty() != nil {
return cgroups.ProcessTypePTY
}
return cgroups.ProcessTypeUser
}
func (p *Handler) SendSignal(signal syscall.Signal) error {
if p.cmd.Process == nil {
return fmt.Errorf("process not started")
}
if signal == syscall.SIGKILL || signal == syscall.SIGTERM {
p.outCancel()
}
return p.cmd.Process.Signal(signal)
}
func (p *Handler) ResizeTty(size *pty.Winsize) error {
if p.tty == nil {
return fmt.Errorf("tty not assigned to process")
}
return pty.Setsize(p.tty, size)
}
func (p *Handler) WriteStdin(data []byte) error {
if p.tty != nil {
return fmt.Errorf("tty assigned to process — input should be written to the pty, not the stdin")
}
p.stdinMu.Lock()
defer p.stdinMu.Unlock()
if p.stdin == nil {
return fmt.Errorf("stdin not enabled or closed")
}
_, err := p.stdin.Write(data)
if err != nil {
return fmt.Errorf("error writing to stdin of process '%d': %w", p.cmd.Process.Pid, err)
}
return nil
}
// CloseStdin closes the stdin pipe to signal EOF to the process.
// Only works for non-PTY processes.
func (p *Handler) CloseStdin() error {
if p.tty != nil {
return fmt.Errorf("cannot close stdin for PTY process — send Ctrl+D (0x04) instead")
}
p.stdinMu.Lock()
defer p.stdinMu.Unlock()
if p.stdin == nil {
return nil
}
err := p.stdin.Close()
// We still set the stdin to nil even on error as there are no errors,
// for which it is really safe to retry close across all distributions.
p.stdin = nil
return err
}
func (p *Handler) WriteTty(data []byte) error {
if p.tty == nil {
return fmt.Errorf("tty not assigned to process — input should be written to the stdin, not the tty")
}
_, err := p.tty.Write(data)
if err != nil {
return fmt.Errorf("error writing to tty of process '%d': %w", p.cmd.Process.Pid, err)
}
return nil
}
func (p *Handler) Start() (uint32, error) {
// Pty is already started in the New method
if p.tty == nil {
err := p.cmd.Start()
if err != nil {
return 0, fmt.Errorf("error starting process '%s': %w", p.userCommand(), err)
}
}
p.logger.
Info().
Str("event_type", "process_start").
Int("pid", p.cmd.Process.Pid).
Str("command", p.userCommand()).
Msg(fmt.Sprintf("Process with pid %d started", p.cmd.Process.Pid))
return uint32(p.cmd.Process.Pid), nil
}
func (p *Handler) Wait() {
// Wait for the output pipes to be closed or cancelled.
<-p.outCtx.Done()
err := p.cmd.Wait()
p.tty.Close()
var errMsg *string
if err != nil {
msg := err.Error()
errMsg = &msg
}
endEvent := &rpc.ProcessEvent_EndEvent{
Error: errMsg,
ExitCode: int32(p.cmd.ProcessState.ExitCode()),
Exited: p.cmd.ProcessState.Exited(),
Status: p.cmd.ProcessState.String(),
}
event := rpc.ProcessEvent_End{
End: endEvent,
}
p.EndEvent.Source <- event
p.logger.
Info().
Str("event_type", "process_end").
Interface("process_result", endEvent).
Msg(fmt.Sprintf("Process with pid %d ended", p.cmd.Process.Pid))
// Ensure the process cancel is called to cleanup resources.
// As it is called after end event and Wait, it should not affect command execution or returned events.
p.cancel()
}

View File

@ -0,0 +1,75 @@
// SPDX-License-Identifier: Apache-2.0
package handler
import (
"sync"
"sync/atomic"
)
type MultiplexedChannel[T any] struct {
Source chan T
channels []chan T
mu sync.RWMutex
exited atomic.Bool
}
func NewMultiplexedChannel[T any](buffer int) *MultiplexedChannel[T] {
c := &MultiplexedChannel[T]{
channels: nil,
Source: make(chan T, buffer),
}
go func() {
for v := range c.Source {
c.mu.RLock()
for _, cons := range c.channels {
cons <- v
}
c.mu.RUnlock()
}
c.exited.Store(true)
for _, cons := range c.channels {
close(cons)
}
}()
return c
}
func (m *MultiplexedChannel[T]) Fork() (chan T, func()) {
if m.exited.Load() {
ch := make(chan T)
close(ch)
return ch, func() {}
}
m.mu.Lock()
defer m.mu.Unlock()
consumer := make(chan T)
m.channels = append(m.channels, consumer)
return consumer, func() {
m.remove(consumer)
}
}
func (m *MultiplexedChannel[T]) remove(consumer chan T) {
m.mu.Lock()
defer m.mu.Unlock()
for i, ch := range m.channels {
if ch == consumer {
m.channels = append(m.channels[:i], m.channels[i+1:]...)
return
}
}
}

View File

@ -0,0 +1,109 @@
// SPDX-License-Identifier: Apache-2.0
package process
import (
"context"
"fmt"
"connectrpc.com/connect"
"github.com/rs/zerolog"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/process/handler"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
)
func handleInput(ctx context.Context, process *handler.Handler, in *rpc.ProcessInput, logger *zerolog.Logger) error {
switch in.GetInput().(type) {
case *rpc.ProcessInput_Pty:
err := process.WriteTty(in.GetPty())
if err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("error writing to tty: %w", err))
}
case *rpc.ProcessInput_Stdin:
err := process.WriteStdin(in.GetStdin())
if err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("error writing to stdin: %w", err))
}
logger.Debug().
Str("event_type", "stdin").
Interface("stdin", in.GetStdin()).
Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).
Msg("Streaming input to process")
default:
return connect.NewError(connect.CodeUnimplemented, fmt.Errorf("invalid input type %T", in.GetInput()))
}
return nil
}
func (s *Service) SendInput(ctx context.Context, req *connect.Request[rpc.SendInputRequest]) (*connect.Response[rpc.SendInputResponse], error) {
proc, err := s.getProcess(req.Msg.GetProcess())
if err != nil {
return nil, err
}
err = handleInput(ctx, proc, req.Msg.GetInput(), s.logger)
if err != nil {
return nil, err
}
return connect.NewResponse(&rpc.SendInputResponse{}), nil
}
func (s *Service) StreamInput(ctx context.Context, stream *connect.ClientStream[rpc.StreamInputRequest]) (*connect.Response[rpc.StreamInputResponse], error) {
return logs.LogClientStreamWithoutEvents(ctx, s.logger, stream, s.streamInputHandler)
}
func (s *Service) streamInputHandler(ctx context.Context, stream *connect.ClientStream[rpc.StreamInputRequest]) (*connect.Response[rpc.StreamInputResponse], error) {
var proc *handler.Handler
for stream.Receive() {
req := stream.Msg()
switch req.GetEvent().(type) {
case *rpc.StreamInputRequest_Start:
p, err := s.getProcess(req.GetStart().GetProcess())
if err != nil {
return nil, err
}
proc = p
case *rpc.StreamInputRequest_Data:
err := handleInput(ctx, proc, req.GetData().GetInput(), s.logger)
if err != nil {
return nil, err
}
case *rpc.StreamInputRequest_Keepalive:
default:
return nil, connect.NewError(connect.CodeUnimplemented, fmt.Errorf("invalid event type %T", req.GetEvent()))
}
}
err := stream.Err()
if err != nil {
return nil, connect.NewError(connect.CodeUnknown, fmt.Errorf("error streaming input: %w", err))
}
return connect.NewResponse(&rpc.StreamInputResponse{}), nil
}
func (s *Service) CloseStdin(
_ context.Context,
req *connect.Request[rpc.CloseStdinRequest],
) (*connect.Response[rpc.CloseStdinResponse], error) {
handler, err := s.getProcess(req.Msg.GetProcess())
if err != nil {
return nil, err
}
if err := handler.CloseStdin(); err != nil {
return nil, connect.NewError(connect.CodeUnknown, fmt.Errorf("error closing stdin: %w", err))
}
return connect.NewResponse(&rpc.CloseStdinResponse{}), nil
}

View File

@ -0,0 +1,30 @@
// SPDX-License-Identifier: Apache-2.0
package process
import (
"context"
"connectrpc.com/connect"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/process/handler"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
)
func (s *Service) List(context.Context, *connect.Request[rpc.ListRequest]) (*connect.Response[rpc.ListResponse], error) {
processes := make([]*rpc.ProcessInfo, 0)
s.processes.Range(func(pid uint32, value *handler.Handler) bool {
processes = append(processes, &rpc.ProcessInfo{
Pid: pid,
Tag: value.Tag,
Config: value.Config,
})
return true
})
return connect.NewResponse(&rpc.ListResponse{
Processes: processes,
}), nil
}

View File

@ -0,0 +1,86 @@
// SPDX-License-Identifier: Apache-2.0
package process
import (
"fmt"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"github.com/rs/zerolog"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/cgroups"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/process/handler"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
spec "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process/processconnect"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
type Service struct {
processes *utils.Map[uint32, *handler.Handler]
logger *zerolog.Logger
defaults *execcontext.Defaults
cgroupManager cgroups.Manager
}
func newService(l *zerolog.Logger, defaults *execcontext.Defaults, cgroupManager cgroups.Manager) *Service {
return &Service{
logger: l,
processes: utils.NewMap[uint32, *handler.Handler](),
defaults: defaults,
cgroupManager: cgroupManager,
}
}
func Handle(server *chi.Mux, l *zerolog.Logger, defaults *execcontext.Defaults, cgroupManager cgroups.Manager) *Service {
service := newService(l, defaults, cgroupManager)
interceptors := connect.WithInterceptors(logs.NewUnaryLogInterceptor(l))
path, h := spec.NewProcessHandler(service, interceptors)
server.Mount(path, h)
return service
}
func (s *Service) getProcess(selector *rpc.ProcessSelector) (*handler.Handler, error) {
var proc *handler.Handler
switch selector.GetSelector().(type) {
case *rpc.ProcessSelector_Pid:
p, ok := s.processes.Load(selector.GetPid())
if !ok {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("process with pid %d not found", selector.GetPid()))
}
proc = p
case *rpc.ProcessSelector_Tag:
tag := selector.GetTag()
s.processes.Range(func(_ uint32, value *handler.Handler) bool {
if value.Tag == nil {
return true
}
if *value.Tag == tag {
proc = value
return true
}
return false
})
if proc == nil {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("process with tag %s not found", tag))
}
default:
return nil, connect.NewError(connect.CodeUnimplemented, fmt.Errorf("invalid input type %T", selector))
}
return proc, nil
}

View File

@ -0,0 +1,40 @@
// SPDX-License-Identifier: Apache-2.0
package process
import (
"context"
"fmt"
"syscall"
"connectrpc.com/connect"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
)
func (s *Service) SendSignal(
_ context.Context,
req *connect.Request[rpc.SendSignalRequest],
) (*connect.Response[rpc.SendSignalResponse], error) {
handler, err := s.getProcess(req.Msg.GetProcess())
if err != nil {
return nil, err
}
var signal syscall.Signal
switch req.Msg.GetSignal() {
case rpc.Signal_SIGNAL_SIGKILL:
signal = syscall.SIGKILL
case rpc.Signal_SIGNAL_SIGTERM:
signal = syscall.SIGTERM
default:
return nil, connect.NewError(connect.CodeUnimplemented, fmt.Errorf("invalid signal: %s", req.Msg.GetSignal()))
}
err = handler.SendSignal(signal)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error sending signal: %w", err))
}
return connect.NewResponse(&rpc.SendSignalResponse{}), nil
}

View File

@ -0,0 +1,249 @@
// SPDX-License-Identifier: Apache-2.0
package process
import (
"context"
"errors"
"fmt"
"net/http"
"os/user"
"strconv"
"time"
"connectrpc.com/connect"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/process/handler"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
)
func (s *Service) InitializeStartProcess(ctx context.Context, user *user.User, req *rpc.StartRequest) error {
var err error
ctx = logs.AddRequestIDToContext(ctx)
defer s.logger.
Err(err).
Interface("request", req).
Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).
Msg("Initialized startCmd")
handlerL := s.logger.With().Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).Logger()
startProcCtx, startProcCancel := context.WithCancel(ctx)
proc, err := handler.New(startProcCtx, user, req, &handlerL, s.defaults, s.cgroupManager, startProcCancel)
if err != nil {
return err
}
pid, err := proc.Start()
if err != nil {
return err
}
s.processes.Store(pid, proc)
go func() {
defer s.processes.Delete(pid)
proc.Wait()
}()
return nil
}
func (s *Service) Start(ctx context.Context, req *connect.Request[rpc.StartRequest], stream *connect.ServerStream[rpc.StartResponse]) error {
return logs.LogServerStreamWithoutEvents(ctx, s.logger, req, stream, s.handleStart)
}
func (s *Service) handleStart(ctx context.Context, req *connect.Request[rpc.StartRequest], stream *connect.ServerStream[rpc.StartResponse]) error {
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
handlerL := s.logger.With().Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).Logger()
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
if err != nil {
return err
}
timeout, err := determineTimeoutFromHeader(stream.Conn().RequestHeader())
if err != nil {
return connect.NewError(connect.CodeInvalidArgument, err)
}
// Create a new context with a timeout if provided.
// We do not want the command to be killed if the request context is cancelled
procCtx, cancelProc := context.Background(), func() {}
if timeout > 0 { // zero timeout means no timeout
procCtx, cancelProc = context.WithTimeout(procCtx, timeout)
}
proc, err := handler.New( //nolint:contextcheck // TODO: fix this later
procCtx,
u,
req.Msg,
&handlerL,
s.defaults,
s.cgroupManager,
cancelProc,
)
if err != nil {
// Ensure the process cancel is called to cleanup resources.
cancelProc()
return err
}
exitChan := make(chan struct{})
startMultiplexer := handler.NewMultiplexedChannel[rpc.ProcessEvent_Start](0)
defer close(startMultiplexer.Source)
start, startCancel := startMultiplexer.Fork()
defer startCancel()
data, dataCancel := proc.DataEvent.Fork()
defer dataCancel()
end, endCancel := proc.EndEvent.Fork()
defer endCancel()
go func() {
defer close(exitChan)
select {
case <-ctx.Done():
cancel(ctx.Err())
return
case event, ok := <-start:
if !ok {
cancel(connect.NewError(connect.CodeUnknown, errors.New("start event channel closed before sending start event")))
return
}
streamErr := stream.Send(&rpc.StartResponse{
Event: &rpc.ProcessEvent{
Event: &event,
},
})
if streamErr != nil {
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending start event: %w", streamErr)))
return
}
}
keepaliveTicker, resetKeepalive := permissions.GetKeepAliveTicker(req)
defer keepaliveTicker.Stop()
dataLoop:
for {
select {
case <-keepaliveTicker.C:
streamErr := stream.Send(&rpc.StartResponse{
Event: &rpc.ProcessEvent{
Event: &rpc.ProcessEvent_Keepalive{
Keepalive: &rpc.ProcessEvent_KeepAlive{},
},
},
})
if streamErr != nil {
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending keepalive: %w", streamErr)))
return
}
case <-ctx.Done():
cancel(ctx.Err())
return
case event, ok := <-data:
if !ok {
break dataLoop
}
streamErr := stream.Send(&rpc.StartResponse{
Event: &rpc.ProcessEvent{
Event: &event,
},
})
if streamErr != nil {
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending data event: %w", streamErr)))
return
}
resetKeepalive()
}
}
select {
case <-ctx.Done():
cancel(ctx.Err())
return
case event, ok := <-end:
if !ok {
cancel(connect.NewError(connect.CodeUnknown, errors.New("end event channel closed before sending end event")))
return
}
streamErr := stream.Send(&rpc.StartResponse{
Event: &rpc.ProcessEvent{
Event: &event,
},
})
if streamErr != nil {
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending end event: %w", streamErr)))
return
}
}
}()
pid, err := proc.Start()
if err != nil {
return connect.NewError(connect.CodeInvalidArgument, err)
}
s.processes.Store(pid, proc)
start <- rpc.ProcessEvent_Start{
Start: &rpc.ProcessEvent_StartEvent{
Pid: pid,
},
}
go func() {
defer s.processes.Delete(pid)
proc.Wait()
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-exitChan:
return nil
}
}
func determineTimeoutFromHeader(header http.Header) (time.Duration, error) {
timeoutHeader := header.Get("Connect-Timeout-Ms")
if timeoutHeader == "" {
return 0, nil
}
timeout, err := strconv.Atoi(timeoutHeader)
if err != nil {
return 0, err
}
return time.Duration(timeout) * time.Millisecond, nil
}

View File

@ -0,0 +1,32 @@
// SPDX-License-Identifier: Apache-2.0
package process
import (
"context"
"fmt"
"connectrpc.com/connect"
"github.com/creack/pty"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
)
func (s *Service) Update(_ context.Context, req *connect.Request[rpc.UpdateRequest]) (*connect.Response[rpc.UpdateResponse], error) {
proc, err := s.getProcess(req.Msg.GetProcess())
if err != nil {
return nil, err
}
if req.Msg.GetPty() != nil {
err := proc.ResizeTty(&pty.Winsize{
Rows: uint16(req.Msg.GetPty().GetSize().GetRows()),
Cols: uint16(req.Msg.GetPty().GetSize().GetCols()),
})
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error resizing tty: %w", err))
}
}
return connect.NewResponse(&rpc.UpdateResponse{}), nil
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,337 @@
// Code generated by protoc-gen-connect-go. DO NOT EDIT.
//
// Source: filesystem/filesystem.proto
package filesystemconnect
import (
connect "connectrpc.com/connect"
context "context"
errors "errors"
filesystem "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
http "net/http"
strings "strings"
)
// This is a compile-time assertion to ensure that this generated file and the connect package are
// compatible. If you get a compiler error that this constant is not defined, this code was
// generated with a version of connect newer than the one compiled into your binary. You can fix the
// problem by either regenerating this code with an older version of connect or updating the connect
// version compiled into your binary.
const _ = connect.IsAtLeastVersion1_13_0
const (
// FilesystemName is the fully-qualified name of the Filesystem service.
FilesystemName = "filesystem.Filesystem"
)
// These constants are the fully-qualified names of the RPCs defined in this package. They're
// exposed at runtime as Spec.Procedure and as the final two segments of the HTTP route.
//
// Note that these are different from the fully-qualified method names used by
// google.golang.org/protobuf/reflect/protoreflect. To convert from these constants to
// reflection-formatted method names, remove the leading slash and convert the remaining slash to a
// period.
const (
// FilesystemStatProcedure is the fully-qualified name of the Filesystem's Stat RPC.
FilesystemStatProcedure = "/filesystem.Filesystem/Stat"
// FilesystemMakeDirProcedure is the fully-qualified name of the Filesystem's MakeDir RPC.
FilesystemMakeDirProcedure = "/filesystem.Filesystem/MakeDir"
// FilesystemMoveProcedure is the fully-qualified name of the Filesystem's Move RPC.
FilesystemMoveProcedure = "/filesystem.Filesystem/Move"
// FilesystemListDirProcedure is the fully-qualified name of the Filesystem's ListDir RPC.
FilesystemListDirProcedure = "/filesystem.Filesystem/ListDir"
// FilesystemRemoveProcedure is the fully-qualified name of the Filesystem's Remove RPC.
FilesystemRemoveProcedure = "/filesystem.Filesystem/Remove"
// FilesystemWatchDirProcedure is the fully-qualified name of the Filesystem's WatchDir RPC.
FilesystemWatchDirProcedure = "/filesystem.Filesystem/WatchDir"
// FilesystemCreateWatcherProcedure is the fully-qualified name of the Filesystem's CreateWatcher
// RPC.
FilesystemCreateWatcherProcedure = "/filesystem.Filesystem/CreateWatcher"
// FilesystemGetWatcherEventsProcedure is the fully-qualified name of the Filesystem's
// GetWatcherEvents RPC.
FilesystemGetWatcherEventsProcedure = "/filesystem.Filesystem/GetWatcherEvents"
// FilesystemRemoveWatcherProcedure is the fully-qualified name of the Filesystem's RemoveWatcher
// RPC.
FilesystemRemoveWatcherProcedure = "/filesystem.Filesystem/RemoveWatcher"
)
// FilesystemClient is a client for the filesystem.Filesystem service.
type FilesystemClient interface {
Stat(context.Context, *connect.Request[filesystem.StatRequest]) (*connect.Response[filesystem.StatResponse], error)
MakeDir(context.Context, *connect.Request[filesystem.MakeDirRequest]) (*connect.Response[filesystem.MakeDirResponse], error)
Move(context.Context, *connect.Request[filesystem.MoveRequest]) (*connect.Response[filesystem.MoveResponse], error)
ListDir(context.Context, *connect.Request[filesystem.ListDirRequest]) (*connect.Response[filesystem.ListDirResponse], error)
Remove(context.Context, *connect.Request[filesystem.RemoveRequest]) (*connect.Response[filesystem.RemoveResponse], error)
WatchDir(context.Context, *connect.Request[filesystem.WatchDirRequest]) (*connect.ServerStreamForClient[filesystem.WatchDirResponse], error)
// Non-streaming versions of WatchDir
CreateWatcher(context.Context, *connect.Request[filesystem.CreateWatcherRequest]) (*connect.Response[filesystem.CreateWatcherResponse], error)
GetWatcherEvents(context.Context, *connect.Request[filesystem.GetWatcherEventsRequest]) (*connect.Response[filesystem.GetWatcherEventsResponse], error)
RemoveWatcher(context.Context, *connect.Request[filesystem.RemoveWatcherRequest]) (*connect.Response[filesystem.RemoveWatcherResponse], error)
}
// NewFilesystemClient constructs a client for the filesystem.Filesystem service. By default, it
// uses the Connect protocol with the binary Protobuf Codec, asks for gzipped responses, and sends
// uncompressed requests. To use the gRPC or gRPC-Web protocols, supply the connect.WithGRPC() or
// connect.WithGRPCWeb() options.
//
// The URL supplied here should be the base URL for the Connect or gRPC server (for example,
// http://api.acme.com or https://acme.com/grpc).
func NewFilesystemClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) FilesystemClient {
baseURL = strings.TrimRight(baseURL, "/")
filesystemMethods := filesystem.File_filesystem_filesystem_proto.Services().ByName("Filesystem").Methods()
return &filesystemClient{
stat: connect.NewClient[filesystem.StatRequest, filesystem.StatResponse](
httpClient,
baseURL+FilesystemStatProcedure,
connect.WithSchema(filesystemMethods.ByName("Stat")),
connect.WithClientOptions(opts...),
),
makeDir: connect.NewClient[filesystem.MakeDirRequest, filesystem.MakeDirResponse](
httpClient,
baseURL+FilesystemMakeDirProcedure,
connect.WithSchema(filesystemMethods.ByName("MakeDir")),
connect.WithClientOptions(opts...),
),
move: connect.NewClient[filesystem.MoveRequest, filesystem.MoveResponse](
httpClient,
baseURL+FilesystemMoveProcedure,
connect.WithSchema(filesystemMethods.ByName("Move")),
connect.WithClientOptions(opts...),
),
listDir: connect.NewClient[filesystem.ListDirRequest, filesystem.ListDirResponse](
httpClient,
baseURL+FilesystemListDirProcedure,
connect.WithSchema(filesystemMethods.ByName("ListDir")),
connect.WithClientOptions(opts...),
),
remove: connect.NewClient[filesystem.RemoveRequest, filesystem.RemoveResponse](
httpClient,
baseURL+FilesystemRemoveProcedure,
connect.WithSchema(filesystemMethods.ByName("Remove")),
connect.WithClientOptions(opts...),
),
watchDir: connect.NewClient[filesystem.WatchDirRequest, filesystem.WatchDirResponse](
httpClient,
baseURL+FilesystemWatchDirProcedure,
connect.WithSchema(filesystemMethods.ByName("WatchDir")),
connect.WithClientOptions(opts...),
),
createWatcher: connect.NewClient[filesystem.CreateWatcherRequest, filesystem.CreateWatcherResponse](
httpClient,
baseURL+FilesystemCreateWatcherProcedure,
connect.WithSchema(filesystemMethods.ByName("CreateWatcher")),
connect.WithClientOptions(opts...),
),
getWatcherEvents: connect.NewClient[filesystem.GetWatcherEventsRequest, filesystem.GetWatcherEventsResponse](
httpClient,
baseURL+FilesystemGetWatcherEventsProcedure,
connect.WithSchema(filesystemMethods.ByName("GetWatcherEvents")),
connect.WithClientOptions(opts...),
),
removeWatcher: connect.NewClient[filesystem.RemoveWatcherRequest, filesystem.RemoveWatcherResponse](
httpClient,
baseURL+FilesystemRemoveWatcherProcedure,
connect.WithSchema(filesystemMethods.ByName("RemoveWatcher")),
connect.WithClientOptions(opts...),
),
}
}
// filesystemClient implements FilesystemClient.
type filesystemClient struct {
stat *connect.Client[filesystem.StatRequest, filesystem.StatResponse]
makeDir *connect.Client[filesystem.MakeDirRequest, filesystem.MakeDirResponse]
move *connect.Client[filesystem.MoveRequest, filesystem.MoveResponse]
listDir *connect.Client[filesystem.ListDirRequest, filesystem.ListDirResponse]
remove *connect.Client[filesystem.RemoveRequest, filesystem.RemoveResponse]
watchDir *connect.Client[filesystem.WatchDirRequest, filesystem.WatchDirResponse]
createWatcher *connect.Client[filesystem.CreateWatcherRequest, filesystem.CreateWatcherResponse]
getWatcherEvents *connect.Client[filesystem.GetWatcherEventsRequest, filesystem.GetWatcherEventsResponse]
removeWatcher *connect.Client[filesystem.RemoveWatcherRequest, filesystem.RemoveWatcherResponse]
}
// Stat calls filesystem.Filesystem.Stat.
func (c *filesystemClient) Stat(ctx context.Context, req *connect.Request[filesystem.StatRequest]) (*connect.Response[filesystem.StatResponse], error) {
return c.stat.CallUnary(ctx, req)
}
// MakeDir calls filesystem.Filesystem.MakeDir.
func (c *filesystemClient) MakeDir(ctx context.Context, req *connect.Request[filesystem.MakeDirRequest]) (*connect.Response[filesystem.MakeDirResponse], error) {
return c.makeDir.CallUnary(ctx, req)
}
// Move calls filesystem.Filesystem.Move.
func (c *filesystemClient) Move(ctx context.Context, req *connect.Request[filesystem.MoveRequest]) (*connect.Response[filesystem.MoveResponse], error) {
return c.move.CallUnary(ctx, req)
}
// ListDir calls filesystem.Filesystem.ListDir.
func (c *filesystemClient) ListDir(ctx context.Context, req *connect.Request[filesystem.ListDirRequest]) (*connect.Response[filesystem.ListDirResponse], error) {
return c.listDir.CallUnary(ctx, req)
}
// Remove calls filesystem.Filesystem.Remove.
func (c *filesystemClient) Remove(ctx context.Context, req *connect.Request[filesystem.RemoveRequest]) (*connect.Response[filesystem.RemoveResponse], error) {
return c.remove.CallUnary(ctx, req)
}
// WatchDir calls filesystem.Filesystem.WatchDir.
func (c *filesystemClient) WatchDir(ctx context.Context, req *connect.Request[filesystem.WatchDirRequest]) (*connect.ServerStreamForClient[filesystem.WatchDirResponse], error) {
return c.watchDir.CallServerStream(ctx, req)
}
// CreateWatcher calls filesystem.Filesystem.CreateWatcher.
func (c *filesystemClient) CreateWatcher(ctx context.Context, req *connect.Request[filesystem.CreateWatcherRequest]) (*connect.Response[filesystem.CreateWatcherResponse], error) {
return c.createWatcher.CallUnary(ctx, req)
}
// GetWatcherEvents calls filesystem.Filesystem.GetWatcherEvents.
func (c *filesystemClient) GetWatcherEvents(ctx context.Context, req *connect.Request[filesystem.GetWatcherEventsRequest]) (*connect.Response[filesystem.GetWatcherEventsResponse], error) {
return c.getWatcherEvents.CallUnary(ctx, req)
}
// RemoveWatcher calls filesystem.Filesystem.RemoveWatcher.
func (c *filesystemClient) RemoveWatcher(ctx context.Context, req *connect.Request[filesystem.RemoveWatcherRequest]) (*connect.Response[filesystem.RemoveWatcherResponse], error) {
return c.removeWatcher.CallUnary(ctx, req)
}
// FilesystemHandler is an implementation of the filesystem.Filesystem service.
type FilesystemHandler interface {
Stat(context.Context, *connect.Request[filesystem.StatRequest]) (*connect.Response[filesystem.StatResponse], error)
MakeDir(context.Context, *connect.Request[filesystem.MakeDirRequest]) (*connect.Response[filesystem.MakeDirResponse], error)
Move(context.Context, *connect.Request[filesystem.MoveRequest]) (*connect.Response[filesystem.MoveResponse], error)
ListDir(context.Context, *connect.Request[filesystem.ListDirRequest]) (*connect.Response[filesystem.ListDirResponse], error)
Remove(context.Context, *connect.Request[filesystem.RemoveRequest]) (*connect.Response[filesystem.RemoveResponse], error)
WatchDir(context.Context, *connect.Request[filesystem.WatchDirRequest], *connect.ServerStream[filesystem.WatchDirResponse]) error
// Non-streaming versions of WatchDir
CreateWatcher(context.Context, *connect.Request[filesystem.CreateWatcherRequest]) (*connect.Response[filesystem.CreateWatcherResponse], error)
GetWatcherEvents(context.Context, *connect.Request[filesystem.GetWatcherEventsRequest]) (*connect.Response[filesystem.GetWatcherEventsResponse], error)
RemoveWatcher(context.Context, *connect.Request[filesystem.RemoveWatcherRequest]) (*connect.Response[filesystem.RemoveWatcherResponse], error)
}
// NewFilesystemHandler builds an HTTP handler from the service implementation. It returns the path
// on which to mount the handler and the handler itself.
//
// By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf
// and JSON codecs. They also support gzip compression.
func NewFilesystemHandler(svc FilesystemHandler, opts ...connect.HandlerOption) (string, http.Handler) {
filesystemMethods := filesystem.File_filesystem_filesystem_proto.Services().ByName("Filesystem").Methods()
filesystemStatHandler := connect.NewUnaryHandler(
FilesystemStatProcedure,
svc.Stat,
connect.WithSchema(filesystemMethods.ByName("Stat")),
connect.WithHandlerOptions(opts...),
)
filesystemMakeDirHandler := connect.NewUnaryHandler(
FilesystemMakeDirProcedure,
svc.MakeDir,
connect.WithSchema(filesystemMethods.ByName("MakeDir")),
connect.WithHandlerOptions(opts...),
)
filesystemMoveHandler := connect.NewUnaryHandler(
FilesystemMoveProcedure,
svc.Move,
connect.WithSchema(filesystemMethods.ByName("Move")),
connect.WithHandlerOptions(opts...),
)
filesystemListDirHandler := connect.NewUnaryHandler(
FilesystemListDirProcedure,
svc.ListDir,
connect.WithSchema(filesystemMethods.ByName("ListDir")),
connect.WithHandlerOptions(opts...),
)
filesystemRemoveHandler := connect.NewUnaryHandler(
FilesystemRemoveProcedure,
svc.Remove,
connect.WithSchema(filesystemMethods.ByName("Remove")),
connect.WithHandlerOptions(opts...),
)
filesystemWatchDirHandler := connect.NewServerStreamHandler(
FilesystemWatchDirProcedure,
svc.WatchDir,
connect.WithSchema(filesystemMethods.ByName("WatchDir")),
connect.WithHandlerOptions(opts...),
)
filesystemCreateWatcherHandler := connect.NewUnaryHandler(
FilesystemCreateWatcherProcedure,
svc.CreateWatcher,
connect.WithSchema(filesystemMethods.ByName("CreateWatcher")),
connect.WithHandlerOptions(opts...),
)
filesystemGetWatcherEventsHandler := connect.NewUnaryHandler(
FilesystemGetWatcherEventsProcedure,
svc.GetWatcherEvents,
connect.WithSchema(filesystemMethods.ByName("GetWatcherEvents")),
connect.WithHandlerOptions(opts...),
)
filesystemRemoveWatcherHandler := connect.NewUnaryHandler(
FilesystemRemoveWatcherProcedure,
svc.RemoveWatcher,
connect.WithSchema(filesystemMethods.ByName("RemoveWatcher")),
connect.WithHandlerOptions(opts...),
)
return "/filesystem.Filesystem/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case FilesystemStatProcedure:
filesystemStatHandler.ServeHTTP(w, r)
case FilesystemMakeDirProcedure:
filesystemMakeDirHandler.ServeHTTP(w, r)
case FilesystemMoveProcedure:
filesystemMoveHandler.ServeHTTP(w, r)
case FilesystemListDirProcedure:
filesystemListDirHandler.ServeHTTP(w, r)
case FilesystemRemoveProcedure:
filesystemRemoveHandler.ServeHTTP(w, r)
case FilesystemWatchDirProcedure:
filesystemWatchDirHandler.ServeHTTP(w, r)
case FilesystemCreateWatcherProcedure:
filesystemCreateWatcherHandler.ServeHTTP(w, r)
case FilesystemGetWatcherEventsProcedure:
filesystemGetWatcherEventsHandler.ServeHTTP(w, r)
case FilesystemRemoveWatcherProcedure:
filesystemRemoveWatcherHandler.ServeHTTP(w, r)
default:
http.NotFound(w, r)
}
})
}
// UnimplementedFilesystemHandler returns CodeUnimplemented from all methods.
type UnimplementedFilesystemHandler struct{}
func (UnimplementedFilesystemHandler) Stat(context.Context, *connect.Request[filesystem.StatRequest]) (*connect.Response[filesystem.StatResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.Stat is not implemented"))
}
func (UnimplementedFilesystemHandler) MakeDir(context.Context, *connect.Request[filesystem.MakeDirRequest]) (*connect.Response[filesystem.MakeDirResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.MakeDir is not implemented"))
}
func (UnimplementedFilesystemHandler) Move(context.Context, *connect.Request[filesystem.MoveRequest]) (*connect.Response[filesystem.MoveResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.Move is not implemented"))
}
func (UnimplementedFilesystemHandler) ListDir(context.Context, *connect.Request[filesystem.ListDirRequest]) (*connect.Response[filesystem.ListDirResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.ListDir is not implemented"))
}
func (UnimplementedFilesystemHandler) Remove(context.Context, *connect.Request[filesystem.RemoveRequest]) (*connect.Response[filesystem.RemoveResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.Remove is not implemented"))
}
func (UnimplementedFilesystemHandler) WatchDir(context.Context, *connect.Request[filesystem.WatchDirRequest], *connect.ServerStream[filesystem.WatchDirResponse]) error {
return connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.WatchDir is not implemented"))
}
func (UnimplementedFilesystemHandler) CreateWatcher(context.Context, *connect.Request[filesystem.CreateWatcherRequest]) (*connect.Response[filesystem.CreateWatcherResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.CreateWatcher is not implemented"))
}
func (UnimplementedFilesystemHandler) GetWatcherEvents(context.Context, *connect.Request[filesystem.GetWatcherEventsRequest]) (*connect.Response[filesystem.GetWatcherEventsResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.GetWatcherEvents is not implemented"))
}
func (UnimplementedFilesystemHandler) RemoveWatcher(context.Context, *connect.Request[filesystem.RemoveWatcherRequest]) (*connect.Response[filesystem.RemoveWatcherResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.RemoveWatcher is not implemented"))
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,310 @@
// Code generated by protoc-gen-connect-go. DO NOT EDIT.
//
// Source: process/process.proto
package processconnect
import (
connect "connectrpc.com/connect"
context "context"
errors "errors"
process "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
http "net/http"
strings "strings"
)
// This is a compile-time assertion to ensure that this generated file and the connect package are
// compatible. If you get a compiler error that this constant is not defined, this code was
// generated with a version of connect newer than the one compiled into your binary. You can fix the
// problem by either regenerating this code with an older version of connect or updating the connect
// version compiled into your binary.
const _ = connect.IsAtLeastVersion1_13_0
const (
// ProcessName is the fully-qualified name of the Process service.
ProcessName = "process.Process"
)
// These constants are the fully-qualified names of the RPCs defined in this package. They're
// exposed at runtime as Spec.Procedure and as the final two segments of the HTTP route.
//
// Note that these are different from the fully-qualified method names used by
// google.golang.org/protobuf/reflect/protoreflect. To convert from these constants to
// reflection-formatted method names, remove the leading slash and convert the remaining slash to a
// period.
const (
// ProcessListProcedure is the fully-qualified name of the Process's List RPC.
ProcessListProcedure = "/process.Process/List"
// ProcessConnectProcedure is the fully-qualified name of the Process's Connect RPC.
ProcessConnectProcedure = "/process.Process/Connect"
// ProcessStartProcedure is the fully-qualified name of the Process's Start RPC.
ProcessStartProcedure = "/process.Process/Start"
// ProcessUpdateProcedure is the fully-qualified name of the Process's Update RPC.
ProcessUpdateProcedure = "/process.Process/Update"
// ProcessStreamInputProcedure is the fully-qualified name of the Process's StreamInput RPC.
ProcessStreamInputProcedure = "/process.Process/StreamInput"
// ProcessSendInputProcedure is the fully-qualified name of the Process's SendInput RPC.
ProcessSendInputProcedure = "/process.Process/SendInput"
// ProcessSendSignalProcedure is the fully-qualified name of the Process's SendSignal RPC.
ProcessSendSignalProcedure = "/process.Process/SendSignal"
// ProcessCloseStdinProcedure is the fully-qualified name of the Process's CloseStdin RPC.
ProcessCloseStdinProcedure = "/process.Process/CloseStdin"
)
// ProcessClient is a client for the process.Process service.
type ProcessClient interface {
List(context.Context, *connect.Request[process.ListRequest]) (*connect.Response[process.ListResponse], error)
Connect(context.Context, *connect.Request[process.ConnectRequest]) (*connect.ServerStreamForClient[process.ConnectResponse], error)
Start(context.Context, *connect.Request[process.StartRequest]) (*connect.ServerStreamForClient[process.StartResponse], error)
Update(context.Context, *connect.Request[process.UpdateRequest]) (*connect.Response[process.UpdateResponse], error)
// Client input stream ensures ordering of messages
StreamInput(context.Context) *connect.ClientStreamForClient[process.StreamInputRequest, process.StreamInputResponse]
SendInput(context.Context, *connect.Request[process.SendInputRequest]) (*connect.Response[process.SendInputResponse], error)
SendSignal(context.Context, *connect.Request[process.SendSignalRequest]) (*connect.Response[process.SendSignalResponse], error)
// Close stdin to signal EOF to the process.
// Only works for non-PTY processes. For PTY, send Ctrl+D (0x04) instead.
CloseStdin(context.Context, *connect.Request[process.CloseStdinRequest]) (*connect.Response[process.CloseStdinResponse], error)
}
// NewProcessClient constructs a client for the process.Process service. By default, it uses the
// Connect protocol with the binary Protobuf Codec, asks for gzipped responses, and sends
// uncompressed requests. To use the gRPC or gRPC-Web protocols, supply the connect.WithGRPC() or
// connect.WithGRPCWeb() options.
//
// The URL supplied here should be the base URL for the Connect or gRPC server (for example,
// http://api.acme.com or https://acme.com/grpc).
func NewProcessClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) ProcessClient {
baseURL = strings.TrimRight(baseURL, "/")
processMethods := process.File_process_process_proto.Services().ByName("Process").Methods()
return &processClient{
list: connect.NewClient[process.ListRequest, process.ListResponse](
httpClient,
baseURL+ProcessListProcedure,
connect.WithSchema(processMethods.ByName("List")),
connect.WithClientOptions(opts...),
),
connect: connect.NewClient[process.ConnectRequest, process.ConnectResponse](
httpClient,
baseURL+ProcessConnectProcedure,
connect.WithSchema(processMethods.ByName("Connect")),
connect.WithClientOptions(opts...),
),
start: connect.NewClient[process.StartRequest, process.StartResponse](
httpClient,
baseURL+ProcessStartProcedure,
connect.WithSchema(processMethods.ByName("Start")),
connect.WithClientOptions(opts...),
),
update: connect.NewClient[process.UpdateRequest, process.UpdateResponse](
httpClient,
baseURL+ProcessUpdateProcedure,
connect.WithSchema(processMethods.ByName("Update")),
connect.WithClientOptions(opts...),
),
streamInput: connect.NewClient[process.StreamInputRequest, process.StreamInputResponse](
httpClient,
baseURL+ProcessStreamInputProcedure,
connect.WithSchema(processMethods.ByName("StreamInput")),
connect.WithClientOptions(opts...),
),
sendInput: connect.NewClient[process.SendInputRequest, process.SendInputResponse](
httpClient,
baseURL+ProcessSendInputProcedure,
connect.WithSchema(processMethods.ByName("SendInput")),
connect.WithClientOptions(opts...),
),
sendSignal: connect.NewClient[process.SendSignalRequest, process.SendSignalResponse](
httpClient,
baseURL+ProcessSendSignalProcedure,
connect.WithSchema(processMethods.ByName("SendSignal")),
connect.WithClientOptions(opts...),
),
closeStdin: connect.NewClient[process.CloseStdinRequest, process.CloseStdinResponse](
httpClient,
baseURL+ProcessCloseStdinProcedure,
connect.WithSchema(processMethods.ByName("CloseStdin")),
connect.WithClientOptions(opts...),
),
}
}
// processClient implements ProcessClient.
type processClient struct {
list *connect.Client[process.ListRequest, process.ListResponse]
connect *connect.Client[process.ConnectRequest, process.ConnectResponse]
start *connect.Client[process.StartRequest, process.StartResponse]
update *connect.Client[process.UpdateRequest, process.UpdateResponse]
streamInput *connect.Client[process.StreamInputRequest, process.StreamInputResponse]
sendInput *connect.Client[process.SendInputRequest, process.SendInputResponse]
sendSignal *connect.Client[process.SendSignalRequest, process.SendSignalResponse]
closeStdin *connect.Client[process.CloseStdinRequest, process.CloseStdinResponse]
}
// List calls process.Process.List.
func (c *processClient) List(ctx context.Context, req *connect.Request[process.ListRequest]) (*connect.Response[process.ListResponse], error) {
return c.list.CallUnary(ctx, req)
}
// Connect calls process.Process.Connect.
func (c *processClient) Connect(ctx context.Context, req *connect.Request[process.ConnectRequest]) (*connect.ServerStreamForClient[process.ConnectResponse], error) {
return c.connect.CallServerStream(ctx, req)
}
// Start calls process.Process.Start.
func (c *processClient) Start(ctx context.Context, req *connect.Request[process.StartRequest]) (*connect.ServerStreamForClient[process.StartResponse], error) {
return c.start.CallServerStream(ctx, req)
}
// Update calls process.Process.Update.
func (c *processClient) Update(ctx context.Context, req *connect.Request[process.UpdateRequest]) (*connect.Response[process.UpdateResponse], error) {
return c.update.CallUnary(ctx, req)
}
// StreamInput calls process.Process.StreamInput.
func (c *processClient) StreamInput(ctx context.Context) *connect.ClientStreamForClient[process.StreamInputRequest, process.StreamInputResponse] {
return c.streamInput.CallClientStream(ctx)
}
// SendInput calls process.Process.SendInput.
func (c *processClient) SendInput(ctx context.Context, req *connect.Request[process.SendInputRequest]) (*connect.Response[process.SendInputResponse], error) {
return c.sendInput.CallUnary(ctx, req)
}
// SendSignal calls process.Process.SendSignal.
func (c *processClient) SendSignal(ctx context.Context, req *connect.Request[process.SendSignalRequest]) (*connect.Response[process.SendSignalResponse], error) {
return c.sendSignal.CallUnary(ctx, req)
}
// CloseStdin calls process.Process.CloseStdin.
func (c *processClient) CloseStdin(ctx context.Context, req *connect.Request[process.CloseStdinRequest]) (*connect.Response[process.CloseStdinResponse], error) {
return c.closeStdin.CallUnary(ctx, req)
}
// ProcessHandler is an implementation of the process.Process service.
type ProcessHandler interface {
List(context.Context, *connect.Request[process.ListRequest]) (*connect.Response[process.ListResponse], error)
Connect(context.Context, *connect.Request[process.ConnectRequest], *connect.ServerStream[process.ConnectResponse]) error
Start(context.Context, *connect.Request[process.StartRequest], *connect.ServerStream[process.StartResponse]) error
Update(context.Context, *connect.Request[process.UpdateRequest]) (*connect.Response[process.UpdateResponse], error)
// Client input stream ensures ordering of messages
StreamInput(context.Context, *connect.ClientStream[process.StreamInputRequest]) (*connect.Response[process.StreamInputResponse], error)
SendInput(context.Context, *connect.Request[process.SendInputRequest]) (*connect.Response[process.SendInputResponse], error)
SendSignal(context.Context, *connect.Request[process.SendSignalRequest]) (*connect.Response[process.SendSignalResponse], error)
// Close stdin to signal EOF to the process.
// Only works for non-PTY processes. For PTY, send Ctrl+D (0x04) instead.
CloseStdin(context.Context, *connect.Request[process.CloseStdinRequest]) (*connect.Response[process.CloseStdinResponse], error)
}
// NewProcessHandler builds an HTTP handler from the service implementation. It returns the path on
// which to mount the handler and the handler itself.
//
// By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf
// and JSON codecs. They also support gzip compression.
func NewProcessHandler(svc ProcessHandler, opts ...connect.HandlerOption) (string, http.Handler) {
processMethods := process.File_process_process_proto.Services().ByName("Process").Methods()
processListHandler := connect.NewUnaryHandler(
ProcessListProcedure,
svc.List,
connect.WithSchema(processMethods.ByName("List")),
connect.WithHandlerOptions(opts...),
)
processConnectHandler := connect.NewServerStreamHandler(
ProcessConnectProcedure,
svc.Connect,
connect.WithSchema(processMethods.ByName("Connect")),
connect.WithHandlerOptions(opts...),
)
processStartHandler := connect.NewServerStreamHandler(
ProcessStartProcedure,
svc.Start,
connect.WithSchema(processMethods.ByName("Start")),
connect.WithHandlerOptions(opts...),
)
processUpdateHandler := connect.NewUnaryHandler(
ProcessUpdateProcedure,
svc.Update,
connect.WithSchema(processMethods.ByName("Update")),
connect.WithHandlerOptions(opts...),
)
processStreamInputHandler := connect.NewClientStreamHandler(
ProcessStreamInputProcedure,
svc.StreamInput,
connect.WithSchema(processMethods.ByName("StreamInput")),
connect.WithHandlerOptions(opts...),
)
processSendInputHandler := connect.NewUnaryHandler(
ProcessSendInputProcedure,
svc.SendInput,
connect.WithSchema(processMethods.ByName("SendInput")),
connect.WithHandlerOptions(opts...),
)
processSendSignalHandler := connect.NewUnaryHandler(
ProcessSendSignalProcedure,
svc.SendSignal,
connect.WithSchema(processMethods.ByName("SendSignal")),
connect.WithHandlerOptions(opts...),
)
processCloseStdinHandler := connect.NewUnaryHandler(
ProcessCloseStdinProcedure,
svc.CloseStdin,
connect.WithSchema(processMethods.ByName("CloseStdin")),
connect.WithHandlerOptions(opts...),
)
return "/process.Process/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case ProcessListProcedure:
processListHandler.ServeHTTP(w, r)
case ProcessConnectProcedure:
processConnectHandler.ServeHTTP(w, r)
case ProcessStartProcedure:
processStartHandler.ServeHTTP(w, r)
case ProcessUpdateProcedure:
processUpdateHandler.ServeHTTP(w, r)
case ProcessStreamInputProcedure:
processStreamInputHandler.ServeHTTP(w, r)
case ProcessSendInputProcedure:
processSendInputHandler.ServeHTTP(w, r)
case ProcessSendSignalProcedure:
processSendSignalHandler.ServeHTTP(w, r)
case ProcessCloseStdinProcedure:
processCloseStdinHandler.ServeHTTP(w, r)
default:
http.NotFound(w, r)
}
})
}
// UnimplementedProcessHandler returns CodeUnimplemented from all methods.
type UnimplementedProcessHandler struct{}
func (UnimplementedProcessHandler) List(context.Context, *connect.Request[process.ListRequest]) (*connect.Response[process.ListResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.List is not implemented"))
}
func (UnimplementedProcessHandler) Connect(context.Context, *connect.Request[process.ConnectRequest], *connect.ServerStream[process.ConnectResponse]) error {
return connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.Connect is not implemented"))
}
func (UnimplementedProcessHandler) Start(context.Context, *connect.Request[process.StartRequest], *connect.ServerStream[process.StartResponse]) error {
return connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.Start is not implemented"))
}
func (UnimplementedProcessHandler) Update(context.Context, *connect.Request[process.UpdateRequest]) (*connect.Response[process.UpdateResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.Update is not implemented"))
}
func (UnimplementedProcessHandler) StreamInput(context.Context, *connect.ClientStream[process.StreamInputRequest]) (*connect.Response[process.StreamInputResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.StreamInput is not implemented"))
}
func (UnimplementedProcessHandler) SendInput(context.Context, *connect.Request[process.SendInputRequest]) (*connect.Response[process.SendInputResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.SendInput is not implemented"))
}
func (UnimplementedProcessHandler) SendSignal(context.Context, *connect.Request[process.SendSignalRequest]) (*connect.Response[process.SendSignalResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.SendSignal is not implemented"))
}
func (UnimplementedProcessHandler) CloseStdin(context.Context, *connect.Request[process.CloseStdinRequest]) (*connect.Response[process.CloseStdinResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.CloseStdin is not implemented"))
}

View File

@ -0,0 +1,339 @@
// SPDX-License-Identifier: Apache-2.0
// Code generated by protoc-gen-connect-go. DO NOT EDIT.
//
// Source: filesystem.proto
package specconnect
import (
connect "connectrpc.com/connect"
context "context"
errors "errors"
spec "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec"
http "net/http"
strings "strings"
)
// This is a compile-time assertion to ensure that this generated file and the connect package are
// compatible. If you get a compiler error that this constant is not defined, this code was
// generated with a version of connect newer than the one compiled into your binary. You can fix the
// problem by either regenerating this code with an older version of connect or updating the connect
// version compiled into your binary.
const _ = connect.IsAtLeastVersion1_13_0
const (
// FilesystemName is the fully-qualified name of the Filesystem service.
FilesystemName = "filesystem.Filesystem"
)
// These constants are the fully-qualified names of the RPCs defined in this package. They're
// exposed at runtime as Spec.Procedure and as the final two segments of the HTTP route.
//
// Note that these are different from the fully-qualified method names used by
// google.golang.org/protobuf/reflect/protoreflect. To convert from these constants to
// reflection-formatted method names, remove the leading slash and convert the remaining slash to a
// period.
const (
// FilesystemStatProcedure is the fully-qualified name of the Filesystem's Stat RPC.
FilesystemStatProcedure = "/filesystem.Filesystem/Stat"
// FilesystemMakeDirProcedure is the fully-qualified name of the Filesystem's MakeDir RPC.
FilesystemMakeDirProcedure = "/filesystem.Filesystem/MakeDir"
// FilesystemMoveProcedure is the fully-qualified name of the Filesystem's Move RPC.
FilesystemMoveProcedure = "/filesystem.Filesystem/Move"
// FilesystemListDirProcedure is the fully-qualified name of the Filesystem's ListDir RPC.
FilesystemListDirProcedure = "/filesystem.Filesystem/ListDir"
// FilesystemRemoveProcedure is the fully-qualified name of the Filesystem's Remove RPC.
FilesystemRemoveProcedure = "/filesystem.Filesystem/Remove"
// FilesystemWatchDirProcedure is the fully-qualified name of the Filesystem's WatchDir RPC.
FilesystemWatchDirProcedure = "/filesystem.Filesystem/WatchDir"
// FilesystemCreateWatcherProcedure is the fully-qualified name of the Filesystem's CreateWatcher
// RPC.
FilesystemCreateWatcherProcedure = "/filesystem.Filesystem/CreateWatcher"
// FilesystemGetWatcherEventsProcedure is the fully-qualified name of the Filesystem's
// GetWatcherEvents RPC.
FilesystemGetWatcherEventsProcedure = "/filesystem.Filesystem/GetWatcherEvents"
// FilesystemRemoveWatcherProcedure is the fully-qualified name of the Filesystem's RemoveWatcher
// RPC.
FilesystemRemoveWatcherProcedure = "/filesystem.Filesystem/RemoveWatcher"
)
// FilesystemClient is a client for the filesystem.Filesystem service.
type FilesystemClient interface {
Stat(context.Context, *connect.Request[spec.StatRequest]) (*connect.Response[spec.StatResponse], error)
MakeDir(context.Context, *connect.Request[spec.MakeDirRequest]) (*connect.Response[spec.MakeDirResponse], error)
Move(context.Context, *connect.Request[spec.MoveRequest]) (*connect.Response[spec.MoveResponse], error)
ListDir(context.Context, *connect.Request[spec.ListDirRequest]) (*connect.Response[spec.ListDirResponse], error)
Remove(context.Context, *connect.Request[spec.RemoveRequest]) (*connect.Response[spec.RemoveResponse], error)
WatchDir(context.Context, *connect.Request[spec.WatchDirRequest]) (*connect.ServerStreamForClient[spec.WatchDirResponse], error)
// Non-streaming versions of WatchDir
CreateWatcher(context.Context, *connect.Request[spec.CreateWatcherRequest]) (*connect.Response[spec.CreateWatcherResponse], error)
GetWatcherEvents(context.Context, *connect.Request[spec.GetWatcherEventsRequest]) (*connect.Response[spec.GetWatcherEventsResponse], error)
RemoveWatcher(context.Context, *connect.Request[spec.RemoveWatcherRequest]) (*connect.Response[spec.RemoveWatcherResponse], error)
}
// NewFilesystemClient constructs a client for the filesystem.Filesystem service. By default, it
// uses the Connect protocol with the binary Protobuf Codec, asks for gzipped responses, and sends
// uncompressed requests. To use the gRPC or gRPC-Web protocols, supply the connect.WithGRPC() or
// connect.WithGRPCWeb() options.
//
// The URL supplied here should be the base URL for the Connect or gRPC server (for example,
// http://api.acme.com or https://acme.com/grpc).
func NewFilesystemClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) FilesystemClient {
baseURL = strings.TrimRight(baseURL, "/")
filesystemMethods := spec.File_filesystem_proto.Services().ByName("Filesystem").Methods()
return &filesystemClient{
stat: connect.NewClient[spec.StatRequest, spec.StatResponse](
httpClient,
baseURL+FilesystemStatProcedure,
connect.WithSchema(filesystemMethods.ByName("Stat")),
connect.WithClientOptions(opts...),
),
makeDir: connect.NewClient[spec.MakeDirRequest, spec.MakeDirResponse](
httpClient,
baseURL+FilesystemMakeDirProcedure,
connect.WithSchema(filesystemMethods.ByName("MakeDir")),
connect.WithClientOptions(opts...),
),
move: connect.NewClient[spec.MoveRequest, spec.MoveResponse](
httpClient,
baseURL+FilesystemMoveProcedure,
connect.WithSchema(filesystemMethods.ByName("Move")),
connect.WithClientOptions(opts...),
),
listDir: connect.NewClient[spec.ListDirRequest, spec.ListDirResponse](
httpClient,
baseURL+FilesystemListDirProcedure,
connect.WithSchema(filesystemMethods.ByName("ListDir")),
connect.WithClientOptions(opts...),
),
remove: connect.NewClient[spec.RemoveRequest, spec.RemoveResponse](
httpClient,
baseURL+FilesystemRemoveProcedure,
connect.WithSchema(filesystemMethods.ByName("Remove")),
connect.WithClientOptions(opts...),
),
watchDir: connect.NewClient[spec.WatchDirRequest, spec.WatchDirResponse](
httpClient,
baseURL+FilesystemWatchDirProcedure,
connect.WithSchema(filesystemMethods.ByName("WatchDir")),
connect.WithClientOptions(opts...),
),
createWatcher: connect.NewClient[spec.CreateWatcherRequest, spec.CreateWatcherResponse](
httpClient,
baseURL+FilesystemCreateWatcherProcedure,
connect.WithSchema(filesystemMethods.ByName("CreateWatcher")),
connect.WithClientOptions(opts...),
),
getWatcherEvents: connect.NewClient[spec.GetWatcherEventsRequest, spec.GetWatcherEventsResponse](
httpClient,
baseURL+FilesystemGetWatcherEventsProcedure,
connect.WithSchema(filesystemMethods.ByName("GetWatcherEvents")),
connect.WithClientOptions(opts...),
),
removeWatcher: connect.NewClient[spec.RemoveWatcherRequest, spec.RemoveWatcherResponse](
httpClient,
baseURL+FilesystemRemoveWatcherProcedure,
connect.WithSchema(filesystemMethods.ByName("RemoveWatcher")),
connect.WithClientOptions(opts...),
),
}
}
// filesystemClient implements FilesystemClient.
type filesystemClient struct {
stat *connect.Client[spec.StatRequest, spec.StatResponse]
makeDir *connect.Client[spec.MakeDirRequest, spec.MakeDirResponse]
move *connect.Client[spec.MoveRequest, spec.MoveResponse]
listDir *connect.Client[spec.ListDirRequest, spec.ListDirResponse]
remove *connect.Client[spec.RemoveRequest, spec.RemoveResponse]
watchDir *connect.Client[spec.WatchDirRequest, spec.WatchDirResponse]
createWatcher *connect.Client[spec.CreateWatcherRequest, spec.CreateWatcherResponse]
getWatcherEvents *connect.Client[spec.GetWatcherEventsRequest, spec.GetWatcherEventsResponse]
removeWatcher *connect.Client[spec.RemoveWatcherRequest, spec.RemoveWatcherResponse]
}
// Stat calls filesystem.Filesystem.Stat.
func (c *filesystemClient) Stat(ctx context.Context, req *connect.Request[spec.StatRequest]) (*connect.Response[spec.StatResponse], error) {
return c.stat.CallUnary(ctx, req)
}
// MakeDir calls filesystem.Filesystem.MakeDir.
func (c *filesystemClient) MakeDir(ctx context.Context, req *connect.Request[spec.MakeDirRequest]) (*connect.Response[spec.MakeDirResponse], error) {
return c.makeDir.CallUnary(ctx, req)
}
// Move calls filesystem.Filesystem.Move.
func (c *filesystemClient) Move(ctx context.Context, req *connect.Request[spec.MoveRequest]) (*connect.Response[spec.MoveResponse], error) {
return c.move.CallUnary(ctx, req)
}
// ListDir calls filesystem.Filesystem.ListDir.
func (c *filesystemClient) ListDir(ctx context.Context, req *connect.Request[spec.ListDirRequest]) (*connect.Response[spec.ListDirResponse], error) {
return c.listDir.CallUnary(ctx, req)
}
// Remove calls filesystem.Filesystem.Remove.
func (c *filesystemClient) Remove(ctx context.Context, req *connect.Request[spec.RemoveRequest]) (*connect.Response[spec.RemoveResponse], error) {
return c.remove.CallUnary(ctx, req)
}
// WatchDir calls filesystem.Filesystem.WatchDir.
func (c *filesystemClient) WatchDir(ctx context.Context, req *connect.Request[spec.WatchDirRequest]) (*connect.ServerStreamForClient[spec.WatchDirResponse], error) {
return c.watchDir.CallServerStream(ctx, req)
}
// CreateWatcher calls filesystem.Filesystem.CreateWatcher.
func (c *filesystemClient) CreateWatcher(ctx context.Context, req *connect.Request[spec.CreateWatcherRequest]) (*connect.Response[spec.CreateWatcherResponse], error) {
return c.createWatcher.CallUnary(ctx, req)
}
// GetWatcherEvents calls filesystem.Filesystem.GetWatcherEvents.
func (c *filesystemClient) GetWatcherEvents(ctx context.Context, req *connect.Request[spec.GetWatcherEventsRequest]) (*connect.Response[spec.GetWatcherEventsResponse], error) {
return c.getWatcherEvents.CallUnary(ctx, req)
}
// RemoveWatcher calls filesystem.Filesystem.RemoveWatcher.
func (c *filesystemClient) RemoveWatcher(ctx context.Context, req *connect.Request[spec.RemoveWatcherRequest]) (*connect.Response[spec.RemoveWatcherResponse], error) {
return c.removeWatcher.CallUnary(ctx, req)
}
// FilesystemHandler is an implementation of the filesystem.Filesystem service.
type FilesystemHandler interface {
Stat(context.Context, *connect.Request[spec.StatRequest]) (*connect.Response[spec.StatResponse], error)
MakeDir(context.Context, *connect.Request[spec.MakeDirRequest]) (*connect.Response[spec.MakeDirResponse], error)
Move(context.Context, *connect.Request[spec.MoveRequest]) (*connect.Response[spec.MoveResponse], error)
ListDir(context.Context, *connect.Request[spec.ListDirRequest]) (*connect.Response[spec.ListDirResponse], error)
Remove(context.Context, *connect.Request[spec.RemoveRequest]) (*connect.Response[spec.RemoveResponse], error)
WatchDir(context.Context, *connect.Request[spec.WatchDirRequest], *connect.ServerStream[spec.WatchDirResponse]) error
// Non-streaming versions of WatchDir
CreateWatcher(context.Context, *connect.Request[spec.CreateWatcherRequest]) (*connect.Response[spec.CreateWatcherResponse], error)
GetWatcherEvents(context.Context, *connect.Request[spec.GetWatcherEventsRequest]) (*connect.Response[spec.GetWatcherEventsResponse], error)
RemoveWatcher(context.Context, *connect.Request[spec.RemoveWatcherRequest]) (*connect.Response[spec.RemoveWatcherResponse], error)
}
// NewFilesystemHandler builds an HTTP handler from the service implementation. It returns the path
// on which to mount the handler and the handler itself.
//
// By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf
// and JSON codecs. They also support gzip compression.
func NewFilesystemHandler(svc FilesystemHandler, opts ...connect.HandlerOption) (string, http.Handler) {
filesystemMethods := spec.File_filesystem_proto.Services().ByName("Filesystem").Methods()
filesystemStatHandler := connect.NewUnaryHandler(
FilesystemStatProcedure,
svc.Stat,
connect.WithSchema(filesystemMethods.ByName("Stat")),
connect.WithHandlerOptions(opts...),
)
filesystemMakeDirHandler := connect.NewUnaryHandler(
FilesystemMakeDirProcedure,
svc.MakeDir,
connect.WithSchema(filesystemMethods.ByName("MakeDir")),
connect.WithHandlerOptions(opts...),
)
filesystemMoveHandler := connect.NewUnaryHandler(
FilesystemMoveProcedure,
svc.Move,
connect.WithSchema(filesystemMethods.ByName("Move")),
connect.WithHandlerOptions(opts...),
)
filesystemListDirHandler := connect.NewUnaryHandler(
FilesystemListDirProcedure,
svc.ListDir,
connect.WithSchema(filesystemMethods.ByName("ListDir")),
connect.WithHandlerOptions(opts...),
)
filesystemRemoveHandler := connect.NewUnaryHandler(
FilesystemRemoveProcedure,
svc.Remove,
connect.WithSchema(filesystemMethods.ByName("Remove")),
connect.WithHandlerOptions(opts...),
)
filesystemWatchDirHandler := connect.NewServerStreamHandler(
FilesystemWatchDirProcedure,
svc.WatchDir,
connect.WithSchema(filesystemMethods.ByName("WatchDir")),
connect.WithHandlerOptions(opts...),
)
filesystemCreateWatcherHandler := connect.NewUnaryHandler(
FilesystemCreateWatcherProcedure,
svc.CreateWatcher,
connect.WithSchema(filesystemMethods.ByName("CreateWatcher")),
connect.WithHandlerOptions(opts...),
)
filesystemGetWatcherEventsHandler := connect.NewUnaryHandler(
FilesystemGetWatcherEventsProcedure,
svc.GetWatcherEvents,
connect.WithSchema(filesystemMethods.ByName("GetWatcherEvents")),
connect.WithHandlerOptions(opts...),
)
filesystemRemoveWatcherHandler := connect.NewUnaryHandler(
FilesystemRemoveWatcherProcedure,
svc.RemoveWatcher,
connect.WithSchema(filesystemMethods.ByName("RemoveWatcher")),
connect.WithHandlerOptions(opts...),
)
return "/filesystem.Filesystem/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case FilesystemStatProcedure:
filesystemStatHandler.ServeHTTP(w, r)
case FilesystemMakeDirProcedure:
filesystemMakeDirHandler.ServeHTTP(w, r)
case FilesystemMoveProcedure:
filesystemMoveHandler.ServeHTTP(w, r)
case FilesystemListDirProcedure:
filesystemListDirHandler.ServeHTTP(w, r)
case FilesystemRemoveProcedure:
filesystemRemoveHandler.ServeHTTP(w, r)
case FilesystemWatchDirProcedure:
filesystemWatchDirHandler.ServeHTTP(w, r)
case FilesystemCreateWatcherProcedure:
filesystemCreateWatcherHandler.ServeHTTP(w, r)
case FilesystemGetWatcherEventsProcedure:
filesystemGetWatcherEventsHandler.ServeHTTP(w, r)
case FilesystemRemoveWatcherProcedure:
filesystemRemoveWatcherHandler.ServeHTTP(w, r)
default:
http.NotFound(w, r)
}
})
}
// UnimplementedFilesystemHandler returns CodeUnimplemented from all methods.
type UnimplementedFilesystemHandler struct{}
func (UnimplementedFilesystemHandler) Stat(context.Context, *connect.Request[spec.StatRequest]) (*connect.Response[spec.StatResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.Stat is not implemented"))
}
func (UnimplementedFilesystemHandler) MakeDir(context.Context, *connect.Request[spec.MakeDirRequest]) (*connect.Response[spec.MakeDirResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.MakeDir is not implemented"))
}
func (UnimplementedFilesystemHandler) Move(context.Context, *connect.Request[spec.MoveRequest]) (*connect.Response[spec.MoveResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.Move is not implemented"))
}
func (UnimplementedFilesystemHandler) ListDir(context.Context, *connect.Request[spec.ListDirRequest]) (*connect.Response[spec.ListDirResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.ListDir is not implemented"))
}
func (UnimplementedFilesystemHandler) Remove(context.Context, *connect.Request[spec.RemoveRequest]) (*connect.Response[spec.RemoveResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.Remove is not implemented"))
}
func (UnimplementedFilesystemHandler) WatchDir(context.Context, *connect.Request[spec.WatchDirRequest], *connect.ServerStream[spec.WatchDirResponse]) error {
return connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.WatchDir is not implemented"))
}
func (UnimplementedFilesystemHandler) CreateWatcher(context.Context, *connect.Request[spec.CreateWatcherRequest]) (*connect.Response[spec.CreateWatcherResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.CreateWatcher is not implemented"))
}
func (UnimplementedFilesystemHandler) GetWatcherEvents(context.Context, *connect.Request[spec.GetWatcherEventsRequest]) (*connect.Response[spec.GetWatcherEventsResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.GetWatcherEvents is not implemented"))
}
func (UnimplementedFilesystemHandler) RemoveWatcher(context.Context, *connect.Request[spec.RemoveWatcherRequest]) (*connect.Response[spec.RemoveWatcherResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.RemoveWatcher is not implemented"))
}

View File

@ -0,0 +1,312 @@
// SPDX-License-Identifier: Apache-2.0
// Code generated by protoc-gen-connect-go. DO NOT EDIT.
//
// Source: process.proto
package specconnect
import (
connect "connectrpc.com/connect"
context "context"
errors "errors"
spec "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec"
http "net/http"
strings "strings"
)
// This is a compile-time assertion to ensure that this generated file and the connect package are
// compatible. If you get a compiler error that this constant is not defined, this code was
// generated with a version of connect newer than the one compiled into your binary. You can fix the
// problem by either regenerating this code with an older version of connect or updating the connect
// version compiled into your binary.
const _ = connect.IsAtLeastVersion1_13_0
const (
// ProcessName is the fully-qualified name of the Process service.
ProcessName = "process.Process"
)
// These constants are the fully-qualified names of the RPCs defined in this package. They're
// exposed at runtime as Spec.Procedure and as the final two segments of the HTTP route.
//
// Note that these are different from the fully-qualified method names used by
// google.golang.org/protobuf/reflect/protoreflect. To convert from these constants to
// reflection-formatted method names, remove the leading slash and convert the remaining slash to a
// period.
const (
// ProcessListProcedure is the fully-qualified name of the Process's List RPC.
ProcessListProcedure = "/process.Process/List"
// ProcessConnectProcedure is the fully-qualified name of the Process's Connect RPC.
ProcessConnectProcedure = "/process.Process/Connect"
// ProcessStartProcedure is the fully-qualified name of the Process's Start RPC.
ProcessStartProcedure = "/process.Process/Start"
// ProcessUpdateProcedure is the fully-qualified name of the Process's Update RPC.
ProcessUpdateProcedure = "/process.Process/Update"
// ProcessStreamInputProcedure is the fully-qualified name of the Process's StreamInput RPC.
ProcessStreamInputProcedure = "/process.Process/StreamInput"
// ProcessSendInputProcedure is the fully-qualified name of the Process's SendInput RPC.
ProcessSendInputProcedure = "/process.Process/SendInput"
// ProcessSendSignalProcedure is the fully-qualified name of the Process's SendSignal RPC.
ProcessSendSignalProcedure = "/process.Process/SendSignal"
// ProcessCloseStdinProcedure is the fully-qualified name of the Process's CloseStdin RPC.
ProcessCloseStdinProcedure = "/process.Process/CloseStdin"
)
// ProcessClient is a client for the process.Process service.
type ProcessClient interface {
List(context.Context, *connect.Request[spec.ListRequest]) (*connect.Response[spec.ListResponse], error)
Connect(context.Context, *connect.Request[spec.ConnectRequest]) (*connect.ServerStreamForClient[spec.ConnectResponse], error)
Start(context.Context, *connect.Request[spec.StartRequest]) (*connect.ServerStreamForClient[spec.StartResponse], error)
Update(context.Context, *connect.Request[spec.UpdateRequest]) (*connect.Response[spec.UpdateResponse], error)
// Client input stream ensures ordering of messages
StreamInput(context.Context) *connect.ClientStreamForClient[spec.StreamInputRequest, spec.StreamInputResponse]
SendInput(context.Context, *connect.Request[spec.SendInputRequest]) (*connect.Response[spec.SendInputResponse], error)
SendSignal(context.Context, *connect.Request[spec.SendSignalRequest]) (*connect.Response[spec.SendSignalResponse], error)
// Close stdin to signal EOF to the process.
// Only works for non-PTY processes. For PTY, send Ctrl+D (0x04) instead.
CloseStdin(context.Context, *connect.Request[spec.CloseStdinRequest]) (*connect.Response[spec.CloseStdinResponse], error)
}
// NewProcessClient constructs a client for the process.Process service. By default, it uses the
// Connect protocol with the binary Protobuf Codec, asks for gzipped responses, and sends
// uncompressed requests. To use the gRPC or gRPC-Web protocols, supply the connect.WithGRPC() or
// connect.WithGRPCWeb() options.
//
// The URL supplied here should be the base URL for the Connect or gRPC server (for example,
// http://api.acme.com or https://acme.com/grpc).
func NewProcessClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) ProcessClient {
baseURL = strings.TrimRight(baseURL, "/")
processMethods := spec.File_process_proto.Services().ByName("Process").Methods()
return &processClient{
list: connect.NewClient[spec.ListRequest, spec.ListResponse](
httpClient,
baseURL+ProcessListProcedure,
connect.WithSchema(processMethods.ByName("List")),
connect.WithClientOptions(opts...),
),
connect: connect.NewClient[spec.ConnectRequest, spec.ConnectResponse](
httpClient,
baseURL+ProcessConnectProcedure,
connect.WithSchema(processMethods.ByName("Connect")),
connect.WithClientOptions(opts...),
),
start: connect.NewClient[spec.StartRequest, spec.StartResponse](
httpClient,
baseURL+ProcessStartProcedure,
connect.WithSchema(processMethods.ByName("Start")),
connect.WithClientOptions(opts...),
),
update: connect.NewClient[spec.UpdateRequest, spec.UpdateResponse](
httpClient,
baseURL+ProcessUpdateProcedure,
connect.WithSchema(processMethods.ByName("Update")),
connect.WithClientOptions(opts...),
),
streamInput: connect.NewClient[spec.StreamInputRequest, spec.StreamInputResponse](
httpClient,
baseURL+ProcessStreamInputProcedure,
connect.WithSchema(processMethods.ByName("StreamInput")),
connect.WithClientOptions(opts...),
),
sendInput: connect.NewClient[spec.SendInputRequest, spec.SendInputResponse](
httpClient,
baseURL+ProcessSendInputProcedure,
connect.WithSchema(processMethods.ByName("SendInput")),
connect.WithClientOptions(opts...),
),
sendSignal: connect.NewClient[spec.SendSignalRequest, spec.SendSignalResponse](
httpClient,
baseURL+ProcessSendSignalProcedure,
connect.WithSchema(processMethods.ByName("SendSignal")),
connect.WithClientOptions(opts...),
),
closeStdin: connect.NewClient[spec.CloseStdinRequest, spec.CloseStdinResponse](
httpClient,
baseURL+ProcessCloseStdinProcedure,
connect.WithSchema(processMethods.ByName("CloseStdin")),
connect.WithClientOptions(opts...),
),
}
}
// processClient implements ProcessClient.
type processClient struct {
list *connect.Client[spec.ListRequest, spec.ListResponse]
connect *connect.Client[spec.ConnectRequest, spec.ConnectResponse]
start *connect.Client[spec.StartRequest, spec.StartResponse]
update *connect.Client[spec.UpdateRequest, spec.UpdateResponse]
streamInput *connect.Client[spec.StreamInputRequest, spec.StreamInputResponse]
sendInput *connect.Client[spec.SendInputRequest, spec.SendInputResponse]
sendSignal *connect.Client[spec.SendSignalRequest, spec.SendSignalResponse]
closeStdin *connect.Client[spec.CloseStdinRequest, spec.CloseStdinResponse]
}
// List calls process.Process.List.
func (c *processClient) List(ctx context.Context, req *connect.Request[spec.ListRequest]) (*connect.Response[spec.ListResponse], error) {
return c.list.CallUnary(ctx, req)
}
// Connect calls process.Process.Connect.
func (c *processClient) Connect(ctx context.Context, req *connect.Request[spec.ConnectRequest]) (*connect.ServerStreamForClient[spec.ConnectResponse], error) {
return c.connect.CallServerStream(ctx, req)
}
// Start calls process.Process.Start.
func (c *processClient) Start(ctx context.Context, req *connect.Request[spec.StartRequest]) (*connect.ServerStreamForClient[spec.StartResponse], error) {
return c.start.CallServerStream(ctx, req)
}
// Update calls process.Process.Update.
func (c *processClient) Update(ctx context.Context, req *connect.Request[spec.UpdateRequest]) (*connect.Response[spec.UpdateResponse], error) {
return c.update.CallUnary(ctx, req)
}
// StreamInput calls process.Process.StreamInput.
func (c *processClient) StreamInput(ctx context.Context) *connect.ClientStreamForClient[spec.StreamInputRequest, spec.StreamInputResponse] {
return c.streamInput.CallClientStream(ctx)
}
// SendInput calls process.Process.SendInput.
func (c *processClient) SendInput(ctx context.Context, req *connect.Request[spec.SendInputRequest]) (*connect.Response[spec.SendInputResponse], error) {
return c.sendInput.CallUnary(ctx, req)
}
// SendSignal calls process.Process.SendSignal.
func (c *processClient) SendSignal(ctx context.Context, req *connect.Request[spec.SendSignalRequest]) (*connect.Response[spec.SendSignalResponse], error) {
return c.sendSignal.CallUnary(ctx, req)
}
// CloseStdin calls process.Process.CloseStdin.
func (c *processClient) CloseStdin(ctx context.Context, req *connect.Request[spec.CloseStdinRequest]) (*connect.Response[spec.CloseStdinResponse], error) {
return c.closeStdin.CallUnary(ctx, req)
}
// ProcessHandler is an implementation of the process.Process service.
type ProcessHandler interface {
List(context.Context, *connect.Request[spec.ListRequest]) (*connect.Response[spec.ListResponse], error)
Connect(context.Context, *connect.Request[spec.ConnectRequest], *connect.ServerStream[spec.ConnectResponse]) error
Start(context.Context, *connect.Request[spec.StartRequest], *connect.ServerStream[spec.StartResponse]) error
Update(context.Context, *connect.Request[spec.UpdateRequest]) (*connect.Response[spec.UpdateResponse], error)
// Client input stream ensures ordering of messages
StreamInput(context.Context, *connect.ClientStream[spec.StreamInputRequest]) (*connect.Response[spec.StreamInputResponse], error)
SendInput(context.Context, *connect.Request[spec.SendInputRequest]) (*connect.Response[spec.SendInputResponse], error)
SendSignal(context.Context, *connect.Request[spec.SendSignalRequest]) (*connect.Response[spec.SendSignalResponse], error)
// Close stdin to signal EOF to the process.
// Only works for non-PTY processes. For PTY, send Ctrl+D (0x04) instead.
CloseStdin(context.Context, *connect.Request[spec.CloseStdinRequest]) (*connect.Response[spec.CloseStdinResponse], error)
}
// NewProcessHandler builds an HTTP handler from the service implementation. It returns the path on
// which to mount the handler and the handler itself.
//
// By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf
// and JSON codecs. They also support gzip compression.
func NewProcessHandler(svc ProcessHandler, opts ...connect.HandlerOption) (string, http.Handler) {
processMethods := spec.File_process_proto.Services().ByName("Process").Methods()
processListHandler := connect.NewUnaryHandler(
ProcessListProcedure,
svc.List,
connect.WithSchema(processMethods.ByName("List")),
connect.WithHandlerOptions(opts...),
)
processConnectHandler := connect.NewServerStreamHandler(
ProcessConnectProcedure,
svc.Connect,
connect.WithSchema(processMethods.ByName("Connect")),
connect.WithHandlerOptions(opts...),
)
processStartHandler := connect.NewServerStreamHandler(
ProcessStartProcedure,
svc.Start,
connect.WithSchema(processMethods.ByName("Start")),
connect.WithHandlerOptions(opts...),
)
processUpdateHandler := connect.NewUnaryHandler(
ProcessUpdateProcedure,
svc.Update,
connect.WithSchema(processMethods.ByName("Update")),
connect.WithHandlerOptions(opts...),
)
processStreamInputHandler := connect.NewClientStreamHandler(
ProcessStreamInputProcedure,
svc.StreamInput,
connect.WithSchema(processMethods.ByName("StreamInput")),
connect.WithHandlerOptions(opts...),
)
processSendInputHandler := connect.NewUnaryHandler(
ProcessSendInputProcedure,
svc.SendInput,
connect.WithSchema(processMethods.ByName("SendInput")),
connect.WithHandlerOptions(opts...),
)
processSendSignalHandler := connect.NewUnaryHandler(
ProcessSendSignalProcedure,
svc.SendSignal,
connect.WithSchema(processMethods.ByName("SendSignal")),
connect.WithHandlerOptions(opts...),
)
processCloseStdinHandler := connect.NewUnaryHandler(
ProcessCloseStdinProcedure,
svc.CloseStdin,
connect.WithSchema(processMethods.ByName("CloseStdin")),
connect.WithHandlerOptions(opts...),
)
return "/process.Process/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case ProcessListProcedure:
processListHandler.ServeHTTP(w, r)
case ProcessConnectProcedure:
processConnectHandler.ServeHTTP(w, r)
case ProcessStartProcedure:
processStartHandler.ServeHTTP(w, r)
case ProcessUpdateProcedure:
processUpdateHandler.ServeHTTP(w, r)
case ProcessStreamInputProcedure:
processStreamInputHandler.ServeHTTP(w, r)
case ProcessSendInputProcedure:
processSendInputHandler.ServeHTTP(w, r)
case ProcessSendSignalProcedure:
processSendSignalHandler.ServeHTTP(w, r)
case ProcessCloseStdinProcedure:
processCloseStdinHandler.ServeHTTP(w, r)
default:
http.NotFound(w, r)
}
})
}
// UnimplementedProcessHandler returns CodeUnimplemented from all methods.
type UnimplementedProcessHandler struct{}
func (UnimplementedProcessHandler) List(context.Context, *connect.Request[spec.ListRequest]) (*connect.Response[spec.ListResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.List is not implemented"))
}
func (UnimplementedProcessHandler) Connect(context.Context, *connect.Request[spec.ConnectRequest], *connect.ServerStream[spec.ConnectResponse]) error {
return connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.Connect is not implemented"))
}
func (UnimplementedProcessHandler) Start(context.Context, *connect.Request[spec.StartRequest], *connect.ServerStream[spec.StartResponse]) error {
return connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.Start is not implemented"))
}
func (UnimplementedProcessHandler) Update(context.Context, *connect.Request[spec.UpdateRequest]) (*connect.Response[spec.UpdateResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.Update is not implemented"))
}
func (UnimplementedProcessHandler) StreamInput(context.Context, *connect.ClientStream[spec.StreamInputRequest]) (*connect.Response[spec.StreamInputResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.StreamInput is not implemented"))
}
func (UnimplementedProcessHandler) SendInput(context.Context, *connect.Request[spec.SendInputRequest]) (*connect.Response[spec.SendInputResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.SendInput is not implemented"))
}
func (UnimplementedProcessHandler) SendSignal(context.Context, *connect.Request[spec.SendSignalRequest]) (*connect.Response[spec.SendSignalResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.SendSignal is not implemented"))
}
func (UnimplementedProcessHandler) CloseStdin(context.Context, *connect.Request[spec.CloseStdinRequest]) (*connect.Response[spec.CloseStdinResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.CloseStdin is not implemented"))
}

View File

@ -0,0 +1,110 @@
// SPDX-License-Identifier: Apache-2.0
package filesystem
import (
"os"
"path/filepath"
"syscall"
"time"
)
func GetEntryFromPath(path string) (EntryInfo, error) {
fileInfo, err := os.Lstat(path)
if err != nil {
return EntryInfo{}, err
}
return GetEntryInfo(path, fileInfo), nil
}
func GetEntryInfo(path string, fileInfo os.FileInfo) EntryInfo {
fileMode := fileInfo.Mode()
var symlinkTarget *string
if fileMode&os.ModeSymlink != 0 {
// If we can't resolve the symlink target, we won't set the target
target := followSymlink(path)
symlinkTarget = &target
}
var entryType FileType
var mode os.FileMode
if symlinkTarget == nil {
entryType = getEntryType(fileMode)
mode = fileMode.Perm()
} else {
// If it's a symlink, we need to determine the type of the target
targetInfo, err := os.Stat(*symlinkTarget)
if err != nil {
entryType = UnknownFileType
} else {
entryType = getEntryType(targetInfo.Mode())
mode = targetInfo.Mode().Perm()
}
}
entry := EntryInfo{
Name: fileInfo.Name(),
Path: path,
Type: entryType,
Size: fileInfo.Size(),
Mode: mode,
Permissions: fileMode.String(),
ModifiedTime: fileInfo.ModTime(),
SymlinkTarget: symlinkTarget,
}
if base := getBase(fileInfo.Sys()); base != nil {
entry.AccessedTime = toTimestamp(base.Atim)
entry.CreatedTime = toTimestamp(base.Ctim)
entry.ModifiedTime = toTimestamp(base.Mtim)
entry.UID = base.Uid
entry.GID = base.Gid
} else if !fileInfo.ModTime().IsZero() {
entry.ModifiedTime = fileInfo.ModTime()
}
return entry
}
// getEntryType determines the type of file entry based on its mode and path.
// If the file is a symlink, it follows the symlink to determine the actual type.
func getEntryType(mode os.FileMode) FileType {
switch {
case mode.IsRegular():
return FileFileType
case mode.IsDir():
return DirectoryFileType
case mode&os.ModeSymlink == os.ModeSymlink:
return SymlinkFileType
default:
return UnknownFileType
}
}
// followSymlink resolves a symbolic link to its target path.
func followSymlink(path string) string {
// Resolve symlinks
resolvedPath, err := filepath.EvalSymlinks(path)
if err != nil {
return path
}
return resolvedPath
}
func toTimestamp(spec syscall.Timespec) time.Time {
if spec.Sec == 0 && spec.Nsec == 0 {
return time.Time{}
}
return time.Unix(spec.Sec, spec.Nsec)
}
func getBase(sys any) *syscall.Stat_t {
st, _ := sys.(*syscall.Stat_t)
return st
}

View File

@ -0,0 +1,266 @@
// SPDX-License-Identifier: Apache-2.0
package filesystem
import (
"os"
"os/user"
"path/filepath"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetEntryType(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
// Create test files
regularFile := filepath.Join(tempDir, "regular.txt")
require.NoError(t, os.WriteFile(regularFile, []byte("test content"), 0o644))
testDir := filepath.Join(tempDir, "testdir")
require.NoError(t, os.MkdirAll(testDir, 0o755))
symlink := filepath.Join(tempDir, "symlink")
require.NoError(t, os.Symlink(regularFile, symlink))
tests := []struct {
name string
path string
expected FileType
}{
{
name: "regular file",
path: regularFile,
expected: FileFileType,
},
{
name: "directory",
path: testDir,
expected: DirectoryFileType,
},
{
name: "symlink to file",
path: symlink,
expected: SymlinkFileType,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
info, err := os.Lstat(tt.path)
require.NoError(t, err)
result := getEntryType(info.Mode())
assert.Equal(t, tt.expected, result)
})
}
}
func TestEntryInfoFromFileInfo_SymlinkChain(t *testing.T) {
t.Parallel()
// Base temporary directory. On macOS this lives under /var/folders/…
// which itself is a symlink to /private/var/folders/….
tempDir := t.TempDir()
// Create final target
target := filepath.Join(tempDir, "target")
require.NoError(t, os.MkdirAll(target, 0o755))
// Create a chain: link1 → link2 → target
link2 := filepath.Join(tempDir, "link2")
require.NoError(t, os.Symlink(target, link2))
link1 := filepath.Join(tempDir, "link1")
require.NoError(t, os.Symlink(link2, link1))
// run the test
result, err := GetEntryFromPath(link1)
require.NoError(t, err)
// verify the results
assert.Equal(t, "link1", result.Name)
assert.Equal(t, link1, result.Path)
assert.Equal(t, DirectoryFileType, result.Type) // Should resolve to final target type
assert.Contains(t, result.Permissions, "L")
// Canonicalize the expected target path to handle macOS symlink indirections
expectedTarget, err := filepath.EvalSymlinks(link1)
require.NoError(t, err)
assert.Equal(t, expectedTarget, *result.SymlinkTarget)
}
func TestEntryInfoFromFileInfo_DifferentPermissions(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
testCases := []struct {
name string
permissions os.FileMode
expectedMode os.FileMode
expectedString string
}{
{"read-only", 0o444, 0o444, "-r--r--r--"},
{"executable", 0o755, 0o755, "-rwxr-xr-x"},
{"write-only", 0o200, 0o200, "--w-------"},
{"no permissions", 0o000, 0o000, "----------"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
testFile := filepath.Join(tempDir, tc.name+".txt")
require.NoError(t, os.WriteFile(testFile, []byte("test"), tc.permissions))
result, err := GetEntryFromPath(testFile)
require.NoError(t, err)
assert.Equal(t, tc.expectedMode, result.Mode)
assert.Equal(t, tc.expectedString, result.Permissions)
})
}
}
func TestEntryInfoFromFileInfo_EmptyFile(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
emptyFile := filepath.Join(tempDir, "empty.txt")
require.NoError(t, os.WriteFile(emptyFile, []byte{}, 0o600))
result, err := GetEntryFromPath(emptyFile)
require.NoError(t, err)
assert.Equal(t, "empty.txt", result.Name)
assert.Equal(t, int64(0), result.Size)
assert.Equal(t, os.FileMode(0o600), result.Mode)
assert.Equal(t, FileFileType, result.Type)
}
func TestEntryInfoFromFileInfo_CyclicSymlink(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
// Create cyclic symlink
cyclicSymlink := filepath.Join(tempDir, "cyclic")
require.NoError(t, os.Symlink(cyclicSymlink, cyclicSymlink))
result, err := GetEntryFromPath(cyclicSymlink)
require.NoError(t, err)
assert.Equal(t, "cyclic", result.Name)
assert.Equal(t, cyclicSymlink, result.Path)
assert.Equal(t, UnknownFileType, result.Type)
assert.Contains(t, result.Permissions, "L")
}
func TestEntryInfoFromFileInfo_BrokenSymlink(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
// Create broken symlink
brokenSymlink := filepath.Join(tempDir, "broken")
require.NoError(t, os.Symlink("/nonexistent", brokenSymlink))
result, err := GetEntryFromPath(brokenSymlink)
require.NoError(t, err)
assert.Equal(t, "broken", result.Name)
assert.Equal(t, brokenSymlink, result.Path)
assert.Equal(t, UnknownFileType, result.Type)
assert.Contains(t, result.Permissions, "L")
// SymlinkTarget might be empty if followSymlink fails
}
func TestEntryInfoFromFileInfo(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
// Create a regular file with known content and permissions
testFile := filepath.Join(tempDir, "test.txt")
testContent := []byte("Hello, World!")
require.NoError(t, os.WriteFile(testFile, testContent, 0o644))
// Get current user for ownership comparison
currentUser, err := user.Current()
require.NoError(t, err)
result, err := GetEntryFromPath(testFile)
require.NoError(t, err)
// Basic assertions
assert.Equal(t, "test.txt", result.Name)
assert.Equal(t, testFile, result.Path)
assert.Equal(t, int64(len(testContent)), result.Size)
assert.Equal(t, FileFileType, result.Type)
assert.Equal(t, os.FileMode(0o644), result.Mode)
assert.Contains(t, result.Permissions, "-rw-r--r--")
assert.Equal(t, currentUser.Uid, strconv.Itoa(int(result.UID)))
assert.Equal(t, currentUser.Gid, strconv.Itoa(int(result.GID)))
assert.NotNil(t, result.ModifiedTime)
assert.Empty(t, result.SymlinkTarget)
// Check that modified time is reasonable (within last minute)
modTime := result.ModifiedTime
assert.WithinDuration(t, time.Now(), modTime, time.Minute)
}
func TestEntryInfoFromFileInfo_Directory(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
testDir := filepath.Join(tempDir, "testdir")
require.NoError(t, os.MkdirAll(testDir, 0o755))
result, err := GetEntryFromPath(testDir)
require.NoError(t, err)
assert.Equal(t, "testdir", result.Name)
assert.Equal(t, testDir, result.Path)
assert.Equal(t, DirectoryFileType, result.Type)
assert.Equal(t, os.FileMode(0o755), result.Mode)
assert.Equal(t, "drwxr-xr-x", result.Permissions)
assert.Empty(t, result.SymlinkTarget)
}
func TestEntryInfoFromFileInfo_Symlink(t *testing.T) {
t.Parallel()
// Base temporary directory. On macOS this lives under /var/folders/…
// which itself is a symlink to /private/var/folders/….
tempDir := t.TempDir()
// Create target file
targetFile := filepath.Join(tempDir, "target.txt")
require.NoError(t, os.WriteFile(targetFile, []byte("target content"), 0o644))
// Create symlink
symlinkPath := filepath.Join(tempDir, "symlink")
require.NoError(t, os.Symlink(targetFile, symlinkPath))
// Use Lstat to get symlink info (not the target)
result, err := GetEntryFromPath(symlinkPath)
require.NoError(t, err)
assert.Equal(t, "symlink", result.Name)
assert.Equal(t, symlinkPath, result.Path)
assert.Equal(t, FileFileType, result.Type) // Should resolve to target type
assert.Contains(t, result.Permissions, "L") // Should show as symlink in permissions
// Canonicalize the expected target path to handle macOS /var → /private/var symlink
expectedTarget, err := filepath.EvalSymlinks(symlinkPath)
require.NoError(t, err)
assert.Equal(t, expectedTarget, *result.SymlinkTarget)
}

View File

@ -0,0 +1,32 @@
// SPDX-License-Identifier: Apache-2.0
package filesystem
import (
"os"
"time"
)
type EntryInfo struct {
Name string
Type FileType
Path string
Size int64
Mode os.FileMode
Permissions string
UID uint32
GID uint32
AccessedTime time.Time
CreatedTime time.Time
ModifiedTime time.Time
SymlinkTarget *string
}
type FileType int32
const (
UnknownFileType FileType = 0
FileFileType FileType = 1
DirectoryFileType FileType = 2
SymlinkFileType FileType = 3
)

View File

@ -0,0 +1,166 @@
// SPDX-License-Identifier: Apache-2.0
package id
import (
"errors"
"fmt"
"maps"
"regexp"
"slices"
"strings"
"github.com/dchest/uniuri"
"github.com/google/uuid"
)
var (
caseInsensitiveAlphabet = []byte("abcdefghijklmnopqrstuvwxyz1234567890")
identifierRegex = regexp.MustCompile(`^[a-z0-9-_]+$`)
tagRegex = regexp.MustCompile(`^[a-z0-9-_.]+$`)
sandboxIDRegex = regexp.MustCompile(`^[a-z0-9]+$`)
)
const (
DefaultTag = "default"
TagSeparator = ":"
NamespaceSeparator = "/"
)
func Generate() string {
return uniuri.NewLenChars(uniuri.UUIDLen, caseInsensitiveAlphabet)
}
// ValidateSandboxID checks that a sandbox ID contains only lowercase alphanumeric characters.
func ValidateSandboxID(sandboxID string) error {
if !sandboxIDRegex.MatchString(sandboxID) {
return fmt.Errorf("invalid sandbox ID: %q", sandboxID)
}
return nil
}
func cleanAndValidate(value, name string, re *regexp.Regexp) (string, error) {
cleaned := strings.ToLower(strings.TrimSpace(value))
if !re.MatchString(cleaned) {
return "", fmt.Errorf("invalid %s: %s", name, value)
}
return cleaned, nil
}
func validateTag(tag string) (string, error) {
cleanedTag, err := cleanAndValidate(tag, "tag", tagRegex)
if err != nil {
return "", err
}
// Prevent tags from being a UUID
_, err = uuid.Parse(cleanedTag)
if err == nil {
return "", errors.New("tag cannot be a UUID")
}
return cleanedTag, nil
}
func ValidateAndDeduplicateTags(tags []string) ([]string, error) {
seen := make(map[string]struct{})
for _, tag := range tags {
cleanedTag, err := validateTag(tag)
if err != nil {
return nil, fmt.Errorf("invalid tag '%s': %w", tag, err)
}
seen[cleanedTag] = struct{}{}
}
return slices.Collect(maps.Keys(seen)), nil
}
// SplitIdentifier splits "namespace/alias" into its parts.
// Returns nil namespace for bare aliases, pointer for explicit namespace.
func SplitIdentifier(identifier string) (namespace *string, alias string) {
before, after, found := strings.Cut(identifier, NamespaceSeparator)
if !found {
return nil, before
}
return &before, after
}
// ParseName parses and validates "namespace/alias:tag" or "alias:tag".
// Returns the cleaned identifier (namespace/alias or alias) and optional tag.
// All components are validated and normalized (lowercase, trimmed).
func ParseName(input string) (identifier string, tag *string, err error) {
input = strings.TrimSpace(input)
// Extract raw parts
identifierPart, tagPart, hasTag := strings.Cut(input, TagSeparator)
namespacePart, aliasPart := SplitIdentifier(identifierPart)
// Validate tag
if hasTag {
validated, err := cleanAndValidate(tagPart, "tag", tagRegex)
if err != nil {
return "", nil, err
}
if !strings.EqualFold(validated, DefaultTag) {
tag = &validated
}
}
// Validate namespace
if namespacePart != nil {
validated, err := cleanAndValidate(*namespacePart, "namespace", identifierRegex)
if err != nil {
return "", nil, err
}
namespacePart = &validated
}
// Validate alias
aliasPart, err = cleanAndValidate(aliasPart, "template ID", identifierRegex)
if err != nil {
return "", nil, err
}
// Build identifier
if namespacePart != nil {
identifier = WithNamespace(*namespacePart, aliasPart)
} else {
identifier = aliasPart
}
return identifier, tag, nil
}
// WithTag returns the identifier with the given tag appended (e.g. "templateID:tag").
func WithTag(identifier, tag string) string {
return identifier + TagSeparator + tag
}
// WithNamespace returns identifier with the given namespace prefix.
func WithNamespace(namespace, alias string) string {
return namespace + NamespaceSeparator + alias
}
// ExtractAlias returns just the alias portion from an identifier (namespace/alias or alias).
func ExtractAlias(identifier string) string {
_, alias := SplitIdentifier(identifier)
return alias
}
// ValidateNamespaceMatchesTeam checks if an explicit namespace in the identifier matches the team's slug.
// Returns an error if the namespace doesn't match.
// If the identifier has no explicit namespace, returns nil (valid).
func ValidateNamespaceMatchesTeam(identifier, teamSlug string) error {
namespace, _ := SplitIdentifier(identifier)
if namespace != nil && *namespace != teamSlug {
return fmt.Errorf("namespace '%s' must match your team '%s'", *namespace, teamSlug)
}
return nil
}

View File

@ -0,0 +1,382 @@
// SPDX-License-Identifier: Apache-2.0
package id
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/utils"
)
func TestParseName(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantIdentifier string
wantTag *string
wantErr bool
}{
{
name: "bare alias only",
input: "my-template",
wantIdentifier: "my-template",
wantTag: nil,
},
{
name: "alias with tag",
input: "my-template:v1",
wantIdentifier: "my-template",
wantTag: utils.ToPtr("v1"),
},
{
name: "namespace and alias",
input: "acme/my-template",
wantIdentifier: "acme/my-template",
wantTag: nil,
},
{
name: "namespace, alias and tag",
input: "acme/my-template:v1",
wantIdentifier: "acme/my-template",
wantTag: utils.ToPtr("v1"),
},
{
name: "namespace with hyphens",
input: "my-team/my-template:prod",
wantIdentifier: "my-team/my-template",
wantTag: utils.ToPtr("prod"),
},
{
name: "default tag normalized to nil",
input: "my-template:default",
wantIdentifier: "my-template",
wantTag: nil,
},
{
name: "uppercase converted to lowercase",
input: "MyTemplate:Prod",
wantIdentifier: "mytemplate",
wantTag: utils.ToPtr("prod"),
},
{
name: "whitespace trimmed",
input: " my-template : v1 ",
wantIdentifier: "my-template",
wantTag: utils.ToPtr("v1"),
},
{
name: "invalid - empty namespace",
input: "/my-template",
wantErr: true,
},
{
name: "invalid - empty tag after colon",
input: "my-template:",
wantErr: true,
},
{
name: "invalid - special characters in alias",
input: "my template!",
wantErr: true,
},
{
name: "invalid - special characters in namespace",
input: "my team!/my-template",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
gotIdentifier, gotTag, err := ParseName(tt.input)
if tt.wantErr {
require.Error(t, err, "Expected ParseName() to return error, got")
return
}
require.NoError(t, err, "Expected ParseName() not to return error, got: %v", err)
assert.Equal(t, tt.wantIdentifier, gotIdentifier, "ParseName() identifier = %v, want %v", gotIdentifier, tt.wantIdentifier)
assert.Equal(t, tt.wantTag, gotTag, "ParseName() tag = %v, want %v", utils.Sprintp(gotTag), utils.Sprintp(tt.wantTag))
})
}
}
func TestWithNamespace(t *testing.T) {
t.Parallel()
got := WithNamespace("acme", "my-template")
want := "acme/my-template"
assert.Equal(t, want, got, "WithNamespace() = %q, want %q", got, want)
}
func TestSplitIdentifier(t *testing.T) {
t.Parallel()
tests := []struct {
name string
identifier string
wantNamespace *string
wantAlias string
}{
{
name: "bare alias",
identifier: "my-template",
wantNamespace: nil,
wantAlias: "my-template",
},
{
name: "with namespace",
identifier: "acme/my-template",
wantNamespace: ptrStr("acme"),
wantAlias: "my-template",
},
{
name: "empty namespace prefix",
identifier: "/my-template",
wantNamespace: ptrStr(""),
wantAlias: "my-template",
},
{
name: "multiple slashes - only first split",
identifier: "a/b/c",
wantNamespace: ptrStr("a"),
wantAlias: "b/c",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
gotNamespace, gotAlias := SplitIdentifier(tt.identifier)
if tt.wantNamespace == nil {
assert.Nil(t, gotNamespace)
} else {
require.NotNil(t, gotNamespace)
assert.Equal(t, *tt.wantNamespace, *gotNamespace)
}
assert.Equal(t, tt.wantAlias, gotAlias)
})
}
}
func ptrStr(s string) *string {
return &s
}
func TestValidateAndDeduplicateTags(t *testing.T) {
t.Parallel()
tests := []struct {
name string
tags []string
want []string
wantErr bool
}{
{
name: "single valid tag",
tags: []string{"v1"},
want: []string{"v1"},
wantErr: false,
},
{
name: "multiple unique tags",
tags: []string{"v1", "prod", "latest"},
want: []string{"v1", "prod", "latest"},
wantErr: false,
},
{
name: "duplicate tags deduplicated",
tags: []string{"v1", "V1", "v1"},
want: []string{"v1"},
wantErr: false,
},
{
name: "tags with dots and underscores",
tags: []string{"v1.0", "v1_1"},
want: []string{"v1.0", "v1_1"},
wantErr: false,
},
{
name: "invalid - UUID tag rejected",
tags: []string{"550e8400-e29b-41d4-a716-446655440000"},
wantErr: true,
},
{
name: "invalid - special characters",
tags: []string{"v1!", "v2@"},
wantErr: true,
},
{
name: "empty list returns empty",
tags: []string{},
want: []string{},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := ValidateAndDeduplicateTags(tt.tags)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.ElementsMatch(t, tt.want, got)
})
}
}
func TestValidateSandboxID(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantErr bool
}{
{
name: "canonical sandbox ID",
input: "i1a2b3c4d5e6f7g8h9j0k",
wantErr: false,
},
{
name: "short alphanumeric",
input: "abc123",
wantErr: false,
},
{
name: "all digits",
input: "1234567890",
wantErr: false,
},
{
name: "all lowercase letters",
input: "abcdefghijklmnopqrst",
wantErr: false,
},
{
name: "invalid - empty",
input: "",
wantErr: true,
},
{
name: "invalid - contains colon (Redis separator)",
input: "abc:def",
wantErr: true,
},
{
name: "invalid - contains open brace (Redis hash slot)",
input: "abc{def",
wantErr: true,
},
{
name: "invalid - contains close brace (Redis hash slot)",
input: "abc}def",
wantErr: true,
},
{
name: "invalid - contains newline",
input: "abc\ndef",
wantErr: true,
},
{
name: "invalid - contains space",
input: "abc def",
wantErr: true,
},
{
name: "invalid - contains hyphen",
input: "abc-def",
wantErr: true,
},
{
name: "invalid - contains uppercase",
input: "abcDEF",
wantErr: true,
},
{
name: "invalid - contains slash",
input: "abc/def",
wantErr: true,
},
{
name: "invalid - contains null byte",
input: "abc\x00def",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := ValidateSandboxID(tt.input)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}
func TestValidateNamespaceMatchesTeam(t *testing.T) {
t.Parallel()
tests := []struct {
name string
identifier string
teamSlug string
wantErr bool
}{
{
name: "bare alias - no namespace",
identifier: "my-template",
teamSlug: "acme",
wantErr: false,
},
{
name: "matching namespace",
identifier: "acme/my-template",
teamSlug: "acme",
wantErr: false,
},
{
name: "mismatched namespace",
identifier: "other-team/my-template",
teamSlug: "acme",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := ValidateNamespaceMatchesTeam(tt.identifier, tt.teamSlug)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}

View File

@ -0,0 +1,9 @@
// SPDX-License-Identifier: Apache-2.0
// Modifications by M/S Omukk
package keys
const (
ApiKeyPrefix = "wrn_"
AccessTokenPrefix = "sk_wrn_"
)

View File

@ -0,0 +1,7 @@
// SPDX-License-Identifier: Apache-2.0
package keys
type Hasher interface {
Hash(key []byte) string
}

View File

@ -0,0 +1,27 @@
// SPDX-License-Identifier: Apache-2.0
package keys
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
)
type HMACSha256Hashing struct {
key []byte
}
func NewHMACSHA256Hashing(key []byte) *HMACSha256Hashing {
return &HMACSha256Hashing{key: key}
}
func (h *HMACSha256Hashing) Hash(content []byte) (string, error) {
mac := hmac.New(sha256.New, h.key)
_, err := mac.Write(content)
if err != nil {
return "", err
}
return hex.EncodeToString(mac.Sum(nil)), nil
}

View File

@ -0,0 +1,76 @@
// SPDX-License-Identifier: Apache-2.0
package keys
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"testing"
"github.com/stretchr/testify/require"
)
func TestHMACSha256Hashing_ValidHash(t *testing.T) {
t.Parallel()
key := []byte("test-key")
hasher := NewHMACSHA256Hashing(key)
content := []byte("hello world")
expectedHash := "18c4b268f0bbf8471eda56af3e70b1d4613d734dc538b4940b59931c412a1591"
actualHash, err := hasher.Hash(content)
require.NoError(t, err)
if actualHash != expectedHash {
t.Errorf("expected %s, got %s", expectedHash, actualHash)
}
}
func TestHMACSha256Hashing_EmptyContent(t *testing.T) {
t.Parallel()
key := []byte("test-key")
hasher := NewHMACSHA256Hashing(key)
content := []byte("")
expectedHash := "2711cc23e9ab1b8a9bc0fe991238da92671624a9ebdaf1c1abec06e7e9a14f9b"
actualHash, err := hasher.Hash(content)
require.NoError(t, err)
if actualHash != expectedHash {
t.Errorf("expected %s, got %s", expectedHash, actualHash)
}
}
func TestHMACSha256Hashing_DifferentKey(t *testing.T) {
t.Parallel()
key := []byte("test-key")
hasher := NewHMACSHA256Hashing(key)
differentKeyHasher := NewHMACSHA256Hashing([]byte("different-key"))
content := []byte("hello world")
hashWithOriginalKey, err := hasher.Hash(content)
require.NoError(t, err)
hashWithDifferentKey, err := differentKeyHasher.Hash(content)
require.NoError(t, err)
if hashWithOriginalKey == hashWithDifferentKey {
t.Errorf("hashes with different keys should not match")
}
}
func TestHMACSha256Hashing_IdenticalResult(t *testing.T) {
t.Parallel()
key := []byte("placeholder-hashing-key")
content := []byte("test content for hashing")
mac := hmac.New(sha256.New, key)
mac.Write(content)
expectedResult := hex.EncodeToString(mac.Sum(nil))
hasher := NewHMACSHA256Hashing(key)
actualResult, err := hasher.Hash(content)
require.NoError(t, err)
if actualResult != expectedResult {
t.Errorf("expected %s, got %s", expectedResult, actualResult)
}
}

View File

@ -0,0 +1,101 @@
// SPDX-License-Identifier: Apache-2.0
package keys
import (
"crypto/rand"
"encoding/hex"
"fmt"
"strings"
)
const (
identifierValueSuffixLength = 4
identifierValuePrefixLength = 2
keyLength = 20
)
var hasher Hasher = NewSHA256Hashing()
type Key struct {
PrefixedRawValue string
HashedValue string
Masked MaskedIdentifier
}
type MaskedIdentifier struct {
Prefix string
ValueLength int
MaskedValuePrefix string
MaskedValueSuffix string
}
// MaskKey returns identifier masking properties in accordance to the OpenAPI response spec
func MaskKey(prefix, value string) (MaskedIdentifier, error) {
valueLength := len(value)
suffixOffset := valueLength - identifierValueSuffixLength
prefixOffset := identifierValuePrefixLength
if suffixOffset < 0 {
return MaskedIdentifier{}, fmt.Errorf("mask value length is less than identifier suffix length (%d)", identifierValueSuffixLength)
}
if suffixOffset == 0 {
return MaskedIdentifier{}, fmt.Errorf("mask value length is equal to identifier suffix length (%d), which would expose the entire identifier in the mask", identifierValueSuffixLength)
}
// cap prefixOffset by suffixOffset to prevent overlap with the suffix.
if prefixOffset > suffixOffset {
prefixOffset = suffixOffset
}
maskPrefix := value[:prefixOffset]
maskSuffix := value[suffixOffset:]
maskedIdentifierProperties := MaskedIdentifier{
Prefix: prefix,
ValueLength: valueLength,
MaskedValuePrefix: maskPrefix,
MaskedValueSuffix: maskSuffix,
}
return maskedIdentifierProperties, nil
}
func GenerateKey(prefix string) (Key, error) {
keyBytes := make([]byte, keyLength)
_, err := rand.Read(keyBytes)
if err != nil {
return Key{}, err
}
generatedIdentifier := hex.EncodeToString(keyBytes)
mask, err := MaskKey(prefix, generatedIdentifier)
if err != nil {
return Key{}, err
}
return Key{
PrefixedRawValue: prefix + generatedIdentifier,
HashedValue: hasher.Hash(keyBytes),
Masked: mask,
}, nil
}
func VerifyKey(prefix string, key string) (string, error) {
if !strings.HasPrefix(key, prefix) {
return "", fmt.Errorf("invalid key prefix")
}
keyValue := key[len(prefix):]
keyBytes, err := hex.DecodeString(keyValue)
if err != nil {
return "", fmt.Errorf("invalid key")
}
return hasher.Hash(keyBytes), nil
}

View File

@ -0,0 +1,162 @@
// SPDX-License-Identifier: Apache-2.0
package keys
import (
"fmt"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMaskKey(t *testing.T) {
t.Parallel()
t.Run("succeeds: value longer than suffix length", func(t *testing.T) {
t.Parallel()
masked, err := MaskKey("test_", "1234567890")
require.NoError(t, err)
assert.Equal(t, "test_", masked.Prefix)
assert.Equal(t, "12", masked.MaskedValuePrefix)
assert.Equal(t, "7890", masked.MaskedValueSuffix)
})
t.Run("succeeds: empty prefix, value longer than suffix length", func(t *testing.T) {
t.Parallel()
masked, err := MaskKey("", "1234567890")
require.NoError(t, err)
assert.Empty(t, masked.Prefix)
assert.Equal(t, "12", masked.MaskedValuePrefix)
assert.Equal(t, "7890", masked.MaskedValueSuffix)
})
t.Run("error: value length less than suffix length", func(t *testing.T) {
t.Parallel()
_, err := MaskKey("test", "123")
require.Error(t, err)
assert.EqualError(t, err, fmt.Sprintf("mask value length is less than identifier suffix length (%d)", identifierValueSuffixLength))
})
t.Run("error: value length equals suffix length", func(t *testing.T) {
t.Parallel()
_, err := MaskKey("test", "1234")
require.Error(t, err)
assert.EqualError(t, err, fmt.Sprintf("mask value length is equal to identifier suffix length (%d), which would expose the entire identifier in the mask", identifierValueSuffixLength))
})
}
func TestGenerateKey(t *testing.T) {
t.Parallel()
keyLength := 40
t.Run("succeeds", func(t *testing.T) {
t.Parallel()
key, err := GenerateKey("test_")
require.NoError(t, err)
assert.Regexp(t, "^test_.*", key.PrefixedRawValue)
assert.Equal(t, "test_", key.Masked.Prefix)
assert.Equal(t, keyLength, key.Masked.ValueLength)
assert.Regexp(t, "^[0-9a-f]{"+strconv.Itoa(identifierValuePrefixLength)+"}$", key.Masked.MaskedValuePrefix)
assert.Regexp(t, "^[0-9a-f]{"+strconv.Itoa(identifierValueSuffixLength)+"}$", key.Masked.MaskedValueSuffix)
assert.Regexp(t, "^\\$sha256\\$.*", key.HashedValue)
})
t.Run("no prefix", func(t *testing.T) {
t.Parallel()
key, err := GenerateKey("")
require.NoError(t, err)
assert.Regexp(t, "^[0-9a-f]{"+strconv.Itoa(keyLength)+"}$", key.PrefixedRawValue)
assert.Empty(t, key.Masked.Prefix)
assert.Equal(t, keyLength, key.Masked.ValueLength)
assert.Regexp(t, "^[0-9a-f]{"+strconv.Itoa(identifierValuePrefixLength)+"}$", key.Masked.MaskedValuePrefix)
assert.Regexp(t, "^[0-9a-f]{"+strconv.Itoa(identifierValueSuffixLength)+"}$", key.Masked.MaskedValueSuffix)
assert.Regexp(t, "^\\$sha256\\$.*", key.HashedValue)
})
}
func TestGetMaskedIdentifierProperties(t *testing.T) {
t.Parallel()
type testCase struct {
name string
prefix string
value string
expectedResult MaskedIdentifier
expectedErrString string
}
testCases := []testCase{
// --- ERROR CASES (value's length <= identifierValueSuffixLength) ---
{
name: "error: value length < suffix length (3 vs 4)",
prefix: "pk_",
value: "abc",
expectedResult: MaskedIdentifier{},
expectedErrString: fmt.Sprintf("mask value length is less than identifier suffix length (%d)", identifierValueSuffixLength),
},
{
name: "error: value length == suffix length (4 vs 4)",
prefix: "sk_",
value: "abcd",
expectedResult: MaskedIdentifier{},
expectedErrString: fmt.Sprintf("mask value length is equal to identifier suffix length (%d), which would expose the entire identifier in the mask", identifierValueSuffixLength),
},
{
name: "error: value length < suffix length (0 vs 4, empty value)",
prefix: "err_",
value: "",
expectedResult: MaskedIdentifier{},
expectedErrString: fmt.Sprintf("mask value length is less than identifier suffix length (%d)", identifierValueSuffixLength),
},
// --- SUCCESS CASES (value's length > identifierValueSuffixLength) ---
{
name: "success: value long (10), prefix val len fully used",
prefix: "pk_",
value: "abcdefghij",
expectedResult: MaskedIdentifier{
Prefix: "pk_",
ValueLength: 10,
MaskedValuePrefix: "ab",
MaskedValueSuffix: "ghij",
},
},
{
name: "success: value medium (5), prefix val len truncated by overlap",
prefix: "",
value: "abcde",
expectedResult: MaskedIdentifier{
Prefix: "",
ValueLength: 5,
MaskedValuePrefix: "a",
MaskedValueSuffix: "bcde",
},
},
{
name: "success: value medium (6), prefix val len fits exactly",
prefix: "pk_",
value: "abcdef",
expectedResult: MaskedIdentifier{
Prefix: "pk_",
ValueLength: 6,
MaskedValuePrefix: "ab",
MaskedValueSuffix: "cdef",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result, err := MaskKey(tc.prefix, tc.value)
if tc.expectedErrString != "" {
require.EqualError(t, err, tc.expectedErrString)
assert.Equal(t, tc.expectedResult, result)
} else {
require.NoError(t, err)
assert.Equal(t, tc.expectedResult, result)
}
})
}
}

View File

@ -0,0 +1,32 @@
// SPDX-License-Identifier: Apache-2.0
package keys
import (
"crypto/sha256"
"encoding/base64"
"fmt"
)
type Sha256Hashing struct{}
func NewSHA256Hashing() *Sha256Hashing {
return &Sha256Hashing{}
}
func (h *Sha256Hashing) Hash(key []byte) string {
hashBytes := sha256.Sum256(key)
hash64 := base64.RawStdEncoding.EncodeToString(hashBytes[:])
return fmt.Sprintf(
"$sha256$%s",
hash64,
)
}
func (h *Sha256Hashing) HashWithoutPrefix(key []byte) string {
hashBytes := sha256.Sum256(key)
return base64.RawStdEncoding.EncodeToString(hashBytes[:])
}

View File

@ -0,0 +1,17 @@
// SPDX-License-Identifier: Apache-2.0
package keys
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestSHA256Hashing(t *testing.T) {
t.Parallel()
hasher := NewSHA256Hashing()
hashed := hasher.Hash([]byte("test"))
assert.Regexp(t, "^\\$sha256\\$.*", hashed)
}

View File

@ -0,0 +1,22 @@
// SPDX-License-Identifier: Apache-2.0
package keys
import (
"crypto/sha512"
"encoding/hex"
)
// HashAccessToken computes the SHA-512 hash of an access token.
func HashAccessToken(token string) string {
h := sha512.Sum512([]byte(token))
return hex.EncodeToString(h[:])
}
// HashAccessTokenBytes computes the SHA-512 hash of an access token from bytes.
func HashAccessTokenBytes(token []byte) string {
h := sha512.Sum512(token)
return hex.EncodeToString(h[:])
}

View File

@ -0,0 +1,49 @@
// SPDX-License-Identifier: Apache-2.0
package smap
import (
cmap "github.com/orcaman/concurrent-map/v2"
)
type Map[V any] struct {
m cmap.ConcurrentMap[string, V]
}
func New[V any]() *Map[V] {
return &Map[V]{
m: cmap.New[V](),
}
}
func (m *Map[V]) Remove(key string) {
m.m.Remove(key)
}
func (m *Map[V]) Get(key string) (V, bool) {
return m.m.Get(key)
}
func (m *Map[V]) Insert(key string, value V) {
m.m.Set(key, value)
}
func (m *Map[V]) Upsert(key string, value V, cb cmap.UpsertCb[V]) V {
return m.m.Upsert(key, value, cb)
}
func (m *Map[V]) InsertIfAbsent(key string, value V) bool {
return m.m.SetIfAbsent(key, value)
}
func (m *Map[V]) Items() map[string]V {
return m.m.Items()
}
func (m *Map[V]) RemoveCb(key string, cb func(key string, v V, exists bool) bool) bool {
return m.m.RemoveCb(key, cb)
}
func (m *Map[V]) Count() int {
return m.m.Count()
}

View File

@ -0,0 +1,45 @@
// SPDX-License-Identifier: Apache-2.0
package utils
import "fmt"
func ToPtr[T any](v T) *T {
return &v
}
func FromPtr[T any](s *T) T {
if s == nil {
var zero T
return zero
}
return *s
}
func Sprintp[T any](s *T) string {
if s == nil {
return "<nil>"
}
return fmt.Sprintf("%v", *s)
}
func DerefOrDefault[T any](s *T, defaultValue T) T {
if s == nil {
return defaultValue
}
return *s
}
func CastPtr[S any, T any](s *S, castFunc func(S) T) *T {
if s == nil {
return nil
}
t := castFunc(*s)
return &t
}

View File

@ -0,0 +1,29 @@
// SPDX-License-Identifier: Apache-2.0
package utils
import (
"sync"
)
type AtomicMax struct {
val int64
mu sync.Mutex
}
func NewAtomicMax() *AtomicMax {
return &AtomicMax{}
}
func (a *AtomicMax) SetToGreater(newValue int64) bool {
a.mu.Lock()
defer a.mu.Unlock()
if a.val > newValue {
return false
}
a.val = newValue
return true
}

View File

@ -0,0 +1,78 @@
// SPDX-License-Identifier: Apache-2.0
package utils
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAtomicMax_NewAtomicMax(t *testing.T) {
t.Parallel()
am := NewAtomicMax()
require.NotNil(t, am)
require.Equal(t, int64(0), am.val)
}
func TestAtomicMax_SetToGreater_InitialValue(t *testing.T) {
t.Parallel()
am := NewAtomicMax()
// Should succeed when newValue > current
assert.True(t, am.SetToGreater(10))
assert.Equal(t, int64(10), am.val)
}
func TestAtomicMax_SetToGreater_EqualValue(t *testing.T) {
t.Parallel()
am := NewAtomicMax()
am.val = 10
// Should succeed when newValue > current
assert.True(t, am.SetToGreater(20))
assert.Equal(t, int64(20), am.val)
}
func TestAtomicMax_SetToGreater_GreaterValue(t *testing.T) {
t.Parallel()
am := NewAtomicMax()
am.val = 10
// Should fail when newValue < current, keeping the max value
assert.False(t, am.SetToGreater(5))
assert.Equal(t, int64(10), am.val)
}
func TestAtomicMax_SetToGreater_NegativeValues(t *testing.T) {
t.Parallel()
am := NewAtomicMax()
am.val = -5
assert.True(t, am.SetToGreater(-2))
assert.Equal(t, int64(-2), am.val)
}
func TestAtomicMax_SetToGreater_Concurrent(t *testing.T) {
t.Parallel()
am := NewAtomicMax()
var wg sync.WaitGroup
// Run 100 goroutines trying to update the value concurrently
numGoroutines := 100
wg.Add(numGoroutines)
for i := range numGoroutines {
go func(val int64) {
defer wg.Done()
am.SetToGreater(val)
}(int64(i))
}
wg.Wait()
// The final value should be 99 (the maximum value)
assert.Equal(t, int64(99), am.val)
}

View File

@ -0,0 +1,53 @@
// SPDX-License-Identifier: Apache-2.0
package utils
import "sync"
type Map[K comparable, V any] struct {
m sync.Map
}
func NewMap[K comparable, V any]() *Map[K, V] {
return &Map[K, V]{
m: sync.Map{},
}
}
func (m *Map[K, V]) Delete(key K) {
m.m.Delete(key)
}
func (m *Map[K, V]) Load(key K) (value V, ok bool) {
v, ok := m.m.Load(key)
if !ok {
return value, ok
}
return v.(V), ok
}
func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
v, loaded := m.m.LoadAndDelete(key)
if !loaded {
return value, loaded
}
return v.(V), loaded
}
func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
a, loaded := m.m.LoadOrStore(key, value)
return a.(V), loaded
}
func (m *Map[K, V]) Range(f func(key K, value V) bool) {
m.m.Range(func(key, value any) bool {
return f(key.(K), value.(V))
})
}
func (m *Map[K, V]) Store(key K, value V) {
m.m.Store(key, value)
}

View File

@ -0,0 +1,45 @@
// SPDX-License-Identifier: Apache-2.0
package utils
import (
"errors"
"mime"
"mime/multipart"
)
// CustomPart is a wrapper around multipart.Part that overloads the FileName method
type CustomPart struct {
*multipart.Part
}
// FileNameWithPath returns the filename parameter of the Part's Content-Disposition header.
// This method borrows from the original FileName method implementation but returns the full
// filename without using `filepath.Base`.
func (p *CustomPart) FileNameWithPath() (string, error) {
dispositionParams, err := p.parseContentDisposition()
if err != nil {
return "", err
}
filename, ok := dispositionParams["filename"]
if !ok {
return "", errors.New("filename not found in Content-Disposition header")
}
return filename, nil
}
func (p *CustomPart) parseContentDisposition() (map[string]string, error) {
v := p.Header.Get("Content-Disposition")
_, dispositionParams, err := mime.ParseMediaType(v)
if err != nil {
return nil, err
}
return dispositionParams, nil
}
// NewCustomPart creates a new CustomPart from a multipart.Part
func NewCustomPart(part *multipart.Part) *CustomPart {
return &CustomPart{Part: part}
}

View File

@ -0,0 +1,14 @@
// SPDX-License-Identifier: Apache-2.0
package utils
import "path/filepath"
// FsnotifyPath creates an optionally recursive path for fsnotify/fsnotify internal implementation
func FsnotifyPath(path string, recursive bool) string {
if recursive {
return filepath.Join(path, "...")
}
return path
}

View File

@ -0,0 +1,293 @@
// SPDX-License-Identifier: Apache-2.0
// Modifications by M/S Omukk
package main
import (
"context"
"flag"
"fmt"
"log"
"net/http"
"os"
"path/filepath"
"strconv"
"time"
"connectrpc.com/authn"
connectcors "connectrpc.com/cors"
"github.com/go-chi/chi/v5"
"github.com/rs/cors"
"git.omukk.dev/wrenn/sandbox/envd/internal/api"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
publicport "git.omukk.dev/wrenn/sandbox/envd/internal/port"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/cgroups"
filesystemRpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/filesystem"
processRpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/process"
processSpec "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
const (
// Downstream timeout should be greater than upstream (in orchestrator proxy).
idleTimeout = 640 * time.Second
maxAge = 2 * time.Hour
defaultPort = 49983
portScannerInterval = 1000 * time.Millisecond
// This is the default user used in the container if not specified otherwise.
// It should be always overridden by the user in /init when building the template.
defaultUser = "root"
kilobyte = 1024
megabyte = 1024 * kilobyte
)
var (
Version = "0.5.4"
commitSHA string
isNotFC bool
port int64
versionFlag bool
commitFlag bool
startCmdFlag string
cgroupRoot string
)
func parseFlags() {
flag.BoolVar(
&isNotFC,
"isnotfc",
false,
"isNotFCmode prints all logs to stdout",
)
flag.BoolVar(
&versionFlag,
"version",
false,
"print envd version",
)
flag.BoolVar(
&commitFlag,
"commit",
false,
"print envd source commit",
)
flag.Int64Var(
&port,
"port",
defaultPort,
"a port on which the daemon should run",
)
flag.StringVar(
&startCmdFlag,
"cmd",
"",
"a command to run on the daemon start",
)
flag.StringVar(
&cgroupRoot,
"cgroup-root",
"/sys/fs/cgroup",
"cgroup root directory",
)
flag.Parse()
}
func withCORS(h http.Handler) http.Handler {
middleware := cors.New(cors.Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{
http.MethodHead,
http.MethodGet,
http.MethodPost,
http.MethodPut,
http.MethodPatch,
http.MethodDelete,
},
AllowedHeaders: []string{"*"},
ExposedHeaders: append(
connectcors.ExposedHeaders(),
"Location",
"Cache-Control",
"X-Content-Type-Options",
),
MaxAge: int(maxAge.Seconds()),
})
return middleware.Handler(h)
}
func main() {
parseFlags()
if versionFlag {
fmt.Printf("%s\n", Version)
return
}
if commitFlag {
fmt.Printf("%s\n", commitSHA)
return
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if err := os.MkdirAll(host.WrennRunDir, 0o755); err != nil {
fmt.Fprintf(os.Stderr, "error creating wrenn run directory: %v\n", err)
}
defaults := &execcontext.Defaults{
User: defaultUser,
EnvVars: utils.NewMap[string, string](),
}
isFCBoolStr := strconv.FormatBool(!isNotFC)
defaults.EnvVars.Store("WRENN_SANDBOX", isFCBoolStr)
if err := os.WriteFile(filepath.Join(host.WrennRunDir, ".WRENN_SANDBOX"), []byte(isFCBoolStr), 0o444); err != nil {
fmt.Fprintf(os.Stderr, "error writing sandbox file: %v\n", err)
}
mmdsChan := make(chan *host.MMDSOpts, 1)
defer close(mmdsChan)
if !isNotFC {
go host.PollForMMDSOpts(ctx, mmdsChan, defaults.EnvVars)
}
l := logs.NewLogger(ctx, isNotFC, mmdsChan)
m := chi.NewRouter()
envLogger := l.With().Str("logger", "envd").Logger()
fsLogger := l.With().Str("logger", "filesystem").Logger()
filesystemRpc.Handle(m, &fsLogger, defaults)
cgroupManager := createCgroupManager()
defer func() {
err := cgroupManager.Close()
if err != nil {
fmt.Fprintf(os.Stderr, "failed to close cgroup manager: %v\n", err)
}
}()
processLogger := l.With().Str("logger", "process").Logger()
processService := processRpc.Handle(m, &processLogger, defaults, cgroupManager)
service := api.New(&envLogger, defaults, mmdsChan, isNotFC)
handler := api.HandlerFromMux(service, m)
middleware := authn.NewMiddleware(permissions.AuthenticateUsername)
s := &http.Server{
Handler: withCORS(
service.WithAuthorization(
middleware.Wrap(handler),
),
),
Addr: fmt.Sprintf("0.0.0.0:%d", port),
// We remove the timeouts as the connection is terminated by closing of the sandbox and keepalive close.
ReadTimeout: 0,
WriteTimeout: 0,
IdleTimeout: idleTimeout,
}
// TODO: Not used anymore in template build, replaced by direct envd command call.
if startCmdFlag != "" {
tag := "startCmd"
cwd := "/home/user"
user, err := permissions.GetUser("root")
if err != nil {
log.Fatalf("error getting user: %v", err) //nolint:gocritic // probably fine to bail if we're done?
}
if err = processService.InitializeStartProcess(ctx, user, &processSpec.StartRequest{
Tag: &tag,
Process: &processSpec.ProcessConfig{
Envs: make(map[string]string),
Cmd: "/bin/bash",
Args: []string{"-l", "-c", startCmdFlag},
Cwd: &cwd,
},
}); err != nil {
log.Fatalf("error starting process: %v", err)
}
}
// Bind all open ports on 127.0.0.1 and localhost to the eth0 interface
portScanner := publicport.NewScanner(portScannerInterval)
defer portScanner.Destroy()
portLogger := l.With().Str("logger", "port-forwarder").Logger()
portForwarder := publicport.NewForwarder(&portLogger, portScanner, cgroupManager)
go portForwarder.StartForwarding(ctx)
go portScanner.ScanAndBroadcast()
err := s.ListenAndServe()
if err != nil {
log.Fatalf("error starting server: %v", err)
}
}
func createCgroupManager() (m cgroups.Manager) {
defer func() {
if m == nil {
fmt.Fprintf(os.Stderr, "falling back to no-op cgroup manager\n")
m = cgroups.NewNoopManager()
}
}()
metrics, err := host.GetMetrics()
if err != nil {
fmt.Fprintf(os.Stderr, "failed to calculate host metrics: %v\n", err)
return nil
}
// try to keep 1/8 of the memory free, but no more than 128 MB
maxMemoryReserved := uint64(float64(metrics.MemTotal) * .125)
maxMemoryReserved = min(maxMemoryReserved, uint64(128)*megabyte)
opts := []cgroups.Cgroup2ManagerOption{
cgroups.WithCgroup2ProcessType(cgroups.ProcessTypePTY, "ptys", map[string]string{
"cpu.weight": "200", // gets much preferred cpu access, to help keep these real time
}),
cgroups.WithCgroup2ProcessType(cgroups.ProcessTypeSocat, "socats", map[string]string{
"cpu.weight": "150", // gets slightly preferred cpu access
"memory.min": fmt.Sprintf("%d", 5*megabyte),
"memory.low": fmt.Sprintf("%d", 8*megabyte),
}),
cgroups.WithCgroup2ProcessType(cgroups.ProcessTypeUser, "user", map[string]string{
"memory.high": fmt.Sprintf("%d", metrics.MemTotal-maxMemoryReserved),
"cpu.weight": "50", // less than envd, and less than core processes that default to 100
}),
}
if cgroupRoot != "" {
opts = append(opts, cgroups.WithCgroup2RootSysFSPath(cgroupRoot))
}
mgr, err := cgroups.NewCgroup2Manager(opts...)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to create cgroup2 manager: %v\n", err)
return nil
}
return mgr
}

15
envd/spec/buf.gen.yaml Normal file
View File

@ -0,0 +1,15 @@
version: v2
plugins:
- protoc_builtin: go
out: ../internal/services/spec
opt: paths=source_relative
- local: protoc-gen-connect-go
out: ../internal/services/spec
opt: paths=source_relative
inputs:
- directory: ../../proto/envd
managed:
enabled: true
override:
- file_option: go_package_prefix
value: git.omukk.dev/wrenn/sandbox/envd/internal/services/spec

305
envd/spec/envd.yaml Normal file
View File

@ -0,0 +1,305 @@
# SPDX-License-Identifier: Apache-2.0
openapi: 3.0.0
info:
title: envd
version: 0.1.1
description: API for managing files' content and controlling envd
tags:
- name: files
paths:
/health:
get:
summary: Check the health of the service
responses:
"204":
description: The service is healthy
/metrics:
get:
summary: Get the stats of the service
security:
- AccessTokenAuth: []
- {}
responses:
"200":
description: The resource usage metrics of the service
content:
application/json:
schema:
$ref: "#/components/schemas/Metrics"
/init:
post:
summary: Set initial vars, ensure the time and metadata is synced with the host
security:
- AccessTokenAuth: []
- {}
requestBody:
content:
application/json:
schema:
type: object
properties:
volumeMounts:
type: array
items:
$ref: "#/components/schemas/VolumeMount"
hyperloopIP:
type: string
description: IP address of the hyperloop server to connect to
envVars:
$ref: "#/components/schemas/EnvVars"
accessToken:
type: string
description: Access token for secure access to envd service
x-go-type: SecureToken
timestamp:
type: string
format: date-time
description: The current timestamp in RFC3339 format
defaultUser:
type: string
description: The default user to use for operations
defaultWorkdir:
type: string
description: The default working directory to use for operations
responses:
"204":
description: Env vars set, the time and metadata is synced with the host
/envs:
get:
summary: Get the environment variables
security:
- AccessTokenAuth: []
- {}
responses:
"200":
description: Environment variables
content:
application/json:
schema:
$ref: "#/components/schemas/EnvVars"
/files:
get:
summary: Download a file
tags: [files]
security:
- AccessTokenAuth: []
- {}
parameters:
- $ref: "#/components/parameters/FilePath"
- $ref: "#/components/parameters/User"
- $ref: "#/components/parameters/Signature"
- $ref: "#/components/parameters/SignatureExpiration"
responses:
"200":
$ref: "#/components/responses/DownloadSuccess"
"401":
$ref: "#/components/responses/InvalidUser"
"400":
$ref: "#/components/responses/InvalidPath"
"404":
$ref: "#/components/responses/FileNotFound"
"500":
$ref: "#/components/responses/InternalServerError"
post:
summary: Upload a file and ensure the parent directories exist. If the file exists, it will be overwritten.
tags: [files]
security:
- AccessTokenAuth: []
- {}
parameters:
- $ref: "#/components/parameters/FilePath"
- $ref: "#/components/parameters/User"
- $ref: "#/components/parameters/Signature"
- $ref: "#/components/parameters/SignatureExpiration"
requestBody:
$ref: "#/components/requestBodies/File"
responses:
"200":
$ref: "#/components/responses/UploadSuccess"
"400":
$ref: "#/components/responses/InvalidPath"
"401":
$ref: "#/components/responses/InvalidUser"
"500":
$ref: "#/components/responses/InternalServerError"
"507":
$ref: "#/components/responses/NotEnoughDiskSpace"
components:
securitySchemes:
AccessTokenAuth:
type: apiKey
in: header
name: X-Access-Token
parameters:
FilePath:
name: path
in: query
required: false
description: Path to the file, URL encoded. Can be relative to user's home directory.
schema:
type: string
User:
name: username
in: query
required: false
description: User used for setting the owner, or resolving relative paths.
schema:
type: string
Signature:
name: signature
in: query
required: false
description: Signature used for file access permission verification.
schema:
type: string
SignatureExpiration:
name: signature_expiration
in: query
required: false
description: Signature expiration used for defining the expiration time of the signature.
schema:
type: integer
requestBodies:
File:
required: true
content:
multipart/form-data:
schema:
type: object
properties:
file:
type: string
format: binary
responses:
UploadSuccess:
description: The file was uploaded successfully.
content:
application/json:
schema:
type: array
items:
$ref: "#/components/schemas/EntryInfo"
DownloadSuccess:
description: Entire file downloaded successfully.
content:
application/octet-stream:
schema:
type: string
format: binary
description: The file content
InvalidPath:
description: Invalid path
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
InternalServerError:
description: Internal server error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
FileNotFound:
description: File not found
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
InvalidUser:
description: Invalid user
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
NotEnoughDiskSpace:
description: Not enough disk space
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
schemas:
Error:
required:
- message
- code
properties:
message:
type: string
description: Error message
code:
type: integer
description: Error code
EntryInfo:
required:
- path
- name
- type
properties:
path:
type: string
description: Path to the file
name:
type: string
description: Name of the file
type:
type: string
description: Type of the file
enum:
- file
EnvVars:
type: object
description: Environment variables to set
additionalProperties:
type: string
Metrics:
type: object
description: Resource usage metrics
properties:
ts:
type: integer
format: int64
description: Unix timestamp in UTC for current sandbox time
cpu_count:
type: integer
description: Number of CPU cores
cpu_used_pct:
type: number
format: float
description: CPU usage percentage
mem_total:
type: integer
description: Total virtual memory in bytes
mem_used:
type: integer
description: Used virtual memory in bytes
disk_used:
type: integer
description: Used disk space in bytes
disk_total:
type: integer
description: Total disk space in bytes
VolumeMount:
type: object
description: Volume
additionalProperties: false
properties:
nfs_target:
type: string
path:
type: string
required:
- nfs_target
- path

3
envd/spec/generate.go Normal file
View File

@ -0,0 +1,3 @@
package spec
//go:generate buf generate