Port envd from e2b with internalized shared packages and Connect RPC
- Copy envd source from e2b-dev/infra, internalize shared dependencies
into envd/internal/shared/ (keys, filesystem, id, smap, utils)
- Switch from gRPC to Connect RPC for all envd services
- Update module paths to git.omukk.dev/wrenn/{sandbox,sandbox/envd}
- Add proto specs (process, filesystem) with buf-based code generation
- Implement full envd: process exec, filesystem ops, port forwarding,
cgroup management, MMDS integration, and HTTP API
- Update main module dependencies (firecracker SDK, pgx, goose, etc.)
- Remove placeholder .gitkeep files replaced by real implementations
This commit is contained in:
108
envd/internal/shared/filesystem/entry.go
Normal file
108
envd/internal/shared/filesystem/entry.go
Normal file
@ -0,0 +1,108 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
func GetEntryFromPath(path string) (EntryInfo, error) {
|
||||
fileInfo, err := os.Lstat(path)
|
||||
if err != nil {
|
||||
return EntryInfo{}, err
|
||||
}
|
||||
|
||||
return GetEntryInfo(path, fileInfo), nil
|
||||
}
|
||||
|
||||
func GetEntryInfo(path string, fileInfo os.FileInfo) EntryInfo {
|
||||
fileMode := fileInfo.Mode()
|
||||
|
||||
var symlinkTarget *string
|
||||
if fileMode&os.ModeSymlink != 0 {
|
||||
// If we can't resolve the symlink target, we won't set the target
|
||||
target := followSymlink(path)
|
||||
symlinkTarget = &target
|
||||
}
|
||||
|
||||
var entryType FileType
|
||||
var mode os.FileMode
|
||||
|
||||
if symlinkTarget == nil {
|
||||
entryType = getEntryType(fileMode)
|
||||
mode = fileMode.Perm()
|
||||
} else {
|
||||
// If it's a symlink, we need to determine the type of the target
|
||||
targetInfo, err := os.Stat(*symlinkTarget)
|
||||
if err != nil {
|
||||
entryType = UnknownFileType
|
||||
} else {
|
||||
entryType = getEntryType(targetInfo.Mode())
|
||||
mode = targetInfo.Mode().Perm()
|
||||
}
|
||||
}
|
||||
|
||||
entry := EntryInfo{
|
||||
Name: fileInfo.Name(),
|
||||
Path: path,
|
||||
Type: entryType,
|
||||
Size: fileInfo.Size(),
|
||||
Mode: mode,
|
||||
Permissions: fileMode.String(),
|
||||
ModifiedTime: fileInfo.ModTime(),
|
||||
SymlinkTarget: symlinkTarget,
|
||||
}
|
||||
|
||||
if base := getBase(fileInfo.Sys()); base != nil {
|
||||
entry.AccessedTime = toTimestamp(base.Atim)
|
||||
entry.CreatedTime = toTimestamp(base.Ctim)
|
||||
entry.ModifiedTime = toTimestamp(base.Mtim)
|
||||
entry.UID = base.Uid
|
||||
entry.GID = base.Gid
|
||||
} else if !fileInfo.ModTime().IsZero() {
|
||||
entry.ModifiedTime = fileInfo.ModTime()
|
||||
}
|
||||
|
||||
return entry
|
||||
}
|
||||
|
||||
// getEntryType determines the type of file entry based on its mode and path.
|
||||
// If the file is a symlink, it follows the symlink to determine the actual type.
|
||||
func getEntryType(mode os.FileMode) FileType {
|
||||
switch {
|
||||
case mode.IsRegular():
|
||||
return FileFileType
|
||||
case mode.IsDir():
|
||||
return DirectoryFileType
|
||||
case mode&os.ModeSymlink == os.ModeSymlink:
|
||||
return SymlinkFileType
|
||||
default:
|
||||
return UnknownFileType
|
||||
}
|
||||
}
|
||||
|
||||
// followSymlink resolves a symbolic link to its target path.
|
||||
func followSymlink(path string) string {
|
||||
// Resolve symlinks
|
||||
resolvedPath, err := filepath.EvalSymlinks(path)
|
||||
if err != nil {
|
||||
return path
|
||||
}
|
||||
|
||||
return resolvedPath
|
||||
}
|
||||
|
||||
func toTimestamp(spec syscall.Timespec) time.Time {
|
||||
if spec.Sec == 0 && spec.Nsec == 0 {
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
return time.Unix(spec.Sec, spec.Nsec)
|
||||
}
|
||||
|
||||
func getBase(sys any) *syscall.Stat_t {
|
||||
st, _ := sys.(*syscall.Stat_t)
|
||||
|
||||
return st
|
||||
}
|
||||
264
envd/internal/shared/filesystem/entry_test.go
Normal file
264
envd/internal/shared/filesystem/entry_test.go
Normal file
@ -0,0 +1,264 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetEntryType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create test files
|
||||
regularFile := filepath.Join(tempDir, "regular.txt")
|
||||
require.NoError(t, os.WriteFile(regularFile, []byte("test content"), 0o644))
|
||||
|
||||
testDir := filepath.Join(tempDir, "testdir")
|
||||
require.NoError(t, os.MkdirAll(testDir, 0o755))
|
||||
|
||||
symlink := filepath.Join(tempDir, "symlink")
|
||||
require.NoError(t, os.Symlink(regularFile, symlink))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expected FileType
|
||||
}{
|
||||
{
|
||||
name: "regular file",
|
||||
path: regularFile,
|
||||
expected: FileFileType,
|
||||
},
|
||||
{
|
||||
name: "directory",
|
||||
path: testDir,
|
||||
expected: DirectoryFileType,
|
||||
},
|
||||
{
|
||||
name: "symlink to file",
|
||||
path: symlink,
|
||||
expected: SymlinkFileType,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
info, err := os.Lstat(tt.path)
|
||||
require.NoError(t, err)
|
||||
|
||||
result := getEntryType(info.Mode())
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEntryInfoFromFileInfo_SymlinkChain(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Base temporary directory. On macOS this lives under /var/folders/…
|
||||
// which itself is a symlink to /private/var/folders/….
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create final target
|
||||
target := filepath.Join(tempDir, "target")
|
||||
require.NoError(t, os.MkdirAll(target, 0o755))
|
||||
|
||||
// Create a chain: link1 → link2 → target
|
||||
link2 := filepath.Join(tempDir, "link2")
|
||||
require.NoError(t, os.Symlink(target, link2))
|
||||
|
||||
link1 := filepath.Join(tempDir, "link1")
|
||||
require.NoError(t, os.Symlink(link2, link1))
|
||||
|
||||
// run the test
|
||||
result, err := GetEntryFromPath(link1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// verify the results
|
||||
assert.Equal(t, "link1", result.Name)
|
||||
assert.Equal(t, link1, result.Path)
|
||||
assert.Equal(t, DirectoryFileType, result.Type) // Should resolve to final target type
|
||||
assert.Contains(t, result.Permissions, "L")
|
||||
|
||||
// Canonicalize the expected target path to handle macOS symlink indirections
|
||||
expectedTarget, err := filepath.EvalSymlinks(link1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedTarget, *result.SymlinkTarget)
|
||||
}
|
||||
|
||||
func TestEntryInfoFromFileInfo_DifferentPermissions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
permissions os.FileMode
|
||||
expectedMode os.FileMode
|
||||
expectedString string
|
||||
}{
|
||||
{"read-only", 0o444, 0o444, "-r--r--r--"},
|
||||
{"executable", 0o755, 0o755, "-rwxr-xr-x"},
|
||||
{"write-only", 0o200, 0o200, "--w-------"},
|
||||
{"no permissions", 0o000, 0o000, "----------"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testFile := filepath.Join(tempDir, tc.name+".txt")
|
||||
require.NoError(t, os.WriteFile(testFile, []byte("test"), tc.permissions))
|
||||
|
||||
result, err := GetEntryFromPath(testFile)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.expectedMode, result.Mode)
|
||||
assert.Equal(t, tc.expectedString, result.Permissions)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEntryInfoFromFileInfo_EmptyFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
emptyFile := filepath.Join(tempDir, "empty.txt")
|
||||
require.NoError(t, os.WriteFile(emptyFile, []byte{}, 0o600))
|
||||
|
||||
result, err := GetEntryFromPath(emptyFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "empty.txt", result.Name)
|
||||
assert.Equal(t, int64(0), result.Size)
|
||||
assert.Equal(t, os.FileMode(0o600), result.Mode)
|
||||
assert.Equal(t, FileFileType, result.Type)
|
||||
}
|
||||
|
||||
func TestEntryInfoFromFileInfo_CyclicSymlink(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create cyclic symlink
|
||||
cyclicSymlink := filepath.Join(tempDir, "cyclic")
|
||||
require.NoError(t, os.Symlink(cyclicSymlink, cyclicSymlink))
|
||||
|
||||
result, err := GetEntryFromPath(cyclicSymlink)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "cyclic", result.Name)
|
||||
assert.Equal(t, cyclicSymlink, result.Path)
|
||||
assert.Equal(t, UnknownFileType, result.Type)
|
||||
assert.Contains(t, result.Permissions, "L")
|
||||
}
|
||||
|
||||
func TestEntryInfoFromFileInfo_BrokenSymlink(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create broken symlink
|
||||
brokenSymlink := filepath.Join(tempDir, "broken")
|
||||
require.NoError(t, os.Symlink("/nonexistent", brokenSymlink))
|
||||
|
||||
result, err := GetEntryFromPath(brokenSymlink)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "broken", result.Name)
|
||||
assert.Equal(t, brokenSymlink, result.Path)
|
||||
assert.Equal(t, UnknownFileType, result.Type)
|
||||
assert.Contains(t, result.Permissions, "L")
|
||||
// SymlinkTarget might be empty if followSymlink fails
|
||||
}
|
||||
|
||||
func TestEntryInfoFromFileInfo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create a regular file with known content and permissions
|
||||
testFile := filepath.Join(tempDir, "test.txt")
|
||||
testContent := []byte("Hello, World!")
|
||||
require.NoError(t, os.WriteFile(testFile, testContent, 0o644))
|
||||
|
||||
// Get current user for ownership comparison
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := GetEntryFromPath(testFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Basic assertions
|
||||
assert.Equal(t, "test.txt", result.Name)
|
||||
assert.Equal(t, testFile, result.Path)
|
||||
assert.Equal(t, int64(len(testContent)), result.Size)
|
||||
assert.Equal(t, FileFileType, result.Type)
|
||||
assert.Equal(t, os.FileMode(0o644), result.Mode)
|
||||
assert.Contains(t, result.Permissions, "-rw-r--r--")
|
||||
assert.Equal(t, currentUser.Uid, strconv.Itoa(int(result.UID)))
|
||||
assert.Equal(t, currentUser.Gid, strconv.Itoa(int(result.GID)))
|
||||
assert.NotNil(t, result.ModifiedTime)
|
||||
assert.Empty(t, result.SymlinkTarget)
|
||||
|
||||
// Check that modified time is reasonable (within last minute)
|
||||
modTime := result.ModifiedTime
|
||||
assert.WithinDuration(t, time.Now(), modTime, time.Minute)
|
||||
}
|
||||
|
||||
func TestEntryInfoFromFileInfo_Directory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
testDir := filepath.Join(tempDir, "testdir")
|
||||
require.NoError(t, os.MkdirAll(testDir, 0o755))
|
||||
|
||||
result, err := GetEntryFromPath(testDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "testdir", result.Name)
|
||||
assert.Equal(t, testDir, result.Path)
|
||||
assert.Equal(t, DirectoryFileType, result.Type)
|
||||
assert.Equal(t, os.FileMode(0o755), result.Mode)
|
||||
assert.Equal(t, "drwxr-xr-x", result.Permissions)
|
||||
assert.Empty(t, result.SymlinkTarget)
|
||||
}
|
||||
|
||||
func TestEntryInfoFromFileInfo_Symlink(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Base temporary directory. On macOS this lives under /var/folders/…
|
||||
// which itself is a symlink to /private/var/folders/….
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create target file
|
||||
targetFile := filepath.Join(tempDir, "target.txt")
|
||||
require.NoError(t, os.WriteFile(targetFile, []byte("target content"), 0o644))
|
||||
|
||||
// Create symlink
|
||||
symlinkPath := filepath.Join(tempDir, "symlink")
|
||||
require.NoError(t, os.Symlink(targetFile, symlinkPath))
|
||||
|
||||
// Use Lstat to get symlink info (not the target)
|
||||
result, err := GetEntryFromPath(symlinkPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "symlink", result.Name)
|
||||
assert.Equal(t, symlinkPath, result.Path)
|
||||
assert.Equal(t, FileFileType, result.Type) // Should resolve to target type
|
||||
assert.Contains(t, result.Permissions, "L") // Should show as symlink in permissions
|
||||
|
||||
// Canonicalize the expected target path to handle macOS /var → /private/var symlink
|
||||
expectedTarget, err := filepath.EvalSymlinks(symlinkPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedTarget, *result.SymlinkTarget)
|
||||
}
|
||||
30
envd/internal/shared/filesystem/model.go
Normal file
30
envd/internal/shared/filesystem/model.go
Normal file
@ -0,0 +1,30 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
type EntryInfo struct {
|
||||
Name string
|
||||
Type FileType
|
||||
Path string
|
||||
Size int64
|
||||
Mode os.FileMode
|
||||
Permissions string
|
||||
UID uint32
|
||||
GID uint32
|
||||
AccessedTime time.Time
|
||||
CreatedTime time.Time
|
||||
ModifiedTime time.Time
|
||||
SymlinkTarget *string
|
||||
}
|
||||
|
||||
type FileType int32
|
||||
|
||||
const (
|
||||
UnknownFileType FileType = 0
|
||||
FileFileType FileType = 1
|
||||
DirectoryFileType FileType = 2
|
||||
SymlinkFileType FileType = 3
|
||||
)
|
||||
164
envd/internal/shared/id/id.go
Normal file
164
envd/internal/shared/id/id.go
Normal file
@ -0,0 +1,164 @@
|
||||
package id
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/dchest/uniuri"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var (
|
||||
caseInsensitiveAlphabet = []byte("abcdefghijklmnopqrstuvwxyz1234567890")
|
||||
identifierRegex = regexp.MustCompile(`^[a-z0-9-_]+$`)
|
||||
tagRegex = regexp.MustCompile(`^[a-z0-9-_.]+$`)
|
||||
sandboxIDRegex = regexp.MustCompile(`^[a-z0-9]+$`)
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultTag = "default"
|
||||
TagSeparator = ":"
|
||||
NamespaceSeparator = "/"
|
||||
)
|
||||
|
||||
func Generate() string {
|
||||
return uniuri.NewLenChars(uniuri.UUIDLen, caseInsensitiveAlphabet)
|
||||
}
|
||||
|
||||
// ValidateSandboxID checks that a sandbox ID contains only lowercase alphanumeric characters.
|
||||
func ValidateSandboxID(sandboxID string) error {
|
||||
if !sandboxIDRegex.MatchString(sandboxID) {
|
||||
return fmt.Errorf("invalid sandbox ID: %q", sandboxID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanAndValidate(value, name string, re *regexp.Regexp) (string, error) {
|
||||
cleaned := strings.ToLower(strings.TrimSpace(value))
|
||||
if !re.MatchString(cleaned) {
|
||||
return "", fmt.Errorf("invalid %s: %s", name, value)
|
||||
}
|
||||
|
||||
return cleaned, nil
|
||||
}
|
||||
|
||||
func validateTag(tag string) (string, error) {
|
||||
cleanedTag, err := cleanAndValidate(tag, "tag", tagRegex)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Prevent tags from being a UUID
|
||||
_, err = uuid.Parse(cleanedTag)
|
||||
if err == nil {
|
||||
return "", errors.New("tag cannot be a UUID")
|
||||
}
|
||||
|
||||
return cleanedTag, nil
|
||||
}
|
||||
|
||||
func ValidateAndDeduplicateTags(tags []string) ([]string, error) {
|
||||
seen := make(map[string]struct{})
|
||||
|
||||
for _, tag := range tags {
|
||||
cleanedTag, err := validateTag(tag)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid tag '%s': %w", tag, err)
|
||||
}
|
||||
|
||||
seen[cleanedTag] = struct{}{}
|
||||
}
|
||||
|
||||
return slices.Collect(maps.Keys(seen)), nil
|
||||
}
|
||||
|
||||
// SplitIdentifier splits "namespace/alias" into its parts.
|
||||
// Returns nil namespace for bare aliases, pointer for explicit namespace.
|
||||
func SplitIdentifier(identifier string) (namespace *string, alias string) {
|
||||
before, after, found := strings.Cut(identifier, NamespaceSeparator)
|
||||
if !found {
|
||||
return nil, before
|
||||
}
|
||||
|
||||
return &before, after
|
||||
}
|
||||
|
||||
// ParseName parses and validates "namespace/alias:tag" or "alias:tag".
|
||||
// Returns the cleaned identifier (namespace/alias or alias) and optional tag.
|
||||
// All components are validated and normalized (lowercase, trimmed).
|
||||
func ParseName(input string) (identifier string, tag *string, err error) {
|
||||
input = strings.TrimSpace(input)
|
||||
|
||||
// Extract raw parts
|
||||
identifierPart, tagPart, hasTag := strings.Cut(input, TagSeparator)
|
||||
namespacePart, aliasPart := SplitIdentifier(identifierPart)
|
||||
|
||||
// Validate tag
|
||||
if hasTag {
|
||||
validated, err := cleanAndValidate(tagPart, "tag", tagRegex)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if !strings.EqualFold(validated, DefaultTag) {
|
||||
tag = &validated
|
||||
}
|
||||
}
|
||||
|
||||
// Validate namespace
|
||||
if namespacePart != nil {
|
||||
validated, err := cleanAndValidate(*namespacePart, "namespace", identifierRegex)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
namespacePart = &validated
|
||||
}
|
||||
|
||||
// Validate alias
|
||||
aliasPart, err = cleanAndValidate(aliasPart, "template ID", identifierRegex)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// Build identifier
|
||||
if namespacePart != nil {
|
||||
identifier = WithNamespace(*namespacePart, aliasPart)
|
||||
} else {
|
||||
identifier = aliasPart
|
||||
}
|
||||
|
||||
return identifier, tag, nil
|
||||
}
|
||||
|
||||
// WithTag returns the identifier with the given tag appended (e.g. "templateID:tag").
|
||||
func WithTag(identifier, tag string) string {
|
||||
return identifier + TagSeparator + tag
|
||||
}
|
||||
|
||||
// WithNamespace returns identifier with the given namespace prefix.
|
||||
func WithNamespace(namespace, alias string) string {
|
||||
return namespace + NamespaceSeparator + alias
|
||||
}
|
||||
|
||||
// ExtractAlias returns just the alias portion from an identifier (namespace/alias or alias).
|
||||
func ExtractAlias(identifier string) string {
|
||||
_, alias := SplitIdentifier(identifier)
|
||||
|
||||
return alias
|
||||
}
|
||||
|
||||
// ValidateNamespaceMatchesTeam checks if an explicit namespace in the identifier matches the team's slug.
|
||||
// Returns an error if the namespace doesn't match.
|
||||
// If the identifier has no explicit namespace, returns nil (valid).
|
||||
func ValidateNamespaceMatchesTeam(identifier, teamSlug string) error {
|
||||
namespace, _ := SplitIdentifier(identifier)
|
||||
if namespace != nil && *namespace != teamSlug {
|
||||
return fmt.Errorf("namespace '%s' must match your team '%s'", *namespace, teamSlug)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
380
envd/internal/shared/id/id_test.go
Normal file
380
envd/internal/shared/id/id_test.go
Normal file
@ -0,0 +1,380 @@
|
||||
package id
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/utils"
|
||||
)
|
||||
|
||||
func TestParseName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantIdentifier string
|
||||
wantTag *string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "bare alias only",
|
||||
input: "my-template",
|
||||
wantIdentifier: "my-template",
|
||||
wantTag: nil,
|
||||
},
|
||||
{
|
||||
name: "alias with tag",
|
||||
input: "my-template:v1",
|
||||
wantIdentifier: "my-template",
|
||||
wantTag: utils.ToPtr("v1"),
|
||||
},
|
||||
{
|
||||
name: "namespace and alias",
|
||||
input: "acme/my-template",
|
||||
wantIdentifier: "acme/my-template",
|
||||
wantTag: nil,
|
||||
},
|
||||
{
|
||||
name: "namespace, alias and tag",
|
||||
input: "acme/my-template:v1",
|
||||
wantIdentifier: "acme/my-template",
|
||||
wantTag: utils.ToPtr("v1"),
|
||||
},
|
||||
{
|
||||
name: "namespace with hyphens",
|
||||
input: "my-team/my-template:prod",
|
||||
wantIdentifier: "my-team/my-template",
|
||||
wantTag: utils.ToPtr("prod"),
|
||||
},
|
||||
{
|
||||
name: "default tag normalized to nil",
|
||||
input: "my-template:default",
|
||||
wantIdentifier: "my-template",
|
||||
wantTag: nil,
|
||||
},
|
||||
{
|
||||
name: "uppercase converted to lowercase",
|
||||
input: "MyTemplate:Prod",
|
||||
wantIdentifier: "mytemplate",
|
||||
wantTag: utils.ToPtr("prod"),
|
||||
},
|
||||
{
|
||||
name: "whitespace trimmed",
|
||||
input: " my-template : v1 ",
|
||||
wantIdentifier: "my-template",
|
||||
wantTag: utils.ToPtr("v1"),
|
||||
},
|
||||
{
|
||||
name: "invalid - empty namespace",
|
||||
input: "/my-template",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - empty tag after colon",
|
||||
input: "my-template:",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - special characters in alias",
|
||||
input: "my template!",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - special characters in namespace",
|
||||
input: "my team!/my-template",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gotIdentifier, gotTag, err := ParseName(tt.input)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err, "Expected ParseName() to return error, got")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "Expected ParseName() not to return error, got: %v", err)
|
||||
assert.Equal(t, tt.wantIdentifier, gotIdentifier, "ParseName() identifier = %v, want %v", gotIdentifier, tt.wantIdentifier)
|
||||
assert.Equal(t, tt.wantTag, gotTag, "ParseName() tag = %v, want %v", utils.Sprintp(gotTag), utils.Sprintp(tt.wantTag))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithNamespace(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := WithNamespace("acme", "my-template")
|
||||
want := "acme/my-template"
|
||||
assert.Equal(t, want, got, "WithNamespace() = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
func TestSplitIdentifier(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
identifier string
|
||||
wantNamespace *string
|
||||
wantAlias string
|
||||
}{
|
||||
{
|
||||
name: "bare alias",
|
||||
identifier: "my-template",
|
||||
wantNamespace: nil,
|
||||
wantAlias: "my-template",
|
||||
},
|
||||
{
|
||||
name: "with namespace",
|
||||
identifier: "acme/my-template",
|
||||
wantNamespace: ptrStr("acme"),
|
||||
wantAlias: "my-template",
|
||||
},
|
||||
{
|
||||
name: "empty namespace prefix",
|
||||
identifier: "/my-template",
|
||||
wantNamespace: ptrStr(""),
|
||||
wantAlias: "my-template",
|
||||
},
|
||||
{
|
||||
name: "multiple slashes - only first split",
|
||||
identifier: "a/b/c",
|
||||
wantNamespace: ptrStr("a"),
|
||||
wantAlias: "b/c",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gotNamespace, gotAlias := SplitIdentifier(tt.identifier)
|
||||
|
||||
if tt.wantNamespace == nil {
|
||||
assert.Nil(t, gotNamespace)
|
||||
} else {
|
||||
require.NotNil(t, gotNamespace)
|
||||
assert.Equal(t, *tt.wantNamespace, *gotNamespace)
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.wantAlias, gotAlias)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ptrStr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func TestValidateAndDeduplicateTags(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tags []string
|
||||
want []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "single valid tag",
|
||||
tags: []string{"v1"},
|
||||
want: []string{"v1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "multiple unique tags",
|
||||
tags: []string{"v1", "prod", "latest"},
|
||||
want: []string{"v1", "prod", "latest"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "duplicate tags deduplicated",
|
||||
tags: []string{"v1", "V1", "v1"},
|
||||
want: []string{"v1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "tags with dots and underscores",
|
||||
tags: []string{"v1.0", "v1_1"},
|
||||
want: []string{"v1.0", "v1_1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid - UUID tag rejected",
|
||||
tags: []string{"550e8400-e29b-41d4-a716-446655440000"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - special characters",
|
||||
tags: []string{"v1!", "v2@"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty list returns empty",
|
||||
tags: []string{},
|
||||
want: []string{},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := ValidateAndDeduplicateTags(tt.tags)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.ElementsMatch(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSandboxID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "canonical sandbox ID",
|
||||
input: "i1a2b3c4d5e6f7g8h9j0k",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "short alphanumeric",
|
||||
input: "abc123",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "all digits",
|
||||
input: "1234567890",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "all lowercase letters",
|
||||
input: "abcdefghijklmnopqrst",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid - empty",
|
||||
input: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - contains colon (Redis separator)",
|
||||
input: "abc:def",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - contains open brace (Redis hash slot)",
|
||||
input: "abc{def",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - contains close brace (Redis hash slot)",
|
||||
input: "abc}def",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - contains newline",
|
||||
input: "abc\ndef",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - contains space",
|
||||
input: "abc def",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - contains hyphen",
|
||||
input: "abc-def",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - contains uppercase",
|
||||
input: "abcDEF",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - contains slash",
|
||||
input: "abc/def",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - contains null byte",
|
||||
input: "abc\x00def",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := ValidateSandboxID(tt.input)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateNamespaceMatchesTeam(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
identifier string
|
||||
teamSlug string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "bare alias - no namespace",
|
||||
identifier: "my-template",
|
||||
teamSlug: "acme",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "matching namespace",
|
||||
identifier: "acme/my-template",
|
||||
teamSlug: "acme",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "mismatched namespace",
|
||||
identifier: "other-team/my-template",
|
||||
teamSlug: "acme",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := ValidateNamespaceMatchesTeam(tt.identifier, tt.teamSlug)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
6
envd/internal/shared/keys/constants.go
Normal file
6
envd/internal/shared/keys/constants.go
Normal file
@ -0,0 +1,6 @@
|
||||
package keys
|
||||
|
||||
const (
|
||||
ApiKeyPrefix = "wrn_"
|
||||
AccessTokenPrefix = "sk_wrn_"
|
||||
)
|
||||
5
envd/internal/shared/keys/hashing.go
Normal file
5
envd/internal/shared/keys/hashing.go
Normal file
@ -0,0 +1,5 @@
|
||||
package keys
|
||||
|
||||
type Hasher interface {
|
||||
Hash(key []byte) string
|
||||
}
|
||||
25
envd/internal/shared/keys/hmac_sha256.go
Normal file
25
envd/internal/shared/keys/hmac_sha256.go
Normal file
@ -0,0 +1,25 @@
|
||||
package keys
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
type HMACSha256Hashing struct {
|
||||
key []byte
|
||||
}
|
||||
|
||||
func NewHMACSHA256Hashing(key []byte) *HMACSha256Hashing {
|
||||
return &HMACSha256Hashing{key: key}
|
||||
}
|
||||
|
||||
func (h *HMACSha256Hashing) Hash(content []byte) (string, error) {
|
||||
mac := hmac.New(sha256.New, h.key)
|
||||
_, err := mac.Write(content)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return hex.EncodeToString(mac.Sum(nil)), nil
|
||||
}
|
||||
74
envd/internal/shared/keys/hmac_sha256_test.go
Normal file
74
envd/internal/shared/keys/hmac_sha256_test.go
Normal file
@ -0,0 +1,74 @@
|
||||
package keys
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHMACSha256Hashing_ValidHash(t *testing.T) {
|
||||
t.Parallel()
|
||||
key := []byte("test-key")
|
||||
hasher := NewHMACSHA256Hashing(key)
|
||||
content := []byte("hello world")
|
||||
expectedHash := "18c4b268f0bbf8471eda56af3e70b1d4613d734dc538b4940b59931c412a1591"
|
||||
actualHash, err := hasher.Hash(content)
|
||||
require.NoError(t, err)
|
||||
|
||||
if actualHash != expectedHash {
|
||||
t.Errorf("expected %s, got %s", expectedHash, actualHash)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHMACSha256Hashing_EmptyContent(t *testing.T) {
|
||||
t.Parallel()
|
||||
key := []byte("test-key")
|
||||
hasher := NewHMACSHA256Hashing(key)
|
||||
content := []byte("")
|
||||
expectedHash := "2711cc23e9ab1b8a9bc0fe991238da92671624a9ebdaf1c1abec06e7e9a14f9b"
|
||||
actualHash, err := hasher.Hash(content)
|
||||
require.NoError(t, err)
|
||||
|
||||
if actualHash != expectedHash {
|
||||
t.Errorf("expected %s, got %s", expectedHash, actualHash)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHMACSha256Hashing_DifferentKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
key := []byte("test-key")
|
||||
hasher := NewHMACSHA256Hashing(key)
|
||||
differentKeyHasher := NewHMACSHA256Hashing([]byte("different-key"))
|
||||
content := []byte("hello world")
|
||||
|
||||
hashWithOriginalKey, err := hasher.Hash(content)
|
||||
require.NoError(t, err)
|
||||
|
||||
hashWithDifferentKey, err := differentKeyHasher.Hash(content)
|
||||
require.NoError(t, err)
|
||||
|
||||
if hashWithOriginalKey == hashWithDifferentKey {
|
||||
t.Errorf("hashes with different keys should not match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHMACSha256Hashing_IdenticalResult(t *testing.T) {
|
||||
t.Parallel()
|
||||
key := []byte("placeholder-hashing-key")
|
||||
content := []byte("test content for hashing")
|
||||
|
||||
mac := hmac.New(sha256.New, key)
|
||||
mac.Write(content)
|
||||
expectedResult := hex.EncodeToString(mac.Sum(nil))
|
||||
|
||||
hasher := NewHMACSHA256Hashing(key)
|
||||
actualResult, err := hasher.Hash(content)
|
||||
require.NoError(t, err)
|
||||
|
||||
if actualResult != expectedResult {
|
||||
t.Errorf("expected %s, got %s", expectedResult, actualResult)
|
||||
}
|
||||
}
|
||||
99
envd/internal/shared/keys/key.go
Normal file
99
envd/internal/shared/keys/key.go
Normal file
@ -0,0 +1,99 @@
|
||||
package keys
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
identifierValueSuffixLength = 4
|
||||
identifierValuePrefixLength = 2
|
||||
|
||||
keyLength = 20
|
||||
)
|
||||
|
||||
var hasher Hasher = NewSHA256Hashing()
|
||||
|
||||
type Key struct {
|
||||
PrefixedRawValue string
|
||||
HashedValue string
|
||||
Masked MaskedIdentifier
|
||||
}
|
||||
|
||||
type MaskedIdentifier struct {
|
||||
Prefix string
|
||||
ValueLength int
|
||||
MaskedValuePrefix string
|
||||
MaskedValueSuffix string
|
||||
}
|
||||
|
||||
// MaskKey returns identifier masking properties in accordance to the OpenAPI response spec
|
||||
func MaskKey(prefix, value string) (MaskedIdentifier, error) {
|
||||
valueLength := len(value)
|
||||
|
||||
suffixOffset := valueLength - identifierValueSuffixLength
|
||||
prefixOffset := identifierValuePrefixLength
|
||||
|
||||
if suffixOffset < 0 {
|
||||
return MaskedIdentifier{}, fmt.Errorf("mask value length is less than identifier suffix length (%d)", identifierValueSuffixLength)
|
||||
}
|
||||
|
||||
if suffixOffset == 0 {
|
||||
return MaskedIdentifier{}, fmt.Errorf("mask value length is equal to identifier suffix length (%d), which would expose the entire identifier in the mask", identifierValueSuffixLength)
|
||||
}
|
||||
|
||||
// cap prefixOffset by suffixOffset to prevent overlap with the suffix.
|
||||
if prefixOffset > suffixOffset {
|
||||
prefixOffset = suffixOffset
|
||||
}
|
||||
|
||||
maskPrefix := value[:prefixOffset]
|
||||
maskSuffix := value[suffixOffset:]
|
||||
|
||||
maskedIdentifierProperties := MaskedIdentifier{
|
||||
Prefix: prefix,
|
||||
ValueLength: valueLength,
|
||||
MaskedValuePrefix: maskPrefix,
|
||||
MaskedValueSuffix: maskSuffix,
|
||||
}
|
||||
|
||||
return maskedIdentifierProperties, nil
|
||||
}
|
||||
|
||||
func GenerateKey(prefix string) (Key, error) {
|
||||
keyBytes := make([]byte, keyLength)
|
||||
|
||||
_, err := rand.Read(keyBytes)
|
||||
if err != nil {
|
||||
return Key{}, err
|
||||
}
|
||||
|
||||
generatedIdentifier := hex.EncodeToString(keyBytes)
|
||||
|
||||
mask, err := MaskKey(prefix, generatedIdentifier)
|
||||
if err != nil {
|
||||
return Key{}, err
|
||||
}
|
||||
|
||||
return Key{
|
||||
PrefixedRawValue: prefix + generatedIdentifier,
|
||||
HashedValue: hasher.Hash(keyBytes),
|
||||
Masked: mask,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func VerifyKey(prefix string, key string) (string, error) {
|
||||
if !strings.HasPrefix(key, prefix) {
|
||||
return "", fmt.Errorf("invalid key prefix")
|
||||
}
|
||||
|
||||
keyValue := key[len(prefix):]
|
||||
keyBytes, err := hex.DecodeString(keyValue)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid key")
|
||||
}
|
||||
|
||||
return hasher.Hash(keyBytes), nil
|
||||
}
|
||||
160
envd/internal/shared/keys/key_test.go
Normal file
160
envd/internal/shared/keys/key_test.go
Normal file
@ -0,0 +1,160 @@
|
||||
package keys
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMaskKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("succeeds: value longer than suffix length", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
masked, err := MaskKey("test_", "1234567890")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test_", masked.Prefix)
|
||||
assert.Equal(t, "12", masked.MaskedValuePrefix)
|
||||
assert.Equal(t, "7890", masked.MaskedValueSuffix)
|
||||
})
|
||||
|
||||
t.Run("succeeds: empty prefix, value longer than suffix length", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
masked, err := MaskKey("", "1234567890")
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, masked.Prefix)
|
||||
assert.Equal(t, "12", masked.MaskedValuePrefix)
|
||||
assert.Equal(t, "7890", masked.MaskedValueSuffix)
|
||||
})
|
||||
|
||||
t.Run("error: value length less than suffix length", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := MaskKey("test", "123")
|
||||
require.Error(t, err)
|
||||
assert.EqualError(t, err, fmt.Sprintf("mask value length is less than identifier suffix length (%d)", identifierValueSuffixLength))
|
||||
})
|
||||
|
||||
t.Run("error: value length equals suffix length", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := MaskKey("test", "1234")
|
||||
require.Error(t, err)
|
||||
assert.EqualError(t, err, fmt.Sprintf("mask value length is equal to identifier suffix length (%d), which would expose the entire identifier in the mask", identifierValueSuffixLength))
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
keyLength := 40
|
||||
|
||||
t.Run("succeeds", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := GenerateKey("test_")
|
||||
require.NoError(t, err)
|
||||
assert.Regexp(t, "^test_.*", key.PrefixedRawValue)
|
||||
assert.Equal(t, "test_", key.Masked.Prefix)
|
||||
assert.Equal(t, keyLength, key.Masked.ValueLength)
|
||||
assert.Regexp(t, "^[0-9a-f]{"+strconv.Itoa(identifierValuePrefixLength)+"}$", key.Masked.MaskedValuePrefix)
|
||||
assert.Regexp(t, "^[0-9a-f]{"+strconv.Itoa(identifierValueSuffixLength)+"}$", key.Masked.MaskedValueSuffix)
|
||||
assert.Regexp(t, "^\\$sha256\\$.*", key.HashedValue)
|
||||
})
|
||||
|
||||
t.Run("no prefix", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := GenerateKey("")
|
||||
require.NoError(t, err)
|
||||
assert.Regexp(t, "^[0-9a-f]{"+strconv.Itoa(keyLength)+"}$", key.PrefixedRawValue)
|
||||
assert.Empty(t, key.Masked.Prefix)
|
||||
assert.Equal(t, keyLength, key.Masked.ValueLength)
|
||||
assert.Regexp(t, "^[0-9a-f]{"+strconv.Itoa(identifierValuePrefixLength)+"}$", key.Masked.MaskedValuePrefix)
|
||||
assert.Regexp(t, "^[0-9a-f]{"+strconv.Itoa(identifierValueSuffixLength)+"}$", key.Masked.MaskedValueSuffix)
|
||||
assert.Regexp(t, "^\\$sha256\\$.*", key.HashedValue)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetMaskedIdentifierProperties(t *testing.T) {
|
||||
t.Parallel()
|
||||
type testCase struct {
|
||||
name string
|
||||
prefix string
|
||||
value string
|
||||
expectedResult MaskedIdentifier
|
||||
expectedErrString string
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
// --- ERROR CASES (value's length <= identifierValueSuffixLength) ---
|
||||
{
|
||||
name: "error: value length < suffix length (3 vs 4)",
|
||||
prefix: "pk_",
|
||||
value: "abc",
|
||||
expectedResult: MaskedIdentifier{},
|
||||
expectedErrString: fmt.Sprintf("mask value length is less than identifier suffix length (%d)", identifierValueSuffixLength),
|
||||
},
|
||||
{
|
||||
name: "error: value length == suffix length (4 vs 4)",
|
||||
prefix: "sk_",
|
||||
value: "abcd",
|
||||
expectedResult: MaskedIdentifier{},
|
||||
expectedErrString: fmt.Sprintf("mask value length is equal to identifier suffix length (%d), which would expose the entire identifier in the mask", identifierValueSuffixLength),
|
||||
},
|
||||
{
|
||||
name: "error: value length < suffix length (0 vs 4, empty value)",
|
||||
prefix: "err_",
|
||||
value: "",
|
||||
expectedResult: MaskedIdentifier{},
|
||||
expectedErrString: fmt.Sprintf("mask value length is less than identifier suffix length (%d)", identifierValueSuffixLength),
|
||||
},
|
||||
|
||||
// --- SUCCESS CASES (value's length > identifierValueSuffixLength) ---
|
||||
{
|
||||
name: "success: value long (10), prefix val len fully used",
|
||||
prefix: "pk_",
|
||||
value: "abcdefghij",
|
||||
expectedResult: MaskedIdentifier{
|
||||
Prefix: "pk_",
|
||||
ValueLength: 10,
|
||||
MaskedValuePrefix: "ab",
|
||||
MaskedValueSuffix: "ghij",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success: value medium (5), prefix val len truncated by overlap",
|
||||
prefix: "",
|
||||
value: "abcde",
|
||||
expectedResult: MaskedIdentifier{
|
||||
Prefix: "",
|
||||
ValueLength: 5,
|
||||
MaskedValuePrefix: "a",
|
||||
MaskedValueSuffix: "bcde",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success: value medium (6), prefix val len fits exactly",
|
||||
prefix: "pk_",
|
||||
value: "abcdef",
|
||||
expectedResult: MaskedIdentifier{
|
||||
Prefix: "pk_",
|
||||
ValueLength: 6,
|
||||
MaskedValuePrefix: "ab",
|
||||
MaskedValueSuffix: "cdef",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result, err := MaskKey(tc.prefix, tc.value)
|
||||
|
||||
if tc.expectedErrString != "" {
|
||||
require.EqualError(t, err, tc.expectedErrString)
|
||||
assert.Equal(t, tc.expectedResult, result)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.expectedResult, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
30
envd/internal/shared/keys/sha256.go
Normal file
30
envd/internal/shared/keys/sha256.go
Normal file
@ -0,0 +1,30 @@
|
||||
package keys
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Sha256Hashing struct{}
|
||||
|
||||
func NewSHA256Hashing() *Sha256Hashing {
|
||||
return &Sha256Hashing{}
|
||||
}
|
||||
|
||||
func (h *Sha256Hashing) Hash(key []byte) string {
|
||||
hashBytes := sha256.Sum256(key)
|
||||
|
||||
hash64 := base64.RawStdEncoding.EncodeToString(hashBytes[:])
|
||||
|
||||
return fmt.Sprintf(
|
||||
"$sha256$%s",
|
||||
hash64,
|
||||
)
|
||||
}
|
||||
|
||||
func (h *Sha256Hashing) HashWithoutPrefix(key []byte) string {
|
||||
hashBytes := sha256.Sum256(key)
|
||||
|
||||
return base64.RawStdEncoding.EncodeToString(hashBytes[:])
|
||||
}
|
||||
15
envd/internal/shared/keys/sha256_test.go
Normal file
15
envd/internal/shared/keys/sha256_test.go
Normal file
@ -0,0 +1,15 @@
|
||||
package keys
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSHA256Hashing(t *testing.T) {
|
||||
t.Parallel()
|
||||
hasher := NewSHA256Hashing()
|
||||
|
||||
hashed := hasher.Hash([]byte("test"))
|
||||
assert.Regexp(t, "^\\$sha256\\$.*", hashed)
|
||||
}
|
||||
20
envd/internal/shared/keys/sha512.go
Normal file
20
envd/internal/shared/keys/sha512.go
Normal file
@ -0,0 +1,20 @@
|
||||
package keys
|
||||
|
||||
import (
|
||||
"crypto/sha512"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
// HashAccessToken computes the SHA-512 hash of an access token.
|
||||
func HashAccessToken(token string) string {
|
||||
h := sha512.Sum512([]byte(token))
|
||||
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
// HashAccessTokenBytes computes the SHA-512 hash of an access token from bytes.
|
||||
func HashAccessTokenBytes(token []byte) string {
|
||||
h := sha512.Sum512(token)
|
||||
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
47
envd/internal/shared/smap/smap.go
Normal file
47
envd/internal/shared/smap/smap.go
Normal file
@ -0,0 +1,47 @@
|
||||
package smap
|
||||
|
||||
import (
|
||||
cmap "github.com/orcaman/concurrent-map/v2"
|
||||
)
|
||||
|
||||
type Map[V any] struct {
|
||||
m cmap.ConcurrentMap[string, V]
|
||||
}
|
||||
|
||||
func New[V any]() *Map[V] {
|
||||
return &Map[V]{
|
||||
m: cmap.New[V](),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Map[V]) Remove(key string) {
|
||||
m.m.Remove(key)
|
||||
}
|
||||
|
||||
func (m *Map[V]) Get(key string) (V, bool) {
|
||||
return m.m.Get(key)
|
||||
}
|
||||
|
||||
func (m *Map[V]) Insert(key string, value V) {
|
||||
m.m.Set(key, value)
|
||||
}
|
||||
|
||||
func (m *Map[V]) Upsert(key string, value V, cb cmap.UpsertCb[V]) V {
|
||||
return m.m.Upsert(key, value, cb)
|
||||
}
|
||||
|
||||
func (m *Map[V]) InsertIfAbsent(key string, value V) bool {
|
||||
return m.m.SetIfAbsent(key, value)
|
||||
}
|
||||
|
||||
func (m *Map[V]) Items() map[string]V {
|
||||
return m.m.Items()
|
||||
}
|
||||
|
||||
func (m *Map[V]) RemoveCb(key string, cb func(key string, v V, exists bool) bool) bool {
|
||||
return m.m.RemoveCb(key, cb)
|
||||
}
|
||||
|
||||
func (m *Map[V]) Count() int {
|
||||
return m.m.Count()
|
||||
}
|
||||
43
envd/internal/shared/utils/ptr.go
Normal file
43
envd/internal/shared/utils/ptr.go
Normal file
@ -0,0 +1,43 @@
|
||||
package utils
|
||||
|
||||
import "fmt"
|
||||
|
||||
func ToPtr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
func FromPtr[T any](s *T) T {
|
||||
if s == nil {
|
||||
var zero T
|
||||
|
||||
return zero
|
||||
}
|
||||
|
||||
return *s
|
||||
}
|
||||
|
||||
func Sprintp[T any](s *T) string {
|
||||
if s == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%v", *s)
|
||||
}
|
||||
|
||||
func DerefOrDefault[T any](s *T, defaultValue T) T {
|
||||
if s == nil {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return *s
|
||||
}
|
||||
|
||||
func CastPtr[S any, T any](s *S, castFunc func(S) T) *T {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
t := castFunc(*s)
|
||||
|
||||
return &t
|
||||
}
|
||||
Reference in New Issue
Block a user