diff --git a/.env.example b/.env.example index cce316d..c52e46f 100644 --- a/.env.example +++ b/.env.example @@ -23,3 +23,6 @@ S3_REGION=fsn1 S3_ENDPOINT=https://fsn1.your-objectstorage.com AWS_ACCESS_KEY_ID= AWS_SECRET_ACCESS_KEY= + +# Auth +JWT_SECRET= diff --git a/cmd/control-plane/main.go b/cmd/control-plane/main.go index 80638bb..2b22e65 100644 --- a/cmd/control-plane/main.go +++ b/cmd/control-plane/main.go @@ -24,6 +24,11 @@ func main() { cfg := config.Load() + if len(cfg.JWTSecret) < 32 { + slog.Error("JWT_SECRET must be at least 32 characters") + os.Exit(1) + } + ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -51,7 +56,7 @@ func main() { ) // API server. - srv := api.New(queries, agentClient) + srv := api.New(queries, agentClient, pool, []byte(cfg.JWTSecret)) // Start reconciler. reconciler := api.NewReconciler(queries, agentClient, "default", 30*time.Second) diff --git a/db/migrations/20260313210608_auth.sql b/db/migrations/20260313210608_auth.sql new file mode 100644 index 0000000..03970a8 --- /dev/null +++ b/db/migrations/20260313210608_auth.sql @@ -0,0 +1,46 @@ +-- +goose Up + +CREATE TABLE users ( + id TEXT PRIMARY KEY, + email TEXT NOT NULL UNIQUE, + password_hash TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE TABLE teams ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE TABLE users_teams ( + user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + team_id TEXT NOT NULL REFERENCES teams(id) ON DELETE CASCADE, + is_default BOOLEAN NOT NULL DEFAULT TRUE, + role TEXT NOT NULL DEFAULT 'owner', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + PRIMARY KEY (team_id, user_id) +); + +CREATE INDEX idx_users_teams_user ON users_teams(user_id); + +CREATE TABLE team_api_keys ( + id TEXT PRIMARY KEY, + team_id TEXT NOT NULL REFERENCES teams(id) ON DELETE CASCADE, + name TEXT NOT NULL DEFAULT '', + key_hash TEXT NOT NULL UNIQUE, + key_prefix TEXT NOT NULL DEFAULT '', + created_by TEXT NOT NULL REFERENCES users(id), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + last_used TIMESTAMPTZ +); + +CREATE INDEX idx_team_api_keys_team ON team_api_keys(team_id); + +-- +goose Down + +DROP TABLE team_api_keys; +DROP TABLE users_teams; +DROP TABLE teams; +DROP TABLE users; diff --git a/db/migrations/20260313210611_team_ownership.sql b/db/migrations/20260313210611_team_ownership.sql new file mode 100644 index 0000000..849e781 --- /dev/null +++ b/db/migrations/20260313210611_team_ownership.sql @@ -0,0 +1,31 @@ +-- +goose Up + +ALTER TABLE sandboxes + ADD COLUMN team_id TEXT NOT NULL DEFAULT ''; + +UPDATE sandboxes SET team_id = owner_id WHERE owner_id != ''; + +ALTER TABLE sandboxes + DROP COLUMN owner_id; + +ALTER TABLE templates + ADD COLUMN team_id TEXT NOT NULL DEFAULT ''; + +CREATE INDEX idx_sandboxes_team ON sandboxes(team_id); +CREATE INDEX idx_templates_team ON templates(team_id); + +-- +goose Down + +ALTER TABLE sandboxes + ADD COLUMN owner_id TEXT NOT NULL DEFAULT ''; + +UPDATE sandboxes SET owner_id = team_id WHERE team_id != ''; + +ALTER TABLE sandboxes + DROP COLUMN team_id; + +ALTER TABLE templates + DROP COLUMN team_id; + +DROP INDEX IF EXISTS idx_sandboxes_team; +DROP INDEX IF EXISTS idx_templates_team; diff --git a/db/queries/api_keys.sql b/db/queries/api_keys.sql new file mode 100644 index 0000000..0580518 --- /dev/null +++ b/db/queries/api_keys.sql @@ -0,0 +1,16 @@ +-- name: InsertAPIKey :one +INSERT INTO team_api_keys (id, team_id, name, key_hash, key_prefix, created_by) +VALUES ($1, $2, $3, $4, $5, $6) +RETURNING *; + +-- name: GetAPIKeyByHash :one +SELECT * FROM team_api_keys WHERE key_hash = $1; + +-- name: ListAPIKeysByTeam :many +SELECT * FROM team_api_keys WHERE team_id = $1 ORDER BY created_at DESC; + +-- name: DeleteAPIKey :exec +DELETE FROM team_api_keys WHERE id = $1 AND team_id = $2; + +-- name: UpdateAPIKeyLastUsed :exec +UPDATE team_api_keys SET last_used = NOW() WHERE id = $1; diff --git a/db/queries/sandboxes.sql b/db/queries/sandboxes.sql index 7a964a7..33203f6 100644 --- a/db/queries/sandboxes.sql +++ b/db/queries/sandboxes.sql @@ -1,14 +1,20 @@ -- name: InsertSandbox :one -INSERT INTO sandboxes (id, owner_id, host_id, template, status, vcpus, memory_mb, timeout_sec) +INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING *; -- name: GetSandbox :one SELECT * FROM sandboxes WHERE id = $1; +-- name: GetSandboxByTeam :one +SELECT * FROM sandboxes WHERE id = $1 AND team_id = $2; + -- name: ListSandboxes :many SELECT * FROM sandboxes ORDER BY created_at DESC; +-- name: ListSandboxesByTeam :many +SELECT * FROM sandboxes WHERE team_id = $1 ORDER BY created_at DESC; + -- name: ListSandboxesByHostAndStatus :many SELECT * FROM sandboxes WHERE host_id = $1 AND status = ANY($2::text[]) diff --git a/db/queries/teams.sql b/db/queries/teams.sql new file mode 100644 index 0000000..f4c4633 --- /dev/null +++ b/db/queries/teams.sql @@ -0,0 +1,17 @@ +-- name: InsertTeam :one +INSERT INTO teams (id, name) +VALUES ($1, $2) +RETURNING *; + +-- name: GetTeam :one +SELECT * FROM teams WHERE id = $1; + +-- name: InsertTeamMember :exec +INSERT INTO users_teams (user_id, team_id, is_default, role) +VALUES ($1, $2, $3, $4); + +-- name: GetDefaultTeamForUser :one +SELECT t.* FROM teams t +JOIN users_teams ut ON ut.team_id = t.id +WHERE ut.user_id = $1 AND ut.is_default = TRUE +LIMIT 1; diff --git a/db/queries/templates.sql b/db/queries/templates.sql index 4a438d7..b17abc3 100644 --- a/db/queries/templates.sql +++ b/db/queries/templates.sql @@ -1,16 +1,28 @@ -- name: InsertTemplate :one -INSERT INTO templates (name, type, vcpus, memory_mb, size_bytes) -VALUES ($1, $2, $3, $4, $5) +INSERT INTO templates (name, type, vcpus, memory_mb, size_bytes, team_id) +VALUES ($1, $2, $3, $4, $5, $6) RETURNING *; -- name: GetTemplate :one SELECT * FROM templates WHERE name = $1; +-- name: GetTemplateByTeam :one +SELECT * FROM templates WHERE name = $1 AND team_id = $2; + -- name: ListTemplates :many SELECT * FROM templates ORDER BY created_at DESC; -- name: ListTemplatesByType :many SELECT * FROM templates WHERE type = $1 ORDER BY created_at DESC; +-- name: ListTemplatesByTeam :many +SELECT * FROM templates WHERE team_id = $1 ORDER BY created_at DESC; + +-- name: ListTemplatesByTeamAndType :many +SELECT * FROM templates WHERE team_id = $1 AND type = $2 ORDER BY created_at DESC; + -- name: DeleteTemplate :exec DELETE FROM templates WHERE name = $1; + +-- name: DeleteTemplateByTeam :exec +DELETE FROM templates WHERE name = $1 AND team_id = $2; diff --git a/db/queries/users.sql b/db/queries/users.sql new file mode 100644 index 0000000..c1f61f0 --- /dev/null +++ b/db/queries/users.sql @@ -0,0 +1,10 @@ +-- name: InsertUser :one +INSERT INTO users (id, email, password_hash) +VALUES ($1, $2, $3) +RETURNING *; + +-- name: GetUserByEmail :one +SELECT * FROM users WHERE email = $1; + +-- name: GetUserByID :one +SELECT * FROM users WHERE id = $1; diff --git a/go.mod b/go.mod index 11c8bcb..82bd17a 100644 --- a/go.mod +++ b/go.mod @@ -5,11 +5,13 @@ go 1.25.0 require ( connectrpc.com/connect v1.19.1 github.com/go-chi/chi/v5 v5.2.5 + github.com/golang-jwt/jwt/v5 v5.3.1 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 github.com/jackc/pgx/v5 v5.8.0 github.com/vishvananda/netlink v1.1.1-0.20210330154013-f5de75959ad5 github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f + golang.org/x/crypto v0.49.0 golang.org/x/sys v0.42.0 google.golang.org/protobuf v1.36.11 ) @@ -18,6 +20,7 @@ require ( github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect - golang.org/x/sync v0.17.0 // indirect - golang.org/x/text v0.29.0 // indirect + github.com/joho/godotenv v1.5.1 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/text v0.35.0 // indirect ) diff --git a/go.sum b/go.sum index 4997587..d8ac123 100644 --- a/go.sum +++ b/go.sum @@ -5,6 +5,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -19,6 +21,8 @@ github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -31,14 +35,16 @@ github.com/vishvananda/netlink v1.1.1-0.20210330154013-f5de75959ad5/go.mod h1:tw github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f h1:p4VB7kIXpOQvVn1ZaTIVp+3vuYAXFe3OJEvjbUYJLaA= github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= -golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= -golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= +golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= -golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= -golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/api/handlers_apikeys.go b/internal/api/handlers_apikeys.go new file mode 100644 index 0000000..b8a5ead --- /dev/null +++ b/internal/api/handlers_apikeys.go @@ -0,0 +1,125 @@ +package api + +import ( + "net/http" + "time" + + "github.com/go-chi/chi/v5" + + "git.omukk.dev/wrenn/sandbox/internal/auth" + "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" +) + +type apiKeyHandler struct { + db *db.Queries +} + +func newAPIKeyHandler(db *db.Queries) *apiKeyHandler { + return &apiKeyHandler{db: db} +} + +type createAPIKeyRequest struct { + Name string `json:"name"` +} + +type apiKeyResponse struct { + ID string `json:"id"` + TeamID string `json:"team_id"` + Name string `json:"name"` + KeyPrefix string `json:"key_prefix"` + CreatedAt string `json:"created_at"` + LastUsed *string `json:"last_used,omitempty"` + Key *string `json:"key,omitempty"` // only populated on Create +} + +func apiKeyToResponse(k db.TeamApiKey) apiKeyResponse { + resp := apiKeyResponse{ + ID: k.ID, + TeamID: k.TeamID, + Name: k.Name, + KeyPrefix: k.KeyPrefix, + } + if k.CreatedAt.Valid { + resp.CreatedAt = k.CreatedAt.Time.Format(time.RFC3339) + } + if k.LastUsed.Valid { + s := k.LastUsed.Time.Format(time.RFC3339) + resp.LastUsed = &s + } + return resp +} + +// Create handles POST /v1/api-keys. +func (h *apiKeyHandler) Create(w http.ResponseWriter, r *http.Request) { + ac := auth.MustFromContext(r.Context()) + + var req createAPIKeyRequest + if err := decodeJSON(r, &req); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body") + return + } + + if req.Name == "" { + req.Name = "Unnamed API Key" + } + + plaintext, hash, err := auth.GenerateAPIKey() + if err != nil { + writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate API key") + return + } + + keyID := id.NewAPIKeyID() + row, err := h.db.InsertAPIKey(r.Context(), db.InsertAPIKeyParams{ + ID: keyID, + TeamID: ac.TeamID, + Name: req.Name, + KeyHash: hash, + KeyPrefix: auth.APIKeyPrefix(plaintext), + CreatedBy: ac.UserID, + }) + if err != nil { + writeError(w, http.StatusInternalServerError, "db_error", "failed to create API key") + return + } + + resp := apiKeyToResponse(row) + resp.Key = &plaintext + + writeJSON(w, http.StatusCreated, resp) +} + +// List handles GET /v1/api-keys. +func (h *apiKeyHandler) List(w http.ResponseWriter, r *http.Request) { + ac := auth.MustFromContext(r.Context()) + + keys, err := h.db.ListAPIKeysByTeam(r.Context(), ac.TeamID) + if err != nil { + writeError(w, http.StatusInternalServerError, "db_error", "failed to list API keys") + return + } + + resp := make([]apiKeyResponse, len(keys)) + for i, k := range keys { + resp[i] = apiKeyToResponse(k) + } + + writeJSON(w, http.StatusOK, resp) +} + +// Delete handles DELETE /v1/api-keys/{id}. +func (h *apiKeyHandler) Delete(w http.ResponseWriter, r *http.Request) { + ac := auth.MustFromContext(r.Context()) + keyID := chi.URLParam(r, "id") + + if err := h.db.DeleteAPIKey(r.Context(), db.DeleteAPIKeyParams{ + ID: keyID, + TeamID: ac.TeamID, + }); err != nil { + writeError(w, http.StatusInternalServerError, "db_error", "failed to delete API key") + return + } + + w.WriteHeader(http.StatusNoContent) +} diff --git a/internal/api/handlers_auth.go b/internal/api/handlers_auth.go new file mode 100644 index 0000000..2fbe1db --- /dev/null +++ b/internal/api/handlers_auth.go @@ -0,0 +1,184 @@ +package api + +import ( + "errors" + "net/http" + "strings" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" + + "git.omukk.dev/wrenn/sandbox/internal/auth" + "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" +) + +type authHandler struct { + db *db.Queries + pool *pgxpool.Pool + jwtSecret []byte +} + +func newAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte) *authHandler { + return &authHandler{db: db, pool: pool, jwtSecret: jwtSecret} +} + +type signupRequest struct { + Email string `json:"email"` + Password string `json:"password"` +} + +type loginRequest struct { + Email string `json:"email"` + Password string `json:"password"` +} + +type authResponse struct { + Token string `json:"token"` + UserID string `json:"user_id"` + TeamID string `json:"team_id"` + Email string `json:"email"` +} + +// Signup handles POST /v1/auth/signup. +func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) { + var req signupRequest + if err := decodeJSON(r, &req); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body") + return + } + + req.Email = strings.TrimSpace(strings.ToLower(req.Email)) + if !strings.Contains(req.Email, "@") || len(req.Email) < 3 { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid email address") + return + } + if len(req.Password) < 8 { + writeError(w, http.StatusBadRequest, "invalid_request", "password must be at least 8 characters") + return + } + + ctx := r.Context() + + passwordHash, err := auth.HashPassword(req.Password) + if err != nil { + writeError(w, http.StatusInternalServerError, "internal_error", "failed to hash password") + return + } + + // Use a transaction to atomically create user + team + membership. + tx, err := h.pool.Begin(ctx) + if err != nil { + writeError(w, http.StatusInternalServerError, "db_error", "failed to begin transaction") + return + } + defer tx.Rollback(ctx) //nolint:errcheck + + qtx := h.db.WithTx(tx) + + userID := id.NewUserID() + _, err = qtx.InsertUser(ctx, db.InsertUserParams{ + ID: userID, + Email: req.Email, + PasswordHash: passwordHash, + }) + if err != nil { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && pgErr.Code == "23505" { + writeError(w, http.StatusConflict, "email_taken", "an account with this email already exists") + return + } + writeError(w, http.StatusInternalServerError, "db_error", "failed to create user") + return + } + + // Create default team. + teamID := id.NewTeamID() + if _, err := qtx.InsertTeam(ctx, db.InsertTeamParams{ + ID: teamID, + Name: req.Email + "'s Team", + }); err != nil { + writeError(w, http.StatusInternalServerError, "db_error", "failed to create team") + return + } + + if err := qtx.InsertTeamMember(ctx, db.InsertTeamMemberParams{ + UserID: userID, + TeamID: teamID, + IsDefault: true, + Role: "owner", + }); err != nil { + writeError(w, http.StatusInternalServerError, "db_error", "failed to add user to team") + return + } + + if err := tx.Commit(ctx); err != nil { + writeError(w, http.StatusInternalServerError, "db_error", "failed to commit signup") + return + } + + token, err := auth.SignJWT(h.jwtSecret, userID, teamID, req.Email) + if err != nil { + writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token") + return + } + + writeJSON(w, http.StatusCreated, authResponse{ + Token: token, + UserID: userID, + TeamID: teamID, + Email: req.Email, + }) +} + +// Login handles POST /v1/auth/login. +func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) { + var req loginRequest + if err := decodeJSON(r, &req); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body") + return + } + + req.Email = strings.TrimSpace(strings.ToLower(req.Email)) + if req.Email == "" || req.Password == "" { + writeError(w, http.StatusBadRequest, "invalid_request", "email and password are required") + return + } + + ctx := r.Context() + + user, err := h.db.GetUserByEmail(ctx, req.Email) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password") + return + } + writeError(w, http.StatusInternalServerError, "db_error", "failed to look up user") + return + } + + if err := auth.CheckPassword(user.PasswordHash, req.Password); err != nil { + writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password") + return + } + + team, err := h.db.GetDefaultTeamForUser(ctx, user.ID) + if err != nil { + writeError(w, http.StatusInternalServerError, "db_error", "failed to look up team") + return + } + + token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email) + if err != nil { + writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token") + return + } + + writeJSON(w, http.StatusOK, authResponse{ + Token: token, + UserID: user.ID, + TeamID: team.ID, + Email: user.Email, + }) +} diff --git a/internal/api/handlers_exec.go b/internal/api/handlers_exec.go index 8323df4..9307a67 100644 --- a/internal/api/handlers_exec.go +++ b/internal/api/handlers_exec.go @@ -12,6 +12,7 @@ import ( "github.com/go-chi/chi/v5" "github.com/jackc/pgx/v5/pgtype" + "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" @@ -47,8 +48,9 @@ type execResponse struct { func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) { sandboxID := chi.URLParam(r, "id") ctx := r.Context() + ac := auth.MustFromContext(ctx) - sb, err := h.db.GetSandbox(ctx, sandboxID) + sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") return diff --git a/internal/api/handlers_exec_stream.go b/internal/api/handlers_exec_stream.go index a2be27d..009f41b 100644 --- a/internal/api/handlers_exec_stream.go +++ b/internal/api/handlers_exec_stream.go @@ -12,6 +12,7 @@ import ( "github.com/gorilla/websocket" "github.com/jackc/pgx/v5/pgtype" + "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" @@ -49,8 +50,9 @@ type wsOutMsg struct { func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) { sandboxID := chi.URLParam(r, "id") ctx := r.Context() + ac := auth.MustFromContext(ctx) - sb, err := h.db.GetSandbox(ctx, sandboxID) + sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") return diff --git a/internal/api/handlers_files.go b/internal/api/handlers_files.go index 71a3aea..c1c0291 100644 --- a/internal/api/handlers_files.go +++ b/internal/api/handlers_files.go @@ -9,6 +9,7 @@ import ( "connectrpc.com/connect" "github.com/go-chi/chi/v5" + "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" @@ -30,8 +31,9 @@ func newFilesHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceCl func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) { sandboxID := chi.URLParam(r, "id") ctx := r.Context() + ac := auth.MustFromContext(ctx) - sb, err := h.db.GetSandbox(ctx, sandboxID) + sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") return @@ -95,8 +97,9 @@ type readFileRequest struct { func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) { sandboxID := chi.URLParam(r, "id") ctx := r.Context() + ac := auth.MustFromContext(ctx) - sb, err := h.db.GetSandbox(ctx, sandboxID) + sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") return diff --git a/internal/api/handlers_files_stream.go b/internal/api/handlers_files_stream.go index 7999a2f..66a3c5b 100644 --- a/internal/api/handlers_files_stream.go +++ b/internal/api/handlers_files_stream.go @@ -10,6 +10,7 @@ import ( "connectrpc.com/connect" "github.com/go-chi/chi/v5" + "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" @@ -30,8 +31,9 @@ func newFilesStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentSer func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) { sandboxID := chi.URLParam(r, "id") ctx := r.Context() + ac := auth.MustFromContext(ctx) - sb, err := h.db.GetSandbox(ctx, sandboxID) + sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") return @@ -140,8 +142,9 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) { sandboxID := chi.URLParam(r, "id") ctx := r.Context() + ac := auth.MustFromContext(ctx) - sb, err := h.db.GetSandbox(ctx, sandboxID) + sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") return diff --git a/internal/api/handlers_sandbox.go b/internal/api/handlers_sandbox.go index bb06a5f..5ffd008 100644 --- a/internal/api/handlers_sandbox.go +++ b/internal/api/handlers_sandbox.go @@ -11,6 +11,7 @@ import ( "github.com/go-chi/chi/v5" "github.com/jackc/pgx/v5/pgtype" + "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/validate" @@ -103,10 +104,11 @@ func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) { } ctx := r.Context() + ac := auth.MustFromContext(ctx) // If the template is a snapshot, use its baked-in vcpus/memory // (they cannot be changed since the VM state is frozen). - if tmpl, err := h.db.GetTemplate(ctx, req.Template); err == nil && tmpl.Type == "snapshot" { + if tmpl, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Template, TeamID: ac.TeamID}); err == nil && tmpl.Type == "snapshot" { if tmpl.Vcpus.Valid { req.VCPUs = tmpl.Vcpus.Int32 } @@ -119,7 +121,7 @@ func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) { // Insert pending record. _, err := h.db.InsertSandbox(ctx, db.InsertSandboxParams{ ID: sandboxID, - OwnerID: "", + TeamID: ac.TeamID, HostID: "default", Template: req.Template, Status: "pending", @@ -173,7 +175,8 @@ func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) { // List handles GET /v1/sandboxes. func (h *sandboxHandler) List(w http.ResponseWriter, r *http.Request) { - sandboxes, err := h.db.ListSandboxes(r.Context()) + ac := auth.MustFromContext(r.Context()) + sandboxes, err := h.db.ListSandboxesByTeam(r.Context(), ac.TeamID) if err != nil { writeError(w, http.StatusInternalServerError, "db_error", "failed to list sandboxes") return @@ -190,8 +193,9 @@ func (h *sandboxHandler) List(w http.ResponseWriter, r *http.Request) { // Get handles GET /v1/sandboxes/{id}. func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) { sandboxID := chi.URLParam(r, "id") + ac := auth.MustFromContext(r.Context()) - sb, err := h.db.GetSandbox(r.Context(), sandboxID) + sb, err := h.db.GetSandboxByTeam(r.Context(), db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") return @@ -206,8 +210,9 @@ func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) { func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) { sandboxID := chi.URLParam(r, "id") ctx := r.Context() + ac := auth.MustFromContext(ctx) - sb, err := h.db.GetSandbox(ctx, sandboxID) + sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") return @@ -241,8 +246,9 @@ func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) { func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) { sandboxID := chi.URLParam(r, "id") ctx := r.Context() + ac := auth.MustFromContext(ctx) - sb, err := h.db.GetSandbox(ctx, sandboxID) + sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") return @@ -283,8 +289,9 @@ func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) { func (h *sandboxHandler) Destroy(w http.ResponseWriter, r *http.Request) { sandboxID := chi.URLParam(r, "id") ctx := r.Context() + ac := auth.MustFromContext(ctx) - _, err := h.db.GetSandbox(ctx, sandboxID) + _, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") return diff --git a/internal/api/handlers_snapshots.go b/internal/api/handlers_snapshots.go index 8e6b36f..20cd99f 100644 --- a/internal/api/handlers_snapshots.go +++ b/internal/api/handlers_snapshots.go @@ -11,6 +11,7 @@ import ( "github.com/go-chi/chi/v5" "github.com/jackc/pgx/v5/pgtype" + "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/validate" @@ -81,22 +82,23 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) { } ctx := r.Context() + ac := auth.MustFromContext(ctx) overwrite := r.URL.Query().Get("overwrite") == "true" - // Check if name already exists. - if _, err := h.db.GetTemplate(ctx, req.Name); err == nil { + // Check if name already exists for this team. + if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil { if !overwrite { writeError(w, http.StatusConflict, "already_exists", "snapshot name already exists; use ?overwrite=true to replace") return } // Delete existing template record and files. - if err := h.db.DeleteTemplate(ctx, req.Name); err != nil { + if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err != nil { slog.Warn("failed to delete existing template", "name", req.Name, "error", err) } } - // Verify sandbox exists and is running or paused. - sb, err := h.db.GetSandbox(ctx, req.SandboxID) + // Verify sandbox exists, belongs to team, and is running or paused. + sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: req.SandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") return @@ -134,6 +136,7 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) { Vcpus: pgtype.Int4{Int32: sb.Vcpus, Valid: true}, MemoryMb: pgtype.Int4{Int32: sb.MemoryMb, Valid: true}, SizeBytes: resp.Msg.SizeBytes, + TeamID: ac.TeamID, }) if err != nil { slog.Error("failed to insert template record", "name", req.Name, "error", err) @@ -147,14 +150,15 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) { // List handles GET /v1/snapshots. func (h *snapshotHandler) List(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + ac := auth.MustFromContext(ctx) typeFilter := r.URL.Query().Get("type") var templates []db.Template var err error if typeFilter != "" { - templates, err = h.db.ListTemplatesByType(ctx, typeFilter) + templates, err = h.db.ListTemplatesByTeamAndType(ctx, db.ListTemplatesByTeamAndTypeParams{TeamID: ac.TeamID, Type: typeFilter}) } else { - templates, err = h.db.ListTemplates(ctx) + templates, err = h.db.ListTemplatesByTeam(ctx, ac.TeamID) } if err != nil { writeError(w, http.StatusInternalServerError, "db_error", "failed to list templates") @@ -177,8 +181,9 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) { return } ctx := r.Context() + ac := auth.MustFromContext(ctx) - if _, err := h.db.GetTemplate(ctx, name); err != nil { + if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: name, TeamID: ac.TeamID}); err != nil { writeError(w, http.StatusNotFound, "not_found", "template not found") return } @@ -190,7 +195,7 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) { slog.Warn("delete snapshot: agent RPC failed", "name", name, "error", err) } - if err := h.db.DeleteTemplate(ctx, name); err != nil { + if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: name, TeamID: ac.TeamID}); err != nil { writeError(w, http.StatusInternalServerError, "db_error", "failed to delete template record") return } diff --git a/internal/api/handlers_test_ui.go b/internal/api/handlers_test_ui.go index 7ac3c89..1161866 100644 --- a/internal/api/handlers_test_ui.go +++ b/internal/api/handlers_test_ui.go @@ -109,13 +109,63 @@ const testUIHTML = ` } .clickable { cursor: pointer; color: #89a785; text-decoration: underline; } .clickable:hover { color: #aacdaa; } + .auth-badge { + display: inline-block; + padding: 2px 8px; + border-radius: 10px; + font-size: 11px; + font-weight: 600; + margin-left: 8px; + } + .auth-badge.authed { background: rgba(94,140,88,0.15); color: #89a785; } + .auth-badge.unauthed { background: rgba(179,85,68,0.15); color: #c27b6d; } + .key-display { + background: #1b201e; + border: 1px solid #5e8c58; + border-radius: 4px; + padding: 8px; + margin-top: 8px; + font-size: 12px; + word-break: break-all; + color: #89a785; + } -

