Port envd from e2b with internalized shared packages and Connect RPC
- Copy envd source from e2b-dev/infra, internalize shared dependencies
into envd/internal/shared/ (keys, filesystem, id, smap, utils)
- Switch from gRPC to Connect RPC for all envd services
- Update module paths to git.omukk.dev/wrenn/{sandbox,sandbox/envd}
- Add proto specs (process, filesystem) with buf-based code generation
- Implement full envd: process exec, filesystem ops, port forwarding,
cgroup management, MMDS integration, and HTTP API
- Update main module dependencies (firecracker SDK, pgx, goose, etc.)
- Remove placeholder .gitkeep files replaced by real implementations
This commit is contained in:
568
envd/internal/api/api.gen.go
Normal file
568
envd/internal/api/api.gen.go
Normal file
@ -0,0 +1,568 @@
|
||||
// Package api provides primitives to interact with the openapi HTTP API.
|
||||
//
|
||||
// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.5.1 DO NOT EDIT.
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/oapi-codegen/runtime"
|
||||
openapi_types "github.com/oapi-codegen/runtime/types"
|
||||
)
|
||||
|
||||
const (
|
||||
AccessTokenAuthScopes = "AccessTokenAuth.Scopes"
|
||||
)
|
||||
|
||||
// Defines values for EntryInfoType.
|
||||
const (
|
||||
File EntryInfoType = "file"
|
||||
)
|
||||
|
||||
// EntryInfo defines model for EntryInfo.
|
||||
type EntryInfo struct {
|
||||
// Name Name of the file
|
||||
Name string `json:"name"`
|
||||
|
||||
// Path Path to the file
|
||||
Path string `json:"path"`
|
||||
|
||||
// Type Type of the file
|
||||
Type EntryInfoType `json:"type"`
|
||||
}
|
||||
|
||||
// EntryInfoType Type of the file
|
||||
type EntryInfoType string
|
||||
|
||||
// EnvVars Environment variables to set
|
||||
type EnvVars map[string]string
|
||||
|
||||
// Error defines model for Error.
|
||||
type Error struct {
|
||||
// Code Error code
|
||||
Code int `json:"code"`
|
||||
|
||||
// Message Error message
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Metrics Resource usage metrics
|
||||
type Metrics struct {
|
||||
// CpuCount Number of CPU cores
|
||||
CpuCount *int `json:"cpu_count,omitempty"`
|
||||
|
||||
// CpuUsedPct CPU usage percentage
|
||||
CpuUsedPct *float32 `json:"cpu_used_pct,omitempty"`
|
||||
|
||||
// DiskTotal Total disk space in bytes
|
||||
DiskTotal *int `json:"disk_total,omitempty"`
|
||||
|
||||
// DiskUsed Used disk space in bytes
|
||||
DiskUsed *int `json:"disk_used,omitempty"`
|
||||
|
||||
// MemTotal Total virtual memory in bytes
|
||||
MemTotal *int `json:"mem_total,omitempty"`
|
||||
|
||||
// MemUsed Used virtual memory in bytes
|
||||
MemUsed *int `json:"mem_used,omitempty"`
|
||||
|
||||
// Ts Unix timestamp in UTC for current sandbox time
|
||||
Ts *int64 `json:"ts,omitempty"`
|
||||
}
|
||||
|
||||
// VolumeMount Volume
|
||||
type VolumeMount struct {
|
||||
NfsTarget string `json:"nfs_target"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// FilePath defines model for FilePath.
|
||||
type FilePath = string
|
||||
|
||||
// Signature defines model for Signature.
|
||||
type Signature = string
|
||||
|
||||
// SignatureExpiration defines model for SignatureExpiration.
|
||||
type SignatureExpiration = int
|
||||
|
||||
// User defines model for User.
|
||||
type User = string
|
||||
|
||||
// FileNotFound defines model for FileNotFound.
|
||||
type FileNotFound = Error
|
||||
|
||||
// InternalServerError defines model for InternalServerError.
|
||||
type InternalServerError = Error
|
||||
|
||||
// InvalidPath defines model for InvalidPath.
|
||||
type InvalidPath = Error
|
||||
|
||||
// InvalidUser defines model for InvalidUser.
|
||||
type InvalidUser = Error
|
||||
|
||||
// NotEnoughDiskSpace defines model for NotEnoughDiskSpace.
|
||||
type NotEnoughDiskSpace = Error
|
||||
|
||||
// UploadSuccess defines model for UploadSuccess.
|
||||
type UploadSuccess = []EntryInfo
|
||||
|
||||
// GetFilesParams defines parameters for GetFiles.
|
||||
type GetFilesParams struct {
|
||||
// Path Path to the file, URL encoded. Can be relative to user's home directory.
|
||||
Path *FilePath `form:"path,omitempty" json:"path,omitempty"`
|
||||
|
||||
// Username User used for setting the owner, or resolving relative paths.
|
||||
Username *User `form:"username,omitempty" json:"username,omitempty"`
|
||||
|
||||
// Signature Signature used for file access permission verification.
|
||||
Signature *Signature `form:"signature,omitempty" json:"signature,omitempty"`
|
||||
|
||||
// SignatureExpiration Signature expiration used for defining the expiration time of the signature.
|
||||
SignatureExpiration *SignatureExpiration `form:"signature_expiration,omitempty" json:"signature_expiration,omitempty"`
|
||||
}
|
||||
|
||||
// PostFilesMultipartBody defines parameters for PostFiles.
|
||||
type PostFilesMultipartBody struct {
|
||||
File *openapi_types.File `json:"file,omitempty"`
|
||||
}
|
||||
|
||||
// PostFilesParams defines parameters for PostFiles.
|
||||
type PostFilesParams struct {
|
||||
// Path Path to the file, URL encoded. Can be relative to user's home directory.
|
||||
Path *FilePath `form:"path,omitempty" json:"path,omitempty"`
|
||||
|
||||
// Username User used for setting the owner, or resolving relative paths.
|
||||
Username *User `form:"username,omitempty" json:"username,omitempty"`
|
||||
|
||||
// Signature Signature used for file access permission verification.
|
||||
Signature *Signature `form:"signature,omitempty" json:"signature,omitempty"`
|
||||
|
||||
// SignatureExpiration Signature expiration used for defining the expiration time of the signature.
|
||||
SignatureExpiration *SignatureExpiration `form:"signature_expiration,omitempty" json:"signature_expiration,omitempty"`
|
||||
}
|
||||
|
||||
// PostInitJSONBody defines parameters for PostInit.
|
||||
type PostInitJSONBody struct {
|
||||
// AccessToken Access token for secure access to envd service
|
||||
AccessToken *SecureToken `json:"accessToken,omitempty"`
|
||||
|
||||
// DefaultUser The default user to use for operations
|
||||
DefaultUser *string `json:"defaultUser,omitempty"`
|
||||
|
||||
// DefaultWorkdir The default working directory to use for operations
|
||||
DefaultWorkdir *string `json:"defaultWorkdir,omitempty"`
|
||||
|
||||
// EnvVars Environment variables to set
|
||||
EnvVars *EnvVars `json:"envVars,omitempty"`
|
||||
|
||||
// HyperloopIP IP address of the hyperloop server to connect to
|
||||
HyperloopIP *string `json:"hyperloopIP,omitempty"`
|
||||
|
||||
// Timestamp The current timestamp in RFC3339 format
|
||||
Timestamp *time.Time `json:"timestamp,omitempty"`
|
||||
VolumeMounts *[]VolumeMount `json:"volumeMounts,omitempty"`
|
||||
}
|
||||
|
||||
// PostFilesMultipartRequestBody defines body for PostFiles for multipart/form-data ContentType.
|
||||
type PostFilesMultipartRequestBody PostFilesMultipartBody
|
||||
|
||||
// PostInitJSONRequestBody defines body for PostInit for application/json ContentType.
|
||||
type PostInitJSONRequestBody PostInitJSONBody
|
||||
|
||||
// ServerInterface represents all server handlers.
|
||||
type ServerInterface interface {
|
||||
// Get the environment variables
|
||||
// (GET /envs)
|
||||
GetEnvs(w http.ResponseWriter, r *http.Request)
|
||||
// Download a file
|
||||
// (GET /files)
|
||||
GetFiles(w http.ResponseWriter, r *http.Request, params GetFilesParams)
|
||||
// Upload a file and ensure the parent directories exist. If the file exists, it will be overwritten.
|
||||
// (POST /files)
|
||||
PostFiles(w http.ResponseWriter, r *http.Request, params PostFilesParams)
|
||||
// Check the health of the service
|
||||
// (GET /health)
|
||||
GetHealth(w http.ResponseWriter, r *http.Request)
|
||||
// Set initial vars, ensure the time and metadata is synced with the host
|
||||
// (POST /init)
|
||||
PostInit(w http.ResponseWriter, r *http.Request)
|
||||
// Get the stats of the service
|
||||
// (GET /metrics)
|
||||
GetMetrics(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
// Unimplemented server implementation that returns http.StatusNotImplemented for each endpoint.
|
||||
|
||||
type Unimplemented struct{}
|
||||
|
||||
// Get the environment variables
|
||||
// (GET /envs)
|
||||
func (_ Unimplemented) GetEnvs(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// Download a file
|
||||
// (GET /files)
|
||||
func (_ Unimplemented) GetFiles(w http.ResponseWriter, r *http.Request, params GetFilesParams) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// Upload a file and ensure the parent directories exist. If the file exists, it will be overwritten.
|
||||
// (POST /files)
|
||||
func (_ Unimplemented) PostFiles(w http.ResponseWriter, r *http.Request, params PostFilesParams) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// Check the health of the service
|
||||
// (GET /health)
|
||||
func (_ Unimplemented) GetHealth(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// Set initial vars, ensure the time and metadata is synced with the host
|
||||
// (POST /init)
|
||||
func (_ Unimplemented) PostInit(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// Get the stats of the service
|
||||
// (GET /metrics)
|
||||
func (_ Unimplemented) GetMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// ServerInterfaceWrapper converts contexts to parameters.
|
||||
type ServerInterfaceWrapper struct {
|
||||
Handler ServerInterface
|
||||
HandlerMiddlewares []MiddlewareFunc
|
||||
ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error)
|
||||
}
|
||||
|
||||
type MiddlewareFunc func(http.Handler) http.Handler
|
||||
|
||||
// GetEnvs operation middleware
|
||||
func (siw *ServerInterfaceWrapper) GetEnvs(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.GetEnvs(w, r)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// GetFiles operation middleware
|
||||
func (siw *ServerInterfaceWrapper) GetFiles(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var err error
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
// Parameter object where we will unmarshal all parameters from the context
|
||||
var params GetFilesParams
|
||||
|
||||
// ------------- Optional query parameter "path" -------------
|
||||
|
||||
err = runtime.BindQueryParameter("form", true, false, "path", r.URL.Query(), ¶ms.Path)
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "path", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "username" -------------
|
||||
|
||||
err = runtime.BindQueryParameter("form", true, false, "username", r.URL.Query(), ¶ms.Username)
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "username", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "signature" -------------
|
||||
|
||||
err = runtime.BindQueryParameter("form", true, false, "signature", r.URL.Query(), ¶ms.Signature)
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "signature_expiration" -------------
|
||||
|
||||
err = runtime.BindQueryParameter("form", true, false, "signature_expiration", r.URL.Query(), ¶ms.SignatureExpiration)
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature_expiration", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.GetFiles(w, r, params)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// PostFiles operation middleware
|
||||
func (siw *ServerInterfaceWrapper) PostFiles(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var err error
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
// Parameter object where we will unmarshal all parameters from the context
|
||||
var params PostFilesParams
|
||||
|
||||
// ------------- Optional query parameter "path" -------------
|
||||
|
||||
err = runtime.BindQueryParameter("form", true, false, "path", r.URL.Query(), ¶ms.Path)
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "path", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "username" -------------
|
||||
|
||||
err = runtime.BindQueryParameter("form", true, false, "username", r.URL.Query(), ¶ms.Username)
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "username", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "signature" -------------
|
||||
|
||||
err = runtime.BindQueryParameter("form", true, false, "signature", r.URL.Query(), ¶ms.Signature)
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "signature_expiration" -------------
|
||||
|
||||
err = runtime.BindQueryParameter("form", true, false, "signature_expiration", r.URL.Query(), ¶ms.SignatureExpiration)
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature_expiration", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.PostFiles(w, r, params)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// GetHealth operation middleware
|
||||
func (siw *ServerInterfaceWrapper) GetHealth(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.GetHealth(w, r)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// PostInit operation middleware
|
||||
func (siw *ServerInterfaceWrapper) PostInit(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.PostInit(w, r)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// GetMetrics operation middleware
|
||||
func (siw *ServerInterfaceWrapper) GetMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.GetMetrics(w, r)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
type UnescapedCookieParamError struct {
|
||||
ParamName string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *UnescapedCookieParamError) Error() string {
|
||||
return fmt.Sprintf("error unescaping cookie parameter '%s'", e.ParamName)
|
||||
}
|
||||
|
||||
func (e *UnescapedCookieParamError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type UnmarshalingParamError struct {
|
||||
ParamName string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *UnmarshalingParamError) Error() string {
|
||||
return fmt.Sprintf("Error unmarshaling parameter %s as JSON: %s", e.ParamName, e.Err.Error())
|
||||
}
|
||||
|
||||
func (e *UnmarshalingParamError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type RequiredParamError struct {
|
||||
ParamName string
|
||||
}
|
||||
|
||||
func (e *RequiredParamError) Error() string {
|
||||
return fmt.Sprintf("Query argument %s is required, but not found", e.ParamName)
|
||||
}
|
||||
|
||||
type RequiredHeaderError struct {
|
||||
ParamName string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *RequiredHeaderError) Error() string {
|
||||
return fmt.Sprintf("Header parameter %s is required, but not found", e.ParamName)
|
||||
}
|
||||
|
||||
func (e *RequiredHeaderError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type InvalidParamFormatError struct {
|
||||
ParamName string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *InvalidParamFormatError) Error() string {
|
||||
return fmt.Sprintf("Invalid format for parameter %s: %s", e.ParamName, e.Err.Error())
|
||||
}
|
||||
|
||||
func (e *InvalidParamFormatError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type TooManyValuesForParamError struct {
|
||||
ParamName string
|
||||
Count int
|
||||
}
|
||||
|
||||
func (e *TooManyValuesForParamError) Error() string {
|
||||
return fmt.Sprintf("Expected one value for %s, got %d", e.ParamName, e.Count)
|
||||
}
|
||||
|
||||
// Handler creates http.Handler with routing matching OpenAPI spec.
|
||||
func Handler(si ServerInterface) http.Handler {
|
||||
return HandlerWithOptions(si, ChiServerOptions{})
|
||||
}
|
||||
|
||||
type ChiServerOptions struct {
|
||||
BaseURL string
|
||||
BaseRouter chi.Router
|
||||
Middlewares []MiddlewareFunc
|
||||
ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error)
|
||||
}
|
||||
|
||||
// HandlerFromMux creates http.Handler with routing matching OpenAPI spec based on the provided mux.
|
||||
func HandlerFromMux(si ServerInterface, r chi.Router) http.Handler {
|
||||
return HandlerWithOptions(si, ChiServerOptions{
|
||||
BaseRouter: r,
|
||||
})
|
||||
}
|
||||
|
||||
func HandlerFromMuxWithBaseURL(si ServerInterface, r chi.Router, baseURL string) http.Handler {
|
||||
return HandlerWithOptions(si, ChiServerOptions{
|
||||
BaseURL: baseURL,
|
||||
BaseRouter: r,
|
||||
})
|
||||
}
|
||||
|
||||
// HandlerWithOptions creates http.Handler with additional options
|
||||
func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handler {
|
||||
r := options.BaseRouter
|
||||
|
||||
if r == nil {
|
||||
r = chi.NewRouter()
|
||||
}
|
||||
if options.ErrorHandlerFunc == nil {
|
||||
options.ErrorHandlerFunc = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
wrapper := ServerInterfaceWrapper{
|
||||
Handler: si,
|
||||
HandlerMiddlewares: options.Middlewares,
|
||||
ErrorHandlerFunc: options.ErrorHandlerFunc,
|
||||
}
|
||||
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Get(options.BaseURL+"/envs", wrapper.GetEnvs)
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Get(options.BaseURL+"/files", wrapper.GetFiles)
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Post(options.BaseURL+"/files", wrapper.PostFiles)
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Get(options.BaseURL+"/health", wrapper.GetHealth)
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Post(options.BaseURL+"/init", wrapper.PostInit)
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Get(options.BaseURL+"/metrics", wrapper.GetMetrics)
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
129
envd/internal/api/auth.go
Normal file
129
envd/internal/api/auth.go
Normal file
@ -0,0 +1,129 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
|
||||
)
|
||||
|
||||
const (
|
||||
SigningReadOperation = "read"
|
||||
SigningWriteOperation = "write"
|
||||
|
||||
accessTokenHeader = "X-Access-Token"
|
||||
)
|
||||
|
||||
// paths that are always allowed without general authentication
|
||||
// POST/init is secured via MMDS hash validation instead
|
||||
var authExcludedPaths = []string{
|
||||
"GET/health",
|
||||
"GET/files",
|
||||
"POST/files",
|
||||
"POST/init",
|
||||
}
|
||||
|
||||
func (a *API) WithAuthorization(handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
if a.accessToken.IsSet() {
|
||||
authHeader := req.Header.Get(accessTokenHeader)
|
||||
|
||||
// check if this path is allowed without authentication (e.g., health check, endpoints supporting signing)
|
||||
allowedPath := slices.Contains(authExcludedPaths, req.Method+req.URL.Path)
|
||||
|
||||
if !a.accessToken.Equals(authHeader) && !allowedPath {
|
||||
a.logger.Error().Msg("Trying to access secured envd without correct access token")
|
||||
|
||||
err := fmt.Errorf("unauthorized access, please provide a valid access token or method signing if supported")
|
||||
jsonError(w, http.StatusUnauthorized, err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
})
|
||||
}
|
||||
|
||||
func (a *API) generateSignature(path string, username string, operation string, signatureExpiration *int64) (string, error) {
|
||||
tokenBytes, err := a.accessToken.Bytes()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("access token is not set: %w", err)
|
||||
}
|
||||
defer memguard.WipeBytes(tokenBytes)
|
||||
|
||||
var signature string
|
||||
hasher := keys.NewSHA256Hashing()
|
||||
|
||||
if signatureExpiration == nil {
|
||||
signature = strings.Join([]string{path, operation, username, string(tokenBytes)}, ":")
|
||||
} else {
|
||||
signature = strings.Join([]string{path, operation, username, string(tokenBytes), strconv.FormatInt(*signatureExpiration, 10)}, ":")
|
||||
}
|
||||
|
||||
return fmt.Sprintf("v1_%s", hasher.HashWithoutPrefix([]byte(signature))), nil
|
||||
}
|
||||
|
||||
func (a *API) validateSigning(r *http.Request, signature *string, signatureExpiration *int, username *string, path string, operation string) (err error) {
|
||||
var expectedSignature string
|
||||
|
||||
// no need to validate signing key if access token is not set
|
||||
if !a.accessToken.IsSet() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// check if access token is sent in the header
|
||||
tokenFromHeader := r.Header.Get(accessTokenHeader)
|
||||
if tokenFromHeader != "" {
|
||||
if !a.accessToken.Equals(tokenFromHeader) {
|
||||
return fmt.Errorf("access token present in header but does not match")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if signature == nil {
|
||||
return fmt.Errorf("missing signature query parameter")
|
||||
}
|
||||
|
||||
// Empty string is used when no username is provided and the default user should be used
|
||||
signatureUsername := ""
|
||||
if username != nil {
|
||||
signatureUsername = *username
|
||||
}
|
||||
|
||||
if signatureExpiration == nil {
|
||||
expectedSignature, err = a.generateSignature(path, signatureUsername, operation, nil)
|
||||
} else {
|
||||
exp := int64(*signatureExpiration)
|
||||
expectedSignature, err = a.generateSignature(path, signatureUsername, operation, &exp)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Msg("error generating signing key")
|
||||
|
||||
return errors.New("invalid signature")
|
||||
}
|
||||
|
||||
// signature validation
|
||||
if expectedSignature != *signature {
|
||||
return fmt.Errorf("invalid signature")
|
||||
}
|
||||
|
||||
// signature expiration
|
||||
if signatureExpiration != nil {
|
||||
exp := int64(*signatureExpiration)
|
||||
if exp < time.Now().Unix() {
|
||||
return fmt.Errorf("signature is already expired")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
62
envd/internal/api/auth_test.go
Normal file
62
envd/internal/api/auth_test.go
Normal file
@ -0,0 +1,62 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
|
||||
)
|
||||
|
||||
func TestKeyGenerationAlgorithmIsStable(t *testing.T) {
|
||||
t.Parallel()
|
||||
apiToken := "secret-access-token"
|
||||
secureToken := &SecureToken{}
|
||||
err := secureToken.Set([]byte(apiToken))
|
||||
require.NoError(t, err)
|
||||
api := &API{accessToken: secureToken}
|
||||
|
||||
path := "/path/to/demo.txt"
|
||||
username := "root"
|
||||
operation := "write"
|
||||
timestamp := time.Now().Unix()
|
||||
|
||||
signature, err := api.generateSignature(path, username, operation, ×tamp)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, signature)
|
||||
|
||||
// locally generated signature
|
||||
hasher := keys.NewSHA256Hashing()
|
||||
localSignatureTmp := fmt.Sprintf("%s:%s:%s:%s:%s", path, operation, username, apiToken, strconv.FormatInt(timestamp, 10))
|
||||
localSignature := fmt.Sprintf("v1_%s", hasher.HashWithoutPrefix([]byte(localSignatureTmp)))
|
||||
|
||||
assert.Equal(t, localSignature, signature)
|
||||
}
|
||||
|
||||
func TestKeyGenerationAlgorithmWithoutExpirationIsStable(t *testing.T) {
|
||||
t.Parallel()
|
||||
apiToken := "secret-access-token"
|
||||
secureToken := &SecureToken{}
|
||||
err := secureToken.Set([]byte(apiToken))
|
||||
require.NoError(t, err)
|
||||
api := &API{accessToken: secureToken}
|
||||
|
||||
path := "/path/to/resource.txt"
|
||||
username := "user"
|
||||
operation := "read"
|
||||
|
||||
signature, err := api.generateSignature(path, username, operation, nil)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, signature)
|
||||
|
||||
// locally generated signature
|
||||
hasher := keys.NewSHA256Hashing()
|
||||
localSignatureTmp := fmt.Sprintf("%s:%s:%s:%s", path, operation, username, apiToken)
|
||||
localSignature := fmt.Sprintf("v1_%s", hasher.HashWithoutPrefix([]byte(localSignatureTmp)))
|
||||
|
||||
assert.Equal(t, localSignature, signature)
|
||||
}
|
||||
8
envd/internal/api/cfg.yaml
Normal file
8
envd/internal/api/cfg.yaml
Normal file
@ -0,0 +1,8 @@
|
||||
# yaml-language-server: $schema=https://raw.githubusercontent.com/deepmap/oapi-codegen/HEAD/configuration-schema.json
|
||||
|
||||
package: api
|
||||
output: api.gen.go
|
||||
generate:
|
||||
models: true
|
||||
chi-server: true
|
||||
client: false
|
||||
173
envd/internal/api/download.go
Normal file
173
envd/internal/api/download.go
Normal file
@ -0,0 +1,173 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||
)
|
||||
|
||||
func (a *API) GetFiles(w http.ResponseWriter, r *http.Request, params GetFilesParams) {
|
||||
defer r.Body.Close()
|
||||
|
||||
var errorCode int
|
||||
var errMsg error
|
||||
|
||||
var path string
|
||||
if params.Path != nil {
|
||||
path = *params.Path
|
||||
}
|
||||
|
||||
operationID := logs.AssignOperationID()
|
||||
|
||||
// signing authorization if needed
|
||||
err := a.validateSigning(r, params.Signature, params.SignatureExpiration, params.Username, path, SigningReadOperation)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("error during auth validation")
|
||||
jsonError(w, http.StatusUnauthorized, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
username, err := execcontext.ResolveDefaultUsername(params.Username, a.defaults.User)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("no user specified")
|
||||
jsonError(w, http.StatusBadRequest, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
l := a.logger.
|
||||
Err(errMsg).
|
||||
Str("method", r.Method+" "+r.URL.Path).
|
||||
Str(string(logs.OperationIDKey), operationID).
|
||||
Str("path", path).
|
||||
Str("username", username)
|
||||
|
||||
if errMsg != nil {
|
||||
l = l.Int("error_code", errorCode)
|
||||
}
|
||||
|
||||
l.Msg("File read")
|
||||
}()
|
||||
|
||||
u, err := user.Lookup(username)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error looking up user '%s': %w", username, err)
|
||||
errorCode = http.StatusUnauthorized
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resolvedPath, err := permissions.ExpandAndResolve(path, u, a.defaults.Workdir)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error expanding and resolving path '%s': %w", path, err)
|
||||
errorCode = http.StatusBadRequest
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
stat, err := os.Stat(resolvedPath)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
errMsg = fmt.Errorf("path '%s' does not exist", resolvedPath)
|
||||
errorCode = http.StatusNotFound
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
errMsg = fmt.Errorf("error checking if path exists '%s': %w", resolvedPath, err)
|
||||
errorCode = http.StatusInternalServerError
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if stat.IsDir() {
|
||||
errMsg = fmt.Errorf("path '%s' is a directory", resolvedPath)
|
||||
errorCode = http.StatusBadRequest
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Validate Accept-Encoding header
|
||||
encoding, err := parseAcceptEncoding(r)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error parsing Accept-Encoding: %w", err)
|
||||
errorCode = http.StatusNotAcceptable
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Tell caches to store separate variants for different Accept-Encoding values
|
||||
w.Header().Set("Vary", "Accept-Encoding")
|
||||
|
||||
// Fall back to identity for Range or conditional requests to preserve http.ServeContent
|
||||
// behavior (206 Partial Content, 304 Not Modified). However, we must check if identity
|
||||
// is acceptable per the Accept-Encoding header.
|
||||
hasRangeOrConditional := r.Header.Get("Range") != "" ||
|
||||
r.Header.Get("If-Modified-Since") != "" ||
|
||||
r.Header.Get("If-None-Match") != "" ||
|
||||
r.Header.Get("If-Range") != ""
|
||||
if hasRangeOrConditional {
|
||||
if !isIdentityAcceptable(r) {
|
||||
errMsg = fmt.Errorf("identity encoding not acceptable for Range or conditional request")
|
||||
errorCode = http.StatusNotAcceptable
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
encoding = EncodingIdentity
|
||||
}
|
||||
|
||||
file, err := os.Open(resolvedPath)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error opening file '%s': %w", resolvedPath, err)
|
||||
errorCode = http.StatusInternalServerError
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
w.Header().Set("Content-Disposition", mime.FormatMediaType("inline", map[string]string{"filename": filepath.Base(resolvedPath)}))
|
||||
|
||||
// Serve with gzip encoding if requested.
|
||||
if encoding == EncodingGzip {
|
||||
w.Header().Set("Content-Encoding", EncodingGzip)
|
||||
|
||||
// Set Content-Type based on file extension, preserving the original type
|
||||
contentType := mime.TypeByExtension(filepath.Ext(path))
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
|
||||
gw := gzip.NewWriter(w)
|
||||
defer gw.Close()
|
||||
|
||||
_, err = io.Copy(gw, file)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("error writing gzip response")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
http.ServeContent(w, r, path, stat.ModTime(), file)
|
||||
}
|
||||
401
envd/internal/api/download_test.go
Normal file
401
envd/internal/api/download_test.go
Normal file
@ -0,0 +1,401 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
func TestGetFilesContentDisposition(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filename string
|
||||
expectedHeader string
|
||||
}{
|
||||
{
|
||||
name: "simple filename",
|
||||
filename: "test.txt",
|
||||
expectedHeader: `inline; filename=test.txt`,
|
||||
},
|
||||
{
|
||||
name: "filename with extension",
|
||||
filename: "presentation.pptx",
|
||||
expectedHeader: `inline; filename=presentation.pptx`,
|
||||
},
|
||||
{
|
||||
name: "filename with multiple dots",
|
||||
filename: "archive.tar.gz",
|
||||
expectedHeader: `inline; filename=archive.tar.gz`,
|
||||
},
|
||||
{
|
||||
name: "filename with spaces",
|
||||
filename: "my document.pdf",
|
||||
expectedHeader: `inline; filename="my document.pdf"`,
|
||||
},
|
||||
{
|
||||
name: "filename with quotes",
|
||||
filename: `file"name.txt`,
|
||||
expectedHeader: `inline; filename="file\"name.txt"`,
|
||||
},
|
||||
{
|
||||
name: "filename with backslash",
|
||||
filename: `file\name.txt`,
|
||||
expectedHeader: `inline; filename="file\\name.txt"`,
|
||||
},
|
||||
{
|
||||
name: "unicode filename",
|
||||
filename: "\u6587\u6863.pdf", // 文档.pdf in Chinese
|
||||
expectedHeader: "inline; filename*=utf-8''%E6%96%87%E6%A1%A3.pdf",
|
||||
},
|
||||
{
|
||||
name: "dotfile preserved",
|
||||
filename: ".env",
|
||||
expectedHeader: `inline; filename=.env`,
|
||||
},
|
||||
{
|
||||
name: "dotfile with extension preserved",
|
||||
filename: ".gitignore",
|
||||
expectedHeader: `inline; filename=.gitignore`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a temp directory and file
|
||||
tempDir := t.TempDir()
|
||||
tempFile := filepath.Join(tempDir, tt.filename)
|
||||
err := os.WriteFile(tempFile, []byte("test content"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test API
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
User: currentUser.Username,
|
||||
}
|
||||
api := New(&logger, defaults, nil, false)
|
||||
|
||||
// Create request and response recorder
|
||||
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call the handler
|
||||
params := GetFilesParams{
|
||||
Path: &tempFile,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.GetFiles(w, req, params)
|
||||
|
||||
// Check response
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Verify Content-Disposition header
|
||||
contentDisposition := resp.Header.Get("Content-Disposition")
|
||||
assert.Equal(t, tt.expectedHeader, contentDisposition, "Content-Disposition header should be set with correct filename")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFilesContentDispositionWithNestedPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a temp directory with nested structure
|
||||
tempDir := t.TempDir()
|
||||
nestedDir := filepath.Join(tempDir, "subdir", "another")
|
||||
err = os.MkdirAll(nestedDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
filename := "document.pdf"
|
||||
tempFile := filepath.Join(nestedDir, filename)
|
||||
err = os.WriteFile(tempFile, []byte("test content"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test API
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
User: currentUser.Username,
|
||||
}
|
||||
api := New(&logger, defaults, nil, false)
|
||||
|
||||
// Create request and response recorder
|
||||
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call the handler
|
||||
params := GetFilesParams{
|
||||
Path: &tempFile,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.GetFiles(w, req, params)
|
||||
|
||||
// Check response
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Verify Content-Disposition header uses only the base filename, not the full path
|
||||
contentDisposition := resp.Header.Get("Content-Disposition")
|
||||
assert.Equal(t, `inline; filename=document.pdf`, contentDisposition, "Content-Disposition should contain only the filename, not the path")
|
||||
}
|
||||
|
||||
func TestGetFiles_GzipEncoding_ExplicitIdentityOffWithRange(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a temp directory with a test file
|
||||
tempDir := t.TempDir()
|
||||
filename := "document.pdf"
|
||||
tempFile := filepath.Join(tempDir, filename)
|
||||
err = os.WriteFile(tempFile, []byte("test content"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test API
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
User: currentUser.Username,
|
||||
}
|
||||
api := New(&logger, defaults, nil, false)
|
||||
|
||||
// Create request and response recorder
|
||||
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip; q=1,*; q=0")
|
||||
req.Header.Set("Range", "bytes=0-4") // Request first 5 bytes
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call the handler
|
||||
params := GetFilesParams{
|
||||
Path: &tempFile,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.GetFiles(w, req, params)
|
||||
|
||||
// Check response
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusNotAcceptable, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestGetFiles_GzipDownload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
originalContent := []byte("hello world, this is a test file for gzip compression")
|
||||
|
||||
// Create a temp file with known content
|
||||
tempDir := t.TempDir()
|
||||
tempFile := filepath.Join(tempDir, "test.txt")
|
||||
err = os.WriteFile(tempFile, originalContent, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
User: currentUser.Username,
|
||||
}
|
||||
api := New(&logger, defaults, nil, false)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
params := GetFilesParams{
|
||||
Path: &tempFile,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.GetFiles(w, req, params)
|
||||
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, "gzip", resp.Header.Get("Content-Encoding"))
|
||||
assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type"))
|
||||
|
||||
// Decompress the gzip response body
|
||||
gzReader, err := gzip.NewReader(resp.Body)
|
||||
require.NoError(t, err)
|
||||
defer gzReader.Close()
|
||||
|
||||
decompressed, err := io.ReadAll(gzReader)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, originalContent, decompressed)
|
||||
}
|
||||
|
||||
func TestPostFiles_GzipUpload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
originalContent := []byte("hello world, this is a test file uploaded with gzip")
|
||||
|
||||
// Build a multipart body
|
||||
var multipartBuf bytes.Buffer
|
||||
mpWriter := multipart.NewWriter(&multipartBuf)
|
||||
part, err := mpWriter.CreateFormFile("file", "uploaded.txt")
|
||||
require.NoError(t, err)
|
||||
_, err = part.Write(originalContent)
|
||||
require.NoError(t, err)
|
||||
err = mpWriter.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Gzip-compress the entire multipart body
|
||||
var gzBuf bytes.Buffer
|
||||
gzWriter := gzip.NewWriter(&gzBuf)
|
||||
_, err = gzWriter.Write(multipartBuf.Bytes())
|
||||
require.NoError(t, err)
|
||||
err = gzWriter.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test API
|
||||
tempDir := t.TempDir()
|
||||
destPath := filepath.Join(tempDir, "uploaded.txt")
|
||||
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
User: currentUser.Username,
|
||||
}
|
||||
api := New(&logger, defaults, nil, false)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/files?path="+url.QueryEscape(destPath), &gzBuf)
|
||||
req.Header.Set("Content-Type", mpWriter.FormDataContentType())
|
||||
req.Header.Set("Content-Encoding", "gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
params := PostFilesParams{
|
||||
Path: &destPath,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.PostFiles(w, req, params)
|
||||
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Verify the file was written with the original (decompressed) content
|
||||
data, err := os.ReadFile(destPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, originalContent, data)
|
||||
}
|
||||
|
||||
func TestGzipUploadThenGzipDownload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
originalContent := []byte("round-trip gzip test: upload compressed, download compressed, verify match")
|
||||
|
||||
// --- Upload with gzip ---
|
||||
|
||||
// Build a multipart body
|
||||
var multipartBuf bytes.Buffer
|
||||
mpWriter := multipart.NewWriter(&multipartBuf)
|
||||
part, err := mpWriter.CreateFormFile("file", "roundtrip.txt")
|
||||
require.NoError(t, err)
|
||||
_, err = part.Write(originalContent)
|
||||
require.NoError(t, err)
|
||||
err = mpWriter.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Gzip-compress the entire multipart body
|
||||
var gzBuf bytes.Buffer
|
||||
gzWriter := gzip.NewWriter(&gzBuf)
|
||||
_, err = gzWriter.Write(multipartBuf.Bytes())
|
||||
require.NoError(t, err)
|
||||
err = gzWriter.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
destPath := filepath.Join(tempDir, "roundtrip.txt")
|
||||
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
User: currentUser.Username,
|
||||
}
|
||||
api := New(&logger, defaults, nil, false)
|
||||
|
||||
uploadReq := httptest.NewRequest(http.MethodPost, "/files?path="+url.QueryEscape(destPath), &gzBuf)
|
||||
uploadReq.Header.Set("Content-Type", mpWriter.FormDataContentType())
|
||||
uploadReq.Header.Set("Content-Encoding", "gzip")
|
||||
uploadW := httptest.NewRecorder()
|
||||
|
||||
uploadParams := PostFilesParams{
|
||||
Path: &destPath,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.PostFiles(uploadW, uploadReq, uploadParams)
|
||||
|
||||
uploadResp := uploadW.Result()
|
||||
defer uploadResp.Body.Close()
|
||||
|
||||
require.Equal(t, http.StatusOK, uploadResp.StatusCode)
|
||||
|
||||
// --- Download with gzip ---
|
||||
|
||||
downloadReq := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(destPath), nil)
|
||||
downloadReq.Header.Set("Accept-Encoding", "gzip")
|
||||
downloadW := httptest.NewRecorder()
|
||||
|
||||
downloadParams := GetFilesParams{
|
||||
Path: &destPath,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.GetFiles(downloadW, downloadReq, downloadParams)
|
||||
|
||||
downloadResp := downloadW.Result()
|
||||
defer downloadResp.Body.Close()
|
||||
|
||||
require.Equal(t, http.StatusOK, downloadResp.StatusCode)
|
||||
assert.Equal(t, "gzip", downloadResp.Header.Get("Content-Encoding"))
|
||||
|
||||
// Decompress and verify content matches original
|
||||
gzReader, err := gzip.NewReader(downloadResp.Body)
|
||||
require.NoError(t, err)
|
||||
defer gzReader.Close()
|
||||
|
||||
decompressed, err := io.ReadAll(gzReader)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, originalContent, decompressed)
|
||||
}
|
||||
227
envd/internal/api/encoding.go
Normal file
227
envd/internal/api/encoding.go
Normal file
@ -0,0 +1,227 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// EncodingGzip is the gzip content encoding.
|
||||
EncodingGzip = "gzip"
|
||||
// EncodingIdentity means no encoding (passthrough).
|
||||
EncodingIdentity = "identity"
|
||||
// EncodingWildcard means any encoding is acceptable.
|
||||
EncodingWildcard = "*"
|
||||
)
|
||||
|
||||
// SupportedEncodings lists the content encodings supported for file transfer.
|
||||
// The order matters - encodings are checked in order of preference.
|
||||
var SupportedEncodings = []string{
|
||||
EncodingGzip,
|
||||
}
|
||||
|
||||
// encodingWithQuality holds an encoding name and its quality value.
|
||||
type encodingWithQuality struct {
|
||||
encoding string
|
||||
quality float64
|
||||
}
|
||||
|
||||
// isSupportedEncoding checks if the given encoding is in the supported list.
|
||||
// Per RFC 7231, content-coding values are case-insensitive.
|
||||
func isSupportedEncoding(encoding string) bool {
|
||||
return slices.Contains(SupportedEncodings, strings.ToLower(encoding))
|
||||
}
|
||||
|
||||
// parseEncodingWithQuality parses an encoding value and extracts the quality.
|
||||
// Returns the encoding name (lowercased) and quality value (default 1.0 if not specified).
|
||||
// Per RFC 7231, content-coding values are case-insensitive.
|
||||
func parseEncodingWithQuality(value string) encodingWithQuality {
|
||||
value = strings.TrimSpace(value)
|
||||
quality := 1.0
|
||||
|
||||
if idx := strings.Index(value, ";"); idx != -1 {
|
||||
params := value[idx+1:]
|
||||
value = strings.TrimSpace(value[:idx])
|
||||
|
||||
// Parse q=X.X parameter
|
||||
for param := range strings.SplitSeq(params, ";") {
|
||||
param = strings.TrimSpace(param)
|
||||
if strings.HasPrefix(strings.ToLower(param), "q=") {
|
||||
if q, err := strconv.ParseFloat(param[2:], 64); err == nil {
|
||||
quality = q
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize encoding to lowercase per RFC 7231
|
||||
return encodingWithQuality{encoding: strings.ToLower(value), quality: quality}
|
||||
}
|
||||
|
||||
// parseEncoding extracts the encoding name from a header value, stripping quality.
|
||||
func parseEncoding(value string) string {
|
||||
return parseEncodingWithQuality(value).encoding
|
||||
}
|
||||
|
||||
// parseContentEncoding parses the Content-Encoding header and returns the encoding.
|
||||
// Returns an error if an unsupported encoding is specified.
|
||||
// If no Content-Encoding header is present, returns empty string.
|
||||
func parseContentEncoding(r *http.Request) (string, error) {
|
||||
header := r.Header.Get("Content-Encoding")
|
||||
if header == "" {
|
||||
return EncodingIdentity, nil
|
||||
}
|
||||
|
||||
encoding := parseEncoding(header)
|
||||
|
||||
if encoding == EncodingIdentity {
|
||||
return EncodingIdentity, nil
|
||||
}
|
||||
|
||||
if !isSupportedEncoding(encoding) {
|
||||
return "", fmt.Errorf("unsupported Content-Encoding: %s, supported: %v", header, SupportedEncodings)
|
||||
}
|
||||
|
||||
return encoding, nil
|
||||
}
|
||||
|
||||
// parseAcceptEncodingHeader parses the Accept-Encoding header and returns
|
||||
// the parsed encodings along with the identity rejection state.
|
||||
// Per RFC 7231 Section 5.3.4, identity is acceptable unless excluded by
|
||||
// "identity;q=0" or "*;q=0" without a more specific entry for identity with q>0.
|
||||
func parseAcceptEncodingHeader(header string) ([]encodingWithQuality, bool) {
|
||||
if header == "" {
|
||||
return nil, false // identity not rejected when header is empty
|
||||
}
|
||||
|
||||
// Parse all encodings with their quality values
|
||||
var encodings []encodingWithQuality
|
||||
for value := range strings.SplitSeq(header, ",") {
|
||||
eq := parseEncodingWithQuality(value)
|
||||
encodings = append(encodings, eq)
|
||||
}
|
||||
|
||||
// Check if identity is rejected per RFC 7231 Section 5.3.4:
|
||||
// identity is acceptable unless excluded by "identity;q=0" or "*;q=0"
|
||||
// without a more specific entry for identity with q>0.
|
||||
identityRejected := false
|
||||
identityExplicitlyAccepted := false
|
||||
wildcardRejected := false
|
||||
|
||||
for _, eq := range encodings {
|
||||
switch eq.encoding {
|
||||
case EncodingIdentity:
|
||||
if eq.quality == 0 {
|
||||
identityRejected = true
|
||||
} else {
|
||||
identityExplicitlyAccepted = true
|
||||
}
|
||||
case EncodingWildcard:
|
||||
if eq.quality == 0 {
|
||||
wildcardRejected = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if wildcardRejected && !identityExplicitlyAccepted {
|
||||
identityRejected = true
|
||||
}
|
||||
|
||||
return encodings, identityRejected
|
||||
}
|
||||
|
||||
// isIdentityAcceptable checks if identity encoding is acceptable based on the
|
||||
// Accept-Encoding header. Per RFC 7231 section 5.3.4, identity is always
|
||||
// implicitly acceptable unless explicitly rejected with q=0.
|
||||
func isIdentityAcceptable(r *http.Request) bool {
|
||||
header := r.Header.Get("Accept-Encoding")
|
||||
_, identityRejected := parseAcceptEncodingHeader(header)
|
||||
|
||||
return !identityRejected
|
||||
}
|
||||
|
||||
// parseAcceptEncoding parses the Accept-Encoding header and returns the best
|
||||
// supported encoding based on quality values. Per RFC 7231 section 5.3.4,
|
||||
// identity is always implicitly acceptable unless explicitly rejected with q=0.
|
||||
// If no Accept-Encoding header is present, returns empty string (identity).
|
||||
func parseAcceptEncoding(r *http.Request) (string, error) {
|
||||
header := r.Header.Get("Accept-Encoding")
|
||||
if header == "" {
|
||||
return EncodingIdentity, nil
|
||||
}
|
||||
|
||||
encodings, identityRejected := parseAcceptEncodingHeader(header)
|
||||
|
||||
// Sort by quality value (highest first)
|
||||
sort.Slice(encodings, func(i, j int) bool {
|
||||
return encodings[i].quality > encodings[j].quality
|
||||
})
|
||||
|
||||
// Find the best supported encoding
|
||||
for _, eq := range encodings {
|
||||
// Skip encodings with q=0 (explicitly rejected)
|
||||
if eq.quality == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if eq.encoding == EncodingIdentity {
|
||||
return EncodingIdentity, nil
|
||||
}
|
||||
|
||||
// Wildcard means any encoding is acceptable - return a supported encoding if identity is rejected
|
||||
if eq.encoding == EncodingWildcard {
|
||||
if identityRejected && len(SupportedEncodings) > 0 {
|
||||
return SupportedEncodings[0], nil
|
||||
}
|
||||
|
||||
return EncodingIdentity, nil
|
||||
}
|
||||
|
||||
if isSupportedEncoding(eq.encoding) {
|
||||
return eq.encoding, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Per RFC 7231, identity is implicitly acceptable unless rejected
|
||||
if !identityRejected {
|
||||
return EncodingIdentity, nil
|
||||
}
|
||||
|
||||
// Identity rejected and no supported encodings found
|
||||
return "", fmt.Errorf("no acceptable encoding found, supported: %v", SupportedEncodings)
|
||||
}
|
||||
|
||||
// getDecompressedBody returns a reader that decompresses the request body based on
|
||||
// Content-Encoding header. Returns the original body if no encoding is specified.
|
||||
// Returns an error if an unsupported encoding is specified.
|
||||
// The caller is responsible for closing both the returned ReadCloser and the
|
||||
// original request body (r.Body) separately.
|
||||
func getDecompressedBody(r *http.Request) (io.ReadCloser, error) {
|
||||
encoding, err := parseContentEncoding(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if encoding == EncodingIdentity {
|
||||
return r.Body, nil
|
||||
}
|
||||
|
||||
switch encoding {
|
||||
case EncodingGzip:
|
||||
gzReader, err := gzip.NewReader(r.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
|
||||
}
|
||||
|
||||
return gzReader, nil
|
||||
default:
|
||||
// This shouldn't happen if isSupportedEncoding is correct
|
||||
return nil, fmt.Errorf("encoding %s is supported but not implemented", encoding)
|
||||
}
|
||||
}
|
||||
494
envd/internal/api/encoding_test.go
Normal file
494
envd/internal/api/encoding_test.go
Normal file
@ -0,0 +1,494 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsSupportedEncoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("gzip is supported", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.True(t, isSupportedEncoding("gzip"))
|
||||
})
|
||||
|
||||
t.Run("GZIP is supported (case-insensitive)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.True(t, isSupportedEncoding("GZIP"))
|
||||
})
|
||||
|
||||
t.Run("Gzip is supported (case-insensitive)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.True(t, isSupportedEncoding("Gzip"))
|
||||
})
|
||||
|
||||
t.Run("br is not supported", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.False(t, isSupportedEncoding("br"))
|
||||
})
|
||||
|
||||
t.Run("deflate is not supported", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.False(t, isSupportedEncoding("deflate"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseEncodingWithQuality(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns encoding with default quality 1.0", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("gzip")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 1.0, eq.quality, 0.001)
|
||||
})
|
||||
|
||||
t.Run("parses quality value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("gzip;q=0.5")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 0.5, eq.quality, 0.001)
|
||||
})
|
||||
|
||||
t.Run("parses quality value with whitespace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("gzip ; q=0.8")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 0.8, eq.quality, 0.001)
|
||||
})
|
||||
|
||||
t.Run("handles q=0", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("gzip;q=0")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 0.0, eq.quality, 0.001)
|
||||
})
|
||||
|
||||
t.Run("handles invalid quality value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("gzip;q=invalid")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 1.0, eq.quality, 0.001) // defaults to 1.0 on parse error
|
||||
})
|
||||
|
||||
t.Run("trims whitespace from encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality(" gzip ")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 1.0, eq.quality, 0.001)
|
||||
})
|
||||
|
||||
t.Run("normalizes encoding to lowercase", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("GZIP")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
})
|
||||
|
||||
t.Run("normalizes mixed case encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("Gzip;q=0.5")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 0.5, eq.quality, 0.001)
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseEncoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns encoding as-is", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Equal(t, "gzip", parseEncoding("gzip"))
|
||||
})
|
||||
|
||||
t.Run("trims whitespace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Equal(t, "gzip", parseEncoding(" gzip "))
|
||||
})
|
||||
|
||||
t.Run("strips quality value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Equal(t, "gzip", parseEncoding("gzip;q=1.0"))
|
||||
})
|
||||
|
||||
t.Run("strips quality value with whitespace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Equal(t, "gzip", parseEncoding("gzip ; q=0.5"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseContentEncoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns identity when no header", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
|
||||
encoding, err := parseContentEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when Content-Encoding is gzip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Content-Encoding", "gzip")
|
||||
|
||||
encoding, err := parseContentEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when Content-Encoding is GZIP (case-insensitive)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Content-Encoding", "GZIP")
|
||||
|
||||
encoding, err := parseContentEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when Content-Encoding is Gzip (case-insensitive)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Content-Encoding", "Gzip")
|
||||
|
||||
encoding, err := parseContentEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns identity for identity encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Content-Encoding", "identity")
|
||||
|
||||
encoding, err := parseContentEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("returns error for unsupported encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Content-Encoding", "br")
|
||||
|
||||
_, err := parseContentEncoding(req)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported Content-Encoding")
|
||||
assert.Contains(t, err.Error(), "supported: [gzip]")
|
||||
})
|
||||
|
||||
t.Run("handles gzip with quality value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Content-Encoding", "gzip;q=1.0")
|
||||
|
||||
encoding, err := parseContentEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseAcceptEncoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns identity when no header", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when Accept-Encoding is gzip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when Accept-Encoding is GZIP (case-insensitive)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "GZIP")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when gzip is among multiple encodings", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "deflate, gzip, br")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip with quality value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip;q=1.0")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns identity for identity encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "identity")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("returns identity for wildcard encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "*")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("falls back to identity for unsupported encoding only", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "br")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("falls back to identity when only unsupported encodings", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "deflate, br")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("selects gzip when it has highest quality", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "br;q=0.5, gzip;q=1.0, deflate;q=0.8")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("selects gzip even with lower quality when others unsupported", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "br;q=1.0, gzip;q=0.5")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns identity when it has higher quality than gzip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip;q=0.5, identity;q=1.0")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("skips encoding with q=0", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip;q=0, identity")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("falls back to identity when gzip rejected and no other supported", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip;q=0, br")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("returns error when identity explicitly rejected and no supported encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "br, identity;q=0")
|
||||
|
||||
_, err := parseAcceptEncoding(req)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no acceptable encoding found")
|
||||
})
|
||||
|
||||
t.Run("returns gzip for wildcard when identity rejected", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "*, identity;q=0")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding) // wildcard with identity rejected returns supported encoding
|
||||
})
|
||||
|
||||
t.Run("returns error when wildcard rejected and no explicit identity", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "*;q=0")
|
||||
|
||||
_, err := parseAcceptEncoding(req)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no acceptable encoding found")
|
||||
})
|
||||
|
||||
t.Run("returns identity when wildcard rejected but identity explicitly accepted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "*;q=0, identity")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when wildcard rejected but gzip explicitly accepted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "*;q=0, gzip")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingGzip, encoding)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetDecompressedBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns original body when no Content-Encoding header", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
content := []byte("test content")
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(content))
|
||||
|
||||
body, err := getDecompressedBody(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, req.Body, body, "should return original body")
|
||||
|
||||
data, err := io.ReadAll(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, data)
|
||||
})
|
||||
|
||||
t.Run("decompresses gzip body when Content-Encoding is gzip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
originalContent := []byte("test content to compress")
|
||||
|
||||
var compressed bytes.Buffer
|
||||
gw := gzip.NewWriter(&compressed)
|
||||
_, err := gw.Write(originalContent)
|
||||
require.NoError(t, err)
|
||||
err = gw.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(compressed.Bytes()))
|
||||
req.Header.Set("Content-Encoding", "gzip")
|
||||
|
||||
body, err := getDecompressedBody(req)
|
||||
require.NoError(t, err)
|
||||
defer body.Close()
|
||||
|
||||
assert.NotEqual(t, req.Body, body, "should return a new gzip reader")
|
||||
|
||||
data, err := io.ReadAll(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, originalContent, data)
|
||||
})
|
||||
|
||||
t.Run("returns error for invalid gzip data", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
invalidGzip := []byte("this is not gzip data")
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(invalidGzip))
|
||||
req.Header.Set("Content-Encoding", "gzip")
|
||||
|
||||
_, err := getDecompressedBody(req)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to create gzip reader")
|
||||
})
|
||||
|
||||
t.Run("returns original body for identity encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
content := []byte("test content")
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(content))
|
||||
req.Header.Set("Content-Encoding", "identity")
|
||||
|
||||
body, err := getDecompressedBody(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, req.Body, body, "should return original body")
|
||||
|
||||
data, err := io.ReadAll(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, data)
|
||||
})
|
||||
|
||||
t.Run("returns error for unsupported encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
content := []byte("test content")
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(content))
|
||||
req.Header.Set("Content-Encoding", "br")
|
||||
|
||||
_, err := getDecompressedBody(req)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported Content-Encoding")
|
||||
})
|
||||
|
||||
t.Run("handles gzip with quality value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
originalContent := []byte("test content to compress")
|
||||
|
||||
var compressed bytes.Buffer
|
||||
gw := gzip.NewWriter(&compressed)
|
||||
_, err := gw.Write(originalContent)
|
||||
require.NoError(t, err)
|
||||
err = gw.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(compressed.Bytes()))
|
||||
req.Header.Set("Content-Encoding", "gzip;q=1.0")
|
||||
|
||||
body, err := getDecompressedBody(req)
|
||||
require.NoError(t, err)
|
||||
defer body.Close()
|
||||
|
||||
data, err := io.ReadAll(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, originalContent, data)
|
||||
})
|
||||
}
|
||||
29
envd/internal/api/envs.go
Normal file
29
envd/internal/api/envs.go
Normal file
@ -0,0 +1,29 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
)
|
||||
|
||||
func (a *API) GetEnvs(w http.ResponseWriter, _ *http.Request) {
|
||||
operationID := logs.AssignOperationID()
|
||||
|
||||
a.logger.Debug().Str(string(logs.OperationIDKey), operationID).Msg("Getting env vars")
|
||||
|
||||
envs := make(EnvVars)
|
||||
a.defaults.EnvVars.Range(func(key, value string) bool {
|
||||
envs[key] = value
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(envs); err != nil {
|
||||
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("Failed to encode env vars")
|
||||
}
|
||||
}
|
||||
21
envd/internal/api/error.go
Normal file
21
envd/internal/api/error.go
Normal file
@ -0,0 +1,21 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func jsonError(w http.ResponseWriter, code int, err error) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
|
||||
w.WriteHeader(code)
|
||||
encodeErr := json.NewEncoder(w).Encode(Error{
|
||||
Code: code,
|
||||
Message: err.Error(),
|
||||
})
|
||||
if encodeErr != nil {
|
||||
http.Error(w, errors.Join(encodeErr, err).Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
3
envd/internal/api/generate.go
Normal file
3
envd/internal/api/generate.go
Normal file
@ -0,0 +1,3 @@
|
||||
package api
|
||||
|
||||
//go:generate go run github.com/oapi-codegen/oapi-codegen/v2/cmd/oapi-codegen -config cfg.yaml ../../spec/envd.yaml
|
||||
314
envd/internal/api/init.go
Normal file
314
envd/internal/api/init.go
Normal file
@ -0,0 +1,314 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/txn2/txeh"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAccessTokenMismatch = errors.New("access token validation failed")
|
||||
ErrAccessTokenResetNotAuthorized = errors.New("access token reset not authorized")
|
||||
)
|
||||
|
||||
const (
|
||||
maxTimeInPast = 50 * time.Millisecond
|
||||
maxTimeInFuture = 5 * time.Second
|
||||
)
|
||||
|
||||
// validateInitAccessToken validates the access token for /init requests.
|
||||
// Token is valid if it matches the existing token OR the MMDS hash.
|
||||
// If neither exists, first-time setup is allowed.
|
||||
func (a *API) validateInitAccessToken(ctx context.Context, requestToken *SecureToken) error {
|
||||
requestTokenSet := requestToken.IsSet()
|
||||
|
||||
// Fast path: token matches existing
|
||||
if a.accessToken.IsSet() && requestTokenSet && a.accessToken.EqualsSecure(requestToken) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check MMDS only if token didn't match existing
|
||||
matchesMMDS, mmdsExists := a.checkMMDSHash(ctx, requestToken)
|
||||
|
||||
switch {
|
||||
case matchesMMDS:
|
||||
return nil
|
||||
case !a.accessToken.IsSet() && !mmdsExists:
|
||||
return nil // first-time setup
|
||||
case !requestTokenSet:
|
||||
return ErrAccessTokenResetNotAuthorized
|
||||
default:
|
||||
return ErrAccessTokenMismatch
|
||||
}
|
||||
}
|
||||
|
||||
// checkMMDSHash checks if the request token matches the MMDS hash.
|
||||
// Returns (matches, mmdsExists).
|
||||
//
|
||||
// The MMDS hash is set by the orchestrator during Resume:
|
||||
// - hash(token): requires this specific token
|
||||
// - hash(""): explicitly allows nil token (token reset authorized)
|
||||
// - "": MMDS not properly configured, no authorization granted
|
||||
func (a *API) checkMMDSHash(ctx context.Context, requestToken *SecureToken) (bool, bool) {
|
||||
if a.isNotFC {
|
||||
return false, false
|
||||
}
|
||||
|
||||
mmdsHash, err := a.mmdsClient.GetAccessTokenHash(ctx)
|
||||
if err != nil {
|
||||
return false, false
|
||||
}
|
||||
|
||||
if mmdsHash == "" {
|
||||
return false, false
|
||||
}
|
||||
|
||||
if !requestToken.IsSet() {
|
||||
return mmdsHash == keys.HashAccessToken(""), true
|
||||
}
|
||||
|
||||
tokenBytes, err := requestToken.Bytes()
|
||||
if err != nil {
|
||||
return false, true
|
||||
}
|
||||
defer memguard.WipeBytes(tokenBytes)
|
||||
|
||||
return keys.HashAccessTokenBytes(tokenBytes) == mmdsHash, true
|
||||
}
|
||||
|
||||
func (a *API) PostInit(w http.ResponseWriter, r *http.Request) {
|
||||
defer r.Body.Close()
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
operationID := logs.AssignOperationID()
|
||||
logger := a.logger.With().Str(string(logs.OperationIDKey), operationID).Logger()
|
||||
|
||||
if r.Body != nil {
|
||||
// Read raw body so we can wipe it after parsing
|
||||
body, err := io.ReadAll(r.Body)
|
||||
// Ensure body is wiped after we're done
|
||||
defer memguard.WipeBytes(body)
|
||||
if err != nil {
|
||||
logger.Error().Msgf("Failed to read request body: %v", err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var initRequest PostInitJSONBody
|
||||
if len(body) > 0 {
|
||||
err = json.Unmarshal(body, &initRequest)
|
||||
if err != nil {
|
||||
logger.Error().Msgf("Failed to decode request: %v", err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure request token is destroyed if not transferred via TakeFrom.
|
||||
// This handles: validation failures, timestamp-based skips, and any early returns.
|
||||
// Safe because Destroy() is nil-safe and TakeFrom clears the source.
|
||||
defer initRequest.AccessToken.Destroy()
|
||||
|
||||
a.initLock.Lock()
|
||||
defer a.initLock.Unlock()
|
||||
|
||||
// Update data only if the request is newer or if there's no timestamp at all
|
||||
if initRequest.Timestamp == nil || a.lastSetTime.SetToGreater(initRequest.Timestamp.UnixNano()) {
|
||||
err = a.SetData(ctx, logger, initRequest)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ErrAccessTokenMismatch), errors.Is(err, ErrAccessTokenResetNotAuthorized):
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
default:
|
||||
logger.Error().Msgf("Failed to set data: %v", err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}
|
||||
w.Write([]byte(err.Error()))
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
go func() { //nolint:contextcheck // TODO: fix this later
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
host.PollForMMDSOpts(ctx, a.mmdsChan, a.defaults.EnvVars)
|
||||
}()
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.Header().Set("Content-Type", "")
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (a *API) SetData(ctx context.Context, logger zerolog.Logger, data PostInitJSONBody) error {
|
||||
// Validate access token before proceeding with any action
|
||||
// The request must provide a token that is either:
|
||||
// 1. Matches the existing access token (if set), OR
|
||||
// 2. Matches the MMDS hash (for token change during resume)
|
||||
if err := a.validateInitAccessToken(ctx, data.AccessToken); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if data.Timestamp != nil {
|
||||
// Check if current time differs significantly from the received timestamp
|
||||
if shouldSetSystemTime(time.Now(), *data.Timestamp) {
|
||||
logger.Debug().Msgf("Setting sandbox start time to: %v", *data.Timestamp)
|
||||
ts := unix.NsecToTimespec(data.Timestamp.UnixNano())
|
||||
err := unix.ClockSettime(unix.CLOCK_REALTIME, &ts)
|
||||
if err != nil {
|
||||
logger.Error().Msgf("Failed to set system time: %v", err)
|
||||
}
|
||||
} else {
|
||||
logger.Debug().Msgf("Current time is within acceptable range of timestamp %v, not setting system time", *data.Timestamp)
|
||||
}
|
||||
}
|
||||
|
||||
if data.EnvVars != nil {
|
||||
logger.Debug().Msg(fmt.Sprintf("Setting %d env vars", len(*data.EnvVars)))
|
||||
|
||||
for key, value := range *data.EnvVars {
|
||||
logger.Debug().Msgf("Setting env var for %s", key)
|
||||
a.defaults.EnvVars.Store(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
if data.AccessToken.IsSet() {
|
||||
logger.Debug().Msg("Setting access token")
|
||||
a.accessToken.TakeFrom(data.AccessToken)
|
||||
} else if a.accessToken.IsSet() {
|
||||
logger.Debug().Msg("Clearing access token")
|
||||
a.accessToken.Destroy()
|
||||
}
|
||||
|
||||
if data.HyperloopIP != nil {
|
||||
go a.SetupHyperloop(*data.HyperloopIP)
|
||||
}
|
||||
|
||||
if data.DefaultUser != nil && *data.DefaultUser != "" {
|
||||
logger.Debug().Msgf("Setting default user to: %s", *data.DefaultUser)
|
||||
a.defaults.User = *data.DefaultUser
|
||||
}
|
||||
|
||||
if data.DefaultWorkdir != nil && *data.DefaultWorkdir != "" {
|
||||
logger.Debug().Msgf("Setting default workdir to: %s", *data.DefaultWorkdir)
|
||||
a.defaults.Workdir = data.DefaultWorkdir
|
||||
}
|
||||
|
||||
if data.VolumeMounts != nil {
|
||||
for _, volume := range *data.VolumeMounts {
|
||||
logger.Debug().Msgf("Mounting %s at %q", volume.NfsTarget, volume.Path)
|
||||
|
||||
go a.setupNfs(context.WithoutCancel(ctx), volume.NfsTarget, volume.Path)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *API) setupNfs(ctx context.Context, nfsTarget, path string) {
|
||||
commands := [][]string{
|
||||
{"mkdir", "-p", path},
|
||||
{"mount", "-v", "-t", "nfs", "-o", "mountproto=tcp,mountport=2049,proto=tcp,port=2049,nfsvers=3,noacl", nfsTarget, path},
|
||||
}
|
||||
|
||||
for _, command := range commands {
|
||||
data, err := exec.CommandContext(ctx, command[0], command[1:]...).CombinedOutput()
|
||||
|
||||
logger := a.getLogger(err)
|
||||
|
||||
logger.
|
||||
Strs("command", command).
|
||||
Str("output", string(data)).
|
||||
Msg("Mount NFS")
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *API) SetupHyperloop(address string) {
|
||||
a.hyperloopLock.Lock()
|
||||
defer a.hyperloopLock.Unlock()
|
||||
|
||||
if err := rewriteHostsFile(address, "/etc/hosts"); err != nil {
|
||||
a.logger.Error().Err(err).Msg("failed to modify hosts file")
|
||||
} else {
|
||||
a.defaults.EnvVars.Store("WRENN_EVENTS_ADDRESS", fmt.Sprintf("http://%s", address))
|
||||
}
|
||||
}
|
||||
|
||||
const eventsHost = "events.wrenn.local"
|
||||
|
||||
func rewriteHostsFile(address, path string) error {
|
||||
hosts, err := txeh.NewHosts(&txeh.HostsConfig{
|
||||
ReadFilePath: path,
|
||||
WriteFilePath: path,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create hosts: %w", err)
|
||||
}
|
||||
|
||||
// Update /etc/hosts to point events.wrenn.local to the hyperloop IP
|
||||
// This will remove any existing entries for events.wrenn.local first
|
||||
ipFamily, err := getIPFamily(address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get ip family: %w", err)
|
||||
}
|
||||
|
||||
if ok, current, _ := hosts.HostAddressLookup(eventsHost, ipFamily); ok && current == address {
|
||||
return nil // nothing to be done
|
||||
}
|
||||
|
||||
hosts.AddHost(address, eventsHost)
|
||||
|
||||
return hosts.Save()
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidAddress = errors.New("invalid IP address")
|
||||
ErrUnknownAddressFormat = errors.New("unknown IP address format")
|
||||
)
|
||||
|
||||
func getIPFamily(address string) (txeh.IPFamily, error) {
|
||||
addressIP, err := netip.ParseAddr(address)
|
||||
if err != nil {
|
||||
return txeh.IPFamilyV4, fmt.Errorf("failed to parse IP address: %w", err)
|
||||
}
|
||||
|
||||
switch {
|
||||
case addressIP.Is4():
|
||||
return txeh.IPFamilyV4, nil
|
||||
case addressIP.Is6():
|
||||
return txeh.IPFamilyV6, nil
|
||||
default:
|
||||
return txeh.IPFamilyV4, fmt.Errorf("%w: %s", ErrUnknownAddressFormat, address)
|
||||
}
|
||||
}
|
||||
|
||||
// shouldSetSystemTime returns true if the current time differs significantly from the received timestamp,
|
||||
// indicating the system clock should be adjusted. Returns true when the sandboxTime is more than
|
||||
// maxTimeInPast before the hostTime or more than maxTimeInFuture after the hostTime.
|
||||
func shouldSetSystemTime(sandboxTime, hostTime time.Time) bool {
|
||||
return sandboxTime.Before(hostTime.Add(-maxTimeInPast)) || sandboxTime.After(hostTime.Add(maxTimeInFuture))
|
||||
}
|
||||
587
envd/internal/api/init_test.go
Normal file
587
envd/internal/api/init_test.go
Normal file
@ -0,0 +1,587 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
|
||||
utilsShared "git.omukk.dev/wrenn/sandbox/envd/internal/shared/utils"
|
||||
)
|
||||
|
||||
func TestSimpleCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
testCases := map[string]func(string) string{
|
||||
"both newlines": func(s string) string { return s },
|
||||
"no newline prefix": func(s string) string { return strings.TrimPrefix(s, "\n") },
|
||||
"no newline suffix": func(s string) string { return strings.TrimSuffix(s, "\n") },
|
||||
"no newline prefix or suffix": strings.TrimSpace,
|
||||
}
|
||||
|
||||
for name, preprocessor := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tempDir := t.TempDir()
|
||||
|
||||
value := `
|
||||
# comment
|
||||
127.0.0.1 one.host
|
||||
127.0.0.2 two.host
|
||||
`
|
||||
value = preprocessor(value)
|
||||
inputPath := filepath.Join(tempDir, "hosts")
|
||||
err := os.WriteFile(inputPath, []byte(value), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = rewriteHostsFile("127.0.0.3", inputPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
data, err := os.ReadFile(inputPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, `# comment
|
||||
127.0.0.1 one.host
|
||||
127.0.0.2 two.host
|
||||
127.0.0.3 events.wrenn.local`, strings.TrimSpace(string(data)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSetSystemTime(t *testing.T) {
|
||||
t.Parallel()
|
||||
sandboxTime := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
hostTime time.Time
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "sandbox time far ahead of host time (should set)",
|
||||
hostTime: sandboxTime.Add(-10 * time.Second),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "sandbox time at maxTimeInPast boundary ahead of host time (should not set)",
|
||||
hostTime: sandboxTime.Add(-50 * time.Millisecond),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "sandbox time just within maxTimeInPast ahead of host time (should not set)",
|
||||
hostTime: sandboxTime.Add(-40 * time.Millisecond),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "sandbox time slightly ahead of host time (should not set)",
|
||||
hostTime: sandboxTime.Add(-10 * time.Millisecond),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "sandbox time equals host time (should not set)",
|
||||
hostTime: sandboxTime,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "sandbox time slightly behind host time (should not set)",
|
||||
hostTime: sandboxTime.Add(1 * time.Second),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "sandbox time just within maxTimeInFuture behind host time (should not set)",
|
||||
hostTime: sandboxTime.Add(4 * time.Second),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "sandbox time at maxTimeInFuture boundary behind host time (should not set)",
|
||||
hostTime: sandboxTime.Add(5 * time.Second),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "sandbox time far behind host time (should set)",
|
||||
hostTime: sandboxTime.Add(1 * time.Minute),
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := shouldSetSystemTime(tt.hostTime, sandboxTime)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func secureTokenPtr(s string) *SecureToken {
|
||||
token := &SecureToken{}
|
||||
_ = token.Set([]byte(s))
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
type mockMMDSClient struct {
|
||||
hash string
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockMMDSClient) GetAccessTokenHash(_ context.Context) (string, error) {
|
||||
return m.hash, m.err
|
||||
}
|
||||
|
||||
func newTestAPI(accessToken *SecureToken, mmdsClient MMDSClient) *API {
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
}
|
||||
api := New(&logger, defaults, nil, false)
|
||||
if accessToken != nil {
|
||||
api.accessToken.TakeFrom(accessToken)
|
||||
}
|
||||
api.mmdsClient = mmdsClient
|
||||
|
||||
return api
|
||||
}
|
||||
|
||||
func TestValidateInitAccessToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := t.Context()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accessToken *SecureToken
|
||||
requestToken *SecureToken
|
||||
mmdsHash string
|
||||
mmdsErr error
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "fast path: token matches existing",
|
||||
accessToken: secureTokenPtr("secret-token"),
|
||||
requestToken: secureTokenPtr("secret-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: nil,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "MMDS match: token hash matches MMDS hash",
|
||||
accessToken: secureTokenPtr("old-token"),
|
||||
requestToken: secureTokenPtr("new-token"),
|
||||
mmdsHash: keys.HashAccessToken("new-token"),
|
||||
mmdsErr: nil,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "first-time setup: no existing token, MMDS error",
|
||||
accessToken: nil,
|
||||
requestToken: secureTokenPtr("new-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "first-time setup: no existing token, empty MMDS hash",
|
||||
accessToken: nil,
|
||||
requestToken: secureTokenPtr("new-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: nil,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "first-time setup: both tokens nil, no MMDS",
|
||||
accessToken: nil,
|
||||
requestToken: nil,
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "mismatch: existing token differs from request, no MMDS",
|
||||
accessToken: secureTokenPtr("existing-token"),
|
||||
requestToken: secureTokenPtr("wrong-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: ErrAccessTokenMismatch,
|
||||
},
|
||||
{
|
||||
name: "mismatch: existing token differs from request, MMDS hash mismatch",
|
||||
accessToken: secureTokenPtr("existing-token"),
|
||||
requestToken: secureTokenPtr("wrong-token"),
|
||||
mmdsHash: keys.HashAccessToken("different-token"),
|
||||
mmdsErr: nil,
|
||||
wantErr: ErrAccessTokenMismatch,
|
||||
},
|
||||
{
|
||||
name: "conflict: existing token, nil request, MMDS exists",
|
||||
accessToken: secureTokenPtr("existing-token"),
|
||||
requestToken: nil,
|
||||
mmdsHash: keys.HashAccessToken("some-token"),
|
||||
mmdsErr: nil,
|
||||
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||
},
|
||||
{
|
||||
name: "conflict: existing token, nil request, no MMDS",
|
||||
accessToken: secureTokenPtr("existing-token"),
|
||||
requestToken: nil,
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: tt.mmdsHash, err: tt.mmdsErr}
|
||||
api := newTestAPI(tt.accessToken, mmdsClient)
|
||||
|
||||
err := api.validateInitAccessToken(ctx, tt.requestToken)
|
||||
|
||||
if tt.wantErr != nil {
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckMMDSHash(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := t.Context()
|
||||
|
||||
t.Run("returns match when token hash equals MMDS hash", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
token := "my-secret-token"
|
||||
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken(token), err: nil}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr(token))
|
||||
|
||||
assert.True(t, matches)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("returns no match when token hash differs from MMDS hash", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken("different-token"), err: nil}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("my-token"))
|
||||
|
||||
assert.False(t, matches)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("returns exists but no match when request token is nil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken("some-token"), err: nil}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, nil)
|
||||
|
||||
assert.False(t, matches)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("returns false, false when MMDS returns error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("any-token"))
|
||||
|
||||
assert.False(t, matches)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("returns false, false when MMDS returns empty hash with non-nil request", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: nil}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("any-token"))
|
||||
|
||||
assert.False(t, matches)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("returns false, false when MMDS returns empty hash with nil request", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: nil}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, nil)
|
||||
|
||||
assert.False(t, matches)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("returns true, true when MMDS returns hash of empty string with nil request (explicit reset)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken(""), err: nil}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, nil)
|
||||
|
||||
assert.True(t, matches)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetData(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := zerolog.Nop()
|
||||
|
||||
t.Run("access token updates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
existingToken *SecureToken
|
||||
requestToken *SecureToken
|
||||
mmdsHash string
|
||||
mmdsErr error
|
||||
wantErr error
|
||||
wantFinalToken *SecureToken
|
||||
}{
|
||||
{
|
||||
name: "first-time setup: sets initial token",
|
||||
existingToken: nil,
|
||||
requestToken: secureTokenPtr("initial-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: nil,
|
||||
wantFinalToken: secureTokenPtr("initial-token"),
|
||||
},
|
||||
{
|
||||
name: "first-time setup: nil request token leaves token unset",
|
||||
existingToken: nil,
|
||||
requestToken: nil,
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: nil,
|
||||
wantFinalToken: nil,
|
||||
},
|
||||
{
|
||||
name: "re-init with same token: token unchanged",
|
||||
existingToken: secureTokenPtr("same-token"),
|
||||
requestToken: secureTokenPtr("same-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: nil,
|
||||
wantFinalToken: secureTokenPtr("same-token"),
|
||||
},
|
||||
{
|
||||
name: "resume with MMDS: updates token when hash matches",
|
||||
existingToken: secureTokenPtr("old-token"),
|
||||
requestToken: secureTokenPtr("new-token"),
|
||||
mmdsHash: keys.HashAccessToken("new-token"),
|
||||
mmdsErr: nil,
|
||||
wantErr: nil,
|
||||
wantFinalToken: secureTokenPtr("new-token"),
|
||||
},
|
||||
{
|
||||
name: "resume with MMDS: fails when hash doesn't match",
|
||||
existingToken: secureTokenPtr("old-token"),
|
||||
requestToken: secureTokenPtr("new-token"),
|
||||
mmdsHash: keys.HashAccessToken("different-token"),
|
||||
mmdsErr: nil,
|
||||
wantErr: ErrAccessTokenMismatch,
|
||||
wantFinalToken: secureTokenPtr("old-token"),
|
||||
},
|
||||
{
|
||||
name: "fails when existing token and request token mismatch without MMDS",
|
||||
existingToken: secureTokenPtr("existing-token"),
|
||||
requestToken: secureTokenPtr("wrong-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: ErrAccessTokenMismatch,
|
||||
wantFinalToken: secureTokenPtr("existing-token"),
|
||||
},
|
||||
{
|
||||
name: "conflict when existing token but nil request token",
|
||||
existingToken: secureTokenPtr("existing-token"),
|
||||
requestToken: nil,
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||
wantFinalToken: secureTokenPtr("existing-token"),
|
||||
},
|
||||
{
|
||||
name: "conflict when existing token but nil request with MMDS present",
|
||||
existingToken: secureTokenPtr("existing-token"),
|
||||
requestToken: nil,
|
||||
mmdsHash: keys.HashAccessToken("some-token"),
|
||||
mmdsErr: nil,
|
||||
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||
wantFinalToken: secureTokenPtr("existing-token"),
|
||||
},
|
||||
{
|
||||
name: "conflict when MMDS returns empty hash and request is nil (prevents unauthorized reset)",
|
||||
existingToken: secureTokenPtr("existing-token"),
|
||||
requestToken: nil,
|
||||
mmdsHash: "",
|
||||
mmdsErr: nil,
|
||||
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||
wantFinalToken: secureTokenPtr("existing-token"),
|
||||
},
|
||||
{
|
||||
name: "resets token when MMDS returns hash of empty string and request is nil (explicit reset)",
|
||||
existingToken: secureTokenPtr("existing-token"),
|
||||
requestToken: nil,
|
||||
mmdsHash: keys.HashAccessToken(""),
|
||||
mmdsErr: nil,
|
||||
wantErr: nil,
|
||||
wantFinalToken: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: tt.mmdsHash, err: tt.mmdsErr}
|
||||
api := newTestAPI(tt.existingToken, mmdsClient)
|
||||
|
||||
data := PostInitJSONBody{
|
||||
AccessToken: tt.requestToken,
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
if tt.wantFinalToken == nil {
|
||||
assert.False(t, api.accessToken.IsSet(), "expected token to not be set")
|
||||
} else {
|
||||
require.True(t, api.accessToken.IsSet(), "expected token to be set")
|
||||
assert.True(t, api.accessToken.EqualsSecure(tt.wantFinalToken), "expected token to match")
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sets environment variables", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
envVars := EnvVars{"FOO": "bar", "BAZ": "qux"}
|
||||
data := PostInitJSONBody{
|
||||
EnvVars: &envVars,
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
require.NoError(t, err)
|
||||
val, ok := api.defaults.EnvVars.Load("FOO")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "bar", val)
|
||||
val, ok = api.defaults.EnvVars.Load("BAZ")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "qux", val)
|
||||
})
|
||||
|
||||
t.Run("sets default user", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
data := PostInitJSONBody{
|
||||
DefaultUser: utilsShared.ToPtr("testuser"),
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "testuser", api.defaults.User)
|
||||
})
|
||||
|
||||
t.Run("does not set default user when empty", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
api.defaults.User = "original"
|
||||
|
||||
data := PostInitJSONBody{
|
||||
DefaultUser: utilsShared.ToPtr(""),
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "original", api.defaults.User)
|
||||
})
|
||||
|
||||
t.Run("sets default workdir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
data := PostInitJSONBody{
|
||||
DefaultWorkdir: utilsShared.ToPtr("/home/user"),
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, api.defaults.Workdir)
|
||||
assert.Equal(t, "/home/user", *api.defaults.Workdir)
|
||||
})
|
||||
|
||||
t.Run("does not set default workdir when empty", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
originalWorkdir := "/original"
|
||||
api.defaults.Workdir = &originalWorkdir
|
||||
|
||||
data := PostInitJSONBody{
|
||||
DefaultWorkdir: utilsShared.ToPtr(""),
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, api.defaults.Workdir)
|
||||
assert.Equal(t, "/original", *api.defaults.Workdir)
|
||||
})
|
||||
|
||||
t.Run("sets multiple fields at once", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
envVars := EnvVars{"KEY": "value"}
|
||||
data := PostInitJSONBody{
|
||||
AccessToken: secureTokenPtr("token"),
|
||||
DefaultUser: utilsShared.ToPtr("user"),
|
||||
DefaultWorkdir: utilsShared.ToPtr("/workdir"),
|
||||
EnvVars: &envVars,
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, api.accessToken.Equals("token"), "expected token to match")
|
||||
assert.Equal(t, "user", api.defaults.User)
|
||||
assert.Equal(t, "/workdir", *api.defaults.Workdir)
|
||||
val, ok := api.defaults.EnvVars.Load("KEY")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "value", val)
|
||||
})
|
||||
}
|
||||
212
envd/internal/api/secure_token.go
Normal file
212
envd/internal/api/secure_token.go
Normal file
@ -0,0 +1,212 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTokenNotSet = errors.New("access token not set")
|
||||
ErrTokenEmpty = errors.New("empty token not allowed")
|
||||
)
|
||||
|
||||
// SecureToken wraps memguard for secure token storage.
|
||||
// It uses LockedBuffer which provides memory locking, guard pages,
|
||||
// and secure zeroing on destroy.
|
||||
type SecureToken struct {
|
||||
mu sync.RWMutex
|
||||
buffer *memguard.LockedBuffer
|
||||
}
|
||||
|
||||
// Set securely replaces the token, destroying the old one first.
|
||||
// The old token memory is zeroed before the new token is stored.
|
||||
// The input byte slice is wiped after copying to secure memory.
|
||||
// Returns ErrTokenEmpty if token is empty - use Destroy() to clear the token instead.
|
||||
func (s *SecureToken) Set(token []byte) error {
|
||||
if len(token) == 0 {
|
||||
return ErrTokenEmpty
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Destroy old token first (zeros memory)
|
||||
if s.buffer != nil {
|
||||
s.buffer.Destroy()
|
||||
s.buffer = nil
|
||||
}
|
||||
|
||||
// Create new LockedBuffer from bytes (source slice is wiped by memguard)
|
||||
s.buffer = memguard.NewBufferFromBytes(token)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler to securely parse a JSON string
|
||||
// directly into memguard, wiping the input bytes after copying.
|
||||
//
|
||||
// Access tokens are hex-encoded HMAC-SHA256 hashes (64 chars of [0-9a-f]),
|
||||
// so they never contain JSON escape sequences.
|
||||
func (s *SecureToken) UnmarshalJSON(data []byte) error {
|
||||
// JSON strings are quoted, so minimum valid is `""` (2 bytes).
|
||||
if len(data) < 2 || data[0] != '"' || data[len(data)-1] != '"' {
|
||||
memguard.WipeBytes(data)
|
||||
|
||||
return errors.New("invalid secure token JSON string")
|
||||
}
|
||||
|
||||
content := data[1 : len(data)-1]
|
||||
|
||||
// Access tokens are hex strings - reject if contains backslash
|
||||
if bytes.ContainsRune(content, '\\') {
|
||||
memguard.WipeBytes(data)
|
||||
|
||||
return errors.New("invalid secure token: unexpected escape sequence")
|
||||
}
|
||||
|
||||
if len(content) == 0 {
|
||||
memguard.WipeBytes(data)
|
||||
|
||||
return ErrTokenEmpty
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.buffer != nil {
|
||||
s.buffer.Destroy()
|
||||
s.buffer = nil
|
||||
}
|
||||
|
||||
// Allocate secure buffer and copy directly into it
|
||||
s.buffer = memguard.NewBuffer(len(content))
|
||||
copy(s.buffer.Bytes(), content)
|
||||
|
||||
// Wipe the input data
|
||||
memguard.WipeBytes(data)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TakeFrom transfers the token from src to this SecureToken, destroying any
|
||||
// existing token. The source token is cleared after transfer.
|
||||
// This avoids copying the underlying bytes.
|
||||
func (s *SecureToken) TakeFrom(src *SecureToken) {
|
||||
if src == nil || s == src {
|
||||
return
|
||||
}
|
||||
|
||||
// Extract buffer from source
|
||||
src.mu.Lock()
|
||||
buffer := src.buffer
|
||||
src.buffer = nil
|
||||
src.mu.Unlock()
|
||||
|
||||
// Install buffer in destination
|
||||
s.mu.Lock()
|
||||
if s.buffer != nil {
|
||||
s.buffer.Destroy()
|
||||
}
|
||||
s.buffer = buffer
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// Equals checks if token matches using constant-time comparison.
|
||||
// Returns false if the receiver is nil.
|
||||
func (s *SecureToken) Equals(token string) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.buffer == nil || !s.buffer.IsAlive() {
|
||||
return false
|
||||
}
|
||||
|
||||
return s.buffer.EqualTo([]byte(token))
|
||||
}
|
||||
|
||||
// EqualsSecure compares this token with another SecureToken using constant-time comparison.
|
||||
// Returns false if either receiver or other is nil.
|
||||
func (s *SecureToken) EqualsSecure(other *SecureToken) bool {
|
||||
if s == nil || other == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if s == other {
|
||||
return s.IsSet()
|
||||
}
|
||||
|
||||
// Get a copy of other's bytes (avoids holding two locks simultaneously)
|
||||
otherBytes, err := other.Bytes()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer memguard.WipeBytes(otherBytes)
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.buffer == nil || !s.buffer.IsAlive() {
|
||||
return false
|
||||
}
|
||||
|
||||
return s.buffer.EqualTo(otherBytes)
|
||||
}
|
||||
|
||||
// IsSet returns true if a token is stored.
|
||||
// Returns false if the receiver is nil.
|
||||
func (s *SecureToken) IsSet() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
return s.buffer != nil && s.buffer.IsAlive()
|
||||
}
|
||||
|
||||
// Bytes returns a copy of the token bytes (for signature generation).
|
||||
// The caller should zero the returned slice after use.
|
||||
// Returns ErrTokenNotSet if the receiver is nil.
|
||||
func (s *SecureToken) Bytes() ([]byte, error) {
|
||||
if s == nil {
|
||||
return nil, ErrTokenNotSet
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.buffer == nil || !s.buffer.IsAlive() {
|
||||
return nil, ErrTokenNotSet
|
||||
}
|
||||
|
||||
// Return a copy (unavoidable for signature generation)
|
||||
src := s.buffer.Bytes()
|
||||
result := make([]byte, len(src))
|
||||
copy(result, src)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Destroy securely wipes the token from memory.
|
||||
// No-op if the receiver is nil.
|
||||
func (s *SecureToken) Destroy() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.buffer != nil {
|
||||
s.buffer.Destroy()
|
||||
s.buffer = nil
|
||||
}
|
||||
}
|
||||
461
envd/internal/api/secure_token_test.go
Normal file
461
envd/internal/api/secure_token_test.go
Normal file
@ -0,0 +1,461 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSecureTokenSetAndEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
|
||||
// Initially not set
|
||||
assert.False(t, st.IsSet(), "token should not be set initially")
|
||||
assert.False(t, st.Equals("any-token"), "equals should return false when not set")
|
||||
|
||||
// Set token
|
||||
err := st.Set([]byte("test-token"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.IsSet(), "token should be set after Set()")
|
||||
assert.True(t, st.Equals("test-token"), "equals should return true for correct token")
|
||||
assert.False(t, st.Equals("wrong-token"), "equals should return false for wrong token")
|
||||
assert.False(t, st.Equals(""), "equals should return false for empty token")
|
||||
}
|
||||
|
||||
func TestSecureTokenReplace(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
|
||||
// Set initial token
|
||||
err := st.Set([]byte("first-token"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.Equals("first-token"))
|
||||
|
||||
// Replace with new token (old one should be destroyed)
|
||||
err = st.Set([]byte("second-token"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.Equals("second-token"), "should match new token")
|
||||
assert.False(t, st.Equals("first-token"), "should not match old token")
|
||||
}
|
||||
|
||||
func TestSecureTokenDestroy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
|
||||
// Set and then destroy
|
||||
err := st.Set([]byte("test-token"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.IsSet())
|
||||
|
||||
st.Destroy()
|
||||
assert.False(t, st.IsSet(), "token should not be set after Destroy()")
|
||||
assert.False(t, st.Equals("test-token"), "equals should return false after Destroy()")
|
||||
|
||||
// Destroy on already destroyed should be safe
|
||||
st.Destroy()
|
||||
assert.False(t, st.IsSet())
|
||||
|
||||
// Nil receiver should be safe
|
||||
var nilToken *SecureToken
|
||||
assert.False(t, nilToken.IsSet(), "nil receiver should return false for IsSet()")
|
||||
assert.False(t, nilToken.Equals("anything"), "nil receiver should return false for Equals()")
|
||||
assert.False(t, nilToken.EqualsSecure(st), "nil receiver should return false for EqualsSecure()")
|
||||
nilToken.Destroy() // should not panic
|
||||
|
||||
_, err = nilToken.Bytes()
|
||||
require.ErrorIs(t, err, ErrTokenNotSet, "nil receiver should return ErrTokenNotSet for Bytes()")
|
||||
}
|
||||
|
||||
func TestSecureTokenBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
|
||||
// Bytes should return error when not set
|
||||
_, err := st.Bytes()
|
||||
require.ErrorIs(t, err, ErrTokenNotSet)
|
||||
|
||||
// Set token and get bytes
|
||||
err = st.Set([]byte("test-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
bytes, err := st.Bytes()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("test-token"), bytes)
|
||||
|
||||
// Zero out the bytes (as caller should do)
|
||||
memguard.WipeBytes(bytes)
|
||||
|
||||
// Original should still be intact
|
||||
assert.True(t, st.Equals("test-token"), "original token should still work after zeroing copy")
|
||||
|
||||
// After destroy, bytes should fail
|
||||
st.Destroy()
|
||||
_, err = st.Bytes()
|
||||
assert.ErrorIs(t, err, ErrTokenNotSet)
|
||||
}
|
||||
|
||||
func TestSecureTokenConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
err := st.Set([]byte("initial-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const numGoroutines = 100
|
||||
|
||||
// Concurrent reads
|
||||
for range numGoroutines {
|
||||
wg.Go(func() {
|
||||
st.IsSet()
|
||||
st.Equals("initial-token")
|
||||
})
|
||||
}
|
||||
|
||||
// Concurrent writes
|
||||
for i := range 10 {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
st.Set([]byte("token-" + string(rune('a'+idx))))
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should still be in a valid state
|
||||
assert.True(t, st.IsSet())
|
||||
}
|
||||
|
||||
func TestSecureTokenEmptyToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
|
||||
// Setting empty token should return an error
|
||||
err := st.Set([]byte{})
|
||||
require.ErrorIs(t, err, ErrTokenEmpty)
|
||||
assert.False(t, st.IsSet(), "token should not be set after empty token error")
|
||||
|
||||
// Setting nil should also return an error
|
||||
err = st.Set(nil)
|
||||
require.ErrorIs(t, err, ErrTokenEmpty)
|
||||
assert.False(t, st.IsSet(), "token should not be set after nil token error")
|
||||
}
|
||||
|
||||
func TestSecureTokenEmptyTokenDoesNotClearExisting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
|
||||
// Set a valid token first
|
||||
err := st.Set([]byte("valid-token"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.IsSet())
|
||||
|
||||
// Attempting to set empty token should fail and preserve existing token
|
||||
err = st.Set([]byte{})
|
||||
require.ErrorIs(t, err, ErrTokenEmpty)
|
||||
assert.True(t, st.IsSet(), "existing token should be preserved after empty token error")
|
||||
assert.True(t, st.Equals("valid-token"), "existing token value should be unchanged")
|
||||
}
|
||||
|
||||
func TestSecureTokenUnmarshalJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("unmarshals valid JSON string", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.UnmarshalJSON([]byte(`"my-secret-token"`))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.IsSet())
|
||||
assert.True(t, st.Equals("my-secret-token"))
|
||||
})
|
||||
|
||||
t.Run("returns error for empty string", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.UnmarshalJSON([]byte(`""`))
|
||||
require.ErrorIs(t, err, ErrTokenEmpty)
|
||||
assert.False(t, st.IsSet())
|
||||
})
|
||||
|
||||
t.Run("returns error for invalid JSON", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.UnmarshalJSON([]byte(`not-valid-json`))
|
||||
require.Error(t, err)
|
||||
assert.False(t, st.IsSet())
|
||||
})
|
||||
|
||||
t.Run("replaces existing token", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.Set([]byte("old-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = st.UnmarshalJSON([]byte(`"new-token"`))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.Equals("new-token"))
|
||||
assert.False(t, st.Equals("old-token"))
|
||||
})
|
||||
|
||||
t.Run("wipes input buffer after parsing", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Create a buffer with a known token
|
||||
input := []byte(`"secret-token-12345"`)
|
||||
original := make([]byte, len(input))
|
||||
copy(original, input)
|
||||
|
||||
st := &SecureToken{}
|
||||
err := st.UnmarshalJSON(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the token was stored correctly
|
||||
assert.True(t, st.Equals("secret-token-12345"))
|
||||
|
||||
// Verify the input buffer was wiped (all zeros)
|
||||
for i, b := range input {
|
||||
assert.Equal(t, byte(0), b, "byte at position %d should be zero, got %d", i, b)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wipes input buffer on error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Create a buffer with an empty token (will error)
|
||||
input := []byte(`""`)
|
||||
|
||||
st := &SecureToken{}
|
||||
err := st.UnmarshalJSON(input)
|
||||
require.Error(t, err)
|
||||
|
||||
// Verify the input buffer was still wiped
|
||||
for i, b := range input {
|
||||
assert.Equal(t, byte(0), b, "byte at position %d should be zero, got %d", i, b)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects escape sequences", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.UnmarshalJSON([]byte(`"token\nwith\nnewlines"`))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "escape sequence")
|
||||
assert.False(t, st.IsSet())
|
||||
})
|
||||
}
|
||||
|
||||
func TestSecureTokenSetWipesInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("wipes input buffer after storing", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Create a buffer with a known token
|
||||
input := []byte("my-secret-token")
|
||||
original := make([]byte, len(input))
|
||||
copy(original, input)
|
||||
|
||||
st := &SecureToken{}
|
||||
err := st.Set(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the token was stored correctly
|
||||
assert.True(t, st.Equals("my-secret-token"))
|
||||
|
||||
// Verify the input buffer was wiped (all zeros)
|
||||
for i, b := range input {
|
||||
assert.Equal(t, byte(0), b, "byte at position %d should be zero, got %d", i, b)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSecureTokenTakeFrom(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("transfers token from source to destination", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
src := &SecureToken{}
|
||||
err := src.Set([]byte("source-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
dst := &SecureToken{}
|
||||
dst.TakeFrom(src)
|
||||
|
||||
assert.True(t, dst.IsSet())
|
||||
assert.True(t, dst.Equals("source-token"))
|
||||
assert.False(t, src.IsSet(), "source should be empty after transfer")
|
||||
})
|
||||
|
||||
t.Run("replaces existing destination token", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
src := &SecureToken{}
|
||||
err := src.Set([]byte("new-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
dst := &SecureToken{}
|
||||
err = dst.Set([]byte("old-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
dst.TakeFrom(src)
|
||||
|
||||
assert.True(t, dst.Equals("new-token"))
|
||||
assert.False(t, dst.Equals("old-token"))
|
||||
assert.False(t, src.IsSet())
|
||||
})
|
||||
|
||||
t.Run("handles nil source", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dst := &SecureToken{}
|
||||
err := dst.Set([]byte("existing-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
dst.TakeFrom(nil)
|
||||
|
||||
assert.True(t, dst.IsSet(), "destination should be unchanged with nil source")
|
||||
assert.True(t, dst.Equals("existing-token"))
|
||||
})
|
||||
|
||||
t.Run("handles empty source", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
src := &SecureToken{}
|
||||
dst := &SecureToken{}
|
||||
err := dst.Set([]byte("existing-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
dst.TakeFrom(src)
|
||||
|
||||
assert.False(t, dst.IsSet(), "destination should be cleared when source is empty")
|
||||
})
|
||||
|
||||
t.Run("self-transfer is no-op and does not deadlock", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.Set([]byte("token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
st.TakeFrom(st)
|
||||
|
||||
assert.True(t, st.IsSet(), "token should remain set after self-transfer")
|
||||
assert.True(t, st.Equals("token"), "token value should be unchanged")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSecureTokenEqualsSecure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns true for matching tokens", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st1 := &SecureToken{}
|
||||
err := st1.Set([]byte("same-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
st2 := &SecureToken{}
|
||||
err = st2.Set([]byte("same-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, st1.EqualsSecure(st2))
|
||||
assert.True(t, st2.EqualsSecure(st1))
|
||||
})
|
||||
|
||||
t.Run("concurrent TakeFrom and EqualsSecure do not deadlock", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// This test verifies the fix for the lock ordering deadlock bug.
|
||||
|
||||
const iterations = 100
|
||||
|
||||
for range iterations {
|
||||
a := &SecureToken{}
|
||||
err := a.Set([]byte("token-a"))
|
||||
require.NoError(t, err)
|
||||
|
||||
b := &SecureToken{}
|
||||
err = b.Set([]byte("token-b"))
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
// Goroutine 1: a.TakeFrom(b)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
a.TakeFrom(b)
|
||||
}()
|
||||
|
||||
// Goroutine 2: b.EqualsSecure(a)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
b.EqualsSecure(a)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns false for different tokens", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st1 := &SecureToken{}
|
||||
err := st1.Set([]byte("token-a"))
|
||||
require.NoError(t, err)
|
||||
|
||||
st2 := &SecureToken{}
|
||||
err = st2.Set([]byte("token-b"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, st1.EqualsSecure(st2))
|
||||
})
|
||||
|
||||
t.Run("returns false when comparing with nil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.Set([]byte("token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, st.EqualsSecure(nil))
|
||||
})
|
||||
|
||||
t.Run("returns false when other is not set", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st1 := &SecureToken{}
|
||||
err := st1.Set([]byte("token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
st2 := &SecureToken{}
|
||||
|
||||
assert.False(t, st1.EqualsSecure(st2))
|
||||
})
|
||||
|
||||
t.Run("returns false when self is not set", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st1 := &SecureToken{}
|
||||
|
||||
st2 := &SecureToken{}
|
||||
err := st2.Set([]byte("token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, st1.EqualsSecure(st2))
|
||||
})
|
||||
|
||||
t.Run("self-comparison returns true when set", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.Set([]byte("token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, st.EqualsSecure(st), "self-comparison should return true and not deadlock")
|
||||
})
|
||||
|
||||
t.Run("self-comparison returns false when not set", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
|
||||
assert.False(t, st.EqualsSecure(st), "self-comparison on unset token should return false")
|
||||
})
|
||||
}
|
||||
93
envd/internal/api/store.go
Normal file
93
envd/internal/api/store.go
Normal file
@ -0,0 +1,93 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
// MMDSClient provides access to MMDS metadata.
|
||||
type MMDSClient interface {
|
||||
GetAccessTokenHash(ctx context.Context) (string, error)
|
||||
}
|
||||
|
||||
// DefaultMMDSClient is the production implementation that calls the real MMDS endpoint.
|
||||
type DefaultMMDSClient struct{}
|
||||
|
||||
func (c *DefaultMMDSClient) GetAccessTokenHash(ctx context.Context) (string, error) {
|
||||
return host.GetAccessTokenHashFromMMDS(ctx)
|
||||
}
|
||||
|
||||
type API struct {
|
||||
isNotFC bool
|
||||
logger *zerolog.Logger
|
||||
accessToken *SecureToken
|
||||
defaults *execcontext.Defaults
|
||||
|
||||
mmdsChan chan *host.MMDSOpts
|
||||
hyperloopLock sync.Mutex
|
||||
mmdsClient MMDSClient
|
||||
|
||||
lastSetTime *utils.AtomicMax
|
||||
initLock sync.Mutex
|
||||
}
|
||||
|
||||
func New(l *zerolog.Logger, defaults *execcontext.Defaults, mmdsChan chan *host.MMDSOpts, isNotFC bool) *API {
|
||||
return &API{
|
||||
logger: l,
|
||||
defaults: defaults,
|
||||
mmdsChan: mmdsChan,
|
||||
isNotFC: isNotFC,
|
||||
mmdsClient: &DefaultMMDSClient{},
|
||||
lastSetTime: utils.NewAtomicMax(),
|
||||
accessToken: &SecureToken{},
|
||||
}
|
||||
}
|
||||
|
||||
func (a *API) GetHealth(w http.ResponseWriter, r *http.Request) {
|
||||
defer r.Body.Close()
|
||||
|
||||
a.logger.Trace().Msg("Health check")
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.Header().Set("Content-Type", "")
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (a *API) GetMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
defer r.Body.Close()
|
||||
|
||||
a.logger.Trace().Msg("Get metrics")
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
metrics, err := host.GetMetrics()
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Msg("Failed to get metrics")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(metrics); err != nil {
|
||||
a.logger.Error().Err(err).Msg("Failed to encode metrics")
|
||||
}
|
||||
}
|
||||
|
||||
func (a *API) getLogger(err error) *zerolog.Event {
|
||||
if err != nil {
|
||||
return a.logger.Error().Err(err) //nolint:zerologlint // this is only prep
|
||||
}
|
||||
|
||||
return a.logger.Info() //nolint:zerologlint // this is only prep
|
||||
}
|
||||
309
envd/internal/api/upload.go
Normal file
309
envd/internal/api/upload.go
Normal file
@ -0,0 +1,309 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
var ErrNoDiskSpace = fmt.Errorf("not enough disk space available")
|
||||
|
||||
func processFile(r *http.Request, path string, part io.Reader, uid, gid int, logger zerolog.Logger) (int, error) {
|
||||
logger.Debug().
|
||||
Str("path", path).
|
||||
Msg("File processing")
|
||||
|
||||
err := permissions.EnsureDirs(filepath.Dir(path), uid, gid)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("error ensuring directories: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
|
||||
canBePreChowned := false
|
||||
stat, err := os.Stat(path)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
errMsg := fmt.Errorf("error getting file info: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, errMsg
|
||||
} else if err == nil {
|
||||
if stat.IsDir() {
|
||||
err := fmt.Errorf("path is a directory: %s", path)
|
||||
|
||||
return http.StatusBadRequest, err
|
||||
}
|
||||
canBePreChowned = true
|
||||
}
|
||||
|
||||
hasBeenChowned := false
|
||||
if canBePreChowned {
|
||||
err = os.Chown(path, uid, gid)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
err = fmt.Errorf("error changing file ownership: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
} else {
|
||||
hasBeenChowned = true
|
||||
}
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o666)
|
||||
if err != nil {
|
||||
if errors.Is(err, syscall.ENOSPC) {
|
||||
err = fmt.Errorf("not enough inodes available: %w", err)
|
||||
|
||||
return http.StatusInsufficientStorage, err
|
||||
}
|
||||
|
||||
err := fmt.Errorf("error opening file: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
|
||||
defer file.Close()
|
||||
|
||||
if !hasBeenChowned {
|
||||
err = os.Chown(path, uid, gid)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("error changing file ownership: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
}
|
||||
|
||||
_, err = file.ReadFrom(part)
|
||||
if err != nil {
|
||||
if errors.Is(err, syscall.ENOSPC) {
|
||||
err = ErrNoDiskSpace
|
||||
if r.ContentLength > 0 {
|
||||
err = fmt.Errorf("attempted to write %d bytes: %w", r.ContentLength, err)
|
||||
}
|
||||
|
||||
return http.StatusInsufficientStorage, err
|
||||
}
|
||||
|
||||
err = fmt.Errorf("error writing file: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
|
||||
return http.StatusNoContent, nil
|
||||
}
|
||||
|
||||
func resolvePath(part *multipart.Part, paths *UploadSuccess, u *user.User, defaultPath *string, params PostFilesParams) (string, error) {
|
||||
var pathToResolve string
|
||||
|
||||
if params.Path != nil {
|
||||
pathToResolve = *params.Path
|
||||
} else {
|
||||
var err error
|
||||
customPart := utils.NewCustomPart(part)
|
||||
pathToResolve, err = customPart.FileNameWithPath()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error getting multipart custom part file name: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
filePath, err := permissions.ExpandAndResolve(pathToResolve, u, defaultPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error resolving path: %w", err)
|
||||
}
|
||||
|
||||
for _, entry := range *paths {
|
||||
if entry.Path == filePath {
|
||||
var alreadyUploaded []string
|
||||
for _, uploadedFile := range *paths {
|
||||
if uploadedFile.Path != filePath {
|
||||
alreadyUploaded = append(alreadyUploaded, uploadedFile.Path)
|
||||
}
|
||||
}
|
||||
|
||||
errMsg := fmt.Errorf("you cannot upload multiple files to the same path '%s' in one upload request, only the first specified file was uploaded", filePath)
|
||||
|
||||
if len(alreadyUploaded) > 1 {
|
||||
errMsg = fmt.Errorf("%w, also the following files were uploaded: %v", errMsg, strings.Join(alreadyUploaded, ", "))
|
||||
}
|
||||
|
||||
return "", errMsg
|
||||
}
|
||||
}
|
||||
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
func (a *API) handlePart(r *http.Request, part *multipart.Part, paths UploadSuccess, u *user.User, uid, gid int, operationID string, params PostFilesParams) (*EntryInfo, int, error) {
|
||||
defer part.Close()
|
||||
|
||||
if part.FormName() != "file" {
|
||||
return nil, http.StatusOK, nil
|
||||
}
|
||||
|
||||
filePath, err := resolvePath(part, &paths, u, a.defaults.Workdir, params)
|
||||
if err != nil {
|
||||
return nil, http.StatusBadRequest, err
|
||||
}
|
||||
|
||||
logger := a.logger.
|
||||
With().
|
||||
Str(string(logs.OperationIDKey), operationID).
|
||||
Str("event_type", "file_processing").
|
||||
Logger()
|
||||
|
||||
status, err := processFile(r, filePath, part, uid, gid, logger)
|
||||
if err != nil {
|
||||
return nil, status, err
|
||||
}
|
||||
|
||||
return &EntryInfo{
|
||||
Path: filePath,
|
||||
Name: filepath.Base(filePath),
|
||||
Type: File,
|
||||
}, http.StatusOK, nil
|
||||
}
|
||||
|
||||
func (a *API) PostFiles(w http.ResponseWriter, r *http.Request, params PostFilesParams) {
|
||||
// Capture original body to ensure it's always closed
|
||||
originalBody := r.Body
|
||||
defer originalBody.Close()
|
||||
|
||||
var errorCode int
|
||||
var errMsg error
|
||||
|
||||
var path string
|
||||
if params.Path != nil {
|
||||
path = *params.Path
|
||||
}
|
||||
|
||||
operationID := logs.AssignOperationID()
|
||||
|
||||
// signing authorization if needed
|
||||
err := a.validateSigning(r, params.Signature, params.SignatureExpiration, params.Username, path, SigningWriteOperation)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("error during auth validation")
|
||||
jsonError(w, http.StatusUnauthorized, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
username, err := execcontext.ResolveDefaultUsername(params.Username, a.defaults.User)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("no user specified")
|
||||
jsonError(w, http.StatusBadRequest, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
l := a.logger.
|
||||
Err(errMsg).
|
||||
Str("method", r.Method+" "+r.URL.Path).
|
||||
Str(string(logs.OperationIDKey), operationID).
|
||||
Str("path", path).
|
||||
Str("username", username)
|
||||
|
||||
if errMsg != nil {
|
||||
l = l.Int("error_code", errorCode)
|
||||
}
|
||||
|
||||
l.Msg("File write")
|
||||
}()
|
||||
|
||||
// Handle gzip-encoded request body
|
||||
body, err := getDecompressedBody(r)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error decompressing request body: %w", err)
|
||||
errorCode = http.StatusBadRequest
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
defer body.Close()
|
||||
r.Body = body
|
||||
|
||||
f, err := r.MultipartReader()
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error parsing multipart form: %w", err)
|
||||
errorCode = http.StatusInternalServerError
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
u, err := user.Lookup(username)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error looking up user '%s': %w", username, err)
|
||||
errorCode = http.StatusUnauthorized
|
||||
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
uid, gid, err := permissions.GetUserIdInts(u)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error getting user ids: %w", err)
|
||||
|
||||
jsonError(w, http.StatusInternalServerError, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
paths := UploadSuccess{}
|
||||
|
||||
for {
|
||||
part, partErr := f.NextPart()
|
||||
|
||||
if partErr == io.EOF {
|
||||
// We're done reading the parts.
|
||||
break
|
||||
} else if partErr != nil {
|
||||
errMsg = fmt.Errorf("error reading form: %w", partErr)
|
||||
errorCode = http.StatusInternalServerError
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
entry, status, err := a.handlePart(r, part, paths, u, uid, gid, operationID, params)
|
||||
if err != nil {
|
||||
errorCode = status
|
||||
errMsg = err
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if entry != nil {
|
||||
paths = append(paths, *entry)
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(paths)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error marshaling response: %w", err)
|
||||
errorCode = http.StatusInternalServerError
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(data)
|
||||
}
|
||||
249
envd/internal/api/upload_test.go
Normal file
249
envd/internal/api/upload_test.go
Normal file
@ -0,0 +1,249 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestProcessFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
uid := os.Getuid()
|
||||
gid := os.Getgid()
|
||||
|
||||
newRequest := func(content []byte) (*http.Request, io.Reader) {
|
||||
request := &http.Request{
|
||||
ContentLength: int64(len(content)),
|
||||
}
|
||||
buffer := bytes.NewBuffer(content)
|
||||
|
||||
return request, buffer
|
||||
}
|
||||
|
||||
var emptyReq http.Request
|
||||
var emptyPart *bytes.Buffer
|
||||
var emptyLogger zerolog.Logger
|
||||
|
||||
t.Run("failed to ensure directories", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
httpStatus, err := processFile(&emptyReq, "/proc/invalid/not-real", emptyPart, uid, gid, emptyLogger)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, http.StatusInternalServerError, httpStatus)
|
||||
assert.ErrorContains(t, err, "error ensuring directories: ")
|
||||
})
|
||||
|
||||
t.Run("attempt to replace directory with a file", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tempDir := t.TempDir()
|
||||
|
||||
httpStatus, err := processFile(&emptyReq, tempDir, emptyPart, uid, gid, emptyLogger)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, http.StatusBadRequest, httpStatus, err.Error())
|
||||
assert.ErrorContains(t, err, "path is a directory: ")
|
||||
})
|
||||
|
||||
t.Run("fail to create file", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
httpStatus, err := processFile(&emptyReq, "/proc/invalid-filename", emptyPart, uid, gid, emptyLogger)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, http.StatusInternalServerError, httpStatus)
|
||||
assert.ErrorContains(t, err, "error opening file: ")
|
||||
})
|
||||
|
||||
t.Run("out of disk space", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// make a tiny tmpfs mount
|
||||
mountSize := 1024
|
||||
tempDir := createTmpfsMount(t, mountSize)
|
||||
|
||||
// create test file
|
||||
firstFileSize := mountSize / 2
|
||||
tempFile1 := filepath.Join(tempDir, "test-file-1")
|
||||
|
||||
// fill it up
|
||||
cmd := exec.CommandContext(t.Context(),
|
||||
"dd", "if=/dev/zero", "of="+tempFile1, fmt.Sprintf("bs=%d", firstFileSize), "count=1")
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// create a new file that would fill up the
|
||||
secondFileContents := make([]byte, mountSize*2)
|
||||
for index := range secondFileContents {
|
||||
secondFileContents[index] = 'a'
|
||||
}
|
||||
|
||||
// try to replace it
|
||||
request, buffer := newRequest(secondFileContents)
|
||||
tempFile2 := filepath.Join(tempDir, "test-file-2")
|
||||
httpStatus, err := processFile(request, tempFile2, buffer, uid, gid, emptyLogger)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, http.StatusInsufficientStorage, httpStatus)
|
||||
assert.ErrorContains(t, err, "attempted to write 2048 bytes: not enough disk space")
|
||||
})
|
||||
|
||||
t.Run("happy path", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tempDir := t.TempDir()
|
||||
tempFile := filepath.Join(tempDir, "test-file")
|
||||
|
||||
content := []byte("test-file-contents")
|
||||
request, buffer := newRequest(content)
|
||||
|
||||
httpStatus, err := processFile(request, tempFile, buffer, uid, gid, emptyLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNoContent, httpStatus)
|
||||
|
||||
data, err := os.ReadFile(tempFile)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, data)
|
||||
})
|
||||
|
||||
t.Run("overwrite file on full disk", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// make a tiny tmpfs mount
|
||||
sizeInBytes := 1024
|
||||
tempDir := createTmpfsMount(t, 1024)
|
||||
|
||||
// create test file
|
||||
tempFile := filepath.Join(tempDir, "test-file")
|
||||
|
||||
// fill it up
|
||||
cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+tempFile, fmt.Sprintf("bs=%d", sizeInBytes), "count=1")
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// try to replace it
|
||||
content := []byte("test-file-contents")
|
||||
request, buffer := newRequest(content)
|
||||
httpStatus, err := processFile(request, tempFile, buffer, uid, gid, emptyLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNoContent, httpStatus)
|
||||
})
|
||||
|
||||
t.Run("write new file on full disk", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// make a tiny tmpfs mount
|
||||
sizeInBytes := 1024
|
||||
tempDir := createTmpfsMount(t, 1024)
|
||||
|
||||
// create test file
|
||||
tempFile1 := filepath.Join(tempDir, "test-file")
|
||||
|
||||
// fill it up
|
||||
cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+tempFile1, fmt.Sprintf("bs=%d", sizeInBytes), "count=1")
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// try to write a new file
|
||||
tempFile2 := filepath.Join(tempDir, "test-file-2")
|
||||
content := []byte("test-file-contents")
|
||||
request, buffer := newRequest(content)
|
||||
httpStatus, err := processFile(request, tempFile2, buffer, uid, gid, emptyLogger)
|
||||
require.ErrorContains(t, err, "not enough disk space available")
|
||||
assert.Equal(t, http.StatusInsufficientStorage, httpStatus)
|
||||
})
|
||||
|
||||
t.Run("write new file with no inodes available", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// make a tiny tmpfs mount
|
||||
tempDir := createTmpfsMountWithInodes(t, 1024, 2)
|
||||
|
||||
// create test file
|
||||
tempFile1 := filepath.Join(tempDir, "test-file")
|
||||
|
||||
// fill it up
|
||||
cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+tempFile1, fmt.Sprintf("bs=%d", 100), "count=1")
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// try to write a new file
|
||||
tempFile2 := filepath.Join(tempDir, "test-file-2")
|
||||
content := []byte("test-file-contents")
|
||||
request, buffer := newRequest(content)
|
||||
httpStatus, err := processFile(request, tempFile2, buffer, uid, gid, emptyLogger)
|
||||
require.ErrorContains(t, err, "not enough inodes available")
|
||||
assert.Equal(t, http.StatusInsufficientStorage, httpStatus)
|
||||
})
|
||||
|
||||
t.Run("update sysfs or other virtual fs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if os.Geteuid() != 0 {
|
||||
t.Skip("skipping sysfs updates: Operation not permitted with non-root user")
|
||||
}
|
||||
|
||||
filePath := "/sys/fs/cgroup/user.slice/cpu.weight"
|
||||
newContent := []byte("102\n")
|
||||
request, buffer := newRequest(newContent)
|
||||
|
||||
httpStatus, err := processFile(request, filePath, buffer, uid, gid, emptyLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNoContent, httpStatus)
|
||||
|
||||
data, err := os.ReadFile(filePath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, newContent, data)
|
||||
})
|
||||
|
||||
t.Run("replace file", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tempDir := t.TempDir()
|
||||
tempFile := filepath.Join(tempDir, "test-file")
|
||||
|
||||
err := os.WriteFile(tempFile, []byte("old-contents"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
newContent := []byte("new-file-contents")
|
||||
request, buffer := newRequest(newContent)
|
||||
|
||||
httpStatus, err := processFile(request, tempFile, buffer, uid, gid, emptyLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNoContent, httpStatus)
|
||||
|
||||
data, err := os.ReadFile(tempFile)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, newContent, data)
|
||||
})
|
||||
}
|
||||
|
||||
func createTmpfsMount(t *testing.T, sizeInBytes int) string {
|
||||
t.Helper()
|
||||
|
||||
return createTmpfsMountWithInodes(t, sizeInBytes, 5)
|
||||
}
|
||||
|
||||
func createTmpfsMountWithInodes(t *testing.T, sizeInBytes, inodesCount int) string {
|
||||
t.Helper()
|
||||
|
||||
if os.Geteuid() != 0 {
|
||||
t.Skip("skipping sysfs updates: Operation not permitted with non-root user")
|
||||
}
|
||||
|
||||
tempDir := t.TempDir()
|
||||
|
||||
cmd := exec.CommandContext(t.Context(),
|
||||
"mount",
|
||||
"tmpfs",
|
||||
tempDir,
|
||||
"-t", "tmpfs",
|
||||
"-o", fmt.Sprintf("size=%d,nr_inodes=%d", sizeInBytes, inodesCount))
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
ctx := context.WithoutCancel(t.Context())
|
||||
cmd := exec.CommandContext(ctx, "umount", tempDir)
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
return tempDir
|
||||
}
|
||||
Reference in New Issue
Block a user