forked from wrenn/wrenn
v0.0.1 (#8)
Co-authored-by: Tasnim Kabir Sadik <tksadik92@gmail.com> Reviewed-on: wrenn/sandbox#8
This commit is contained in:
22
internal/api/agent_helper.go
Normal file
22
internal/api/agent_helper.go
Normal file
@ -0,0 +1,22 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
// agentForHost looks up the host record and returns a Connect RPC client for it.
|
||||
// Returns an error if the host is not found or has no address.
|
||||
func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, hostID pgtype.UUID) (hostagentv1connect.HostAgentServiceClient, error) {
|
||||
host, err := queries.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
return pool.GetForHost(host)
|
||||
}
|
||||
230
internal/api/handler_sandbox_proxy.go
Normal file
230
internal/api/handler_sandbox_proxy.go
Normal file
@ -0,0 +1,230 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"path"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
)
|
||||
|
||||
// Sentinel errors returned by proxyTarget, used to map to HTTP status codes
|
||||
// without relying on error message text.
|
||||
var (
|
||||
errProxySandboxNotFound = errors.New("sandbox not found")
|
||||
errProxyNoHostAddress = errors.New("host agent has no address")
|
||||
)
|
||||
|
||||
const proxyCacheTTL = 120 * time.Second
|
||||
|
||||
// sandboxHostPattern matches hostnames like "49999-cl-abcd1234.localhost" or
|
||||
// "49999-cl-abcd1234.example.com". Captures: port, sandbox ID.
|
||||
var sandboxHostPattern = regexp.MustCompile(`^(\d+)-(cl-[0-9a-z]+)\.`)
|
||||
|
||||
// errProxySandboxNotRunning carries the sandbox status so callers can include
|
||||
// it in the HTTP response without parsing error strings.
|
||||
type errProxySandboxNotRunning struct{ status string }
|
||||
|
||||
func (e errProxySandboxNotRunning) Error() string {
|
||||
return fmt.Sprintf("sandbox is not running (status: %s)", e.status)
|
||||
}
|
||||
|
||||
// proxyCacheEntry caches the resolved agent URL for a (sandbox, team) pair.
|
||||
// The *httputil.ReverseProxy is built per-request (cheap) so the Director closure
|
||||
// can capture the correct port without the cache key needing to include it.
|
||||
type proxyCacheEntry struct {
|
||||
agentURL *url.URL
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// proxyCacheKey is a fixed-size key from two UUIDs, avoids string allocation.
|
||||
type proxyCacheKey [32]byte
|
||||
|
||||
func makeProxyCacheKey(sandboxID, teamID pgtype.UUID) proxyCacheKey {
|
||||
var k proxyCacheKey
|
||||
copy(k[:16], sandboxID.Bytes[:])
|
||||
copy(k[16:], teamID.Bytes[:])
|
||||
return k
|
||||
}
|
||||
|
||||
// SandboxProxyWrapper wraps an existing HTTP handler and intercepts requests
|
||||
// whose Host header matches the {port}-{sandbox_id}.{domain} pattern. Matching
|
||||
// requests are reverse-proxied through the host agent that owns the sandbox.
|
||||
// All other requests are passed through to the inner handler.
|
||||
//
|
||||
// Authentication is via X-API-Key header only (no JWT). The API key's team
|
||||
// must own the sandbox.
|
||||
type SandboxProxyWrapper struct {
|
||||
inner http.Handler
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
transport http.RoundTripper
|
||||
|
||||
cacheMu sync.Mutex
|
||||
cache map[proxyCacheKey]proxyCacheEntry
|
||||
}
|
||||
|
||||
// NewSandboxProxyWrapper creates a new proxy wrapper.
|
||||
func NewSandboxProxyWrapper(inner http.Handler, queries *db.Queries, pool *lifecycle.HostClientPool) *SandboxProxyWrapper {
|
||||
return &SandboxProxyWrapper{
|
||||
inner: inner,
|
||||
db: queries,
|
||||
pool: pool,
|
||||
transport: pool.Transport(),
|
||||
cache: make(map[proxyCacheKey]proxyCacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
// proxyTarget looks up the cached agent URL for (sandboxID, teamID).
|
||||
// On a miss it queries the DB, resolves the address, and populates the cache.
|
||||
// The *httputil.ReverseProxy is built by the caller so the Director closure
|
||||
// captures the correct port without the cache key needing to include it.
|
||||
func (h *SandboxProxyWrapper) proxyTarget(ctx context.Context, sandboxID, teamID pgtype.UUID) (*url.URL, error) {
|
||||
cacheKey := makeProxyCacheKey(sandboxID, teamID)
|
||||
|
||||
h.cacheMu.Lock()
|
||||
entry, ok := h.cache[cacheKey]
|
||||
h.cacheMu.Unlock()
|
||||
|
||||
if ok && time.Now().Before(entry.expiresAt) {
|
||||
return entry.agentURL, nil
|
||||
}
|
||||
|
||||
// Cache miss or expired — query DB.
|
||||
target, err := h.db.GetSandboxProxyTarget(ctx, db.GetSandboxProxyTargetParams{
|
||||
ID: sandboxID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errProxySandboxNotFound
|
||||
}
|
||||
if target.Status != "running" {
|
||||
return nil, errProxySandboxNotRunning{status: target.Status}
|
||||
}
|
||||
if target.HostAddress == "" {
|
||||
return nil, errProxyNoHostAddress
|
||||
}
|
||||
|
||||
agentURL, err := url.Parse(h.pool.ResolveAddr(target.HostAddress))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid host agent address: %w", err)
|
||||
}
|
||||
|
||||
h.cacheMu.Lock()
|
||||
h.cache[cacheKey] = proxyCacheEntry{
|
||||
agentURL: agentURL,
|
||||
expiresAt: time.Now().Add(proxyCacheTTL),
|
||||
}
|
||||
h.cacheMu.Unlock()
|
||||
|
||||
return agentURL, nil
|
||||
}
|
||||
|
||||
// evictProxyCache removes the cached entry for a (sandbox, team) pair.
|
||||
// Called on 502 so a stopped/moved sandbox is re-resolved on the next request.
|
||||
func (h *SandboxProxyWrapper) evictProxyCache(sandboxID, teamID pgtype.UUID) {
|
||||
h.cacheMu.Lock()
|
||||
delete(h.cache, makeProxyCacheKey(sandboxID, teamID))
|
||||
h.cacheMu.Unlock()
|
||||
}
|
||||
|
||||
func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
host := r.Host
|
||||
// Strip port from Host header (e.g. "49999-cl-abcd1234.localhost:8000" → "49999-cl-abcd1234.localhost")
|
||||
if colonIdx := strings.LastIndex(host, ":"); colonIdx != -1 {
|
||||
host = host[:colonIdx]
|
||||
}
|
||||
|
||||
matches := sandboxHostPattern.FindStringSubmatch(host)
|
||||
if matches == nil {
|
||||
h.inner.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
port := matches[1]
|
||||
sandboxIDStr := matches[2]
|
||||
|
||||
// Validate port.
|
||||
portNum, err := strconv.Atoi(port)
|
||||
if err != nil || portNum < 1 || portNum > 65535 {
|
||||
http.Error(w, "invalid port", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Authenticate: require API key or JWT, extract team ID.
|
||||
teamID, err := h.authenticateRequest(r)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid sandbox ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
agentURL, err := h.proxyTarget(r.Context(), sandboxID, teamID)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, errProxySandboxNotFound):
|
||||
http.Error(w, err.Error(), http.StatusNotFound)
|
||||
case errors.As(err, new(errProxySandboxNotRunning)):
|
||||
http.Error(w, err.Error(), http.StatusConflict)
|
||||
default:
|
||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
proxy := &httputil.ReverseProxy{
|
||||
Transport: h.transport,
|
||||
Director: func(req *http.Request) {
|
||||
req.URL.Scheme = agentURL.Scheme
|
||||
req.URL.Host = agentURL.Host
|
||||
req.URL.Path = path.Join("/proxy", sandboxIDStr, port, path.Clean("/"+req.URL.Path))
|
||||
req.Host = agentURL.Host
|
||||
},
|
||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
slog.Debug("sandbox proxy error",
|
||||
"sandbox_id", sandboxIDStr,
|
||||
"port", port,
|
||||
"error", err,
|
||||
)
|
||||
h.evictProxyCache(sandboxID, teamID)
|
||||
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
|
||||
},
|
||||
}
|
||||
proxy.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// authenticateRequest validates the request's API key and returns the team ID.
|
||||
// Only API key authentication is supported for sandbox proxy requests (not JWT).
|
||||
func (h *SandboxProxyWrapper) authenticateRequest(r *http.Request) (pgtype.UUID, error) {
|
||||
key := r.Header.Get("X-API-Key")
|
||||
if key == "" {
|
||||
return pgtype.UUID{}, fmt.Errorf("X-API-Key header required")
|
||||
}
|
||||
|
||||
hash := auth.HashAPIKey(key)
|
||||
row, err := h.db.GetAPIKeyByHash(r.Context(), hash)
|
||||
if err != nil {
|
||||
return pgtype.UUID{}, fmt.Errorf("invalid API key")
|
||||
}
|
||||
return row.TeamID, nil
|
||||
}
|
||||
@ -6,17 +6,20 @@ import (
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/audit"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
)
|
||||
|
||||
type apiKeyHandler struct {
|
||||
svc *service.APIKeyService
|
||||
svc *service.APIKeyService
|
||||
audit *audit.AuditLogger
|
||||
}
|
||||
|
||||
func newAPIKeyHandler(svc *service.APIKeyService) *apiKeyHandler {
|
||||
return &apiKeyHandler{svc: svc}
|
||||
func newAPIKeyHandler(svc *service.APIKeyService, al *audit.AuditLogger) *apiKeyHandler {
|
||||
return &apiKeyHandler{svc: svc, audit: al}
|
||||
}
|
||||
|
||||
type createAPIKeyRequest struct {
|
||||
@ -37,11 +40,11 @@ type apiKeyResponse struct {
|
||||
|
||||
func apiKeyToResponse(k db.TeamApiKey) apiKeyResponse {
|
||||
resp := apiKeyResponse{
|
||||
ID: k.ID,
|
||||
TeamID: k.TeamID,
|
||||
ID: id.FormatAPIKeyID(k.ID),
|
||||
TeamID: id.FormatTeamID(k.TeamID),
|
||||
Name: k.Name,
|
||||
KeyPrefix: k.KeyPrefix,
|
||||
CreatedBy: k.CreatedBy,
|
||||
CreatedBy: id.FormatUserID(k.CreatedBy),
|
||||
}
|
||||
if k.CreatedAt.Valid {
|
||||
resp.CreatedAt = k.CreatedAt.Time.Format(time.RFC3339)
|
||||
@ -55,11 +58,11 @@ func apiKeyToResponse(k db.TeamApiKey) apiKeyResponse {
|
||||
|
||||
func apiKeyWithCreatorToResponse(k db.ListAPIKeysByTeamWithCreatorRow) apiKeyResponse {
|
||||
resp := apiKeyResponse{
|
||||
ID: k.ID,
|
||||
TeamID: k.TeamID,
|
||||
ID: id.FormatAPIKeyID(k.ID),
|
||||
TeamID: id.FormatTeamID(k.TeamID),
|
||||
Name: k.Name,
|
||||
KeyPrefix: k.KeyPrefix,
|
||||
CreatedBy: k.CreatedBy,
|
||||
CreatedBy: id.FormatUserID(k.CreatedBy),
|
||||
CreatorEmail: k.CreatorEmail,
|
||||
}
|
||||
if k.CreatedAt.Valid {
|
||||
@ -91,6 +94,7 @@ func (h *apiKeyHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
resp := apiKeyToResponse(result.Row)
|
||||
resp.Key = &result.Plaintext
|
||||
|
||||
h.audit.LogAPIKeyCreate(r.Context(), ac, result.Row.ID, result.Row.Name)
|
||||
writeJSON(w, http.StatusCreated, resp)
|
||||
}
|
||||
|
||||
@ -115,12 +119,19 @@ func (h *apiKeyHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
// Delete handles DELETE /v1/api-keys/{id}.
|
||||
func (h *apiKeyHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
keyID := chi.URLParam(r, "id")
|
||||
keyIDStr := chi.URLParam(r, "id")
|
||||
|
||||
keyID, err := id.ParseAPIKeyID(keyIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid API key ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.Delete(r.Context(), keyID, ac.TeamID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete API key")
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogAPIKeyRevoke(r.Context(), ac, keyID)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
148
internal/api/handlers_audit.go
Normal file
148
internal/api/handlers_audit.go
Normal file
@ -0,0 +1,148 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
)
|
||||
|
||||
type auditHandler struct {
|
||||
svc *service.AuditService
|
||||
}
|
||||
|
||||
func newAuditHandler(svc *service.AuditService) *auditHandler {
|
||||
return &auditHandler{svc: svc}
|
||||
}
|
||||
|
||||
type auditLogResponse struct {
|
||||
ID string `json:"id"`
|
||||
ActorType string `json:"actor_type"`
|
||||
ActorID string `json:"actor_id,omitempty"`
|
||||
ActorName string `json:"actor_name,omitempty"`
|
||||
ResourceType string `json:"resource_type"`
|
||||
ResourceID string `json:"resource_id,omitempty"`
|
||||
Action string `json:"action"`
|
||||
Scope string `json:"scope"`
|
||||
Status string `json:"status"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// List handles GET /v1/audit-logs.
|
||||
// Query params:
|
||||
// - before: RFC3339 timestamp cursor (exclusive); omit to start from latest
|
||||
// - limit: page size, default 50, max 200
|
||||
// - resource_type: filter by resource type (sandbox, snapshot, team, api_key, member, host)
|
||||
// - action: filter by action verb
|
||||
//
|
||||
// Members see only team-scoped events; admins/owners see all.
|
||||
func (h *auditHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
// Parse ?before cursor.
|
||||
var before time.Time
|
||||
if s := r.URL.Query().Get("before"); s != "" {
|
||||
var err error
|
||||
before, err = time.Parse(time.RFC3339, s)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "before must be an RFC3339 timestamp")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Parse ?limit.
|
||||
limit := 50
|
||||
if s := r.URL.Query().Get("limit"); s != "" {
|
||||
n, err := strconv.Atoi(s)
|
||||
if err != nil || n < 1 {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "limit must be a positive integer")
|
||||
return
|
||||
}
|
||||
limit = n
|
||||
}
|
||||
|
||||
// Parse ?before_id cursor (UUID).
|
||||
var beforeID pgtype.UUID
|
||||
if s := r.URL.Query().Get("before_id"); s != "" {
|
||||
parsed, err := id.ParseAuditLogID(s)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "before_id must be a valid audit log ID")
|
||||
return
|
||||
}
|
||||
beforeID = parsed
|
||||
}
|
||||
|
||||
entries, err := h.svc.List(r.Context(), service.AuditListParams{
|
||||
TeamID: ac.TeamID,
|
||||
AdminScoped: ac.Role == "owner" || ac.Role == "admin",
|
||||
ResourceTypes: parseMultiParam(r.URL.Query()["resource_type"]),
|
||||
Actions: parseMultiParam(r.URL.Query()["action"]),
|
||||
Before: before,
|
||||
BeforeID: beforeID,
|
||||
Limit: limit,
|
||||
})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list audit logs")
|
||||
return
|
||||
}
|
||||
|
||||
items := make([]auditLogResponse, len(entries))
|
||||
for i, e := range entries {
|
||||
items[i] = auditLogResponse{
|
||||
ID: e.ID,
|
||||
ActorType: e.ActorType,
|
||||
ActorID: e.ActorID,
|
||||
ActorName: e.ActorName,
|
||||
ResourceType: e.ResourceType,
|
||||
ResourceID: e.ResourceID,
|
||||
Action: e.Action,
|
||||
Scope: e.Scope,
|
||||
Status: e.Status,
|
||||
Metadata: e.Metadata,
|
||||
CreatedAt: e.CreatedAt.UTC().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
resp := map[string]any{"items": items}
|
||||
if len(items) > 0 {
|
||||
last := entries[len(entries)-1]
|
||||
resp["next_before"] = last.CreatedAt.UTC().Format(time.RFC3339)
|
||||
resp["next_before_id"] = last.ID
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// parseMultiParam flattens repeated params and comma-separated values into a
|
||||
// single deduplicated slice. Empty strings are dropped. Returns nil (no filter)
|
||||
// when no values are present.
|
||||
//
|
||||
// Both ?resource_type=sandbox&resource_type=snapshot
|
||||
// and ?resource_type=sandbox,snapshot are accepted.
|
||||
func parseMultiParam(values []string) []string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{})
|
||||
var out []string
|
||||
for _, v := range values {
|
||||
for _, part := range strings.Split(v, ",") {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[part]; !ok {
|
||||
seen[part] = struct{}{}
|
||||
out = append(out, part)
|
||||
}
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@ -1,7 +1,9 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
@ -15,6 +17,45 @@ import (
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
// loginTeam returns the team and role to stamp into a login JWT.
|
||||
// It prefers the user's default team; if none is flagged as default it falls
|
||||
// back to the earliest-joined team. Returns pgx.ErrNoRows when the user has
|
||||
// no team memberships at all.
|
||||
func loginTeam(ctx context.Context, q *db.Queries, userID pgtype.UUID) (db.Team, string, error) {
|
||||
team, err := q.GetDefaultTeamForUser(ctx, userID)
|
||||
if err == nil {
|
||||
membership, err := q.GetTeamMembership(ctx, db.GetTeamMembershipParams{UserID: userID, TeamID: team.ID})
|
||||
if err != nil {
|
||||
return db.Team{}, "", err
|
||||
}
|
||||
return team, membership.Role, nil
|
||||
}
|
||||
if !errors.Is(err, pgx.ErrNoRows) {
|
||||
return db.Team{}, "", err
|
||||
}
|
||||
// No default set — fall back to earliest-joined team.
|
||||
rows, err := q.GetTeamsForUser(ctx, userID)
|
||||
if err != nil {
|
||||
return db.Team{}, "", err
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
return db.Team{}, "", pgx.ErrNoRows
|
||||
}
|
||||
first := rows[0]
|
||||
return db.Team{
|
||||
ID: first.ID,
|
||||
Name: first.Name,
|
||||
Slug: first.Slug,
|
||||
IsByoc: first.IsByoc,
|
||||
CreatedAt: first.CreatedAt,
|
||||
DeletedAt: first.DeletedAt,
|
||||
}, first.Role, nil
|
||||
}
|
||||
|
||||
type switchTeamRequest struct {
|
||||
TeamID string `json:"team_id"`
|
||||
}
|
||||
|
||||
type authHandler struct {
|
||||
db *db.Queries
|
||||
pool *pgxpool.Pool
|
||||
@ -28,6 +69,7 @@ func newAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte) *authH
|
||||
type signupRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type loginRequest struct {
|
||||
@ -40,6 +82,7 @@ type authResponse struct {
|
||||
UserID string `json:"user_id"`
|
||||
TeamID string `json:"team_id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Signup handles POST /v1/auth/signup.
|
||||
@ -51,6 +94,7 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
req.Email = strings.TrimSpace(strings.ToLower(req.Email))
|
||||
req.Name = strings.TrimSpace(req.Name)
|
||||
if !strings.Contains(req.Email, "@") || len(req.Email) < 3 {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid email address")
|
||||
return
|
||||
@ -59,6 +103,10 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "password must be at least 8 characters")
|
||||
return
|
||||
}
|
||||
if req.Name == "" || len(req.Name) > 100 {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "name must be between 1 and 100 characters")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
@ -83,6 +131,7 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
ID: userID,
|
||||
Email: req.Email,
|
||||
PasswordHash: pgtype.Text{String: passwordHash, Valid: true},
|
||||
Name: req.Name,
|
||||
})
|
||||
if err != nil {
|
||||
var pgErr *pgconn.PgError
|
||||
@ -98,7 +147,8 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
teamID := id.NewTeamID()
|
||||
if _, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
|
||||
ID: teamID,
|
||||
Name: req.Email + "'s Team",
|
||||
Name: req.Name + "'s Team",
|
||||
Slug: id.NewTeamSlug(),
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to create team")
|
||||
return
|
||||
@ -119,7 +169,7 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, req.Email)
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, req.Email, req.Name, "owner", false)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
return
|
||||
@ -127,9 +177,10 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
writeJSON(w, http.StatusCreated, authResponse{
|
||||
Token: token,
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
UserID: id.FormatUserID(userID),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Email: req.Email,
|
||||
Name: req.Name,
|
||||
})
|
||||
}
|
||||
|
||||
@ -152,6 +203,7 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
user, err := h.db.GetUserByEmail(ctx, req.Email)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Warn("login failed: unknown email", "email", req.Email, "ip", r.RemoteAddr)
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
|
||||
return
|
||||
}
|
||||
@ -160,21 +212,27 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if !user.PasswordHash.Valid {
|
||||
slog.Warn("login failed: no password set", "email", req.Email, "ip", r.RemoteAddr)
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
|
||||
return
|
||||
}
|
||||
if err := auth.CheckPassword(user.PasswordHash.String, req.Password); err != nil {
|
||||
slog.Warn("login failed: wrong password", "email", req.Email, "ip", r.RemoteAddr)
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
|
||||
return
|
||||
}
|
||||
|
||||
team, err := h.db.GetDefaultTeamForUser(ctx, user.ID)
|
||||
team, role, err := loginTeam(ctx, h.db, user.ID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
writeError(w, http.StatusForbidden, "no_team", "user is not a member of any team")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up team")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email)
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
return
|
||||
@ -182,8 +240,85 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
writeJSON(w, http.StatusOK, authResponse{
|
||||
Token: token,
|
||||
UserID: user.ID,
|
||||
TeamID: team.ID,
|
||||
UserID: id.FormatUserID(user.ID),
|
||||
TeamID: id.FormatTeamID(team.ID),
|
||||
Email: user.Email,
|
||||
Name: user.Name,
|
||||
})
|
||||
}
|
||||
|
||||
// SwitchTeam handles POST /v1/auth/switch-team.
|
||||
// Verifies from DB that the user is a member of the target team, then re-issues
|
||||
// a JWT scoped to that team. The JWT's team_id is used as a pre-filter on all
|
||||
// subsequent team-scoped requests; DB is the source of truth for actual permissions.
|
||||
func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
var req switchTeamRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
if req.TeamID == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "team_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
teamID, err := id.ParseTeamID(req.TeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team_id")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
// Verify team exists and is not deleted.
|
||||
team, err := h.db.GetTeam(ctx, teamID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
writeError(w, http.StatusNotFound, "not_found", "team not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up team")
|
||||
return
|
||||
}
|
||||
if team.DeletedAt.Valid {
|
||||
writeError(w, http.StatusNotFound, "not_found", "team not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify membership from DB — JWT role is not trusted here.
|
||||
membership, err := h.db.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: ac.UserID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "not a member of this team")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up membership")
|
||||
return
|
||||
}
|
||||
|
||||
// Fetch current name from DB — JWT name is not trusted here (may be stale or empty for old tokens).
|
||||
user, err := h.db.GetUserByID(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up user")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, ac.UserID, teamID, ac.Email, user.Name, membership.Role, user.IsAdmin)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, authResponse{
|
||||
Token: token,
|
||||
UserID: id.FormatUserID(ac.UserID),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Email: ac.Email,
|
||||
Name: user.Name,
|
||||
})
|
||||
}
|
||||
|
||||
276
internal/api/handlers_builds.go
Normal file
276
internal/api/handlers_builds.go
Normal file
@ -0,0 +1,276 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/layout"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/validate"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
type buildHandler struct {
|
||||
svc *service.BuildService
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
}
|
||||
|
||||
func newBuildHandler(svc *service.BuildService, db *db.Queries, pool *lifecycle.HostClientPool) *buildHandler {
|
||||
return &buildHandler{svc: svc, db: db, pool: pool}
|
||||
}
|
||||
|
||||
type createBuildRequest struct {
|
||||
Name string `json:"name"`
|
||||
BaseTemplate string `json:"base_template"`
|
||||
Recipe []string `json:"recipe"`
|
||||
Healthcheck string `json:"healthcheck"`
|
||||
VCPUs int32 `json:"vcpus"`
|
||||
MemoryMB int32 `json:"memory_mb"`
|
||||
SkipPrePost bool `json:"skip_pre_post"`
|
||||
}
|
||||
|
||||
type buildResponse struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
BaseTemplate string `json:"base_template"`
|
||||
Recipe json.RawMessage `json:"recipe"`
|
||||
Healthcheck *string `json:"healthcheck,omitempty"`
|
||||
VCPUs int32 `json:"vcpus"`
|
||||
MemoryMB int32 `json:"memory_mb"`
|
||||
Status string `json:"status"`
|
||||
CurrentStep int32 `json:"current_step"`
|
||||
TotalSteps int32 `json:"total_steps"`
|
||||
Logs json.RawMessage `json:"logs"`
|
||||
Error *string `json:"error,omitempty"`
|
||||
SandboxID *string `json:"sandbox_id,omitempty"`
|
||||
HostID *string `json:"host_id,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
StartedAt *string `json:"started_at,omitempty"`
|
||||
CompletedAt *string `json:"completed_at,omitempty"`
|
||||
}
|
||||
|
||||
func buildToResponse(b db.TemplateBuild) buildResponse {
|
||||
resp := buildResponse{
|
||||
ID: id.FormatBuildID(b.ID),
|
||||
Name: b.Name,
|
||||
BaseTemplate: b.BaseTemplate,
|
||||
Recipe: b.Recipe,
|
||||
VCPUs: b.Vcpus,
|
||||
MemoryMB: b.MemoryMb,
|
||||
Status: b.Status,
|
||||
CurrentStep: b.CurrentStep,
|
||||
TotalSteps: b.TotalSteps,
|
||||
Logs: b.Logs,
|
||||
}
|
||||
if b.Healthcheck != "" {
|
||||
resp.Healthcheck = &b.Healthcheck
|
||||
}
|
||||
if b.Error != "" {
|
||||
resp.Error = &b.Error
|
||||
}
|
||||
if b.SandboxID.Valid {
|
||||
s := id.FormatSandboxID(b.SandboxID)
|
||||
resp.SandboxID = &s
|
||||
}
|
||||
if b.HostID.Valid {
|
||||
s := id.FormatHostID(b.HostID)
|
||||
resp.HostID = &s
|
||||
}
|
||||
if b.CreatedAt.Valid {
|
||||
resp.CreatedAt = b.CreatedAt.Time.Format(time.RFC3339)
|
||||
}
|
||||
if b.StartedAt.Valid {
|
||||
s := b.StartedAt.Time.Format(time.RFC3339)
|
||||
resp.StartedAt = &s
|
||||
}
|
||||
if b.CompletedAt.Valid {
|
||||
s := b.CompletedAt.Time.Format(time.RFC3339)
|
||||
resp.CompletedAt = &s
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// Create handles POST /v1/admin/builds.
|
||||
func (h *buildHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
var req createBuildRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "name is required")
|
||||
return
|
||||
}
|
||||
if err := validate.SafeName(req.Name); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid template name: %s", err))
|
||||
return
|
||||
}
|
||||
if len(req.Recipe) == 0 {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "recipe must contain at least one command")
|
||||
return
|
||||
}
|
||||
|
||||
build, err := h.svc.Create(r.Context(), service.BuildCreateParams{
|
||||
Name: req.Name,
|
||||
BaseTemplate: req.BaseTemplate,
|
||||
Recipe: req.Recipe,
|
||||
Healthcheck: req.Healthcheck,
|
||||
VCPUs: req.VCPUs,
|
||||
MemoryMB: req.MemoryMB,
|
||||
SkipPrePost: req.SkipPrePost,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("failed to create build", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "build_error", "failed to create build")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, buildToResponse(build))
|
||||
}
|
||||
|
||||
// List handles GET /v1/admin/builds.
|
||||
func (h *buildHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
builds, err := h.svc.List(r.Context())
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list builds")
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]buildResponse, len(builds))
|
||||
for i, b := range builds {
|
||||
resp[i] = buildToResponse(b)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Get handles GET /v1/admin/builds/{id}.
|
||||
func (h *buildHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
buildIDStr := chi.URLParam(r, "id")
|
||||
|
||||
buildID, err := id.ParseBuildID(buildIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid build ID")
|
||||
return
|
||||
}
|
||||
|
||||
build, err := h.svc.Get(r.Context(), buildID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "build not found")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, buildToResponse(build))
|
||||
}
|
||||
|
||||
// ListTemplates handles GET /v1/admin/templates — returns all templates across all teams.
|
||||
func (h *buildHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
|
||||
templates, err := h.db.ListTemplates(r.Context())
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list templates")
|
||||
return
|
||||
}
|
||||
|
||||
type templateResponse struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
VCPUs int32 `json:"vcpus"`
|
||||
MemoryMB int32 `json:"memory_mb"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
TeamID string `json:"team_id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
resp := make([]templateResponse, len(templates))
|
||||
for i, t := range templates {
|
||||
resp[i] = templateResponse{
|
||||
Name: t.Name,
|
||||
Type: t.Type,
|
||||
VCPUs: t.Vcpus,
|
||||
MemoryMB: t.MemoryMb,
|
||||
SizeBytes: t.SizeBytes,
|
||||
TeamID: id.FormatTeamID(t.TeamID),
|
||||
}
|
||||
if t.CreatedAt.Valid {
|
||||
resp[i].CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// DeleteTemplate handles DELETE /v1/admin/templates/{name}.
|
||||
func (h *buildHandler) DeleteTemplate(w http.ResponseWriter, r *http.Request) {
|
||||
name := chi.URLParam(r, "name")
|
||||
if err := validate.SafeName(name); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid template name: %s", err))
|
||||
return
|
||||
}
|
||||
ctx := r.Context()
|
||||
|
||||
tmpl, err := h.db.GetPlatformTemplateByName(ctx, name)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "template not found")
|
||||
return
|
||||
}
|
||||
if layout.IsMinimal(tmpl.TeamID, tmpl.ID) {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted")
|
||||
return
|
||||
}
|
||||
|
||||
// Broadcast delete to all online hosts.
|
||||
hosts, _ := h.db.ListActiveHosts(ctx)
|
||||
for _, host := range hosts {
|
||||
if host.Status != "online" {
|
||||
continue
|
||||
}
|
||||
agent, err := h.pool.GetForHost(host)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
|
||||
TeamId: formatUUIDForRPC(tmpl.TeamID),
|
||||
TemplateId: formatUUIDForRPC(tmpl.ID),
|
||||
})); err != nil {
|
||||
if connect.CodeOf(err) != connect.CodeNotFound {
|
||||
slog.Warn("admin: failed to delete template on host", "host_id", id.FormatHostID(host.ID), "name", name, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.db.DeleteTemplate(ctx, tmpl.ID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete template record")
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// Cancel handles POST /v1/admin/builds/{id}/cancel.
|
||||
func (h *buildHandler) Cancel(w http.ResponseWriter, r *http.Request) {
|
||||
buildIDStr := chi.URLParam(r, "id")
|
||||
|
||||
buildID, err := id.ParseBuildID(buildIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid build ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.Cancel(r.Context(), buildID); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
242
internal/api/handlers_channels.go
Normal file
242
internal/api/handlers_channels.go
Normal file
@ -0,0 +1,242 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/audit"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/channels"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
type channelHandler struct {
|
||||
svc *channels.Service
|
||||
audit *audit.AuditLogger
|
||||
}
|
||||
|
||||
func newChannelHandler(svc *channels.Service, al *audit.AuditLogger) *channelHandler {
|
||||
return &channelHandler{svc: svc, audit: al}
|
||||
}
|
||||
|
||||
type createChannelRequest struct {
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Config map[string]string `json:"config"`
|
||||
Events []string `json:"events"`
|
||||
}
|
||||
|
||||
type updateChannelRequest struct {
|
||||
Name string `json:"name"`
|
||||
Events []string `json:"events"`
|
||||
}
|
||||
|
||||
type rotateConfigRequest struct {
|
||||
Config map[string]string `json:"config"`
|
||||
}
|
||||
|
||||
type testChannelRequest struct {
|
||||
Provider string `json:"provider"`
|
||||
Config map[string]string `json:"config"`
|
||||
}
|
||||
|
||||
type channelResponse struct {
|
||||
ID string `json:"id"`
|
||||
TeamID string `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Events []string `json:"events"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
Secret *string `json:"secret,omitempty"`
|
||||
}
|
||||
|
||||
func channelToResponse(ch db.Channel) channelResponse {
|
||||
resp := channelResponse{
|
||||
ID: id.FormatChannelID(ch.ID),
|
||||
TeamID: id.FormatTeamID(ch.TeamID),
|
||||
Name: ch.Name,
|
||||
Provider: ch.Provider,
|
||||
Events: ch.EventTypes,
|
||||
}
|
||||
if ch.CreatedAt.Valid {
|
||||
resp.CreatedAt = ch.CreatedAt.Time.Format(time.RFC3339)
|
||||
}
|
||||
if ch.UpdatedAt.Valid {
|
||||
resp.UpdatedAt = ch.UpdatedAt.Time.Format(time.RFC3339)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// Create handles POST /v1/channels.
|
||||
func (h *channelHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
var req createChannelRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.svc.Create(r.Context(), channels.CreateParams{
|
||||
TeamID: ac.TeamID,
|
||||
Name: req.Name,
|
||||
Provider: req.Provider,
|
||||
Config: req.Config,
|
||||
Events: req.Events,
|
||||
})
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogChannelCreate(r.Context(), ac, result.Channel.ID, result.Channel.Name, result.Channel.Provider)
|
||||
|
||||
resp := channelToResponse(result.Channel)
|
||||
if result.PlaintextSecret != "" {
|
||||
resp.Secret = &result.PlaintextSecret
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, resp)
|
||||
}
|
||||
|
||||
// List handles GET /v1/channels.
|
||||
func (h *channelHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
chs, err := h.svc.List(r.Context(), ac.TeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list channels")
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]channelResponse, len(chs))
|
||||
for i, ch := range chs {
|
||||
resp[i] = channelToResponse(ch)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Get handles GET /v1/channels/{id}.
|
||||
func (h *channelHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
channelIDStr := chi.URLParam(r, "id")
|
||||
|
||||
channelID, err := id.ParseChannelID(channelIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid channel ID")
|
||||
return
|
||||
}
|
||||
|
||||
ch, err := h.svc.Get(r.Context(), channelID, ac.TeamID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
writeError(w, http.StatusNotFound, "not_found", "channel not found")
|
||||
} else {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get channel")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, channelToResponse(ch))
|
||||
}
|
||||
|
||||
// Update handles PATCH /v1/channels/{id}.
|
||||
func (h *channelHandler) Update(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
channelIDStr := chi.URLParam(r, "id")
|
||||
|
||||
channelID, err := id.ParseChannelID(channelIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid channel ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req updateChannelRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
ch, err := h.svc.Update(r.Context(), channelID, ac.TeamID, req.Name, req.Events)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogChannelUpdate(r.Context(), ac, channelID)
|
||||
writeJSON(w, http.StatusOK, channelToResponse(ch))
|
||||
}
|
||||
|
||||
// Test handles POST /v1/channels/test.
|
||||
func (h *channelHandler) Test(w http.ResponseWriter, r *http.Request) {
|
||||
var req testChannelRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.Test(r.Context(), req.Provider, req.Config); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
|
||||
}
|
||||
|
||||
// RotateConfig handles PUT /v1/channels/{id}/config.
|
||||
func (h *channelHandler) RotateConfig(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
channelIDStr := chi.URLParam(r, "id")
|
||||
|
||||
channelID, err := id.ParseChannelID(channelIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid channel ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req rotateConfigRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
ch, err := h.svc.RotateConfig(r.Context(), channelID, ac.TeamID, req.Config)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogChannelRotateConfig(r.Context(), ac, channelID)
|
||||
writeJSON(w, http.StatusOK, channelToResponse(ch))
|
||||
}
|
||||
|
||||
// Delete handles DELETE /v1/channels/{id}.
|
||||
func (h *channelHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
channelIDStr := chi.URLParam(r, "id")
|
||||
|
||||
channelID, err := id.ParseChannelID(channelIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid channel ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.Delete(r.Context(), channelID, ac.TeamID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete channel")
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogChannelDelete(r.Context(), ac, channelID)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
@ -14,17 +14,18 @@ import (
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
type execHandler struct {
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
}
|
||||
|
||||
func newExecHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *execHandler {
|
||||
return &execHandler{db: db, agent: agent}
|
||||
func newExecHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execHandler {
|
||||
return &execHandler{db: db, pool: pool}
|
||||
}
|
||||
|
||||
type execRequest struct {
|
||||
@ -46,10 +47,16 @@ type execResponse struct {
|
||||
|
||||
// Exec handles POST /v1/sandboxes/{id}/exec.
|
||||
func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
@ -73,8 +80,14 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
start := time.Now()
|
||||
|
||||
resp, err := h.agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
|
||||
SandboxId: sandboxID,
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Cmd: req.Cmd,
|
||||
Args: req.Args,
|
||||
TimeoutSec: req.TimeoutSec,
|
||||
@ -95,7 +108,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last_active_at", "id", sandboxID, "error", err)
|
||||
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
|
||||
}
|
||||
|
||||
// Use base64 encoding if output contains non-UTF-8 bytes.
|
||||
@ -106,7 +119,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||
if !utf8.Valid(stdout) || !utf8.Valid(stderr) {
|
||||
encoding = "base64"
|
||||
writeJSON(w, http.StatusOK, execResponse{
|
||||
SandboxID: sandboxID,
|
||||
SandboxID: sandboxIDStr,
|
||||
Cmd: req.Cmd,
|
||||
Stdout: base64.StdEncoding.EncodeToString(stdout),
|
||||
Stderr: base64.StdEncoding.EncodeToString(stderr),
|
||||
@ -118,7 +131,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, execResponse{
|
||||
SandboxID: sandboxID,
|
||||
SandboxID: sandboxIDStr,
|
||||
Cmd: req.Cmd,
|
||||
Stdout: string(stdout),
|
||||
Stderr: string(stderr),
|
||||
|
||||
@ -14,17 +14,18 @@ import (
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
type execStreamHandler struct {
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
}
|
||||
|
||||
func newExecStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *execStreamHandler {
|
||||
return &execStreamHandler{db: db, agent: agent}
|
||||
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execStreamHandler {
|
||||
return &execStreamHandler{db: db, pool: pool}
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
@ -48,10 +49,16 @@ type wsOutMsg struct {
|
||||
|
||||
// ExecStream handles WS /v1/sandboxes/{id}/exec/stream.
|
||||
func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
@ -80,12 +87,18 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
sendWSError(conn, "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
// Open streaming exec to host agent.
|
||||
streamCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
stream, err := h.agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{
|
||||
SandboxId: sandboxID,
|
||||
stream, err := agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Cmd: startMsg.Cmd,
|
||||
Args: startMsg.Args,
|
||||
}))
|
||||
@ -151,7 +164,7 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxID, "error", err)
|
||||
slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -11,17 +11,18 @@ import (
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
type filesHandler struct {
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
}
|
||||
|
||||
func newFilesHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *filesHandler {
|
||||
return &filesHandler{db: db, agent: agent}
|
||||
func newFilesHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesHandler {
|
||||
return &filesHandler{db: db, pool: pool}
|
||||
}
|
||||
|
||||
// Upload handles POST /v1/sandboxes/{id}/files/write.
|
||||
@ -29,10 +30,16 @@ func newFilesHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceCl
|
||||
// - "path" text field: absolute destination path inside the sandbox
|
||||
// - "file" file field: binary content to write
|
||||
func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
@ -75,8 +82,14 @@ func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := h.agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{
|
||||
SandboxId: sandboxID,
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Path: filePath,
|
||||
Content: content,
|
||||
})); err != nil {
|
||||
@ -95,10 +108,16 @@ type readFileRequest struct {
|
||||
// Download handles POST /v1/sandboxes/{id}/files/read.
|
||||
// Accepts JSON body with path, returns raw file content with Content-Disposition.
|
||||
func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
@ -120,8 +139,14 @@ func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.agent.ReadFile(ctx, connect.NewRequest(&pb.ReadFileRequest{
|
||||
SandboxId: sandboxID,
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := agent.ReadFile(ctx, connect.NewRequest(&pb.ReadFileRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Path: req.Path,
|
||||
}))
|
||||
if err != nil {
|
||||
|
||||
@ -12,27 +12,34 @@ import (
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
type filesStreamHandler struct {
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
}
|
||||
|
||||
func newFilesStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *filesStreamHandler {
|
||||
return &filesStreamHandler{db: db, agent: agent}
|
||||
func newFilesStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesStreamHandler {
|
||||
return &filesStreamHandler{db: db, pool: pool}
|
||||
}
|
||||
|
||||
// StreamUpload handles POST /v1/sandboxes/{id}/files/stream/write.
|
||||
// Expects multipart/form-data with "path" text field and "file" file field.
|
||||
// Streams file content directly from the request body to the host agent without buffering.
|
||||
func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
@ -88,14 +95,20 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
defer filePart.Close()
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
// Open client-streaming RPC to host agent.
|
||||
stream := h.agent.WriteFileStream(ctx)
|
||||
stream := agent.WriteFileStream(ctx)
|
||||
|
||||
// Send metadata first.
|
||||
if err := stream.Send(&pb.WriteFileStreamRequest{
|
||||
Content: &pb.WriteFileStreamRequest_Meta{
|
||||
Meta: &pb.WriteFileStreamMeta{
|
||||
SandboxId: sandboxID,
|
||||
SandboxId: sandboxIDStr,
|
||||
Path: filePath,
|
||||
},
|
||||
},
|
||||
@ -140,10 +153,16 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
|
||||
// StreamDownload handles POST /v1/sandboxes/{id}/files/stream/read.
|
||||
// Accepts JSON body with path, streams file content back without buffering.
|
||||
func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
@ -164,9 +183,15 @@ func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Reque
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
// Open server-streaming RPC to host agent.
|
||||
stream, err := h.agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{
|
||||
SandboxId: sandboxID,
|
||||
stream, err := agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Path: req.Path,
|
||||
}))
|
||||
if err != nil {
|
||||
|
||||
@ -1,23 +1,30 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/audit"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
)
|
||||
|
||||
type hostHandler struct {
|
||||
svc *service.HostService
|
||||
queries *db.Queries
|
||||
audit *audit.AuditLogger
|
||||
}
|
||||
|
||||
func newHostHandler(svc *service.HostService, queries *db.Queries) *hostHandler {
|
||||
return &hostHandler{svc: svc, queries: queries}
|
||||
func newHostHandler(svc *service.HostService, queries *db.Queries, al *audit.AuditLogger) *hostHandler {
|
||||
return &hostHandler{svc: svc, queries: queries, audit: al}
|
||||
}
|
||||
|
||||
// Request/response types.
|
||||
@ -34,6 +41,24 @@ type createHostResponse struct {
|
||||
RegistrationToken string `json:"registration_token"`
|
||||
}
|
||||
|
||||
type refreshTokenRequest struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
type refreshTokenResponse struct {
|
||||
Host hostResponse `json:"host"`
|
||||
Token string `json:"token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
CertPEM string `json:"cert_pem,omitempty"`
|
||||
KeyPEM string `json:"key_pem,omitempty"`
|
||||
CACertPEM string `json:"ca_cert_pem,omitempty"`
|
||||
}
|
||||
|
||||
type deletePreviewResponse struct {
|
||||
Host hostResponse `json:"host"`
|
||||
SandboxIDs []string `json:"sandbox_ids"`
|
||||
}
|
||||
|
||||
type registerHostRequest struct {
|
||||
Token string `json:"token"`
|
||||
Arch string `json:"arch,omitempty"`
|
||||
@ -44,8 +69,12 @@ type registerHostRequest struct {
|
||||
}
|
||||
|
||||
type registerHostResponse struct {
|
||||
Host hostResponse `json:"host"`
|
||||
Token string `json:"token"`
|
||||
Host hostResponse `json:"host"`
|
||||
Token string `json:"token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
CertPEM string `json:"cert_pem,omitempty"`
|
||||
KeyPEM string `json:"key_pem,omitempty"`
|
||||
CACertPEM string `json:"ca_cert_pem,omitempty"`
|
||||
}
|
||||
|
||||
type addTagRequest struct {
|
||||
@ -56,6 +85,7 @@ type hostResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
TeamID *string `json:"team_id,omitempty"`
|
||||
TeamName *string `json:"team_name,omitempty"`
|
||||
Provider *string `json:"provider,omitempty"`
|
||||
AvailabilityZone *string `json:"availability_zone,omitempty"`
|
||||
Arch *string `json:"arch,omitempty"`
|
||||
@ -72,34 +102,35 @@ type hostResponse struct {
|
||||
|
||||
func hostToResponse(h db.Host) hostResponse {
|
||||
resp := hostResponse{
|
||||
ID: h.ID,
|
||||
ID: id.FormatHostID(h.ID),
|
||||
Type: h.Type,
|
||||
Status: h.Status,
|
||||
CreatedBy: h.CreatedBy,
|
||||
CreatedBy: id.FormatUserID(h.CreatedBy),
|
||||
}
|
||||
if h.TeamID.Valid {
|
||||
resp.TeamID = &h.TeamID.String
|
||||
s := id.FormatTeamID(h.TeamID)
|
||||
resp.TeamID = &s
|
||||
}
|
||||
if h.Provider.Valid {
|
||||
resp.Provider = &h.Provider.String
|
||||
if h.Provider != "" {
|
||||
resp.Provider = &h.Provider
|
||||
}
|
||||
if h.AvailabilityZone.Valid {
|
||||
resp.AvailabilityZone = &h.AvailabilityZone.String
|
||||
if h.AvailabilityZone != "" {
|
||||
resp.AvailabilityZone = &h.AvailabilityZone
|
||||
}
|
||||
if h.Arch.Valid {
|
||||
resp.Arch = &h.Arch.String
|
||||
if h.Arch != "" {
|
||||
resp.Arch = &h.Arch
|
||||
}
|
||||
if h.CpuCores.Valid {
|
||||
resp.CPUCores = &h.CpuCores.Int32
|
||||
if h.CpuCores != 0 {
|
||||
resp.CPUCores = &h.CpuCores
|
||||
}
|
||||
if h.MemoryMb.Valid {
|
||||
resp.MemoryMB = &h.MemoryMb.Int32
|
||||
if h.MemoryMb != 0 {
|
||||
resp.MemoryMB = &h.MemoryMb
|
||||
}
|
||||
if h.DiskGb.Valid {
|
||||
resp.DiskGB = &h.DiskGb.Int32
|
||||
if h.DiskGb != 0 {
|
||||
resp.DiskGB = &h.DiskGb
|
||||
}
|
||||
if h.Address.Valid {
|
||||
resp.Address = &h.Address.String
|
||||
if h.Address != "" {
|
||||
resp.Address = &h.Address
|
||||
}
|
||||
if h.LastHeartbeatAt.Valid {
|
||||
s := h.LastHeartbeatAt.Time.Format(time.RFC3339)
|
||||
@ -112,7 +143,7 @@ func hostToResponse(h db.Host) hostResponse {
|
||||
}
|
||||
|
||||
// isAdmin fetches the user record and returns whether they are an admin.
|
||||
func (h *hostHandler) isAdmin(r *http.Request, userID string) bool {
|
||||
func (h *hostHandler) isAdmin(r *http.Request, userID pgtype.UUID) bool {
|
||||
user, err := h.queries.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
return false
|
||||
@ -130,20 +161,32 @@ func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
result, err := h.svc.Create(r.Context(), service.HostCreateParams{
|
||||
Type: req.Type,
|
||||
TeamID: req.TeamID,
|
||||
Provider: req.Provider,
|
||||
AvailabilityZone: req.AvailabilityZone,
|
||||
RequestingUserID: ac.UserID,
|
||||
IsRequestorAdmin: h.isAdmin(r, ac.UserID),
|
||||
})
|
||||
// Parse optional team ID from request body.
|
||||
var params service.HostCreateParams
|
||||
params.Type = req.Type
|
||||
params.Provider = req.Provider
|
||||
params.AvailabilityZone = req.AvailabilityZone
|
||||
params.RequestingUserID = ac.UserID
|
||||
params.IsRequestorAdmin = h.isAdmin(r, ac.UserID)
|
||||
if req.TeamID != "" {
|
||||
teamID, err := id.ParseTeamID(req.TeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team_id")
|
||||
return
|
||||
}
|
||||
params.TeamID = teamID
|
||||
}
|
||||
|
||||
result, err := h.svc.Create(r.Context(), params)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
// Log audit for the owning team (BYOC hosts have a team; shared hosts use caller's team).
|
||||
h.audit.LogHostCreate(r.Context(), ac, result.Host.ID, result.Host.TeamID)
|
||||
|
||||
writeJSON(w, http.StatusCreated, createHostResponse{
|
||||
Host: hostToResponse(result.Host),
|
||||
RegistrationToken: result.RegistrationToken,
|
||||
@ -153,16 +196,50 @@ func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
// List handles GET /v1/hosts.
|
||||
func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
admin := h.isAdmin(r, ac.UserID)
|
||||
|
||||
hosts, err := h.svc.List(r.Context(), ac.TeamID, h.isAdmin(r, ac.UserID))
|
||||
hosts, err := h.svc.List(r.Context(), ac.TeamID, admin)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list hosts")
|
||||
return
|
||||
}
|
||||
|
||||
// Collect unique team IDs so we can fetch team names in one pass.
|
||||
var teamNames map[string]string
|
||||
if admin {
|
||||
seen := make(map[string]struct{})
|
||||
for _, host := range hosts {
|
||||
if host.TeamID.Valid {
|
||||
key := id.FormatTeamID(host.TeamID)
|
||||
seen[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(seen) > 0 {
|
||||
teamNames = make(map[string]string, len(seen))
|
||||
for _, host := range hosts {
|
||||
if !host.TeamID.Valid {
|
||||
continue
|
||||
}
|
||||
key := id.FormatTeamID(host.TeamID)
|
||||
if _, ok := teamNames[key]; ok {
|
||||
continue
|
||||
}
|
||||
if team, err := h.queries.GetTeam(r.Context(), host.TeamID); err == nil {
|
||||
teamNames[key] = team.Name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resp := make([]hostResponse, len(hosts))
|
||||
for i, host := range hosts {
|
||||
resp[i] = hostToResponse(host)
|
||||
if host.TeamID.Valid {
|
||||
key := id.FormatTeamID(host.TeamID)
|
||||
if name, ok := teamNames[key]; ok {
|
||||
resp[i].TeamName = &name
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
@ -170,9 +247,15 @@ func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Get handles GET /v1/hosts/{id}.
|
||||
func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
hostIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
hostID, err := id.ParseHostID(hostIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
|
||||
return
|
||||
}
|
||||
|
||||
host, err := h.svc.Get(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
@ -183,25 +266,86 @@ func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusOK, hostToResponse(host))
|
||||
}
|
||||
|
||||
// Delete handles DELETE /v1/hosts/{id}.
|
||||
func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
// DeletePreview handles GET /v1/hosts/{id}/delete-preview.
|
||||
// Returns what would be affected without making changes, for confirmation UI.
|
||||
func (h *hostHandler) DeletePreview(w http.ResponseWriter, r *http.Request) {
|
||||
hostIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
if err := h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID)); err != nil {
|
||||
hostID, err := id.ParseHostID(hostIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
|
||||
return
|
||||
}
|
||||
|
||||
preview, err := h.svc.DeletePreview(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
writeJSON(w, http.StatusOK, deletePreviewResponse{
|
||||
Host: hostToResponse(preview.Host),
|
||||
SandboxIDs: preview.SandboxIDs,
|
||||
})
|
||||
}
|
||||
|
||||
// Delete handles DELETE /v1/hosts/{id}.
|
||||
// Without ?force=true: returns 409 with affected sandbox IDs if any are active.
|
||||
// With ?force=true: gracefully stops all sandboxes then deletes the host.
|
||||
func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
hostIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
force := r.URL.Query().Get("force") == "true"
|
||||
|
||||
hostID, err := id.ParseHostID(hostIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Fetch host before deletion to capture team_id for audit.
|
||||
deletedHost, hostErr := h.queries.GetHost(r.Context(), hostID)
|
||||
if hostErr != nil {
|
||||
slog.Warn("audit: could not fetch host before delete", "host_id", hostIDStr, "error", hostErr)
|
||||
}
|
||||
|
||||
err = h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID), force)
|
||||
if err == nil {
|
||||
h.audit.LogHostDelete(r.Context(), ac, hostID, deletedHost.TeamID)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if it's a "has running sandboxes" error and return a structured 409.
|
||||
var hasSandboxes *service.HostHasSandboxesError
|
||||
if errors.As(err, &hasSandboxes) {
|
||||
writeJSON(w, http.StatusConflict, map[string]any{
|
||||
"error": map[string]any{
|
||||
"code": "has_active_sandboxes",
|
||||
"message": "host has active sandboxes; use ?force=true to destroy them and delete the host",
|
||||
"sandbox_ids": hasSandboxes.SandboxIDs,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
}
|
||||
|
||||
// RegenerateToken handles POST /v1/hosts/{id}/token.
|
||||
func (h *hostHandler) RegenerateToken(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
hostIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
hostID, err := id.ParseHostID(hostIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.svc.RegenerateToken(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID))
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
@ -247,36 +391,61 @@ func (h *hostHandler) Register(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, registerHostResponse{
|
||||
Host: hostToResponse(result.Host),
|
||||
Token: result.JWT,
|
||||
Host: hostToResponse(result.Host),
|
||||
Token: result.JWT,
|
||||
RefreshToken: result.RefreshToken,
|
||||
CertPEM: result.CertPEM,
|
||||
KeyPEM: result.KeyPEM,
|
||||
CACertPEM: result.CACertPEM,
|
||||
})
|
||||
}
|
||||
|
||||
// Heartbeat handles POST /v1/hosts/{id}/heartbeat (host-token-authenticated).
|
||||
func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
hostIDStr := chi.URLParam(r, "id")
|
||||
hc := auth.MustHostFromContext(r.Context())
|
||||
|
||||
hostID, err := id.ParseHostID(hostIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Prevent a host from heartbeating for a different host.
|
||||
if hostID != hc.HostID {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "host ID mismatch")
|
||||
return
|
||||
}
|
||||
|
||||
// Capture pre-heartbeat status to detect unreachable → online transition.
|
||||
prevHost, _ := h.queries.GetHost(r.Context(), hc.HostID)
|
||||
|
||||
if err := h.svc.Heartbeat(r.Context(), hc.HostID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to update heartbeat")
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
// Log marked_up if the host just recovered from unreachable.
|
||||
if prevHost.Status == "unreachable" {
|
||||
h.audit.LogHostMarkedUp(r.Context(), prevHost.TeamID, hc.HostID)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// AddTag handles POST /v1/hosts/{id}/tags.
|
||||
func (h *hostHandler) AddTag(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
hostIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
admin := h.isAdmin(r, ac.UserID)
|
||||
|
||||
hostID, err := id.ParseHostID(hostIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req addTagRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
@ -298,10 +467,16 @@ func (h *hostHandler) AddTag(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// RemoveTag handles DELETE /v1/hosts/{id}/tags/{tag}.
|
||||
func (h *hostHandler) RemoveTag(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
hostIDStr := chi.URLParam(r, "id")
|
||||
tag := chi.URLParam(r, "tag")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
hostID, err := id.ParseHostID(hostIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.RemoveTag(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID), tag); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
@ -311,11 +486,47 @@ func (h *hostHandler) RemoveTag(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// RefreshToken handles POST /v1/hosts/auth/refresh (unauthenticated).
|
||||
// The host agent sends its refresh token to receive a new JWT and rotated refresh token.
|
||||
func (h *hostHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
|
||||
var req refreshTokenRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
if req.RefreshToken == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "refresh_token is required")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.svc.Refresh(r.Context(), req.RefreshToken)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, refreshTokenResponse{
|
||||
Host: hostToResponse(result.Host),
|
||||
Token: result.JWT,
|
||||
RefreshToken: result.RefreshToken,
|
||||
CertPEM: result.CertPEM,
|
||||
KeyPEM: result.KeyPEM,
|
||||
CACertPEM: result.CACertPEM,
|
||||
})
|
||||
}
|
||||
|
||||
// ListTags handles GET /v1/hosts/{id}/tags.
|
||||
func (h *hostHandler) ListTags(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
hostIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
hostID, err := id.ParseHostID(hostIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
|
||||
return
|
||||
}
|
||||
|
||||
tags, err := h.svc.ListTags(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
|
||||
156
internal/api/handlers_metrics.go
Normal file
156
internal/api/handlers_metrics.go
Normal file
@ -0,0 +1,156 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
type sandboxMetricsHandler struct {
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
}
|
||||
|
||||
func newSandboxMetricsHandler(db *db.Queries, pool *lifecycle.HostClientPool) *sandboxMetricsHandler {
|
||||
return &sandboxMetricsHandler{db: db, pool: pool}
|
||||
}
|
||||
|
||||
type metricPointResponse struct {
|
||||
TimestampUnix int64 `json:"timestamp_unix"`
|
||||
CPUPct float64 `json:"cpu_pct"`
|
||||
MemBytes int64 `json:"mem_bytes"`
|
||||
DiskBytes int64 `json:"disk_bytes"`
|
||||
}
|
||||
|
||||
type metricsResponse struct {
|
||||
SandboxID string `json:"sandbox_id"`
|
||||
Range string `json:"range"`
|
||||
Points []metricPointResponse `json:"points"`
|
||||
}
|
||||
|
||||
// GetMetrics handles GET /v1/sandboxes/{id}/metrics?range=10m|2h|24h.
|
||||
func (h *sandboxMetricsHandler) GetMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
rangeTier := r.URL.Query().Get("range")
|
||||
if rangeTier == "" {
|
||||
rangeTier = "10m"
|
||||
}
|
||||
validRanges := map[string]bool{"5m": true, "10m": true, "1h": true, "2h": true, "6h": true, "12h": true, "24h": true}
|
||||
if !validRanges[rangeTier] {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "range must be one of: 5m, 10m, 1h, 2h, 6h, 12h, 24h")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
|
||||
switch sb.Status {
|
||||
case "running":
|
||||
h.getFromAgent(w, r, sandboxIDStr, rangeTier, sb.HostID)
|
||||
case "paused":
|
||||
h.getFromDB(ctx, w, sandboxIDStr, sandboxID, rangeTier)
|
||||
default:
|
||||
writeError(w, http.StatusNotFound, "not_found", "metrics not available for sandbox in state: "+sb.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sandboxMetricsHandler) getFromAgent(w http.ResponseWriter, r *http.Request, sandboxIDStr, rangeTier string, hostID pgtype.UUID) {
|
||||
ctx := r.Context()
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := agent.GetSandboxMetrics(ctx, connect.NewRequest(&pb.GetSandboxMetricsRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Range: rangeTier,
|
||||
}))
|
||||
if err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
points := make([]metricPointResponse, len(resp.Msg.Points))
|
||||
for i, p := range resp.Msg.Points {
|
||||
points[i] = metricPointResponse{
|
||||
TimestampUnix: p.TimestampUnix,
|
||||
CPUPct: p.CpuPct,
|
||||
MemBytes: p.MemBytes,
|
||||
DiskBytes: p.DiskBytes,
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, metricsResponse{
|
||||
SandboxID: sandboxIDStr,
|
||||
Range: rangeTier,
|
||||
Points: points,
|
||||
})
|
||||
}
|
||||
|
||||
// rangeToDB maps a user-facing range filter to the DB tier and cutoff duration.
|
||||
var rangeToDB = map[string]struct {
|
||||
tier string
|
||||
cutoff time.Duration
|
||||
}{
|
||||
"5m": {"10m", 5 * time.Minute},
|
||||
"10m": {"10m", 10 * time.Minute},
|
||||
"1h": {"2h", 1 * time.Hour},
|
||||
"2h": {"2h", 2 * time.Hour},
|
||||
"6h": {"24h", 6 * time.Hour},
|
||||
"12h": {"24h", 12 * time.Hour},
|
||||
"24h": {"24h", 24 * time.Hour},
|
||||
}
|
||||
|
||||
func (h *sandboxMetricsHandler) getFromDB(ctx context.Context, w http.ResponseWriter, sandboxIDStr string, sandboxID pgtype.UUID, rangeTier string) {
|
||||
mapping := rangeToDB[rangeTier]
|
||||
rows, err := h.db.GetSandboxMetricPoints(ctx, db.GetSandboxMetricPointsParams{
|
||||
SandboxID: sandboxID,
|
||||
Tier: mapping.tier,
|
||||
Ts: time.Now().Add(-mapping.cutoff).Unix(),
|
||||
})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to read metrics")
|
||||
return
|
||||
}
|
||||
|
||||
points := make([]metricPointResponse, len(rows))
|
||||
for i, row := range rows {
|
||||
points[i] = metricPointResponse{
|
||||
TimestampUnix: row.Ts,
|
||||
CPUPct: row.CpuPct,
|
||||
MemBytes: row.MemBytes,
|
||||
DiskBytes: row.DiskBytes,
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, metricsResponse{
|
||||
SandboxID: sandboxIDStr,
|
||||
Range: rangeTier,
|
||||
Points: points,
|
||||
})
|
||||
}
|
||||
@ -150,19 +150,19 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
team, err := h.db.GetDefaultTeamForUser(ctx, user.ID)
|
||||
team, role, err := loginTeam(ctx, h.db, user.ID)
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to get team", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email)
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
return
|
||||
}
|
||||
redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email)
|
||||
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name)
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, pgx.ErrNoRows) {
|
||||
@ -199,6 +199,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
_, err = qtx.InsertUserOAuth(ctx, db.InsertUserOAuthParams{
|
||||
ID: userID,
|
||||
Email: email,
|
||||
Name: profile.Name,
|
||||
})
|
||||
if err != nil {
|
||||
var pgErr *pgconn.PgError
|
||||
@ -219,6 +220,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
if _, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
|
||||
ID: teamID,
|
||||
Name: teamName,
|
||||
Slug: id.NewTeamSlug(),
|
||||
}); err != nil {
|
||||
slog.Error("oauth: failed to create team", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
@ -253,14 +255,14 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email)
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email, profile.Name, "owner", false)
|
||||
if err != nil {
|
||||
slog.Error("oauth: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
return
|
||||
}
|
||||
|
||||
redirectWithToken(w, r, redirectBase, token, userID, teamID, email)
|
||||
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(userID), id.FormatTeamID(teamID), email, profile.Name)
|
||||
}
|
||||
|
||||
// retryAsLogin handles the race where a concurrent request already created the user.
|
||||
@ -282,29 +284,39 @@ func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, prov
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
team, err := h.db.GetDefaultTeamForUser(ctx, user.ID)
|
||||
team, role, err := loginTeam(ctx, h.db, user.ID)
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to get team", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email)
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
return
|
||||
}
|
||||
redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email)
|
||||
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name)
|
||||
}
|
||||
|
||||
func redirectWithToken(w http.ResponseWriter, r *http.Request, base, token, userID, teamID, email string) {
|
||||
u := base + "?" + url.Values{
|
||||
"token": {token},
|
||||
"user_id": {userID},
|
||||
"team_id": {teamID},
|
||||
"email": {email},
|
||||
}.Encode()
|
||||
http.Redirect(w, r, u, http.StatusFound)
|
||||
func redirectWithToken(w http.ResponseWriter, r *http.Request, base, token, userID, teamID, email, name string) {
|
||||
// Set auth data as short-lived cookies instead of URL query parameters.
|
||||
// This prevents token leakage via server access logs, Referer headers, and browser history.
|
||||
for _, c := range []http.Cookie{
|
||||
{Name: "wrenn_oauth_token", Value: token},
|
||||
{Name: "wrenn_oauth_user_id", Value: userID},
|
||||
{Name: "wrenn_oauth_team_id", Value: teamID},
|
||||
{Name: "wrenn_oauth_email", Value: email},
|
||||
{Name: "wrenn_oauth_name", Value: name},
|
||||
} {
|
||||
c.Path = "/auth/"
|
||||
c.MaxAge = 60
|
||||
c.HttpOnly = false // frontend JS must read these
|
||||
c.SameSite = http.SameSiteLaxMode
|
||||
c.Secure = isSecure(r)
|
||||
http.SetCookie(w, &c)
|
||||
}
|
||||
http.Redirect(w, r, base, http.StatusFound)
|
||||
}
|
||||
|
||||
func redirectWithError(w http.ResponseWriter, r *http.Request, base, code string) {
|
||||
|
||||
@ -7,17 +7,20 @@ import (
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/audit"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
)
|
||||
|
||||
type sandboxHandler struct {
|
||||
svc *service.SandboxService
|
||||
svc *service.SandboxService
|
||||
audit *audit.AuditLogger
|
||||
}
|
||||
|
||||
func newSandboxHandler(svc *service.SandboxService) *sandboxHandler {
|
||||
return &sandboxHandler{svc: svc}
|
||||
func newSandboxHandler(svc *service.SandboxService, al *audit.AuditLogger) *sandboxHandler {
|
||||
return &sandboxHandler{svc: svc, audit: al}
|
||||
}
|
||||
|
||||
type createSandboxRequest struct {
|
||||
@ -44,7 +47,7 @@ type sandboxResponse struct {
|
||||
|
||||
func sandboxToResponse(sb db.Sandbox) sandboxResponse {
|
||||
resp := sandboxResponse{
|
||||
ID: sb.ID,
|
||||
ID: id.FormatSandboxID(sb.ID),
|
||||
Status: sb.Status,
|
||||
Template: sb.Template,
|
||||
VCPUs: sb.Vcpus,
|
||||
@ -79,6 +82,10 @@ func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
if !ac.TeamID.Valid {
|
||||
writeError(w, http.StatusForbidden, "no_team", "no active team context; re-authenticate")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.svc.Create(r.Context(), service.SandboxCreateParams{
|
||||
TeamID: ac.TeamID,
|
||||
@ -93,6 +100,7 @@ func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSandboxCreate(r.Context(), ac, sb.ID, sb.Template)
|
||||
writeJSON(w, http.StatusCreated, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
@ -115,9 +123,15 @@ func (h *sandboxHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Get handles GET /v1/sandboxes/{id}.
|
||||
func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.svc.Get(r.Context(), sandboxID, ac.TeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
@ -129,9 +143,15 @@ func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Pause handles POST /v1/sandboxes/{id}/pause.
|
||||
func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.svc.Pause(r.Context(), sandboxID, ac.TeamID)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
@ -139,14 +159,21 @@ func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSandboxPause(r.Context(), ac, sandboxID)
|
||||
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// Resume handles POST /v1/sandboxes/{id}/resume.
|
||||
func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.svc.Resume(r.Context(), sandboxID, ac.TeamID)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
@ -154,14 +181,21 @@ func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSandboxResume(r.Context(), ac, sandboxID)
|
||||
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// Ping handles POST /v1/sandboxes/{id}/ping.
|
||||
func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.Ping(r.Context(), sandboxID, ac.TeamID); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
@ -173,14 +207,21 @@ func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Destroy handles DELETE /v1/sandboxes/{id}.
|
||||
func (h *sandboxHandler) Destroy(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.Destroy(r.Context(), sandboxID, ac.TeamID); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSandboxDestroy(r.Context(), ac, sandboxID)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
@ -9,25 +10,57 @@ import (
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/audit"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/layout"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/validate"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
type snapshotHandler struct {
|
||||
svc *service.TemplateService
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
pool *lifecycle.HostClientPool
|
||||
audit *audit.AuditLogger
|
||||
}
|
||||
|
||||
func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *snapshotHandler {
|
||||
return &snapshotHandler{svc: svc, db: db, agent: agent}
|
||||
func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, pool *lifecycle.HostClientPool, al *audit.AuditLogger) *snapshotHandler {
|
||||
return &snapshotHandler{svc: svc, db: db, pool: pool, audit: al}
|
||||
}
|
||||
|
||||
// deleteSnapshotBroadcast attempts to delete snapshot files on all online hosts.
|
||||
// Snapshots aren't currently host-tracked in the DB, so we broadcast to all hosts
|
||||
// and ignore NotFound errors.
|
||||
func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, teamID, templateID pgtype.UUID) error {
|
||||
hosts, err := h.db.ListActiveHosts(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list hosts: %w", err)
|
||||
}
|
||||
for _, host := range hosts {
|
||||
if host.Status != "online" {
|
||||
continue
|
||||
}
|
||||
agent, err := h.pool.GetForHost(host)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
|
||||
TeamId: formatUUIDForRPC(teamID),
|
||||
TemplateId: formatUUIDForRPC(templateID),
|
||||
})); err != nil {
|
||||
if connect.CodeOf(err) != connect.CodeNotFound {
|
||||
slog.Warn("snapshot: failed to delete on host", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type createSnapshotRequest struct {
|
||||
@ -42,6 +75,7 @@ type snapshotResponse struct {
|
||||
MemoryMB *int32 `json:"memory_mb,omitempty"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Platform bool `json:"platform"`
|
||||
}
|
||||
|
||||
func templateToResponse(t db.Template) snapshotResponse {
|
||||
@ -49,12 +83,13 @@ func templateToResponse(t db.Template) snapshotResponse {
|
||||
Name: t.Name,
|
||||
Type: t.Type,
|
||||
SizeBytes: t.SizeBytes,
|
||||
Platform: t.TeamID == id.PlatformTeamID,
|
||||
}
|
||||
if t.Vcpus.Valid {
|
||||
resp.VCPUs = &t.Vcpus.Int32
|
||||
if t.Vcpus != 0 {
|
||||
resp.VCPUs = &t.Vcpus
|
||||
}
|
||||
if t.MemoryMb.Valid {
|
||||
resp.MemoryMB = &t.MemoryMb.Int32
|
||||
if t.MemoryMb != 0 {
|
||||
resp.MemoryMB = &t.MemoryMb
|
||||
}
|
||||
if t.CreatedAt.Valid {
|
||||
resp.CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
|
||||
@ -75,6 +110,12 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(req.SandboxID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox_id")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
req.Name = id.NewSnapshotName()
|
||||
}
|
||||
@ -87,16 +128,21 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(ctx)
|
||||
overwrite := r.URL.Query().Get("overwrite") == "true"
|
||||
|
||||
// Check for global name collision.
|
||||
if _, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
|
||||
writeError(w, http.StatusConflict, "name_reserved", "template name is reserved by a global template")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if name already exists for this team.
|
||||
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
|
||||
if existing, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
|
||||
if !overwrite {
|
||||
writeError(w, http.StatusConflict, "already_exists", "snapshot name already exists; use ?overwrite=true to replace")
|
||||
return
|
||||
}
|
||||
// Delete old files from the agent before removing the DB record.
|
||||
if _, err := h.agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{Name: req.Name})); err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, "failed to delete existing snapshot files: "+msg)
|
||||
// Delete old snapshot files from all hosts before removing the DB record.
|
||||
if err := h.deleteSnapshotBroadcast(ctx, existing.TeamID, existing.ID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files")
|
||||
return
|
||||
}
|
||||
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err != nil {
|
||||
@ -106,7 +152,7 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Verify sandbox exists, belongs to team, and is running or paused.
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: req.SandboxID, TeamID: ac.TeamID})
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
@ -116,30 +162,59 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{
|
||||
SandboxId: req.SandboxID,
|
||||
Name: req.Name,
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
// Pre-mark sandbox as "paused" in DB BEFORE issuing the snapshot RPC.
|
||||
// The host agent's CreateSnapshot removes the sandbox from its in-memory
|
||||
// map immediately; if the reconciler fires during the flatten window and
|
||||
// the DB still says "running", it will mark the sandbox "stopped".
|
||||
if sb.Status == "running" {
|
||||
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "paused",
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to update sandbox status")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Use a detached context with a generous timeout so the snapshot completes
|
||||
// even if the client disconnects (the flatten step can take 10-20s).
|
||||
snapCtx, snapCancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer snapCancel()
|
||||
|
||||
// Generate the new template ID upfront so the host agent knows where to store files.
|
||||
newTemplateID := id.NewTemplateID()
|
||||
|
||||
resp, err := agent.CreateSnapshot(snapCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
|
||||
SandboxId: req.SandboxID,
|
||||
Name: req.Name,
|
||||
TeamId: formatUUIDForRPC(ac.TeamID),
|
||||
TemplateId: formatUUIDForRPC(newTemplateID),
|
||||
}))
|
||||
if err != nil {
|
||||
// Snapshot failed — revert status back to what it was.
|
||||
if sb.Status == "running" {
|
||||
if _, dbErr := h.db.UpdateSandboxStatus(snapCtx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "running",
|
||||
}); dbErr != nil {
|
||||
slog.Error("failed to revert sandbox status after snapshot error", "sandbox_id", req.SandboxID, "error", dbErr)
|
||||
}
|
||||
}
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
// Mark sandbox as paused (if it was running, it got paused by the snapshot).
|
||||
if sb.Status != "paused" {
|
||||
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: req.SandboxID, Status: "paused",
|
||||
}); err != nil {
|
||||
slog.Error("failed to update sandbox status after snapshot", "sandbox_id", req.SandboxID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
tmpl, err := h.db.InsertTemplate(ctx, db.InsertTemplateParams{
|
||||
tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{
|
||||
ID: newTemplateID,
|
||||
Name: req.Name,
|
||||
Type: "snapshot",
|
||||
Vcpus: pgtype.Int4{Int32: sb.Vcpus, Valid: true},
|
||||
MemoryMb: pgtype.Int4{Int32: sb.MemoryMb, Valid: true},
|
||||
Vcpus: sb.Vcpus,
|
||||
MemoryMb: sb.MemoryMb,
|
||||
SizeBytes: resp.Msg.SizeBytes,
|
||||
TeamID: ac.TeamID,
|
||||
})
|
||||
@ -149,6 +224,12 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSnapshotCreate(snapCtx, ac, req.Name)
|
||||
|
||||
if ctx.Err() != nil {
|
||||
slog.Info("snapshot created but client disconnected before response", "name", req.Name)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, templateToResponse(tmpl))
|
||||
}
|
||||
|
||||
@ -181,16 +262,23 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: name, TeamID: ac.TeamID}); err != nil {
|
||||
tmpl, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: name, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "template not found")
|
||||
return
|
||||
}
|
||||
// Platform templates can only be deleted by admins via /v1/admin/templates.
|
||||
if tmpl.TeamID == id.PlatformTeamID {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "platform templates cannot be deleted here")
|
||||
return
|
||||
}
|
||||
if layout.IsMinimal(tmpl.TeamID, tmpl.ID) {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := h.agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
|
||||
Name: name,
|
||||
})); err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, "failed to delete snapshot files: "+msg)
|
||||
if err := h.deleteSnapshotBroadcast(ctx, tmpl.TeamID, tmpl.ID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete snapshot files")
|
||||
return
|
||||
}
|
||||
|
||||
@ -199,5 +287,6 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSnapshotDelete(r.Context(), ac, name)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
95
internal/api/handlers_stats.go
Normal file
95
internal/api/handlers_stats.go
Normal file
@ -0,0 +1,95 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
)
|
||||
|
||||
type statsHandler struct {
|
||||
svc *service.StatsService
|
||||
}
|
||||
|
||||
func newStatsHandler(svc *service.StatsService) *statsHandler {
|
||||
return &statsHandler{svc: svc}
|
||||
}
|
||||
|
||||
type statsCurrentResponse struct {
|
||||
RunningCount int32 `json:"running_count"`
|
||||
VCPUsReserved int32 `json:"vcpus_reserved"`
|
||||
MemoryMBReserved int32 `json:"memory_mb_reserved"`
|
||||
}
|
||||
|
||||
type statsPeaksResponse struct {
|
||||
RunningCount int32 `json:"running_count"`
|
||||
VCPUs int32 `json:"vcpus"`
|
||||
MemoryMB int32 `json:"memory_mb"`
|
||||
}
|
||||
|
||||
type statsSeriesResponse struct {
|
||||
Labels []string `json:"labels"`
|
||||
Running []int32 `json:"running"`
|
||||
VCPUs []int32 `json:"vcpus"`
|
||||
MemoryMB []int32 `json:"memory_mb"`
|
||||
}
|
||||
|
||||
type statsResponse struct {
|
||||
Range string `json:"range"`
|
||||
Current statsCurrentResponse `json:"current"`
|
||||
Peaks statsPeaksResponse `json:"peaks"`
|
||||
Series statsSeriesResponse `json:"series"`
|
||||
}
|
||||
|
||||
// GetStats handles GET /v1/sandboxes/stats?range=5m|1h|6h|24h|30d
|
||||
func (h *statsHandler) GetStats(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
rangeParam := r.URL.Query().Get("range")
|
||||
if rangeParam == "" {
|
||||
rangeParam = string(service.Range1h)
|
||||
}
|
||||
tr := service.TimeRange(rangeParam)
|
||||
if !service.ValidRange(tr) {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "range must be one of: 5m, 1h, 6h, 24h, 30d")
|
||||
return
|
||||
}
|
||||
|
||||
current, peaks, series, err := h.svc.GetStats(r.Context(), ac.TeamID, tr)
|
||||
if err != nil {
|
||||
slog.Error("stats handler: get stats failed", "team_id", ac.TeamID, "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to retrieve stats")
|
||||
return
|
||||
}
|
||||
|
||||
resp := statsResponse{
|
||||
Range: rangeParam,
|
||||
Current: statsCurrentResponse{
|
||||
RunningCount: current.RunningCount,
|
||||
VCPUsReserved: current.VCPUsReserved,
|
||||
MemoryMBReserved: current.MemoryMBReserved,
|
||||
},
|
||||
Peaks: statsPeaksResponse{
|
||||
RunningCount: peaks.RunningCount,
|
||||
VCPUs: peaks.VCPUs,
|
||||
MemoryMB: peaks.MemoryMB,
|
||||
},
|
||||
Series: statsSeriesResponse{
|
||||
Labels: make([]string, len(series)),
|
||||
Running: make([]int32, len(series)),
|
||||
VCPUs: make([]int32, len(series)),
|
||||
MemoryMB: make([]int32, len(series)),
|
||||
},
|
||||
}
|
||||
|
||||
for i, pt := range series {
|
||||
resp.Series.Labels[i] = pt.Bucket.UTC().Format(time.RFC3339)
|
||||
resp.Series.Running[i] = pt.RunningCount
|
||||
resp.Series.VCPUs[i] = pt.VCPUsReserved
|
||||
resp.Series.MemoryMB[i] = pt.MemoryMBReserved
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
390
internal/api/handlers_team.go
Normal file
390
internal/api/handlers_team.go
Normal file
@ -0,0 +1,390 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/audit"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
)
|
||||
|
||||
type teamHandler struct {
|
||||
svc *service.TeamService
|
||||
audit *audit.AuditLogger
|
||||
}
|
||||
|
||||
func newTeamHandler(svc *service.TeamService, al *audit.AuditLogger) *teamHandler {
|
||||
return &teamHandler{svc: svc, audit: al}
|
||||
}
|
||||
|
||||
// teamResponse is the JSON shape for a team.
|
||||
type teamResponse struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
IsByoc bool `json:"is_byoc"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// teamWithRoleResponse includes the calling user's role.
|
||||
type teamWithRoleResponse struct {
|
||||
teamResponse
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
type memberResponse struct {
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
JoinedAt string `json:"joined_at,omitempty"`
|
||||
}
|
||||
|
||||
func teamToResponse(t db.Team) teamResponse {
|
||||
resp := teamResponse{
|
||||
ID: id.FormatTeamID(t.ID),
|
||||
Name: t.Name,
|
||||
Slug: t.Slug,
|
||||
IsByoc: t.IsByoc,
|
||||
}
|
||||
if t.CreatedAt.Valid {
|
||||
resp.CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func memberInfoToResponse(m service.MemberInfo) memberResponse {
|
||||
return memberResponse{
|
||||
UserID: m.UserID,
|
||||
Name: m.Name,
|
||||
Email: m.Email,
|
||||
Role: m.Role,
|
||||
JoinedAt: m.JoinedAt.Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
// requireTeamAccess is an inline check used by every team-scoped handler:
|
||||
// the JWT team_id must match the URL {id} before any DB call is made.
|
||||
// Returns false and writes 403 if they don't match.
|
||||
func requireTeamAccess(w http.ResponseWriter, r *http.Request, ac auth.AuthContext) (pgtype.UUID, bool) {
|
||||
teamIDStr := chi.URLParam(r, "id")
|
||||
teamID, err := id.ParseTeamID(teamIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team ID")
|
||||
return pgtype.UUID{}, false
|
||||
}
|
||||
if ac.TeamID != teamID {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "JWT team does not match requested team; use switch-team first")
|
||||
return pgtype.UUID{}, false
|
||||
}
|
||||
return teamID, true
|
||||
}
|
||||
|
||||
// List handles GET /v1/teams
|
||||
// Returns all teams the authenticated user belongs to.
|
||||
func (h *teamHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
teams, err := h.svc.ListTeamsForUser(r.Context(), ac.UserID)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]teamWithRoleResponse, len(teams))
|
||||
for i, t := range teams {
|
||||
resp[i] = teamWithRoleResponse{
|
||||
teamResponse: teamToResponse(t.Team),
|
||||
Role: t.Role,
|
||||
}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Create handles POST /v1/teams
|
||||
// Creates a new team owned by the authenticated user.
|
||||
func (h *teamHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
req.Name = strings.TrimSpace(req.Name)
|
||||
|
||||
team, err := h.svc.CreateTeam(r.Context(), ac.UserID, req.Name)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, teamWithRoleResponse{
|
||||
teamResponse: teamToResponse(team.Team),
|
||||
Role: team.Role,
|
||||
})
|
||||
}
|
||||
|
||||
// Get handles GET /v1/teams/{id}
|
||||
// Returns team info and member list.
|
||||
func (h *teamHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
teamID, ok := requireTeamAccess(w, r, ac)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
team, err := h.svc.GetTeam(r.Context(), teamID)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
members, err := h.svc.GetMembers(r.Context(), teamID)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
memberResp := make([]memberResponse, len(members))
|
||||
for i, m := range members {
|
||||
memberResp[i] = memberInfoToResponse(m)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"team": teamToResponse(team),
|
||||
"members": memberResp,
|
||||
})
|
||||
}
|
||||
|
||||
// Rename handles PATCH /v1/teams/{id}
|
||||
// Renames the team. Requires admin or owner role (verified from DB).
|
||||
func (h *teamHandler) Rename(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
teamID, ok := requireTeamAccess(w, r, ac)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
req.Name = strings.TrimSpace(req.Name)
|
||||
|
||||
// Fetch old name for audit log before renaming.
|
||||
oldTeam, err := h.svc.GetTeam(r.Context(), teamID)
|
||||
if err != nil {
|
||||
slog.Warn("audit: could not fetch old team name for rename log", "team_id", id.FormatTeamID(teamID), "error", err)
|
||||
}
|
||||
|
||||
if err := h.svc.RenameTeam(r.Context(), teamID, ac.UserID, req.Name); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogTeamRename(r.Context(), ac, teamID, oldTeam.Name, req.Name)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// Delete handles DELETE /v1/teams/{id}
|
||||
// Soft-deletes the team and destroys active sandboxes. Owner only.
|
||||
func (h *teamHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
teamID, ok := requireTeamAccess(w, r, ac)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.DeleteTeam(r.Context(), teamID, ac.UserID); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// ListMembers handles GET /v1/teams/{id}/members
|
||||
func (h *teamHandler) ListMembers(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
teamID, ok := requireTeamAccess(w, r, ac)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
members, err := h.svc.GetMembers(r.Context(), teamID)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]memberResponse, len(members))
|
||||
for i, m := range members {
|
||||
resp[i] = memberInfoToResponse(m)
|
||||
}
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// AddMember handles POST /v1/teams/{id}/members
|
||||
// Adds a user by email. Requires admin or owner (verified from DB).
|
||||
func (h *teamHandler) AddMember(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
teamID, ok := requireTeamAccess(w, r, ac)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
req.Email = strings.TrimSpace(strings.ToLower(req.Email))
|
||||
if req.Email == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "email is required")
|
||||
return
|
||||
}
|
||||
|
||||
member, err := h.svc.AddMember(r.Context(), teamID, ac.UserID, req.Email)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
// member.UserID is already formatted with prefix; parse it back for the audit logger.
|
||||
targetUserID, parseErr := id.ParseUserID(member.UserID)
|
||||
if parseErr == nil {
|
||||
h.audit.LogMemberAdd(r.Context(), ac, targetUserID, member.Email, member.Role)
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, memberInfoToResponse(member))
|
||||
}
|
||||
|
||||
// RemoveMember handles DELETE /v1/teams/{id}/members/{uid}
|
||||
// Removes a member. Requires admin or owner (verified from DB). Owner cannot be removed.
|
||||
func (h *teamHandler) RemoveMember(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
teamID, ok := requireTeamAccess(w, r, ac)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
targetUserIDStr := chi.URLParam(r, "uid")
|
||||
|
||||
targetUserID, err := id.ParseUserID(targetUserIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.RemoveMember(r.Context(), teamID, ac.UserID, targetUserID); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogMemberRemove(r.Context(), ac, targetUserID)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// UpdateMemberRole handles PATCH /v1/teams/{id}/members/{uid}
|
||||
// Changes a member's role (admin or member). Owner's role cannot be changed.
|
||||
func (h *teamHandler) UpdateMemberRole(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
teamID, ok := requireTeamAccess(w, r, ac)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
targetUserIDStr := chi.URLParam(r, "uid")
|
||||
|
||||
targetUserID, err := id.ParseUserID(targetUserIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Role string `json:"role"`
|
||||
}
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.UpdateMemberRole(r.Context(), teamID, ac.UserID, targetUserID, req.Role); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogMemberRoleUpdate(r.Context(), ac, targetUserID, req.Role)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// Leave handles POST /v1/teams/{id}/leave
|
||||
// Removes the calling user from the team. Owner cannot leave.
|
||||
func (h *teamHandler) Leave(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
teamID, ok := requireTeamAccess(w, r, ac)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.LeaveTeam(r.Context(), teamID, ac.UserID); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogMemberLeave(r.Context(), ac)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// SetBYOC handles PUT /v1/admin/teams/{id}/byoc (admin only).
|
||||
// Enables or disables the BYOC feature flag for a team.
|
||||
func (h *teamHandler) SetBYOC(w http.ResponseWriter, r *http.Request) {
|
||||
teamIDStr := chi.URLParam(r, "id")
|
||||
|
||||
teamID, err := id.ParseTeamID(teamIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.SetBYOC(r.Context(), teamID, req.Enabled); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
52
internal/api/handlers_users.go
Normal file
52
internal/api/handlers_users.go
Normal file
@ -0,0 +1,52 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
type usersHandler struct {
|
||||
db *db.Queries
|
||||
}
|
||||
|
||||
func newUsersHandler(db *db.Queries) *usersHandler {
|
||||
return &usersHandler{db: db}
|
||||
}
|
||||
|
||||
// Search handles GET /v1/users/search?email=<prefix>
|
||||
// Returns up to 10 users whose email starts with the given prefix.
|
||||
// The prefix must be at least 3 characters long and contain "@".
|
||||
func (h *usersHandler) Search(w http.ResponseWriter, r *http.Request) {
|
||||
auth.MustFromContext(r.Context()) // ensure authenticated
|
||||
|
||||
prefix := strings.TrimSpace(r.URL.Query().Get("email"))
|
||||
if len(prefix) < 3 || !strings.Contains(prefix, "@") {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "email prefix must be at least 3 characters and contain '@'")
|
||||
return
|
||||
}
|
||||
|
||||
// Escape LIKE metacharacters to prevent pattern injection.
|
||||
escaped := strings.NewReplacer("%", "\\%", "_", "\\_").Replace(prefix)
|
||||
|
||||
results, err := h.db.SearchUsersByEmailPrefix(r.Context(), pgtype.Text{String: escaped, Valid: true})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal", "search failed")
|
||||
return
|
||||
}
|
||||
|
||||
type userResult struct {
|
||||
UserID string `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
resp := make([]userResult, len(results))
|
||||
for i, u := range results {
|
||||
resp[i] = userResult{UserID: id.FormatUserID(u.ID), Email: u.Email}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
216
internal/api/host_monitor.go
Normal file
216
internal/api/host_monitor.go
Normal file
@ -0,0 +1,216 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/audit"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
// unreachableThreshold is how long a host can go without a heartbeat before
|
||||
// it is considered unreachable (3 missed 30-second heartbeats).
|
||||
const unreachableThreshold = 90 * time.Second
|
||||
|
||||
// HostMonitor runs on a fixed interval and performs two duties:
|
||||
//
|
||||
// 1. Passive check: marks hosts whose last_heartbeat_at is stale as
|
||||
// "unreachable" and marks their active sandboxes as "missing".
|
||||
//
|
||||
// 2. Active reconciliation: for each online host, calls ListSandboxes and
|
||||
// reconciles DB state against live host state — restoring "missing"
|
||||
// sandboxes that are actually alive, and stopping orphaned ones.
|
||||
type HostMonitor struct {
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
audit *audit.AuditLogger
|
||||
interval time.Duration
|
||||
}
|
||||
|
||||
// NewHostMonitor creates a HostMonitor.
|
||||
func NewHostMonitor(queries *db.Queries, pool *lifecycle.HostClientPool, al *audit.AuditLogger, interval time.Duration) *HostMonitor {
|
||||
return &HostMonitor{
|
||||
db: queries,
|
||||
pool: pool,
|
||||
audit: al,
|
||||
interval: interval,
|
||||
}
|
||||
}
|
||||
|
||||
// Start runs the monitor loop until the context is cancelled.
|
||||
func (m *HostMonitor) Start(ctx context.Context) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(m.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Run immediately on startup so the CP doesn't wait one full interval
|
||||
// before reconciling host and sandbox state.
|
||||
m.run(ctx)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.run(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (m *HostMonitor) run(ctx context.Context) {
|
||||
hosts, err := m.db.ListActiveHosts(ctx)
|
||||
if err != nil {
|
||||
slog.Warn("host monitor: failed to list hosts", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, host := range hosts {
|
||||
m.checkHost(ctx, host)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
|
||||
// --- Passive phase: check heartbeat staleness ---
|
||||
|
||||
stale := !host.LastHeartbeatAt.Valid ||
|
||||
time.Since(host.LastHeartbeatAt.Time) > unreachableThreshold
|
||||
|
||||
if stale && host.Status != "unreachable" {
|
||||
slog.Info("host monitor: marking host unreachable", "host_id", id.FormatHostID(host.ID),
|
||||
"last_heartbeat", host.LastHeartbeatAt.Time)
|
||||
if err := m.db.MarkHostUnreachable(ctx, host.ID); err != nil {
|
||||
slog.Warn("host monitor: failed to mark host unreachable", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
}
|
||||
if err := m.db.MarkSandboxesMissingByHost(ctx, host.ID); err != nil {
|
||||
slog.Warn("host monitor: failed to mark sandboxes missing", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
}
|
||||
m.audit.LogHostMarkedDown(ctx, host.TeamID, host.ID)
|
||||
return
|
||||
}
|
||||
|
||||
// --- Active reconciliation: only for online hosts ---
|
||||
|
||||
if host.Status != "online" {
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := m.pool.GetForHost(host)
|
||||
if err != nil {
|
||||
// Host has no address yet (e.g., just registered) — skip.
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := agent.ListSandboxes(ctx, connect.NewRequest(&pb.ListSandboxesRequest{}))
|
||||
if err != nil {
|
||||
// RPC failure is a transient condition; the passive phase will catch it
|
||||
// if heartbeats stop arriving.
|
||||
slog.Debug("host monitor: ListSandboxes failed (transient)", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Build set of sandbox IDs alive on the host.
|
||||
// The host agent returns sandbox IDs as strings (formatted with prefix).
|
||||
alive := make(map[string]struct{}, len(resp.Msg.Sandboxes))
|
||||
for _, sb := range resp.Msg.Sandboxes {
|
||||
alive[sb.SandboxId] = struct{}{}
|
||||
}
|
||||
|
||||
autoPaused := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds))
|
||||
for _, apID := range resp.Msg.AutoPausedSandboxIds {
|
||||
autoPaused[apID] = struct{}{}
|
||||
}
|
||||
|
||||
// --- Restore sandboxes that are "missing" in DB but alive on host ---
|
||||
// This handles the case where CP marked them missing due to a transient
|
||||
// heartbeat gap, but the host was actually fine.
|
||||
|
||||
missingSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||
HostID: host.ID,
|
||||
Column2: []string{"missing"},
|
||||
})
|
||||
if err != nil {
|
||||
slog.Warn("host monitor: failed to list missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
} else {
|
||||
var toRestore []pgtype.UUID
|
||||
var toStop []pgtype.UUID
|
||||
for _, sb := range missingSandboxes {
|
||||
sbIDStr := id.FormatSandboxID(sb.ID)
|
||||
if _, ok := alive[sbIDStr]; ok {
|
||||
toRestore = append(toRestore, sb.ID)
|
||||
} else {
|
||||
toStop = append(toStop, sb.ID)
|
||||
}
|
||||
}
|
||||
if len(toRestore) > 0 {
|
||||
slog.Info("host monitor: restoring missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toRestore))
|
||||
if err := m.db.BulkRestoreRunning(ctx, toRestore); err != nil {
|
||||
slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
}
|
||||
}
|
||||
if len(toStop) > 0 {
|
||||
slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toStop))
|
||||
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: toStop,
|
||||
Status: "stopped",
|
||||
}); err != nil {
|
||||
slog.Warn("host monitor: failed to stop missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Find running sandboxes in DB that are no longer alive on the host ---
|
||||
|
||||
runningSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||
HostID: host.ID,
|
||||
Column2: []string{"running"},
|
||||
})
|
||||
if err != nil {
|
||||
slog.Warn("host monitor: failed to list running sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
var toPause, toStop []pgtype.UUID
|
||||
sbTeamID := make(map[pgtype.UUID]pgtype.UUID, len(runningSandboxes))
|
||||
for _, sb := range runningSandboxes {
|
||||
sbIDStr := id.FormatSandboxID(sb.ID)
|
||||
sbTeamID[sb.ID] = sb.TeamID
|
||||
if _, ok := alive[sbIDStr]; ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := autoPaused[sbIDStr]; ok {
|
||||
toPause = append(toPause, sb.ID)
|
||||
} else {
|
||||
toStop = append(toStop, sb.ID)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toPause) > 0 {
|
||||
slog.Info("host monitor: marking auto-paused sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toPause))
|
||||
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: toPause,
|
||||
Status: "paused",
|
||||
}); err != nil {
|
||||
slog.Warn("host monitor: failed to mark paused", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
}
|
||||
for _, sbID := range toPause {
|
||||
m.audit.LogSandboxAutoPause(ctx, sbTeamID[sbID], sbID)
|
||||
}
|
||||
}
|
||||
if len(toStop) > 0 {
|
||||
slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", id.FormatHostID(host.ID), "count", len(toStop))
|
||||
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: toStop,
|
||||
Status: "stopped",
|
||||
}); err != nil {
|
||||
slog.Warn("host monitor: failed to mark stopped", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
68
internal/api/metrics_sampler.go
Normal file
68
internal/api/metrics_sampler.go
Normal file
@ -0,0 +1,68 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
)
|
||||
|
||||
// MetricsSampler records per-team sandbox resource usage to
|
||||
// sandbox_metrics_snapshots every interval. It also prunes rows older than
|
||||
// 60 days on each tick to keep the table bounded.
|
||||
type MetricsSampler struct {
|
||||
db *db.Queries
|
||||
interval time.Duration
|
||||
}
|
||||
|
||||
// NewMetricsSampler creates a MetricsSampler.
|
||||
func NewMetricsSampler(queries *db.Queries, interval time.Duration) *MetricsSampler {
|
||||
return &MetricsSampler{db: queries, interval: interval}
|
||||
}
|
||||
|
||||
// Start runs the sampler loop until the context is cancelled.
|
||||
func (s *MetricsSampler) Start(ctx context.Context) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(s.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Sample immediately on startup.
|
||||
s.run(ctx)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.run(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *MetricsSampler) run(ctx context.Context) {
|
||||
s.prune(ctx)
|
||||
if err := s.sample(ctx); err != nil {
|
||||
slog.Warn("metrics sampler: sample failed", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MetricsSampler) sample(ctx context.Context) error {
|
||||
rows, err := s.db.SampleSandboxMetrics(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, row := range rows {
|
||||
if err := s.db.InsertMetricsSnapshot(ctx, db.InsertMetricsSnapshotParams(row)); err != nil {
|
||||
slog.Warn("metrics sampler: insert snapshot failed", "team_id", row.TeamID, "error", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MetricsSampler) prune(ctx context.Context) {
|
||||
if err := s.db.PruneOldMetrics(ctx); err != nil {
|
||||
slog.Warn("metrics sampler: prune failed", "error", err)
|
||||
}
|
||||
}
|
||||
@ -12,6 +12,9 @@ import (
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
type errorResponse struct {
|
||||
@ -35,6 +38,11 @@ func writeError(w http.ResponseWriter, status int, code, message string) {
|
||||
})
|
||||
}
|
||||
|
||||
// formatUUIDForRPC converts a pgtype.UUID to a hex string for RPC messages.
|
||||
func formatUUIDForRPC(u pgtype.UUID) string {
|
||||
return id.UUIDString(u)
|
||||
}
|
||||
|
||||
// agentErrToHTTP maps a Connect RPC error to an HTTP status, error code, and message.
|
||||
func agentErrToHTTP(err error) (int, string, string) {
|
||||
switch connect.CodeOf(err) {
|
||||
@ -87,8 +95,12 @@ func serviceErrToHTTP(err error) (int, string, string) {
|
||||
return http.StatusNotFound, "not_found", msg
|
||||
case strings.Contains(msg, "not running"), strings.Contains(msg, "not paused"):
|
||||
return http.StatusConflict, "invalid_state", msg
|
||||
case strings.Contains(msg, "conflict:"):
|
||||
return http.StatusConflict, "conflict", msg
|
||||
case strings.Contains(msg, "forbidden"):
|
||||
return http.StatusForbidden, "forbidden", msg
|
||||
case strings.Contains(msg, "invalid or expired"):
|
||||
return http.StatusUnauthorized, "unauthorized", msg
|
||||
case strings.Contains(msg, "invalid"):
|
||||
return http.StatusBadRequest, "invalid_request", msg
|
||||
default:
|
||||
|
||||
30
internal/api/middleware_admin.go
Normal file
30
internal/api/middleware_admin.go
Normal file
@ -0,0 +1,30 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
)
|
||||
|
||||
// requireAdmin validates that the authenticated user is a platform admin.
|
||||
// Must run after requireJWT (depends on AuthContext being present).
|
||||
// Re-validates against the DB — the JWT is_admin claim is for UI only;
|
||||
// the DB is the source of truth for admin access.
|
||||
func requireAdmin(queries *db.Queries) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ac, ok := auth.FromContext(r.Context())
|
||||
if !ok {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "authentication required")
|
||||
return
|
||||
}
|
||||
user, err := queries.GetUserByID(r.Context(), ac.UserID)
|
||||
if err != nil || !user.IsAdmin {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "admin access required")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1,38 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
)
|
||||
|
||||
// requireAPIKey validates the X-API-Key header, looks up the SHA-256 hash in DB,
|
||||
// and stamps TeamID into the request context.
|
||||
func requireAPIKey(queries *db.Queries) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
key := r.Header.Get("X-API-Key")
|
||||
if key == "" {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key header required")
|
||||
return
|
||||
}
|
||||
|
||||
hash := auth.HashAPIKey(key)
|
||||
row, err := queries.GetAPIKeyByHash(r.Context(), hash)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
// Best-effort update of last_used timestamp.
|
||||
if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil {
|
||||
slog.Warn("failed to update api key last_used", "key_id", row.ID, "error", err)
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{TeamID: row.TeamID})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -7,6 +7,7 @@ import (
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
// requireAPIKeyOrJWT accepts either X-API-Key header or Authorization: Bearer JWT.
|
||||
@ -19,15 +20,20 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
|
||||
hash := auth.HashAPIKey(key)
|
||||
row, err := queries.GetAPIKeyByHash(r.Context(), hash)
|
||||
if err != nil {
|
||||
slog.Warn("api key auth failed", "prefix", auth.APIKeyPrefix(key), "ip", r.RemoteAddr)
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil {
|
||||
slog.Warn("failed to update api key last_used", "key_id", row.ID, "error", err)
|
||||
slog.Warn("failed to update api key last_used", "key_id", id.FormatAPIKeyID(row.ID), "error", err)
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{TeamID: row.TeamID})
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: row.TeamID,
|
||||
APIKeyID: row.ID,
|
||||
APIKeyName: row.Name,
|
||||
})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
@ -37,14 +43,28 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
|
||||
tokenStr := strings.TrimPrefix(header, "Bearer ")
|
||||
claims, err := auth.VerifyJWT(jwtSecret, tokenStr)
|
||||
if err != nil {
|
||||
slog.Warn("jwt auth failed", "error", err, "ip", r.RemoteAddr)
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired token")
|
||||
return
|
||||
}
|
||||
|
||||
teamID, err := id.ParseTeamID(claims.TeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token")
|
||||
return
|
||||
}
|
||||
userID, err := id.ParseUserID(claims.Subject)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: claims.TeamID,
|
||||
UserID: claims.Subject,
|
||||
TeamID: teamID,
|
||||
UserID: userID,
|
||||
Email: claims.Email,
|
||||
Name: claims.Name,
|
||||
Role: claims.Role,
|
||||
})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
// requireHostToken validates the X-Host-Token header containing a host JWT,
|
||||
@ -23,7 +24,13 @@ func requireHostToken(secret []byte) func(http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := auth.WithHostContext(r.Context(), auth.HostContext{HostID: claims.HostID})
|
||||
hostID, err := id.ParseHostID(claims.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid host ID in token")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := auth.WithHostContext(r.Context(), auth.HostContext{HostID: hostID})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
// requireJWT validates the Authorization: Bearer <token> header, verifies the JWT
|
||||
@ -25,11 +26,26 @@ func requireJWT(secret []byte) func(http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
teamID, err := id.ParseTeamID(claims.TeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token")
|
||||
return
|
||||
}
|
||||
userID, err := id.ParseUserID(claims.Subject)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: claims.TeamID,
|
||||
UserID: claims.Subject,
|
||||
Email: claims.Email,
|
||||
TeamID: teamID,
|
||||
UserID: userID,
|
||||
Email: claims.Email,
|
||||
Name: claims.Name,
|
||||
Role: claims.Role,
|
||||
IsAdmin: claims.IsAdmin,
|
||||
})
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,126 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
// Reconciler periodically compares the host agent's sandbox list with the DB
|
||||
// and marks sandboxes that no longer exist on the host as stopped.
|
||||
type Reconciler struct {
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
hostID string
|
||||
interval time.Duration
|
||||
}
|
||||
|
||||
// NewReconciler creates a new reconciler.
|
||||
func NewReconciler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient, hostID string, interval time.Duration) *Reconciler {
|
||||
return &Reconciler{
|
||||
db: db,
|
||||
agent: agent,
|
||||
hostID: hostID,
|
||||
interval: interval,
|
||||
}
|
||||
}
|
||||
|
||||
// Start runs the reconciliation loop until the context is cancelled.
|
||||
func (rc *Reconciler) Start(ctx context.Context) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(rc.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
rc.reconcile(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (rc *Reconciler) reconcile(ctx context.Context) {
|
||||
// Single RPC returns both the running sandbox list and any IDs that
|
||||
// were auto-paused by the TTL reaper since the last call.
|
||||
resp, err := rc.agent.ListSandboxes(ctx, connect.NewRequest(&pb.ListSandboxesRequest{}))
|
||||
if err != nil {
|
||||
slog.Warn("reconciler: failed to list sandboxes from host agent", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Build a set of sandbox IDs that are alive on the host.
|
||||
alive := make(map[string]struct{}, len(resp.Msg.Sandboxes))
|
||||
for _, sb := range resp.Msg.Sandboxes {
|
||||
alive[sb.SandboxId] = struct{}{}
|
||||
}
|
||||
|
||||
// Build auto-paused set from the same response.
|
||||
autoPausedSet := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds))
|
||||
for _, id := range resp.Msg.AutoPausedSandboxIds {
|
||||
autoPausedSet[id] = struct{}{}
|
||||
}
|
||||
|
||||
// Get all DB sandboxes for this host that are running.
|
||||
// Paused sandboxes are excluded: they are expected to not exist on the
|
||||
// host agent because pause = snapshot + destroy resources.
|
||||
dbSandboxes, err := rc.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||
HostID: rc.hostID,
|
||||
Column2: []string{"running"},
|
||||
})
|
||||
if err != nil {
|
||||
slog.Warn("reconciler: failed to list DB sandboxes", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Find sandboxes in DB that are no longer on the host.
|
||||
var stale []string
|
||||
for _, sb := range dbSandboxes {
|
||||
if _, ok := alive[sb.ID]; !ok {
|
||||
stale = append(stale, sb.ID)
|
||||
}
|
||||
}
|
||||
|
||||
if len(stale) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Split stale sandboxes into those auto-paused by the TTL reaper vs
|
||||
// those that crashed/were orphaned.
|
||||
var toPause, toStop []string
|
||||
for _, id := range stale {
|
||||
if _, ok := autoPausedSet[id]; ok {
|
||||
toPause = append(toPause, id)
|
||||
} else {
|
||||
toStop = append(toStop, id)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toPause) > 0 {
|
||||
slog.Info("reconciler: marking auto-paused sandboxes", "count", len(toPause), "ids", toPause)
|
||||
if err := rc.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: toPause,
|
||||
Status: "paused",
|
||||
}); err != nil {
|
||||
slog.Warn("reconciler: failed to mark auto-paused sandboxes", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toStop) > 0 {
|
||||
slog.Info("reconciler: marking stale sandboxes as stopped", "count", len(toStop), "ids", toStop)
|
||||
if err := rc.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: toStop,
|
||||
Status: "stopped",
|
||||
}); err != nil {
|
||||
slog.Warn("reconciler: failed to update stale sandboxes", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -9,10 +9,14 @@ import (
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/audit"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/channels"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/scheduler"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
//go:embed openapi.yaml
|
||||
@ -20,30 +24,54 @@ var openapiYAML []byte
|
||||
|
||||
// Server is the control plane HTTP server.
|
||||
type Server struct {
|
||||
router chi.Router
|
||||
router chi.Router
|
||||
BuildSvc *service.BuildService
|
||||
}
|
||||
|
||||
// New constructs the chi router and registers all routes.
|
||||
func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, pool *pgxpool.Pool, rdb *redis.Client, jwtSecret []byte, oauthRegistry *oauth.Registry, oauthRedirectURL string) *Server {
|
||||
func New(
|
||||
queries *db.Queries,
|
||||
pool *lifecycle.HostClientPool,
|
||||
sched scheduler.HostScheduler,
|
||||
pgPool *pgxpool.Pool,
|
||||
rdb *redis.Client,
|
||||
jwtSecret []byte,
|
||||
oauthRegistry *oauth.Registry,
|
||||
oauthRedirectURL string,
|
||||
ca *auth.CA,
|
||||
al *audit.AuditLogger,
|
||||
channelSvc *channels.Service,
|
||||
) *Server {
|
||||
r := chi.NewRouter()
|
||||
r.Use(requestLogger())
|
||||
|
||||
// Shared service layer.
|
||||
sandboxSvc := &service.SandboxService{DB: queries, Agent: agent}
|
||||
sandboxSvc := &service.SandboxService{DB: queries, Pool: pool, Scheduler: sched}
|
||||
apiKeySvc := &service.APIKeyService{DB: queries}
|
||||
templateSvc := &service.TemplateService{DB: queries}
|
||||
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret}
|
||||
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool, CA: ca}
|
||||
teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool}
|
||||
auditSvc := &service.AuditService{DB: queries}
|
||||
statsSvc := &service.StatsService{DB: queries, Pool: pgPool}
|
||||
buildSvc := &service.BuildService{DB: queries, Redis: rdb, Pool: pool, Scheduler: sched}
|
||||
|
||||
sandbox := newSandboxHandler(sandboxSvc)
|
||||
exec := newExecHandler(queries, agent)
|
||||
execStream := newExecStreamHandler(queries, agent)
|
||||
files := newFilesHandler(queries, agent)
|
||||
filesStream := newFilesStreamHandler(queries, agent)
|
||||
snapshots := newSnapshotHandler(templateSvc, queries, agent)
|
||||
authH := newAuthHandler(queries, pool, jwtSecret)
|
||||
oauthH := newOAuthHandler(queries, pool, jwtSecret, oauthRegistry, oauthRedirectURL)
|
||||
apiKeys := newAPIKeyHandler(apiKeySvc)
|
||||
hostH := newHostHandler(hostSvc, queries)
|
||||
sandbox := newSandboxHandler(sandboxSvc, al)
|
||||
exec := newExecHandler(queries, pool)
|
||||
execStream := newExecStreamHandler(queries, pool)
|
||||
files := newFilesHandler(queries, pool)
|
||||
filesStream := newFilesStreamHandler(queries, pool)
|
||||
snapshots := newSnapshotHandler(templateSvc, queries, pool, al)
|
||||
authH := newAuthHandler(queries, pgPool, jwtSecret)
|
||||
oauthH := newOAuthHandler(queries, pgPool, jwtSecret, oauthRegistry, oauthRedirectURL)
|
||||
apiKeys := newAPIKeyHandler(apiKeySvc, al)
|
||||
hostH := newHostHandler(hostSvc, queries, al)
|
||||
teamH := newTeamHandler(teamSvc, al)
|
||||
usersH := newUsersHandler(queries)
|
||||
auditH := newAuditHandler(auditSvc)
|
||||
statsH := newStatsHandler(statsSvc)
|
||||
metricsH := newSandboxMetricsHandler(queries, pool)
|
||||
buildH := newBuildHandler(buildSvc, queries, pool)
|
||||
channelH := newChannelHandler(channelSvc, al)
|
||||
|
||||
// OpenAPI spec and docs.
|
||||
r.Get("/openapi.yaml", serveOpenAPI)
|
||||
@ -55,6 +83,9 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
|
||||
r.Get("/auth/oauth/{provider}", oauthH.Redirect)
|
||||
r.Get("/auth/oauth/{provider}/callback", oauthH.Callback)
|
||||
|
||||
// JWT-authenticated: switch active team.
|
||||
r.With(requireJWT(jwtSecret)).Post("/v1/auth/switch-team", authH.SwitchTeam)
|
||||
|
||||
// JWT-authenticated: API key management.
|
||||
r.Route("/v1/api-keys", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret))
|
||||
@ -63,11 +94,32 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
|
||||
r.Delete("/{id}", apiKeys.Delete)
|
||||
})
|
||||
|
||||
// JWT-authenticated: team management.
|
||||
r.Route("/v1/teams", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret))
|
||||
r.Get("/", teamH.List)
|
||||
r.Post("/", teamH.Create)
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
r.Get("/", teamH.Get)
|
||||
r.Patch("/", teamH.Rename)
|
||||
r.Delete("/", teamH.Delete)
|
||||
r.Get("/members", teamH.ListMembers)
|
||||
r.Post("/members", teamH.AddMember)
|
||||
r.Patch("/members/{uid}", teamH.UpdateMemberRole)
|
||||
r.Delete("/members/{uid}", teamH.RemoveMember)
|
||||
r.Post("/leave", teamH.Leave)
|
||||
})
|
||||
})
|
||||
|
||||
// JWT-authenticated: user search (for add-member UI).
|
||||
r.With(requireJWT(jwtSecret)).Get("/v1/users/search", usersH.Search)
|
||||
|
||||
// Sandbox lifecycle: accepts API key or JWT bearer token.
|
||||
r.Route("/v1/sandboxes", func(r chi.Router) {
|
||||
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
|
||||
r.Post("/", sandbox.Create)
|
||||
r.Get("/", sandbox.List)
|
||||
r.Get("/stats", statsH.GetStats)
|
||||
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
r.Get("/", sandbox.Get)
|
||||
@ -81,6 +133,7 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
|
||||
r.Post("/files/read", files.Download)
|
||||
r.Post("/files/stream/write", filesStream.StreamUpload)
|
||||
r.Post("/files/stream/read", filesStream.StreamDownload)
|
||||
r.Get("/metrics", metricsH.GetMetrics)
|
||||
})
|
||||
})
|
||||
|
||||
@ -97,6 +150,9 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
|
||||
// Unauthenticated: one-time registration token.
|
||||
r.Post("/register", hostH.Register)
|
||||
|
||||
// Unauthenticated: refresh token exchange.
|
||||
r.Post("/auth/refresh", hostH.RefreshToken)
|
||||
|
||||
// Host-token-authenticated: heartbeat.
|
||||
r.With(requireHostToken(jwtSecret)).Post("/{id}/heartbeat", hostH.Heartbeat)
|
||||
|
||||
@ -108,6 +164,7 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
r.Get("/", hostH.Get)
|
||||
r.Delete("/", hostH.Delete)
|
||||
r.Get("/delete-preview", hostH.DeletePreview)
|
||||
r.Post("/token", hostH.RegenerateToken)
|
||||
r.Get("/tags", hostH.ListTags)
|
||||
r.Post("/tags", hostH.AddTag)
|
||||
@ -116,7 +173,37 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
|
||||
})
|
||||
})
|
||||
|
||||
return &Server{router: r}
|
||||
// JWT-authenticated: notification channels.
|
||||
r.Route("/v1/channels", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret))
|
||||
r.Post("/", channelH.Create)
|
||||
r.Get("/", channelH.List)
|
||||
r.Post("/test", channelH.Test)
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
r.Get("/", channelH.Get)
|
||||
r.Patch("/", channelH.Update)
|
||||
r.Delete("/", channelH.Delete)
|
||||
r.Put("/config", channelH.RotateConfig)
|
||||
})
|
||||
})
|
||||
|
||||
// JWT-authenticated: audit log.
|
||||
r.With(requireJWT(jwtSecret)).Get("/v1/audit-logs", auditH.List)
|
||||
|
||||
// Platform admin routes — require JWT + DB-validated admin status.
|
||||
r.Route("/v1/admin", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret))
|
||||
r.Use(requireAdmin(queries))
|
||||
r.Put("/teams/{id}/byoc", teamH.SetBYOC)
|
||||
r.Get("/templates", buildH.ListTemplates)
|
||||
r.Delete("/templates/{name}", buildH.DeleteTemplate)
|
||||
r.Post("/builds", buildH.Create)
|
||||
r.Get("/builds", buildH.List)
|
||||
r.Get("/builds/{id}", buildH.Get)
|
||||
r.Post("/builds/{id}/cancel", buildH.Cancel)
|
||||
})
|
||||
|
||||
return &Server{router: r, BuildSvc: buildSvc}
|
||||
}
|
||||
|
||||
// Handler returns the HTTP handler.
|
||||
@ -137,7 +224,7 @@ func serveDocs(w http.ResponseWriter, r *http.Request) {
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Wrenn Sandbox API</title>
|
||||
<link rel="stylesheet" href="https://unpkg.com/swagger-ui-dist@5/swagger-ui.css">
|
||||
<link rel="stylesheet" href="https://unpkg.com/swagger-ui-dist@5.18.2/swagger-ui.css" integrity="sha384-rcbEi6xgdPk0iWkAQzT2F3FeBJXdG+ydrawGlfHAFIZG7wU6aKbQaRewysYpmrlW" crossorigin="anonymous">
|
||||
<style>
|
||||
body { margin: 0; background: #fafafa; }
|
||||
.swagger-ui .topbar { display: none; }
|
||||
@ -145,7 +232,7 @@ func serveDocs(w http.ResponseWriter, r *http.Request) {
|
||||
</head>
|
||||
<body>
|
||||
<div id="swagger-ui"></div>
|
||||
<script src="https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js"></script>
|
||||
<script src="https://unpkg.com/swagger-ui-dist@5.18.2/swagger-ui-bundle.js" integrity="sha384-NXtFPpN61oWCuN4D42K6Zd5Rt2+uxeIT36R7kpXBuY9tLnZorzrJ4ykpqwJfgjpZ" crossorigin="anonymous"></script>
|
||||
<script>
|
||||
SwaggerUIBundle({
|
||||
url: "/openapi.yaml",
|
||||
|
||||
Reference in New Issue
Block a user