Wrenn Sandbox Test Console

+

Wrenn Sandbox Test Console not authenticated

+ +
+

Authentication

+ + + + +
+ + + +
+
+
+ + +
+

API Keys

+ + +
+ + +
+ +
+ + +
+

Create Sandbox

@@ -189,6 +239,8 @@ const testUIHTML = ` ` diff --git a/internal/api/middleware_apikey.go b/internal/api/middleware_apikey.go new file mode 100644 index 0000000..8a53506 --- /dev/null +++ b/internal/api/middleware_apikey.go @@ -0,0 +1,38 @@ +package api + +import ( + "log/slog" + "net/http" + + "git.omukk.dev/wrenn/sandbox/internal/auth" + "git.omukk.dev/wrenn/sandbox/internal/db" +) + +// requireAPIKey validates the X-API-Key header, looks up the SHA-256 hash in DB, +// and stamps TeamID into the request context. +func requireAPIKey(queries *db.Queries) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + key := r.Header.Get("X-API-Key") + if key == "" { + writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key header required") + return + } + + hash := auth.HashAPIKey(key) + row, err := queries.GetAPIKeyByHash(r.Context(), hash) + if err != nil { + writeError(w, http.StatusUnauthorized, "unauthorized", "invalid API key") + return + } + + // Best-effort update of last_used timestamp. + if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil { + slog.Warn("failed to update api key last_used", "key_id", row.ID, "error", err) + } + + ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{TeamID: row.TeamID}) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} diff --git a/internal/api/middleware_jwt.go b/internal/api/middleware_jwt.go new file mode 100644 index 0000000..c071064 --- /dev/null +++ b/internal/api/middleware_jwt.go @@ -0,0 +1,36 @@ +package api + +import ( + "net/http" + "strings" + + "git.omukk.dev/wrenn/sandbox/internal/auth" +) + +// requireJWT validates the Authorization: Bearer header, verifies the JWT +// signature and expiry, and stamps UserID + TeamID + Email into the request context. +func requireJWT(secret []byte) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + header := r.Header.Get("Authorization") + if !strings.HasPrefix(header, "Bearer ") { + writeError(w, http.StatusUnauthorized, "unauthorized", "Authorization: Bearer required") + return + } + + tokenStr := strings.TrimPrefix(header, "Bearer ") + claims, err := auth.VerifyJWT(secret, tokenStr) + if err != nil { + writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired token") + return + } + + ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{ + TeamID: claims.TeamID, + UserID: claims.Subject, + Email: claims.Email, + }) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} diff --git a/internal/api/openapi.yaml b/internal/api/openapi.yaml index b67c693..6ca4ff1 100644 --- a/internal/api/openapi.yaml +++ b/internal/api/openapi.yaml @@ -8,11 +8,133 @@ servers: - url: http://localhost:8080 description: Local development +security: [] + paths: + /v1/auth/signup: + post: + summary: Create a new account + operationId: signup + tags: [auth] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/SignupRequest" + responses: + "201": + description: Account created + content: + application/json: + schema: + $ref: "#/components/schemas/AuthResponse" + "400": + description: Invalid request (bad email, short password) + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "409": + description: Email already registered + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/auth/login: + post: + summary: Log in with email and password + operationId: login + tags: [auth] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/LoginRequest" + responses: + "200": + description: Login successful + content: + application/json: + schema: + $ref: "#/components/schemas/AuthResponse" + "401": + description: Invalid credentials + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/api-keys: + post: + summary: Create an API key + operationId: createAPIKey + tags: [api-keys] + security: + - bearerAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateAPIKeyRequest" + responses: + "201": + description: API key created (plaintext key only shown once) + content: + application/json: + schema: + $ref: "#/components/schemas/APIKeyResponse" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + get: + summary: List API keys for your team + operationId: listAPIKeys + tags: [api-keys] + security: + - bearerAuth: [] + responses: + "200": + description: List of API keys (plaintext keys are never returned) + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/APIKeyResponse" + + /v1/api-keys/{id}: + parameters: + - name: id + in: path + required: true + schema: + type: string + + delete: + summary: Delete an API key + operationId: deleteAPIKey + tags: [api-keys] + security: + - bearerAuth: [] + responses: + "204": + description: API key deleted + /v1/sandboxes: post: summary: Create a sandbox operationId: createSandbox + tags: [sandboxes] + security: + - apiKeyAuth: [] requestBody: required: true content: @@ -34,8 +156,11 @@ paths: $ref: "#/components/schemas/Error" get: - summary: List all sandboxes + summary: List sandboxes for your team operationId: listSandboxes + tags: [sandboxes] + security: + - apiKeyAuth: [] responses: "200": description: List of sandboxes @@ -57,6 +182,9 @@ paths: get: summary: Get sandbox details operationId: getSandbox + tags: [sandboxes] + security: + - apiKeyAuth: [] responses: "200": description: Sandbox details @@ -74,6 +202,9 @@ paths: delete: summary: Destroy a sandbox operationId: destroySandbox + tags: [sandboxes] + security: + - apiKeyAuth: [] responses: "204": description: Sandbox destroyed @@ -89,6 +220,9 @@ paths: post: summary: Execute a command operationId: execCommand + tags: [sandboxes] + security: + - apiKeyAuth: [] requestBody: required: true content: @@ -126,6 +260,9 @@ paths: post: summary: Pause a running sandbox operationId: pauseSandbox + tags: [sandboxes] + security: + - apiKeyAuth: [] description: | Takes a snapshot of the sandbox (VM state + memory + rootfs), then destroys all running resources. The sandbox exists only as files on @@ -155,6 +292,9 @@ paths: post: summary: Resume a paused sandbox operationId: resumeSandbox + tags: [sandboxes] + security: + - apiKeyAuth: [] description: | Restores a paused sandbox from its snapshot using UFFD for lazy memory loading. Boots a fresh Firecracker process, sets up a new @@ -177,6 +317,9 @@ paths: post: summary: Create a snapshot template operationId: createSnapshot + tags: [snapshots] + security: + - apiKeyAuth: [] description: | Pauses a running sandbox, takes a full snapshot, copies the snapshot files to the images directory as a reusable template, then destroys @@ -210,8 +353,11 @@ paths: $ref: "#/components/schemas/Error" get: - summary: List templates + summary: List templates for your team operationId: listSnapshots + tags: [snapshots] + security: + - apiKeyAuth: [] parameters: - name: type in: query @@ -241,6 +387,9 @@ paths: delete: summary: Delete a snapshot template operationId: deleteSnapshot + tags: [snapshots] + security: + - apiKeyAuth: [] description: Removes the snapshot files from disk and deletes the database record. responses: "204": @@ -263,6 +412,9 @@ paths: post: summary: Upload a file operationId: uploadFile + tags: [sandboxes] + security: + - apiKeyAuth: [] requestBody: required: true content: @@ -305,6 +457,9 @@ paths: post: summary: Download a file operationId: downloadFile + tags: [sandboxes] + security: + - apiKeyAuth: [] requestBody: required: true content: @@ -337,6 +492,9 @@ paths: get: summary: Stream command execution via WebSocket operationId: execStream + tags: [sandboxes] + security: + - apiKeyAuth: [] description: | Opens a WebSocket connection for streaming command execution. @@ -387,6 +545,9 @@ paths: post: summary: Upload a file (streaming) operationId: streamUploadFile + tags: [sandboxes] + security: + - apiKeyAuth: [] description: | Streams file content to the sandbox without buffering in memory. Suitable for large files. Uses the same multipart/form-data format @@ -433,6 +594,9 @@ paths: post: summary: Download a file (streaming) operationId: streamDownloadFile + tags: [sandboxes] + security: + - apiKeyAuth: [] description: | Streams file content from the sandbox without buffering in memory. Suitable for large files. Returns raw bytes with chunked transfer encoding. @@ -464,7 +628,85 @@ paths: $ref: "#/components/schemas/Error" components: + securitySchemes: + apiKeyAuth: + type: apiKey + in: header + name: X-API-Key + description: API key for sandbox lifecycle operations. Create via POST /v1/api-keys. + + bearerAuth: + type: http + scheme: bearer + bearerFormat: JWT + description: JWT token from /v1/auth/login or /v1/auth/signup. Valid for 6 hours. + schemas: + SignupRequest: + type: object + required: [email, password] + properties: + email: + type: string + format: email + password: + type: string + minLength: 8 + + LoginRequest: + type: object + required: [email, password] + properties: + email: + type: string + format: email + password: + type: string + + AuthResponse: + type: object + properties: + token: + type: string + description: JWT token (valid for 6 hours) + user_id: + type: string + team_id: + type: string + email: + type: string + + CreateAPIKeyRequest: + type: object + properties: + name: + type: string + default: Unnamed API Key + + APIKeyResponse: + type: object + properties: + id: + type: string + team_id: + type: string + name: + type: string + key_prefix: + type: string + description: Display prefix (e.g. "wrn_ab12cd34...") + created_at: + type: string + format: date-time + last_used: + type: string + format: date-time + nullable: true + key: + type: string + description: Full plaintext key. Only returned on creation, never again. + nullable: true + CreateSandboxRequest: type: object properties: diff --git a/internal/api/server.go b/internal/api/server.go index 78dc26b..286bf3e 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -6,6 +6,7 @@ import ( "net/http" "github.com/go-chi/chi/v5" + "github.com/jackc/pgx/v5/pgxpool" "git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" @@ -20,7 +21,7 @@ type Server struct { } // New constructs the chi router and registers all routes. -func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *Server { +func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, pool *pgxpool.Pool, jwtSecret []byte) *Server { r := chi.NewRouter() r.Use(requestLogger()) @@ -30,6 +31,8 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient) * files := newFilesHandler(queries, agent) filesStream := newFilesStreamHandler(queries, agent) snapshots := newSnapshotHandler(queries, agent) + authH := newAuthHandler(queries, pool, jwtSecret) + apiKeys := newAPIKeyHandler(queries) // OpenAPI spec and docs. r.Get("/openapi.yaml", serveOpenAPI) @@ -38,8 +41,21 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient) * // Test UI for sandbox lifecycle management. r.Get("/test", serveTestUI) - // Sandbox CRUD. + // Unauthenticated auth endpoints. + r.Post("/v1/auth/signup", authH.Signup) + r.Post("/v1/auth/login", authH.Login) + + // JWT-authenticated: API key management. + r.Route("/v1/api-keys", func(r chi.Router) { + r.Use(requireJWT(jwtSecret)) + r.Post("/", apiKeys.Create) + r.Get("/", apiKeys.List) + r.Delete("/{id}", apiKeys.Delete) + }) + + // API-key-authenticated: sandbox lifecycle. r.Route("/v1/sandboxes", func(r chi.Router) { + r.Use(requireAPIKey(queries)) r.Post("/", sandbox.Create) r.Get("/", sandbox.List) @@ -57,8 +73,9 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient) * }) }) - // Snapshot / template management. + // API-key-authenticated: snapshot / template management. r.Route("/v1/snapshots", func(r chi.Router) { + r.Use(requireAPIKey(queries)) r.Post("/", snapshots.Create) r.Get("/", snapshots.List) r.Delete("/{name}", snapshots.Delete) diff --git a/internal/auth/apikey.go b/internal/auth/apikey.go index 8832b06..7e315ee 100644 --- a/internal/auth/apikey.go +++ b/internal/auth/apikey.go @@ -1 +1,35 @@ package auth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "fmt" +) + +// GenerateAPIKey returns a plaintext key in the form "wrn_" + 32 random hex chars +// and its SHA-256 hash. The caller must show the plaintext to the user exactly once; +// only the hash is stored. +func GenerateAPIKey() (plaintext, hash string, err error) { + b := make([]byte, 16) // 16 bytes → 32 hex chars + if _, err = rand.Read(b); err != nil { + return "", "", fmt.Errorf("generate api key: %w", err) + } + plaintext = "wrn_" + hex.EncodeToString(b) + hash = HashAPIKey(plaintext) + return plaintext, hash, nil +} + +// HashAPIKey returns the hex-encoded SHA-256 hash of a plaintext API key. +func HashAPIKey(plaintext string) string { + sum := sha256.Sum256([]byte(plaintext)) + return hex.EncodeToString(sum[:]) +} + +// APIKeyPrefix returns the displayable prefix of an API key (e.g. "wrn_ab12..."). +func APIKeyPrefix(plaintext string) string { + if len(plaintext) > 12 { + return plaintext[:12] + "..." + } + return plaintext +} diff --git a/internal/auth/context.go b/internal/auth/context.go new file mode 100644 index 0000000..cab29ed --- /dev/null +++ b/internal/auth/context.go @@ -0,0 +1,35 @@ +package auth + +import "context" + +type contextKey int + +const authCtxKey contextKey = 0 + +// AuthContext is stamped into request context by auth middleware. +type AuthContext struct { + TeamID string + UserID string // empty when authenticated via API key + Email string // empty when authenticated via API key +} + +// WithAuthContext returns a new context with the given AuthContext. +func WithAuthContext(ctx context.Context, a AuthContext) context.Context { + return context.WithValue(ctx, authCtxKey, a) +} + +// FromContext retrieves the AuthContext. Returns zero value and false if absent. +func FromContext(ctx context.Context) (AuthContext, bool) { + a, ok := ctx.Value(authCtxKey).(AuthContext) + return a, ok +} + +// MustFromContext retrieves the AuthContext. Panics if absent — only call +// inside handlers behind auth middleware. +func MustFromContext(ctx context.Context) AuthContext { + a, ok := FromContext(ctx) + if !ok { + panic("auth: MustFromContext called on unauthenticated request") + } + return a +} diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go new file mode 100644 index 0000000..4015f2c --- /dev/null +++ b/internal/auth/jwt.go @@ -0,0 +1,51 @@ +package auth + +import ( + "fmt" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +const jwtExpiry = 6 * time.Hour + +// Claims are the JWT payload. +type Claims struct { + TeamID string `json:"team_id"` + Email string `json:"email"` + jwt.RegisteredClaims +} + +// SignJWT signs a new 6-hour JWT for the given user. +func SignJWT(secret []byte, userID, teamID, email string) (string, error) { + now := time.Now() + claims := Claims{ + TeamID: teamID, + Email: email, + RegisteredClaims: jwt.RegisteredClaims{ + Subject: userID, + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(now.Add(jwtExpiry)), + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString(secret) +} + +// VerifyJWT parses and validates a JWT, returning the claims on success. +func VerifyJWT(secret []byte, tokenStr string) (Claims, error) { + token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) + } + return secret, nil + }) + if err != nil { + return Claims{}, fmt.Errorf("invalid token: %w", err) + } + c, ok := token.Claims.(*Claims) + if !ok || !token.Valid { + return Claims{}, fmt.Errorf("invalid token claims") + } + return *c, nil +} diff --git a/internal/auth/password.go b/internal/auth/password.go new file mode 100644 index 0000000..0c285a6 --- /dev/null +++ b/internal/auth/password.go @@ -0,0 +1,16 @@ +package auth + +import "golang.org/x/crypto/bcrypt" + +const bcryptCost = 12 + +// HashPassword returns the bcrypt hash of a plaintext password. +func HashPassword(plaintext string) (string, error) { + b, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcryptCost) + return string(b), err +} + +// CheckPassword returns nil if plaintext matches the stored hash. +func CheckPassword(hash, plaintext string) error { + return bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintext)) +} diff --git a/internal/auth/ratelimit.go b/internal/auth/ratelimit.go deleted file mode 100644 index 8832b06..0000000 --- a/internal/auth/ratelimit.go +++ /dev/null @@ -1 +0,0 @@ -package auth diff --git a/internal/config/config.go b/internal/config/config.go index ebf81c1..29a5b08 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,6 +3,8 @@ package config import ( "os" "strings" + + "github.com/joho/godotenv" ) // Config holds the control plane configuration. @@ -10,14 +12,20 @@ type Config struct { DatabaseURL string ListenAddr string HostAgentAddr string + JWTSecret string } -// Load reads configuration from environment variables. +// Load reads configuration from a .env file (if present) and environment variables. +// Real environment variables take precedence over .env values. func Load() Config { + // Best-effort load — missing .env file is fine. + _ = godotenv.Load() + cfg := Config{ DatabaseURL: envOrDefault("DATABASE_URL", "postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable"), ListenAddr: envOrDefault("CP_LISTEN_ADDR", ":8080"), HostAgentAddr: envOrDefault("CP_HOST_AGENT_ADDR", "http://localhost:50051"), + JWTSecret: os.Getenv("JWT_SECRET"), } // Ensure the host agent address has a scheme. diff --git a/internal/db/api_keys.sql.go b/internal/db/api_keys.sql.go new file mode 100644 index 0000000..5af21ff --- /dev/null +++ b/internal/db/api_keys.sql.go @@ -0,0 +1,124 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: api_keys.sql + +package db + +import ( + "context" +) + +const deleteAPIKey = `-- name: DeleteAPIKey :exec +DELETE FROM team_api_keys WHERE id = $1 AND team_id = $2 +` + +type DeleteAPIKeyParams struct { + ID string `json:"id"` + TeamID string `json:"team_id"` +} + +func (q *Queries) DeleteAPIKey(ctx context.Context, arg DeleteAPIKeyParams) error { + _, err := q.db.Exec(ctx, deleteAPIKey, arg.ID, arg.TeamID) + return err +} + +const getAPIKeyByHash = `-- name: GetAPIKeyByHash :one +SELECT id, team_id, name, key_hash, key_prefix, created_by, created_at, last_used FROM team_api_keys WHERE key_hash = $1 +` + +func (q *Queries) GetAPIKeyByHash(ctx context.Context, keyHash string) (TeamApiKey, error) { + row := q.db.QueryRow(ctx, getAPIKeyByHash, keyHash) + var i TeamApiKey + err := row.Scan( + &i.ID, + &i.TeamID, + &i.Name, + &i.KeyHash, + &i.KeyPrefix, + &i.CreatedBy, + &i.CreatedAt, + &i.LastUsed, + ) + return i, err +} + +const insertAPIKey = `-- name: InsertAPIKey :one +INSERT INTO team_api_keys (id, team_id, name, key_hash, key_prefix, created_by) +VALUES ($1, $2, $3, $4, $5, $6) +RETURNING id, team_id, name, key_hash, key_prefix, created_by, created_at, last_used +` + +type InsertAPIKeyParams struct { + ID string `json:"id"` + TeamID string `json:"team_id"` + Name string `json:"name"` + KeyHash string `json:"key_hash"` + KeyPrefix string `json:"key_prefix"` + CreatedBy string `json:"created_by"` +} + +func (q *Queries) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (TeamApiKey, error) { + row := q.db.QueryRow(ctx, insertAPIKey, + arg.ID, + arg.TeamID, + arg.Name, + arg.KeyHash, + arg.KeyPrefix, + arg.CreatedBy, + ) + var i TeamApiKey + err := row.Scan( + &i.ID, + &i.TeamID, + &i.Name, + &i.KeyHash, + &i.KeyPrefix, + &i.CreatedBy, + &i.CreatedAt, + &i.LastUsed, + ) + return i, err +} + +const listAPIKeysByTeam = `-- name: ListAPIKeysByTeam :many +SELECT id, team_id, name, key_hash, key_prefix, created_by, created_at, last_used FROM team_api_keys WHERE team_id = $1 ORDER BY created_at DESC +` + +func (q *Queries) ListAPIKeysByTeam(ctx context.Context, teamID string) ([]TeamApiKey, error) { + rows, err := q.db.Query(ctx, listAPIKeysByTeam, teamID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []TeamApiKey + for rows.Next() { + var i TeamApiKey + if err := rows.Scan( + &i.ID, + &i.TeamID, + &i.Name, + &i.KeyHash, + &i.KeyPrefix, + &i.CreatedBy, + &i.CreatedAt, + &i.LastUsed, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateAPIKeyLastUsed = `-- name: UpdateAPIKeyLastUsed :exec +UPDATE team_api_keys SET last_used = NOW() WHERE id = $1 +` + +func (q *Queries) UpdateAPIKeyLastUsed(ctx context.Context, id string) error { + _, err := q.db.Exec(ctx, updateAPIKeyLastUsed, id) + return err +} diff --git a/internal/db/models.go b/internal/db/models.go index 0c992e5..fc5bbe8 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -10,7 +10,6 @@ import ( type Sandbox struct { ID string `json:"id"` - OwnerID string `json:"owner_id"` HostID string `json:"host_id"` Template string `json:"template"` Status string `json:"status"` @@ -23,6 +22,24 @@ type Sandbox struct { StartedAt pgtype.Timestamptz `json:"started_at"` LastActiveAt pgtype.Timestamptz `json:"last_active_at"` LastUpdated pgtype.Timestamptz `json:"last_updated"` + TeamID string `json:"team_id"` +} + +type Team struct { + ID string `json:"id"` + Name string `json:"name"` + CreatedAt pgtype.Timestamptz `json:"created_at"` +} + +type TeamApiKey struct { + ID string `json:"id"` + TeamID string `json:"team_id"` + Name string `json:"name"` + KeyHash string `json:"key_hash"` + KeyPrefix string `json:"key_prefix"` + CreatedBy string `json:"created_by"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + LastUsed pgtype.Timestamptz `json:"last_used"` } type Template struct { @@ -32,4 +49,21 @@ type Template struct { MemoryMb pgtype.Int4 `json:"memory_mb"` SizeBytes int64 `json:"size_bytes"` CreatedAt pgtype.Timestamptz `json:"created_at"` + TeamID string `json:"team_id"` +} + +type User struct { + ID string `json:"id"` + Email string `json:"email"` + PasswordHash string `json:"password_hash"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +type UsersTeam struct { + UserID string `json:"user_id"` + TeamID string `json:"team_id"` + IsDefault bool `json:"is_default"` + Role string `json:"role"` + CreatedAt pgtype.Timestamptz `json:"created_at"` } diff --git a/internal/db/sandboxes.sql.go b/internal/db/sandboxes.sql.go index e11f8f2..577f1d0 100644 --- a/internal/db/sandboxes.sql.go +++ b/internal/db/sandboxes.sql.go @@ -29,7 +29,7 @@ func (q *Queries) BulkUpdateStatusByIDs(ctx context.Context, arg BulkUpdateStatu } const getSandbox = `-- name: GetSandbox :one -SELECT id, owner_id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes WHERE id = $1 +SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes WHERE id = $1 ` func (q *Queries) GetSandbox(ctx context.Context, id string) (Sandbox, error) { @@ -37,7 +37,6 @@ func (q *Queries) GetSandbox(ctx context.Context, id string) (Sandbox, error) { var i Sandbox err := row.Scan( &i.ID, - &i.OwnerID, &i.HostID, &i.Template, &i.Status, @@ -50,19 +49,51 @@ func (q *Queries) GetSandbox(ctx context.Context, id string) (Sandbox, error) { &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, + &i.TeamID, + ) + return i, err +} + +const getSandboxByTeam = `-- name: GetSandboxByTeam :one +SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes WHERE id = $1 AND team_id = $2 +` + +type GetSandboxByTeamParams struct { + ID string `json:"id"` + TeamID string `json:"team_id"` +} + +func (q *Queries) GetSandboxByTeam(ctx context.Context, arg GetSandboxByTeamParams) (Sandbox, error) { + row := q.db.QueryRow(ctx, getSandboxByTeam, arg.ID, arg.TeamID) + var i Sandbox + err := row.Scan( + &i.ID, + &i.HostID, + &i.Template, + &i.Status, + &i.Vcpus, + &i.MemoryMb, + &i.TimeoutSec, + &i.GuestIp, + &i.HostIp, + &i.CreatedAt, + &i.StartedAt, + &i.LastActiveAt, + &i.LastUpdated, + &i.TeamID, ) return i, err } const insertSandbox = `-- name: InsertSandbox :one -INSERT INTO sandboxes (id, owner_id, host_id, template, status, vcpus, memory_mb, timeout_sec) +INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) -RETURNING id, owner_id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated +RETURNING id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id ` type InsertSandboxParams struct { ID string `json:"id"` - OwnerID string `json:"owner_id"` + TeamID string `json:"team_id"` HostID string `json:"host_id"` Template string `json:"template"` Status string `json:"status"` @@ -74,7 +105,7 @@ type InsertSandboxParams struct { func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (Sandbox, error) { row := q.db.QueryRow(ctx, insertSandbox, arg.ID, - arg.OwnerID, + arg.TeamID, arg.HostID, arg.Template, arg.Status, @@ -85,7 +116,6 @@ func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (S var i Sandbox err := row.Scan( &i.ID, - &i.OwnerID, &i.HostID, &i.Template, &i.Status, @@ -98,12 +128,13 @@ func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (S &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, + &i.TeamID, ) return i, err } const listSandboxes = `-- name: ListSandboxes :many -SELECT id, owner_id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes ORDER BY created_at DESC +SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes ORDER BY created_at DESC ` func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) { @@ -117,7 +148,6 @@ func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) { var i Sandbox if err := rows.Scan( &i.ID, - &i.OwnerID, &i.HostID, &i.Template, &i.Status, @@ -130,6 +160,7 @@ func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) { &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, + &i.TeamID, ); err != nil { return nil, err } @@ -142,7 +173,7 @@ func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) { } const listSandboxesByHostAndStatus = `-- name: ListSandboxesByHostAndStatus :many -SELECT id, owner_id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes +SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes WHERE host_id = $1 AND status = ANY($2::text[]) ORDER BY created_at DESC ` @@ -163,7 +194,6 @@ func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSand var i Sandbox if err := rows.Scan( &i.ID, - &i.OwnerID, &i.HostID, &i.Template, &i.Status, @@ -176,6 +206,46 @@ func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSand &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, + &i.TeamID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listSandboxesByTeam = `-- name: ListSandboxesByTeam :many +SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes WHERE team_id = $1 ORDER BY created_at DESC +` + +func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID string) ([]Sandbox, error) { + rows, err := q.db.Query(ctx, listSandboxesByTeam, teamID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Sandbox + for rows.Next() { + var i Sandbox + if err := rows.Scan( + &i.ID, + &i.HostID, + &i.Template, + &i.Status, + &i.Vcpus, + &i.MemoryMb, + &i.TimeoutSec, + &i.GuestIp, + &i.HostIp, + &i.CreatedAt, + &i.StartedAt, + &i.LastActiveAt, + &i.LastUpdated, + &i.TeamID, ); err != nil { return nil, err } @@ -213,7 +283,7 @@ SET status = 'running', last_active_at = $4, last_updated = NOW() WHERE id = $1 -RETURNING id, owner_id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated +RETURNING id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id ` type UpdateSandboxRunningParams struct { @@ -233,7 +303,6 @@ func (q *Queries) UpdateSandboxRunning(ctx context.Context, arg UpdateSandboxRun var i Sandbox err := row.Scan( &i.ID, - &i.OwnerID, &i.HostID, &i.Template, &i.Status, @@ -246,6 +315,7 @@ func (q *Queries) UpdateSandboxRunning(ctx context.Context, arg UpdateSandboxRun &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, + &i.TeamID, ) return i, err } @@ -255,7 +325,7 @@ UPDATE sandboxes SET status = $2, last_updated = NOW() WHERE id = $1 -RETURNING id, owner_id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated +RETURNING id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id ` type UpdateSandboxStatusParams struct { @@ -268,7 +338,6 @@ func (q *Queries) UpdateSandboxStatus(ctx context.Context, arg UpdateSandboxStat var i Sandbox err := row.Scan( &i.ID, - &i.OwnerID, &i.HostID, &i.Template, &i.Status, @@ -281,6 +350,7 @@ func (q *Queries) UpdateSandboxStatus(ctx context.Context, arg UpdateSandboxStat &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, + &i.TeamID, ) return i, err } diff --git a/internal/db/teams.sql.go b/internal/db/teams.sql.go new file mode 100644 index 0000000..61d03bb --- /dev/null +++ b/internal/db/teams.sql.go @@ -0,0 +1,75 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: teams.sql + +package db + +import ( + "context" +) + +const getDefaultTeamForUser = `-- name: GetDefaultTeamForUser :one +SELECT t.id, t.name, t.created_at FROM teams t +JOIN users_teams ut ON ut.team_id = t.id +WHERE ut.user_id = $1 AND ut.is_default = TRUE +LIMIT 1 +` + +func (q *Queries) GetDefaultTeamForUser(ctx context.Context, userID string) (Team, error) { + row := q.db.QueryRow(ctx, getDefaultTeamForUser, userID) + var i Team + err := row.Scan(&i.ID, &i.Name, &i.CreatedAt) + return i, err +} + +const getTeam = `-- name: GetTeam :one +SELECT id, name, created_at FROM teams WHERE id = $1 +` + +func (q *Queries) GetTeam(ctx context.Context, id string) (Team, error) { + row := q.db.QueryRow(ctx, getTeam, id) + var i Team + err := row.Scan(&i.ID, &i.Name, &i.CreatedAt) + return i, err +} + +const insertTeam = `-- name: InsertTeam :one +INSERT INTO teams (id, name) +VALUES ($1, $2) +RETURNING id, name, created_at +` + +type InsertTeamParams struct { + ID string `json:"id"` + Name string `json:"name"` +} + +func (q *Queries) InsertTeam(ctx context.Context, arg InsertTeamParams) (Team, error) { + row := q.db.QueryRow(ctx, insertTeam, arg.ID, arg.Name) + var i Team + err := row.Scan(&i.ID, &i.Name, &i.CreatedAt) + return i, err +} + +const insertTeamMember = `-- name: InsertTeamMember :exec +INSERT INTO users_teams (user_id, team_id, is_default, role) +VALUES ($1, $2, $3, $4) +` + +type InsertTeamMemberParams struct { + UserID string `json:"user_id"` + TeamID string `json:"team_id"` + IsDefault bool `json:"is_default"` + Role string `json:"role"` +} + +func (q *Queries) InsertTeamMember(ctx context.Context, arg InsertTeamMemberParams) error { + _, err := q.db.Exec(ctx, insertTeamMember, + arg.UserID, + arg.TeamID, + arg.IsDefault, + arg.Role, + ) + return err +} diff --git a/internal/db/templates.sql.go b/internal/db/templates.sql.go index 6e1653e..cafae69 100644 --- a/internal/db/templates.sql.go +++ b/internal/db/templates.sql.go @@ -20,8 +20,22 @@ func (q *Queries) DeleteTemplate(ctx context.Context, name string) error { return err } +const deleteTemplateByTeam = `-- name: DeleteTemplateByTeam :exec +DELETE FROM templates WHERE name = $1 AND team_id = $2 +` + +type DeleteTemplateByTeamParams struct { + Name string `json:"name"` + TeamID string `json:"team_id"` +} + +func (q *Queries) DeleteTemplateByTeam(ctx context.Context, arg DeleteTemplateByTeamParams) error { + _, err := q.db.Exec(ctx, deleteTemplateByTeam, arg.Name, arg.TeamID) + return err +} + const getTemplate = `-- name: GetTemplate :one -SELECT name, type, vcpus, memory_mb, size_bytes, created_at FROM templates WHERE name = $1 +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE name = $1 ` func (q *Queries) GetTemplate(ctx context.Context, name string) (Template, error) { @@ -34,14 +48,39 @@ func (q *Queries) GetTemplate(ctx context.Context, name string) (Template, error &i.MemoryMb, &i.SizeBytes, &i.CreatedAt, + &i.TeamID, + ) + return i, err +} + +const getTemplateByTeam = `-- name: GetTemplateByTeam :one +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE name = $1 AND team_id = $2 +` + +type GetTemplateByTeamParams struct { + Name string `json:"name"` + TeamID string `json:"team_id"` +} + +func (q *Queries) GetTemplateByTeam(ctx context.Context, arg GetTemplateByTeamParams) (Template, error) { + row := q.db.QueryRow(ctx, getTemplateByTeam, arg.Name, arg.TeamID) + var i Template + err := row.Scan( + &i.Name, + &i.Type, + &i.Vcpus, + &i.MemoryMb, + &i.SizeBytes, + &i.CreatedAt, + &i.TeamID, ) return i, err } const insertTemplate = `-- name: InsertTemplate :one -INSERT INTO templates (name, type, vcpus, memory_mb, size_bytes) -VALUES ($1, $2, $3, $4, $5) -RETURNING name, type, vcpus, memory_mb, size_bytes, created_at +INSERT INTO templates (name, type, vcpus, memory_mb, size_bytes, team_id) +VALUES ($1, $2, $3, $4, $5, $6) +RETURNING name, type, vcpus, memory_mb, size_bytes, created_at, team_id ` type InsertTemplateParams struct { @@ -50,6 +89,7 @@ type InsertTemplateParams struct { Vcpus pgtype.Int4 `json:"vcpus"` MemoryMb pgtype.Int4 `json:"memory_mb"` SizeBytes int64 `json:"size_bytes"` + TeamID string `json:"team_id"` } func (q *Queries) InsertTemplate(ctx context.Context, arg InsertTemplateParams) (Template, error) { @@ -59,6 +99,7 @@ func (q *Queries) InsertTemplate(ctx context.Context, arg InsertTemplateParams) arg.Vcpus, arg.MemoryMb, arg.SizeBytes, + arg.TeamID, ) var i Template err := row.Scan( @@ -68,12 +109,13 @@ func (q *Queries) InsertTemplate(ctx context.Context, arg InsertTemplateParams) &i.MemoryMb, &i.SizeBytes, &i.CreatedAt, + &i.TeamID, ) return i, err } const listTemplates = `-- name: ListTemplates :many -SELECT name, type, vcpus, memory_mb, size_bytes, created_at FROM templates ORDER BY created_at DESC +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates ORDER BY created_at DESC ` func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) { @@ -92,6 +134,76 @@ func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) { &i.MemoryMb, &i.SizeBytes, &i.CreatedAt, + &i.TeamID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listTemplatesByTeam = `-- name: ListTemplatesByTeam :many +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE team_id = $1 ORDER BY created_at DESC +` + +func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID string) ([]Template, error) { + rows, err := q.db.Query(ctx, listTemplatesByTeam, teamID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Template + for rows.Next() { + var i Template + if err := rows.Scan( + &i.Name, + &i.Type, + &i.Vcpus, + &i.MemoryMb, + &i.SizeBytes, + &i.CreatedAt, + &i.TeamID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listTemplatesByTeamAndType = `-- name: ListTemplatesByTeamAndType :many +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE team_id = $1 AND type = $2 ORDER BY created_at DESC +` + +type ListTemplatesByTeamAndTypeParams struct { + TeamID string `json:"team_id"` + Type string `json:"type"` +} + +func (q *Queries) ListTemplatesByTeamAndType(ctx context.Context, arg ListTemplatesByTeamAndTypeParams) ([]Template, error) { + rows, err := q.db.Query(ctx, listTemplatesByTeamAndType, arg.TeamID, arg.Type) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Template + for rows.Next() { + var i Template + if err := rows.Scan( + &i.Name, + &i.Type, + &i.Vcpus, + &i.MemoryMb, + &i.SizeBytes, + &i.CreatedAt, + &i.TeamID, ); err != nil { return nil, err } @@ -104,7 +216,7 @@ func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) { } const listTemplatesByType = `-- name: ListTemplatesByType :many -SELECT name, type, vcpus, memory_mb, size_bytes, created_at FROM templates WHERE type = $1 ORDER BY created_at DESC +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE type = $1 ORDER BY created_at DESC ` func (q *Queries) ListTemplatesByType(ctx context.Context, type_ string) ([]Template, error) { @@ -123,6 +235,7 @@ func (q *Queries) ListTemplatesByType(ctx context.Context, type_ string) ([]Temp &i.MemoryMb, &i.SizeBytes, &i.CreatedAt, + &i.TeamID, ); err != nil { return nil, err } diff --git a/internal/db/users.sql.go b/internal/db/users.sql.go new file mode 100644 index 0000000..d6277a1 --- /dev/null +++ b/internal/db/users.sql.go @@ -0,0 +1,69 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: users.sql + +package db + +import ( + "context" +) + +const getUserByEmail = `-- name: GetUserByEmail :one +SELECT id, email, password_hash, created_at, updated_at FROM users WHERE email = $1 +` + +func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error) { + row := q.db.QueryRow(ctx, getUserByEmail, email) + var i User + err := row.Scan( + &i.ID, + &i.Email, + &i.PasswordHash, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getUserByID = `-- name: GetUserByID :one +SELECT id, email, password_hash, created_at, updated_at FROM users WHERE id = $1 +` + +func (q *Queries) GetUserByID(ctx context.Context, id string) (User, error) { + row := q.db.QueryRow(ctx, getUserByID, id) + var i User + err := row.Scan( + &i.ID, + &i.Email, + &i.PasswordHash, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const insertUser = `-- name: InsertUser :one +INSERT INTO users (id, email, password_hash) +VALUES ($1, $2, $3) +RETURNING id, email, password_hash, created_at, updated_at +` + +type InsertUserParams struct { + ID string `json:"id"` + Email string `json:"email"` + PasswordHash string `json:"password_hash"` +} + +func (q *Queries) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) { + row := q.db.QueryRow(ctx, insertUser, arg.ID, arg.Email, arg.PasswordHash) + var i User + err := row.Scan( + &i.ID, + &i.Email, + &i.PasswordHash, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} diff --git a/internal/id/id.go b/internal/id/id.go index 5b561a5..eedf5f4 100644 --- a/internal/id/id.go +++ b/internal/id/id.go @@ -6,20 +6,35 @@ import ( "fmt" ) -// NewSandboxID generates a new sandbox ID in the format "sb-" + 8 hex chars. -func NewSandboxID() string { +func hex8() string { b := make([]byte, 4) if _, err := rand.Read(b); err != nil { panic(fmt.Sprintf("crypto/rand failed: %v", err)) } - return "sb-" + hex.EncodeToString(b) + return hex.EncodeToString(b) +} + +// NewSandboxID generates a new sandbox ID in the format "sb-" + 8 hex chars. +func NewSandboxID() string { + return "sb-" + hex8() } // NewSnapshotName generates a snapshot name in the format "template-" + 8 hex chars. func NewSnapshotName() string { - b := make([]byte, 4) - if _, err := rand.Read(b); err != nil { - panic(fmt.Sprintf("crypto/rand failed: %v", err)) - } - return "template-" + hex.EncodeToString(b) + return "template-" + hex8() +} + +// NewUserID generates a new user ID in the format "usr-" + 8 hex chars. +func NewUserID() string { + return "usr-" + hex8() +} + +// NewTeamID generates a new team ID in the format "team-" + 8 hex chars. +func NewTeamID() string { + return "team-" + hex8() +} + +// NewAPIKeyID generates a new API key ID in the format "key-" + 8 hex chars. +func NewAPIKeyID() string { + return "key-" + hex8() }