forked from wrenn/wrenn
v0.0.1 (#8)
Co-authored-by: Tasnim Kabir Sadik <tksadik92@gmail.com> Reviewed-on: wrenn/sandbox#8
This commit is contained in:
22
internal/api/agent_helper.go
Normal file
22
internal/api/agent_helper.go
Normal file
@ -0,0 +1,22 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
// agentForHost looks up the host record and returns a Connect RPC client for it.
|
||||
// Returns an error if the host is not found or has no address.
|
||||
func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, hostID pgtype.UUID) (hostagentv1connect.HostAgentServiceClient, error) {
|
||||
host, err := queries.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
return pool.GetForHost(host)
|
||||
}
|
||||
230
internal/api/handler_sandbox_proxy.go
Normal file
230
internal/api/handler_sandbox_proxy.go
Normal file
@ -0,0 +1,230 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"path"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
)
|
||||
|
||||
// Sentinel errors returned by proxyTarget, used to map to HTTP status codes
|
||||
// without relying on error message text.
|
||||
var (
|
||||
errProxySandboxNotFound = errors.New("sandbox not found")
|
||||
errProxyNoHostAddress = errors.New("host agent has no address")
|
||||
)
|
||||
|
||||
const proxyCacheTTL = 120 * time.Second
|
||||
|
||||
// sandboxHostPattern matches hostnames like "49999-cl-abcd1234.localhost" or
|
||||
// "49999-cl-abcd1234.example.com". Captures: port, sandbox ID.
|
||||
var sandboxHostPattern = regexp.MustCompile(`^(\d+)-(cl-[0-9a-z]+)\.`)
|
||||
|
||||
// errProxySandboxNotRunning carries the sandbox status so callers can include
|
||||
// it in the HTTP response without parsing error strings.
|
||||
type errProxySandboxNotRunning struct{ status string }
|
||||
|
||||
func (e errProxySandboxNotRunning) Error() string {
|
||||
return fmt.Sprintf("sandbox is not running (status: %s)", e.status)
|
||||
}
|
||||
|
||||
// proxyCacheEntry caches the resolved agent URL for a (sandbox, team) pair.
|
||||
// The *httputil.ReverseProxy is built per-request (cheap) so the Director closure
|
||||
// can capture the correct port without the cache key needing to include it.
|
||||
type proxyCacheEntry struct {
|
||||
agentURL *url.URL
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// proxyCacheKey is a fixed-size key from two UUIDs, avoids string allocation.
|
||||
type proxyCacheKey [32]byte
|
||||
|
||||
func makeProxyCacheKey(sandboxID, teamID pgtype.UUID) proxyCacheKey {
|
||||
var k proxyCacheKey
|
||||
copy(k[:16], sandboxID.Bytes[:])
|
||||
copy(k[16:], teamID.Bytes[:])
|
||||
return k
|
||||
}
|
||||
|
||||
// SandboxProxyWrapper wraps an existing HTTP handler and intercepts requests
|
||||
// whose Host header matches the {port}-{sandbox_id}.{domain} pattern. Matching
|
||||
// requests are reverse-proxied through the host agent that owns the sandbox.
|
||||
// All other requests are passed through to the inner handler.
|
||||
//
|
||||
// Authentication is via X-API-Key header only (no JWT). The API key's team
|
||||
// must own the sandbox.
|
||||
type SandboxProxyWrapper struct {
|
||||
inner http.Handler
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
transport http.RoundTripper
|
||||
|
||||
cacheMu sync.Mutex
|
||||
cache map[proxyCacheKey]proxyCacheEntry
|
||||
}
|
||||
|
||||
// NewSandboxProxyWrapper creates a new proxy wrapper.
|
||||
func NewSandboxProxyWrapper(inner http.Handler, queries *db.Queries, pool *lifecycle.HostClientPool) *SandboxProxyWrapper {
|
||||
return &SandboxProxyWrapper{
|
||||
inner: inner,
|
||||
db: queries,
|
||||
pool: pool,
|
||||
transport: pool.Transport(),
|
||||
cache: make(map[proxyCacheKey]proxyCacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
// proxyTarget looks up the cached agent URL for (sandboxID, teamID).
|
||||
// On a miss it queries the DB, resolves the address, and populates the cache.
|
||||
// The *httputil.ReverseProxy is built by the caller so the Director closure
|
||||
// captures the correct port without the cache key needing to include it.
|
||||
func (h *SandboxProxyWrapper) proxyTarget(ctx context.Context, sandboxID, teamID pgtype.UUID) (*url.URL, error) {
|
||||
cacheKey := makeProxyCacheKey(sandboxID, teamID)
|
||||
|
||||
h.cacheMu.Lock()
|
||||
entry, ok := h.cache[cacheKey]
|
||||
h.cacheMu.Unlock()
|
||||
|
||||
if ok && time.Now().Before(entry.expiresAt) {
|
||||
return entry.agentURL, nil
|
||||
}
|
||||
|
||||
// Cache miss or expired — query DB.
|
||||
target, err := h.db.GetSandboxProxyTarget(ctx, db.GetSandboxProxyTargetParams{
|
||||
ID: sandboxID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errProxySandboxNotFound
|
||||
}
|
||||
if target.Status != "running" {
|
||||
return nil, errProxySandboxNotRunning{status: target.Status}
|
||||
}
|
||||
if target.HostAddress == "" {
|
||||
return nil, errProxyNoHostAddress
|
||||
}
|
||||
|
||||
agentURL, err := url.Parse(h.pool.ResolveAddr(target.HostAddress))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid host agent address: %w", err)
|
||||
}
|
||||
|
||||
h.cacheMu.Lock()
|
||||
h.cache[cacheKey] = proxyCacheEntry{
|
||||
agentURL: agentURL,
|
||||
expiresAt: time.Now().Add(proxyCacheTTL),
|
||||
}
|
||||
h.cacheMu.Unlock()
|
||||
|
||||
return agentURL, nil
|
||||
}
|
||||
|
||||
// evictProxyCache removes the cached entry for a (sandbox, team) pair.
|
||||
// Called on 502 so a stopped/moved sandbox is re-resolved on the next request.
|
||||
func (h *SandboxProxyWrapper) evictProxyCache(sandboxID, teamID pgtype.UUID) {
|
||||
h.cacheMu.Lock()
|
||||
delete(h.cache, makeProxyCacheKey(sandboxID, teamID))
|
||||
h.cacheMu.Unlock()
|
||||
}
|
||||
|
||||
func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
host := r.Host
|
||||
// Strip port from Host header (e.g. "49999-cl-abcd1234.localhost:8000" → "49999-cl-abcd1234.localhost")
|
||||
if colonIdx := strings.LastIndex(host, ":"); colonIdx != -1 {
|
||||
host = host[:colonIdx]
|
||||
}
|
||||
|
||||
matches := sandboxHostPattern.FindStringSubmatch(host)
|
||||
if matches == nil {
|
||||
h.inner.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
port := matches[1]
|
||||
sandboxIDStr := matches[2]
|
||||
|
||||
// Validate port.
|
||||
portNum, err := strconv.Atoi(port)
|
||||
if err != nil || portNum < 1 || portNum > 65535 {
|
||||
http.Error(w, "invalid port", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Authenticate: require API key or JWT, extract team ID.
|
||||
teamID, err := h.authenticateRequest(r)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid sandbox ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
agentURL, err := h.proxyTarget(r.Context(), sandboxID, teamID)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, errProxySandboxNotFound):
|
||||
http.Error(w, err.Error(), http.StatusNotFound)
|
||||
case errors.As(err, new(errProxySandboxNotRunning)):
|
||||
http.Error(w, err.Error(), http.StatusConflict)
|
||||
default:
|
||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
proxy := &httputil.ReverseProxy{
|
||||
Transport: h.transport,
|
||||
Director: func(req *http.Request) {
|
||||
req.URL.Scheme = agentURL.Scheme
|
||||
req.URL.Host = agentURL.Host
|
||||
req.URL.Path = path.Join("/proxy", sandboxIDStr, port, path.Clean("/"+req.URL.Path))
|
||||
req.Host = agentURL.Host
|
||||
},
|
||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
slog.Debug("sandbox proxy error",
|
||||
"sandbox_id", sandboxIDStr,
|
||||
"port", port,
|
||||
"error", err,
|
||||
)
|
||||
h.evictProxyCache(sandboxID, teamID)
|
||||
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
|
||||
},
|
||||
}
|
||||
proxy.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// authenticateRequest validates the request's API key and returns the team ID.
|
||||
// Only API key authentication is supported for sandbox proxy requests (not JWT).
|
||||
func (h *SandboxProxyWrapper) authenticateRequest(r *http.Request) (pgtype.UUID, error) {
|
||||
key := r.Header.Get("X-API-Key")
|
||||
if key == "" {
|
||||
return pgtype.UUID{}, fmt.Errorf("X-API-Key header required")
|
||||
}
|
||||
|
||||
hash := auth.HashAPIKey(key)
|
||||
row, err := h.db.GetAPIKeyByHash(r.Context(), hash)
|
||||
if err != nil {
|
||||
return pgtype.UUID{}, fmt.Errorf("invalid API key")
|
||||
}
|
||||
return row.TeamID, nil
|
||||
}
|
||||
@ -6,17 +6,20 @@ import (
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/audit"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
)
|
||||
|
||||
type apiKeyHandler struct {
|
||||
svc *service.APIKeyService
|
||||
svc *service.APIKeyService
|
||||
audit *audit.AuditLogger
|
||||
}
|
||||
|
||||
func newAPIKeyHandler(svc *service.APIKeyService) *apiKeyHandler {
|
||||
return &apiKeyHandler{svc: svc}
|
||||
func newAPIKeyHandler(svc *service.APIKeyService, al *audit.AuditLogger) *apiKeyHandler {
|
||||
return &apiKeyHandler{svc: svc, audit: al}
|
||||
}
|
||||
|
||||
type createAPIKeyRequest struct {
|
||||
@ -37,11 +40,11 @@ type apiKeyResponse struct {
|
||||
|
||||
func apiKeyToResponse(k db.TeamApiKey) apiKeyResponse {
|
||||
resp := apiKeyResponse{
|
||||
ID: k.ID,
|
||||
TeamID: k.TeamID,
|
||||
ID: id.FormatAPIKeyID(k.ID),
|
||||
TeamID: id.FormatTeamID(k.TeamID),
|
||||
Name: k.Name,
|
||||
KeyPrefix: k.KeyPrefix,
|
||||
CreatedBy: k.CreatedBy,
|
||||
CreatedBy: id.FormatUserID(k.CreatedBy),
|
||||
}
|
||||
if k.CreatedAt.Valid {
|
||||
resp.CreatedAt = k.CreatedAt.Time.Format(time.RFC3339)
|
||||
@ -55,11 +58,11 @@ func apiKeyToResponse(k db.TeamApiKey) apiKeyResponse {
|
||||
|
||||
func apiKeyWithCreatorToResponse(k db.ListAPIKeysByTeamWithCreatorRow) apiKeyResponse {
|
||||
resp := apiKeyResponse{
|
||||
ID: k.ID,
|
||||
TeamID: k.TeamID,
|
||||
ID: id.FormatAPIKeyID(k.ID),
|
||||
TeamID: id.FormatTeamID(k.TeamID),
|
||||
Name: k.Name,
|
||||
KeyPrefix: k.KeyPrefix,
|
||||
CreatedBy: k.CreatedBy,
|
||||
CreatedBy: id.FormatUserID(k.CreatedBy),
|
||||
CreatorEmail: k.CreatorEmail,
|
||||
}
|
||||
if k.CreatedAt.Valid {
|
||||
@ -91,6 +94,7 @@ func (h *apiKeyHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
resp := apiKeyToResponse(result.Row)
|
||||
resp.Key = &result.Plaintext
|
||||
|
||||
h.audit.LogAPIKeyCreate(r.Context(), ac, result.Row.ID, result.Row.Name)
|
||||
writeJSON(w, http.StatusCreated, resp)
|
||||
}
|
||||
|
||||
@ -115,12 +119,19 @@ func (h *apiKeyHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
// Delete handles DELETE /v1/api-keys/{id}.
|
||||
func (h *apiKeyHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
keyID := chi.URLParam(r, "id")
|
||||
keyIDStr := chi.URLParam(r, "id")
|
||||
|
||||
keyID, err := id.ParseAPIKeyID(keyIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid API key ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.Delete(r.Context(), keyID, ac.TeamID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete API key")
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogAPIKeyRevoke(r.Context(), ac, keyID)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
148
internal/api/handlers_audit.go
Normal file
148
internal/api/handlers_audit.go
Normal file
@ -0,0 +1,148 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
)
|
||||
|
||||
type auditHandler struct {
|
||||
svc *service.AuditService
|
||||
}
|
||||
|
||||
func newAuditHandler(svc *service.AuditService) *auditHandler {
|
||||
return &auditHandler{svc: svc}
|
||||
}
|
||||
|
||||
type auditLogResponse struct {
|
||||
ID string `json:"id"`
|
||||
ActorType string `json:"actor_type"`
|
||||
ActorID string `json:"actor_id,omitempty"`
|
||||
ActorName string `json:"actor_name,omitempty"`
|
||||
ResourceType string `json:"resource_type"`
|
||||
ResourceID string `json:"resource_id,omitempty"`
|
||||
Action string `json:"action"`
|
||||
Scope string `json:"scope"`
|
||||
Status string `json:"status"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// List handles GET /v1/audit-logs.
|
||||
// Query params:
|
||||
// - before: RFC3339 timestamp cursor (exclusive); omit to start from latest
|
||||
// - limit: page size, default 50, max 200
|
||||
// - resource_type: filter by resource type (sandbox, snapshot, team, api_key, member, host)
|
||||
// - action: filter by action verb
|
||||
//
|
||||
// Members see only team-scoped events; admins/owners see all.
|
||||
func (h *auditHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
// Parse ?before cursor.
|
||||
var before time.Time
|
||||
if s := r.URL.Query().Get("before"); s != "" {
|
||||
var err error
|
||||
before, err = time.Parse(time.RFC3339, s)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "before must be an RFC3339 timestamp")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Parse ?limit.
|
||||
limit := 50
|
||||
if s := r.URL.Query().Get("limit"); s != "" {
|
||||
n, err := strconv.Atoi(s)
|
||||
if err != nil || n < 1 {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "limit must be a positive integer")
|
||||
return
|
||||
}
|
||||
limit = n
|
||||
}
|
||||
|
||||
// Parse ?before_id cursor (UUID).
|
||||
var beforeID pgtype.UUID
|
||||
if s := r.URL.Query().Get("before_id"); s != "" {
|
||||
parsed, err := id.ParseAuditLogID(s)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "before_id must be a valid audit log ID")
|
||||
return
|
||||
}
|
||||
beforeID = parsed
|
||||
}
|
||||
|
||||
entries, err := h.svc.List(r.Context(), service.AuditListParams{
|
||||
TeamID: ac.TeamID,
|
||||
AdminScoped: ac.Role == "owner" || ac.Role == "admin",
|
||||
ResourceTypes: parseMultiParam(r.URL.Query()["resource_type"]),
|
||||
Actions: parseMultiParam(r.URL.Query()["action"]),
|
||||
Before: before,
|
||||
BeforeID: beforeID,
|
||||
Limit: limit,
|
||||
})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list audit logs")
|
||||
return
|
||||
}
|
||||
|
||||
items := make([]auditLogResponse, len(entries))
|
||||
for i, e := range entries {
|
||||
items[i] = auditLogResponse{
|
||||
ID: e.ID,
|
||||
ActorType: e.ActorType,
|
||||
ActorID: e.ActorID,
|
||||
ActorName: e.ActorName,
|
||||
ResourceType: e.ResourceType,
|
||||
ResourceID: e.ResourceID,
|
||||
Action: e.Action,
|
||||
Scope: e.Scope,
|
||||
Status: e.Status,
|
||||
Metadata: e.Metadata,
|
||||
CreatedAt: e.CreatedAt.UTC().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
resp := map[string]any{"items": items}
|
||||
if len(items) > 0 {
|
||||
last := entries[len(entries)-1]
|
||||
resp["next_before"] = last.CreatedAt.UTC().Format(time.RFC3339)
|
||||
resp["next_before_id"] = last.ID
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// parseMultiParam flattens repeated params and comma-separated values into a
|
||||
// single deduplicated slice. Empty strings are dropped. Returns nil (no filter)
|
||||
// when no values are present.
|
||||
//
|
||||
// Both ?resource_type=sandbox&resource_type=snapshot
|
||||
// and ?resource_type=sandbox,snapshot are accepted.
|
||||
func parseMultiParam(values []string) []string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{})
|
||||
var out []string
|
||||
for _, v := range values {
|
||||
for _, part := range strings.Split(v, ",") {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[part]; !ok {
|
||||
seen[part] = struct{}{}
|
||||
out = append(out, part)
|
||||
}
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@ -1,7 +1,9 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
@ -15,6 +17,45 @@ import (
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
// loginTeam returns the team and role to stamp into a login JWT.
|
||||
// It prefers the user's default team; if none is flagged as default it falls
|
||||
// back to the earliest-joined team. Returns pgx.ErrNoRows when the user has
|
||||
// no team memberships at all.
|
||||
func loginTeam(ctx context.Context, q *db.Queries, userID pgtype.UUID) (db.Team, string, error) {
|
||||
team, err := q.GetDefaultTeamForUser(ctx, userID)
|
||||
if err == nil {
|
||||
membership, err := q.GetTeamMembership(ctx, db.GetTeamMembershipParams{UserID: userID, TeamID: team.ID})
|
||||
if err != nil {
|
||||
return db.Team{}, "", err
|
||||
}
|
||||
return team, membership.Role, nil
|
||||
}
|
||||
if !errors.Is(err, pgx.ErrNoRows) {
|
||||
return db.Team{}, "", err
|
||||
}
|
||||
// No default set — fall back to earliest-joined team.
|
||||
rows, err := q.GetTeamsForUser(ctx, userID)
|
||||
if err != nil {
|
||||
return db.Team{}, "", err
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
return db.Team{}, "", pgx.ErrNoRows
|
||||
}
|
||||
first := rows[0]
|
||||
return db.Team{
|
||||
ID: first.ID,
|
||||
Name: first.Name,
|
||||
Slug: first.Slug,
|
||||
IsByoc: first.IsByoc,
|
||||
CreatedAt: first.CreatedAt,
|
||||
DeletedAt: first.DeletedAt,
|
||||
}, first.Role, nil
|
||||
}
|
||||
|
||||
type switchTeamRequest struct {
|
||||
TeamID string `json:"team_id"`
|
||||
}
|
||||
|
||||
type authHandler struct {
|
||||
db *db.Queries
|
||||
pool *pgxpool.Pool
|
||||
@ -28,6 +69,7 @@ func newAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte) *authH
|
||||
type signupRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type loginRequest struct {
|
||||
@ -40,6 +82,7 @@ type authResponse struct {
|
||||
UserID string `json:"user_id"`
|
||||
TeamID string `json:"team_id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Signup handles POST /v1/auth/signup.
|
||||
@ -51,6 +94,7 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
req.Email = strings.TrimSpace(strings.ToLower(req.Email))
|
||||
req.Name = strings.TrimSpace(req.Name)
|
||||
if !strings.Contains(req.Email, "@") || len(req.Email) < 3 {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid email address")
|
||||
return
|
||||
@ -59,6 +103,10 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "password must be at least 8 characters")
|
||||
return
|
||||
}
|
||||
if req.Name == "" || len(req.Name) > 100 {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "name must be between 1 and 100 characters")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
@ -83,6 +131,7 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
ID: userID,
|
||||
Email: req.Email,
|
||||
PasswordHash: pgtype.Text{String: passwordHash, Valid: true},
|
||||
Name: req.Name,
|
||||
})
|
||||
if err != nil {
|
||||
var pgErr *pgconn.PgError
|
||||
@ -98,7 +147,8 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
teamID := id.NewTeamID()
|
||||
if _, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
|
||||
ID: teamID,
|
||||
Name: req.Email + "'s Team",
|
||||
Name: req.Name + "'s Team",
|
||||
Slug: id.NewTeamSlug(),
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to create team")
|
||||
return
|
||||
@ -119,7 +169,7 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, req.Email)
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, req.Email, req.Name, "owner", false)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
return
|
||||
@ -127,9 +177,10 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
writeJSON(w, http.StatusCreated, authResponse{
|
||||
Token: token,
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
UserID: id.FormatUserID(userID),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Email: req.Email,
|
||||
Name: req.Name,
|
||||
})
|
||||
}
|
||||
|
||||
@ -152,6 +203,7 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
user, err := h.db.GetUserByEmail(ctx, req.Email)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Warn("login failed: unknown email", "email", req.Email, "ip", r.RemoteAddr)
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
|
||||
return
|
||||
}
|
||||
@ -160,21 +212,27 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if !user.PasswordHash.Valid {
|
||||
slog.Warn("login failed: no password set", "email", req.Email, "ip", r.RemoteAddr)
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
|
||||
return
|
||||
}
|
||||
if err := auth.CheckPassword(user.PasswordHash.String, req.Password); err != nil {
|
||||
slog.Warn("login failed: wrong password", "email", req.Email, "ip", r.RemoteAddr)
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
|
||||
return
|
||||
}
|
||||
|
||||
team, err := h.db.GetDefaultTeamForUser(ctx, user.ID)
|
||||
team, role, err := loginTeam(ctx, h.db, user.ID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
writeError(w, http.StatusForbidden, "no_team", "user is not a member of any team")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up team")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email)
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
return
|
||||
@ -182,8 +240,85 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
writeJSON(w, http.StatusOK, authResponse{
|
||||
Token: token,
|
||||
UserID: user.ID,
|
||||
TeamID: team.ID,
|
||||
UserID: id.FormatUserID(user.ID),
|
||||
TeamID: id.FormatTeamID(team.ID),
|
||||
Email: user.Email,
|
||||
Name: user.Name,
|
||||
})
|
||||
}
|
||||
|
||||
// SwitchTeam handles POST /v1/auth/switch-team.
|
||||
// Verifies from DB that the user is a member of the target team, then re-issues
|
||||
// a JWT scoped to that team. The JWT's team_id is used as a pre-filter on all
|
||||
// subsequent team-scoped requests; DB is the source of truth for actual permissions.
|
||||
func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
var req switchTeamRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
if req.TeamID == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "team_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
teamID, err := id.ParseTeamID(req.TeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team_id")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
// Verify team exists and is not deleted.
|
||||
team, err := h.db.GetTeam(ctx, teamID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
writeError(w, http.StatusNotFound, "not_found", "team not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up team")
|
||||
return
|
||||
}
|
||||
if team.DeletedAt.Valid {
|
||||
writeError(w, http.StatusNotFound, "not_found", "team not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify membership from DB — JWT role is not trusted here.
|
||||
membership, err := h.db.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: ac.UserID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "not a member of this team")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up membership")
|
||||
return
|
||||
}
|
||||
|
||||
// Fetch current name from DB — JWT name is not trusted here (may be stale or empty for old tokens).
|
||||
user, err := h.db.GetUserByID(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up user")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, ac.UserID, teamID, ac.Email, user.Name, membership.Role, user.IsAdmin)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, authResponse{
|
||||
Token: token,
|
||||
UserID: id.FormatUserID(ac.UserID),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Email: ac.Email,
|
||||
Name: user.Name,
|
||||
})
|
||||
}
|
||||
|
||||
276
internal/api/handlers_builds.go
Normal file
276
internal/api/handlers_builds.go
Normal file
@ -0,0 +1,276 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/layout"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/validate"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
type buildHandler struct {
|
||||
svc *service.BuildService
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
}
|
||||
|
||||
func newBuildHandler(svc *service.BuildService, db *db.Queries, pool *lifecycle.HostClientPool) *buildHandler {
|
||||
return &buildHandler{svc: svc, db: db, pool: pool}
|
||||
}
|
||||
|
||||
type createBuildRequest struct {
|
||||
Name string `json:"name"`
|
||||
BaseTemplate string `json:"base_template"`
|
||||
Recipe []string `json:"recipe"`
|
||||
Healthcheck string `json:"healthcheck"`
|
||||
VCPUs int32 `json:"vcpus"`
|
||||
MemoryMB int32 `json:"memory_mb"`
|
||||
SkipPrePost bool `json:"skip_pre_post"`
|
||||
}
|
||||
|
||||
type buildResponse struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
BaseTemplate string `json:"base_template"`
|
||||
Recipe json.RawMessage `json:"recipe"`
|
||||
Healthcheck *string `json:"healthcheck,omitempty"`
|
||||
VCPUs int32 `json:"vcpus"`
|
||||
MemoryMB int32 `json:"memory_mb"`
|
||||
Status string `json:"status"`
|
||||
CurrentStep int32 `json:"current_step"`
|
||||
TotalSteps int32 `json:"total_steps"`
|
||||
Logs json.RawMessage `json:"logs"`
|
||||
Error *string `json:"error,omitempty"`
|
||||
SandboxID *string `json:"sandbox_id,omitempty"`
|
||||
HostID *string `json:"host_id,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
StartedAt *string `json:"started_at,omitempty"`
|
||||
CompletedAt *string `json:"completed_at,omitempty"`
|
||||
}
|
||||
|
||||
func buildToResponse(b db.TemplateBuild) buildResponse {
|
||||
resp := buildResponse{
|
||||
ID: id.FormatBuildID(b.ID),
|
||||
Name: b.Name,
|
||||
BaseTemplate: b.BaseTemplate,
|
||||
Recipe: b.Recipe,
|
||||
VCPUs: b.Vcpus,
|
||||
MemoryMB: b.MemoryMb,
|
||||
Status: b.Status,
|
||||
CurrentStep: b.CurrentStep,
|
||||
TotalSteps: b.TotalSteps,
|
||||
Logs: b.Logs,
|
||||
}
|
||||
if b.Healthcheck != "" {
|
||||
resp.Healthcheck = &b.Healthcheck
|
||||
}
|
||||
if b.Error != "" {
|
||||
resp.Error = &b.Error
|
||||
}
|
||||
if b.SandboxID.Valid {
|
||||
s := id.FormatSandboxID(b.SandboxID)
|
||||
resp.SandboxID = &s
|
||||
}
|
||||
if b.HostID.Valid {
|
||||
s := id.FormatHostID(b.HostID)
|
||||
resp.HostID = &s
|
||||
}
|
||||
if b.CreatedAt.Valid {
|
||||
resp.CreatedAt = b.CreatedAt.Time.Format(time.RFC3339)
|
||||
}
|
||||
if b.StartedAt.Valid {
|
||||
s := b.StartedAt.Time.Format(time.RFC3339)
|
||||
resp.StartedAt = &s
|
||||
}
|
||||
if b.CompletedAt.Valid {
|
||||
s := b.CompletedAt.Time.Format(time.RFC3339)
|
||||
resp.CompletedAt = &s
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// Create handles POST /v1/admin/builds.
|
||||
func (h *buildHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
var req createBuildRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "name is required")
|
||||
return
|
||||
}
|
||||
if err := validate.SafeName(req.Name); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid template name: %s", err))
|
||||
return
|
||||
}
|
||||
if len(req.Recipe) == 0 {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "recipe must contain at least one command")
|
||||
return
|
||||
}
|
||||
|
||||
build, err := h.svc.Create(r.Context(), service.BuildCreateParams{
|
||||
Name: req.Name,
|
||||
BaseTemplate: req.BaseTemplate,
|
||||
Recipe: req.Recipe,
|
||||
Healthcheck: req.Healthcheck,
|
||||
VCPUs: req.VCPUs,
|
||||
MemoryMB: req.MemoryMB,
|
||||
SkipPrePost: req.SkipPrePost,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("failed to create build", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "build_error", "failed to create build")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, buildToResponse(build))
|
||||
}
|
||||
|
||||
// List handles GET /v1/admin/builds.
|
||||
func (h *buildHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
builds, err := h.svc.List(r.Context())
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list builds")
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]buildResponse, len(builds))
|
||||
for i, b := range builds {
|
||||
resp[i] = buildToResponse(b)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Get handles GET /v1/admin/builds/{id}.
|
||||
func (h *buildHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
buildIDStr := chi.URLParam(r, "id")
|
||||
|
||||
buildID, err := id.ParseBuildID(buildIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid build ID")
|
||||
return
|
||||
}
|
||||
|
||||
build, err := h.svc.Get(r.Context(), buildID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "build not found")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, buildToResponse(build))
|
||||
}
|
||||
|
||||
// ListTemplates handles GET /v1/admin/templates — returns all templates across all teams.
|
||||
func (h *buildHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
|
||||
templates, err := h.db.ListTemplates(r.Context())
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list templates")
|
||||
return
|
||||
}
|
||||
|
||||
type templateResponse struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
VCPUs int32 `json:"vcpus"`
|
||||
MemoryMB int32 `json:"memory_mb"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
TeamID string `json:"team_id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
resp := make([]templateResponse, len(templates))
|
||||
for i, t := range templates {
|
||||
resp[i] = templateResponse{
|
||||
Name: t.Name,
|
||||
Type: t.Type,
|
||||
VCPUs: t.Vcpus,
|
||||
MemoryMB: t.MemoryMb,
|
||||
SizeBytes: t.SizeBytes,
|
||||
TeamID: id.FormatTeamID(t.TeamID),
|
||||
}
|
||||
if t.CreatedAt.Valid {
|
||||
resp[i].CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// DeleteTemplate handles DELETE /v1/admin/templates/{name}.
|
||||
func (h *buildHandler) DeleteTemplate(w http.ResponseWriter, r *http.Request) {
|
||||
name := chi.URLParam(r, "name")
|
||||
if err := validate.SafeName(name); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid template name: %s", err))
|
||||
return
|
||||
}
|
||||
ctx := r.Context()
|
||||
|
||||
tmpl, err := h.db.GetPlatformTemplateByName(ctx, name)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "template not found")
|
||||
return
|
||||
}
|
||||
if layout.IsMinimal(tmpl.TeamID, tmpl.ID) {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted")
|
||||
return
|
||||
}
|
||||
|
||||
// Broadcast delete to all online hosts.
|
||||
hosts, _ := h.db.ListActiveHosts(ctx)
|
||||
for _, host := range hosts {
|
||||
if host.Status != "online" {
|
||||
continue
|
||||
}
|
||||
agent, err := h.pool.GetForHost(host)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
|
||||
TeamId: formatUUIDForRPC(tmpl.TeamID),
|
||||
TemplateId: formatUUIDForRPC(tmpl.ID),
|
||||
})); err != nil {
|
||||
if connect.CodeOf(err) != connect.CodeNotFound {
|
||||
slog.Warn("admin: failed to delete template on host", "host_id", id.FormatHostID(host.ID), "name", name, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.db.DeleteTemplate(ctx, tmpl.ID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete template record")
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// Cancel handles POST /v1/admin/builds/{id}/cancel.
|
||||
func (h *buildHandler) Cancel(w http.ResponseWriter, r *http.Request) {
|
||||
buildIDStr := chi.URLParam(r, "id")
|
||||
|
||||
buildID, err := id.ParseBuildID(buildIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid build ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.Cancel(r.Context(), buildID); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
242
internal/api/handlers_channels.go
Normal file
242
internal/api/handlers_channels.go
Normal file
@ -0,0 +1,242 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/audit"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/channels"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
type channelHandler struct {
|
||||
svc *channels.Service
|
||||
audit *audit.AuditLogger
|
||||
}
|
||||
|
||||
func newChannelHandler(svc *channels.Service, al *audit.AuditLogger) *channelHandler {
|
||||
return &channelHandler{svc: svc, audit: al}
|
||||
}
|
||||
|
||||
type createChannelRequest struct {
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Config map[string]string `json:"config"`
|
||||
Events []string `json:"events"`
|
||||
}
|
||||
|
||||
type updateChannelRequest struct {
|
||||
Name string `json:"name"`
|
||||
Events []string `json:"events"`
|
||||
}
|
||||
|
||||
type rotateConfigRequest struct {
|
||||
Config map[string]string `json:"config"`
|
||||
}
|
||||
|
||||
type testChannelRequest struct {
|
||||
Provider string `json:"provider"`
|
||||
Config map[string]string `json:"config"`
|
||||
}
|
||||
|
||||
type channelResponse struct {
|
||||
ID string `json:"id"`
|
||||
TeamID string `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Events []string `json:"events"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
Secret *string `json:"secret,omitempty"`
|
||||
}
|
||||
|
||||
func channelToResponse(ch db.Channel) channelResponse {
|
||||
resp := channelResponse{
|
||||
ID: id.FormatChannelID(ch.ID),
|
||||
TeamID: id.FormatTeamID(ch.TeamID),
|
||||
Name: ch.Name,
|
||||
Provider: ch.Provider,
|
||||
Events: ch.EventTypes,
|
||||
}
|
||||
if ch.CreatedAt.Valid {
|
||||
resp.CreatedAt = ch.CreatedAt.Time.Format(time.RFC3339)
|
||||
}
|
||||
if ch.UpdatedAt.Valid {
|
||||
resp.UpdatedAt = ch.UpdatedAt.Time.Format(time.RFC3339)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// Create handles POST /v1/channels.
|
||||
func (h *channelHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
var req createChannelRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.svc.Create(r.Context(), channels.CreateParams{
|
||||
TeamID: ac.TeamID,
|
||||
Name: req.Name,
|
||||
Provider: req.Provider,
|
||||
Config: req.Config,
|
||||
Events: req.Events,
|
||||
})
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogChannelCreate(r.Context(), ac, result.Channel.ID, result.Channel.Name, result.Channel.Provider)
|
||||
|
||||
resp := channelToResponse(result.Channel)
|
||||
if result.PlaintextSecret != "" {
|
||||
resp.Secret = &result.PlaintextSecret
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, resp)
|
||||
}
|
||||
|
||||
// List handles GET /v1/channels.
|
||||
func (h *channelHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
chs, err := h.svc.List(r.Context(), ac.TeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list channels")
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]channelResponse, len(chs))
|
||||
for i, ch := range chs {
|
||||
resp[i] = channelToResponse(ch)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Get handles GET /v1/channels/{id}.
|
||||
func (h *channelHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
channelIDStr := chi.URLParam(r, "id")
|
||||
|
||||
channelID, err := id.ParseChannelID(channelIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid channel ID")
|
||||
return
|
||||
}
|
||||
|
||||
ch, err := h.svc.Get(r.Context(), channelID, ac.TeamID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
writeError(w, http.StatusNotFound, "not_found", "channel not found")
|
||||
} else {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get channel")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, channelToResponse(ch))
|
||||
}
|
||||
|
||||
// Update handles PATCH /v1/channels/{id}.
|
||||
func (h *channelHandler) Update(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
channelIDStr := chi.URLParam(r, "id")
|
||||
|
||||
channelID, err := id.ParseChannelID(channelIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid channel ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req updateChannelRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
ch, err := h.svc.Update(r.Context(), channelID, ac.TeamID, req.Name, req.Events)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogChannelUpdate(r.Context(), ac, channelID)
|
||||
writeJSON(w, http.StatusOK, channelToResponse(ch))
|
||||
}
|
||||
|
||||
// Test handles POST /v1/channels/test.
|
||||
func (h *channelHandler) Test(w http.ResponseWriter, r *http.Request) {
|
||||
var req testChannelRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.Test(r.Context(), req.Provider, req.Config); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
|
||||
}
|
||||
|
||||
// RotateConfig handles PUT /v1/channels/{id}/config.
|
||||
func (h *channelHandler) RotateConfig(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
channelIDStr := chi.URLParam(r, "id")
|
||||
|
||||
channelID, err := id.ParseChannelID(channelIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid channel ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req rotateConfigRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
ch, err := h.svc.RotateConfig(r.Context(), channelID, ac.TeamID, req.Config)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogChannelRotateConfig(r.Context(), ac, channelID)
|
||||
writeJSON(w, http.StatusOK, channelToResponse(ch))
|
||||
}
|
||||
|
||||
// Delete handles DELETE /v1/channels/{id}.
|
||||
func (h *channelHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
channelIDStr := chi.URLParam(r, "id")
|
||||
|
||||
channelID, err := id.ParseChannelID(channelIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid channel ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.Delete(r.Context(), channelID, ac.TeamID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete channel")
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogChannelDelete(r.Context(), ac, channelID)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
@ -14,17 +14,18 @@ import (
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
type execHandler struct {
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
}
|
||||
|
||||
func newExecHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *execHandler {
|
||||
return &execHandler{db: db, agent: agent}
|
||||
func newExecHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execHandler {
|
||||
return &execHandler{db: db, pool: pool}
|
||||
}
|
||||
|
||||
type execRequest struct {
|
||||
@ -46,10 +47,16 @@ type execResponse struct {
|
||||
|
||||
// Exec handles POST /v1/sandboxes/{id}/exec.
|
||||
func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
@ -73,8 +80,14 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
start := time.Now()
|
||||
|
||||
resp, err := h.agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
|
||||
SandboxId: sandboxID,
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Cmd: req.Cmd,
|
||||
Args: req.Args,
|
||||
TimeoutSec: req.TimeoutSec,
|
||||
@ -95,7 +108,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last_active_at", "id", sandboxID, "error", err)
|
||||
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
|
||||
}
|
||||
|
||||
// Use base64 encoding if output contains non-UTF-8 bytes.
|
||||
@ -106,7 +119,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||
if !utf8.Valid(stdout) || !utf8.Valid(stderr) {
|
||||
encoding = "base64"
|
||||
writeJSON(w, http.StatusOK, execResponse{
|
||||
SandboxID: sandboxID,
|
||||
SandboxID: sandboxIDStr,
|
||||
Cmd: req.Cmd,
|
||||
Stdout: base64.StdEncoding.EncodeToString(stdout),
|
||||
Stderr: base64.StdEncoding.EncodeToString(stderr),
|
||||
@ -118,7 +131,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, execResponse{
|
||||
SandboxID: sandboxID,
|
||||
SandboxID: sandboxIDStr,
|
||||
Cmd: req.Cmd,
|
||||
Stdout: string(stdout),
|
||||
Stderr: string(stderr),
|
||||
|
||||
@ -14,17 +14,18 @@ import (
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
type execStreamHandler struct {
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
}
|
||||
|
||||
func newExecStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *execStreamHandler {
|
||||
return &execStreamHandler{db: db, agent: agent}
|
||||
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execStreamHandler {
|
||||
return &execStreamHandler{db: db, pool: pool}
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
@ -48,10 +49,16 @@ type wsOutMsg struct {
|
||||
|
||||
// ExecStream handles WS /v1/sandboxes/{id}/exec/stream.
|
||||
func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
@ -80,12 +87,18 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
sendWSError(conn, "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
// Open streaming exec to host agent.
|
||||
streamCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
stream, err := h.agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{
|
||||
SandboxId: sandboxID,
|
||||
stream, err := agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Cmd: startMsg.Cmd,
|
||||
Args: startMsg.Args,
|
||||
}))
|
||||
@ -151,7 +164,7 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxID, "error", err)
|
||||
slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -11,17 +11,18 @@ import (
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
type filesHandler struct {
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
}
|
||||
|
||||
func newFilesHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *filesHandler {
|
||||
return &filesHandler{db: db, agent: agent}
|
||||
func newFilesHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesHandler {
|
||||
return &filesHandler{db: db, pool: pool}
|
||||
}
|
||||
|
||||
// Upload handles POST /v1/sandboxes/{id}/files/write.
|
||||
@ -29,10 +30,16 @@ func newFilesHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceCl
|
||||
// - "path" text field: absolute destination path inside the sandbox
|
||||
// - "file" file field: binary content to write
|
||||
func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
@ -75,8 +82,14 @@ func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := h.agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{
|
||||
SandboxId: sandboxID,
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Path: filePath,
|
||||
Content: content,
|
||||
})); err != nil {
|
||||
@ -95,10 +108,16 @@ type readFileRequest struct {
|
||||
// Download handles POST /v1/sandboxes/{id}/files/read.
|
||||
// Accepts JSON body with path, returns raw file content with Content-Disposition.
|
||||
func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
@ -120,8 +139,14 @@ func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.agent.ReadFile(ctx, connect.NewRequest(&pb.ReadFileRequest{
|
||||
SandboxId: sandboxID,
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := agent.ReadFile(ctx, connect.NewRequest(&pb.ReadFileRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Path: req.Path,
|
||||
}))
|
||||
if err != nil {
|
||||
|
||||
@ -12,27 +12,34 @@ import (
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
type filesStreamHandler struct {
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
}
|
||||
|
||||
func newFilesStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *filesStreamHandler {
|
||||
return &filesStreamHandler{db: db, agent: agent}
|
||||
func newFilesStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesStreamHandler {
|
||||
return &filesStreamHandler{db: db, pool: pool}
|
||||
}
|
||||
|
||||
// StreamUpload handles POST /v1/sandboxes/{id}/files/stream/write.
|
||||
// Expects multipart/form-data with "path" text field and "file" file field.
|
||||
// Streams file content directly from the request body to the host agent without buffering.
|
||||
func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
@ -88,14 +95,20 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
defer filePart.Close()
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
// Open client-streaming RPC to host agent.
|
||||
stream := h.agent.WriteFileStream(ctx)
|
||||
stream := agent.WriteFileStream(ctx)
|
||||
|
||||
// Send metadata first.
|
||||
if err := stream.Send(&pb.WriteFileStreamRequest{
|
||||
Content: &pb.WriteFileStreamRequest_Meta{
|
||||
Meta: &pb.WriteFileStreamMeta{
|
||||
SandboxId: sandboxID,
|
||||
SandboxId: sandboxIDStr,
|
||||
Path: filePath,
|
||||
},
|
||||
},
|
||||
@ -140,10 +153,16 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
|
||||
// StreamDownload handles POST /v1/sandboxes/{id}/files/stream/read.
|
||||
// Accepts JSON body with path, streams file content back without buffering.
|
||||
func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
@ -164,9 +183,15 @@ func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Reque
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
// Open server-streaming RPC to host agent.
|
||||
stream, err := h.agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{
|
||||
SandboxId: sandboxID,
|
||||
stream, err := agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Path: req.Path,
|
||||
}))
|
||||
if err != nil {
|
||||
|
||||
@ -1,23 +1,30 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/audit"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
)
|
||||
|
||||
type hostHandler struct {
|
||||
svc *service.HostService
|
||||
queries *db.Queries
|
||||
audit *audit.AuditLogger
|
||||
}
|
||||
|
||||
func newHostHandler(svc *service.HostService, queries *db.Queries) *hostHandler {
|
||||
return &hostHandler{svc: svc, queries: queries}
|
||||
func newHostHandler(svc *service.HostService, queries *db.Queries, al *audit.AuditLogger) *hostHandler {
|
||||
return &hostHandler{svc: svc, queries: queries, audit: al}
|
||||
}
|
||||
|
||||
// Request/response types.
|
||||
@ -34,6 +41,24 @@ type createHostResponse struct {
|
||||
RegistrationToken string `json:"registration_token"`
|
||||
}
|
||||
|
||||
type refreshTokenRequest struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
type refreshTokenResponse struct {
|
||||
Host hostResponse `json:"host"`
|
||||
Token string `json:"token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
CertPEM string `json:"cert_pem,omitempty"`
|
||||
KeyPEM string `json:"key_pem,omitempty"`
|
||||
CACertPEM string `json:"ca_cert_pem,omitempty"`
|
||||
}
|
||||
|
||||
type deletePreviewResponse struct {
|
||||
Host hostResponse `json:"host"`
|
||||
SandboxIDs []string `json:"sandbox_ids"`
|
||||
}
|
||||
|
||||
type registerHostRequest struct {
|
||||
Token string `json:"token"`
|
||||
Arch string `json:"arch,omitempty"`
|
||||
@ -44,8 +69,12 @@ type registerHostRequest struct {
|
||||
}
|
||||
|
||||
type registerHostResponse struct {
|
||||
Host hostResponse `json:"host"`
|
||||
Token string `json:"token"`
|
||||
Host hostResponse `json:"host"`
|
||||
Token string `json:"token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
CertPEM string `json:"cert_pem,omitempty"`
|
||||
KeyPEM string `json:"key_pem,omitempty"`
|
||||
CACertPEM string `json:"ca_cert_pem,omitempty"`
|
||||
}
|
||||
|
||||
type addTagRequest struct {
|
||||
@ -56,6 +85,7 @@ type hostResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
TeamID *string `json:"team_id,omitempty"`
|
||||
TeamName *string `json:"team_name,omitempty"`
|
||||
Provider *string `json:"provider,omitempty"`
|
||||
AvailabilityZone *string `json:"availability_zone,omitempty"`
|
||||
Arch *string `json:"arch,omitempty"`
|
||||
@ -72,34 +102,35 @@ type hostResponse struct {
|
||||
|
||||
func hostToResponse(h db.Host) hostResponse {
|
||||
resp := hostResponse{
|
||||
ID: h.ID,
|
||||
ID: id.FormatHostID(h.ID),
|
||||
Type: h.Type,
|
||||
Status: h.Status,
|
||||
CreatedBy: h.CreatedBy,
|
||||
CreatedBy: id.FormatUserID(h.CreatedBy),
|
||||
}
|
||||
if h.TeamID.Valid {
|
||||
resp.TeamID = &h.TeamID.String
|
||||
s := id.FormatTeamID(h.TeamID)
|
||||
resp.TeamID = &s
|
||||
}
|
||||
if h.Provider.Valid {
|
||||
resp.Provider = &h.Provider.String
|
||||
if h.Provider != "" {
|
||||
resp.Provider = &h.Provider
|
||||
}
|
||||
if h.AvailabilityZone.Valid {
|
||||
resp.AvailabilityZone = &h.AvailabilityZone.String
|
||||
if h.AvailabilityZone != "" {
|
||||
resp.AvailabilityZone = &h.AvailabilityZone
|
||||
}
|
||||
if h.Arch.Valid {
|
||||
resp.Arch = &h.Arch.String
|
||||
if h.Arch != "" {
|
||||
resp.Arch = &h.Arch
|
||||
}
|
||||
if h.CpuCores.Valid {
|
||||
resp.CPUCores = &h.CpuCores.Int32
|
||||
if h.CpuCores != 0 {
|
||||
resp.CPUCores = &h.CpuCores
|
||||
}
|
||||
if h.MemoryMb.Valid {
|
||||
resp.MemoryMB = &h.MemoryMb.Int32
|
||||
if h.MemoryMb != 0 {
|
||||
resp.MemoryMB = &h.MemoryMb
|
||||
}
|
||||
if h.DiskGb.Valid {
|
||||
resp.DiskGB = &h.DiskGb.Int32
|
||||
if h.DiskGb != 0 {
|
||||
resp.DiskGB = &h.DiskGb
|
||||
}
|
||||
if h.Address.Valid {
|
||||
resp.Address = &h.Address.String
|
||||
if h.Address != "" {
|
||||
resp.Address = &h.Address
|
||||
}
|
||||
if h.LastHeartbeatAt.Valid {
|
||||
s := h.LastHeartbeatAt.Time.Format(time.RFC3339)
|
||||
@ -112,7 +143,7 @@ func hostToResponse(h db.Host) hostResponse {
|
||||
}
|
||||
|
||||
// isAdmin fetches the user record and returns whether they are an admin.
|
||||
func (h *hostHandler) isAdmin(r *http.Request, userID string) bool {
|
||||
func (h *hostHandler) isAdmin(r *http.Request, userID pgtype.UUID) bool {
|
||||
user, err := h.queries.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
return false
|
||||
@ -130,20 +161,32 @@ func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
result, err := h.svc.Create(r.Context(), service.HostCreateParams{
|
||||
Type: req.Type,
|
||||
TeamID: req.TeamID,
|
||||
Provider: req.Provider,
|
||||
AvailabilityZone: req.AvailabilityZone,
|
||||
RequestingUserID: ac.UserID,
|
||||
IsRequestorAdmin: h.isAdmin(r, ac.UserID),
|
||||
})
|
||||
// Parse optional team ID from request body.
|
||||
var params service.HostCreateParams
|
||||
params.Type = req.Type
|
||||
params.Provider = req.Provider
|
||||
params.AvailabilityZone = req.AvailabilityZone
|
||||
params.RequestingUserID = ac.UserID
|
||||
params.IsRequestorAdmin = h.isAdmin(r, ac.UserID)
|
||||
if req.TeamID != "" {
|
||||
teamID, err := id.ParseTeamID(req.TeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team_id")
|
||||
return
|
||||
}
|
||||
params.TeamID = teamID
|
||||
}
|
||||
|
||||
result, err := h.svc.Create(r.Context(), params)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
// Log audit for the owning team (BYOC hosts have a team; shared hosts use caller's team).
|
||||
h.audit.LogHostCreate(r.Context(), ac, result.Host.ID, result.Host.TeamID)
|
||||
|
||||
writeJSON(w, http.StatusCreated, createHostResponse{
|
||||
Host: hostToResponse(result.Host),
|
||||
RegistrationToken: result.RegistrationToken,
|
||||
@ -153,16 +196,50 @@ func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
// List handles GET /v1/hosts.
|
||||
func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
admin := h.isAdmin(r, ac.UserID)
|
||||
|
||||
hosts, err := h.svc.List(r.Context(), ac.TeamID, h.isAdmin(r, ac.UserID))
|
||||
hosts, err := h.svc.List(r.Context(), ac.TeamID, admin)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list hosts")
|
||||
return
|
||||
}
|
||||
|
||||
// Collect unique team IDs so we can fetch team names in one pass.
|
||||
var teamNames map[string]string
|
||||
if admin {
|
||||
seen := make(map[string]struct{})
|
||||
for _, host := range hosts {
|
||||
if host.TeamID.Valid {
|
||||
key := id.FormatTeamID(host.TeamID)
|
||||
seen[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(seen) > 0 {
|
||||
teamNames = make(map[string]string, len(seen))
|
||||
for _, host := range hosts {
|
||||
if !host.TeamID.Valid {
|
||||
continue
|
||||
}
|
||||
key := id.FormatTeamID(host.TeamID)
|
||||
if _, ok := teamNames[key]; ok {
|
||||
continue
|
||||
}
|
||||
if team, err := h.queries.GetTeam(r.Context(), host.TeamID); err == nil {
|
||||
teamNames[key] = team.Name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resp := make([]hostResponse, len(hosts))
|
||||
for i, host := range hosts {
|
||||
resp[i] = hostToResponse(host)
|
||||
if host.TeamID.Valid {
|
||||
key := id.FormatTeamID(host.TeamID)
|
||||
if name, ok := teamNames[key]; ok {
|
||||
resp[i].TeamName = &name
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
@ -170,9 +247,15 @@ func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Get handles GET /v1/hosts/{id}.
|
||||
func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
hostIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
hostID, err := id.ParseHostID(hostIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
|
||||
return
|
||||
}
|
||||
|
||||
host, err := h.svc.Get(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
@ -183,25 +266,86 @@ func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusOK, hostToResponse(host))
|
||||
}
|
||||
|
||||
// Delete handles DELETE /v1/hosts/{id}.
|
||||
func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
// DeletePreview handles GET /v1/hosts/{id}/delete-preview.
|
||||
// Returns what would be affected without making changes, for confirmation UI.
|
||||
func (h *hostHandler) DeletePreview(w http.ResponseWriter, r *http.Request) {
|
||||
hostIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
if err := h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID)); err != nil {
|
||||
hostID, err := id.ParseHostID(hostIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
|
||||
return
|
||||
}
|
||||
|
||||
preview, err := h.svc.DeletePreview(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
writeJSON(w, http.StatusOK, deletePreviewResponse{
|
||||
Host: hostToResponse(preview.Host),
|
||||
SandboxIDs: preview.SandboxIDs,
|
||||
})
|
||||
}
|
||||
|
||||
// Delete handles DELETE /v1/hosts/{id}.
|
||||
// Without ?force=true: returns 409 with affected sandbox IDs if any are active.
|
||||
// With ?force=true: gracefully stops all sandboxes then deletes the host.
|
||||
func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
hostIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
force := r.URL.Query().Get("force") == "true"
|
||||
|
||||
hostID, err := id.ParseHostID(hostIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Fetch host before deletion to capture team_id for audit.
|
||||
deletedHost, hostErr := h.queries.GetHost(r.Context(), hostID)
|
||||
if hostErr != nil {
|
||||
slog.Warn("audit: could not fetch host before delete", "host_id", hostIDStr, "error", hostErr)
|
||||
}
|
||||
|
||||
err = h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID), force)
|
||||
if err == nil {
|
||||
h.audit.LogHostDelete(r.Context(), ac, hostID, deletedHost.TeamID)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if it's a "has running sandboxes" error and return a structured 409.
|
||||
var hasSandboxes *service.HostHasSandboxesError
|
||||
if errors.As(err, &hasSandboxes) {
|
||||
writeJSON(w, http.StatusConflict, map[string]any{
|
||||
"error": map[string]any{
|
||||
"code": "has_active_sandboxes",
|
||||
"message": "host has active sandboxes; use ?force=true to destroy them and delete the host",
|
||||
"sandbox_ids": hasSandboxes.SandboxIDs,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
}
|
||||
|
||||
// RegenerateToken handles POST /v1/hosts/{id}/token.
|
||||
func (h *hostHandler) RegenerateToken(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
hostIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
hostID, err := id.ParseHostID(hostIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.svc.RegenerateToken(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID))
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
@ -247,36 +391,61 @@ func (h *hostHandler) Register(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, registerHostResponse{
|
||||
Host: hostToResponse(result.Host),
|
||||
Token: result.JWT,
|
||||
Host: hostToResponse(result.Host),
|
||||
Token: result.JWT,
|
||||
RefreshToken: result.RefreshToken,
|
||||
CertPEM: result.CertPEM,
|
||||
KeyPEM: result.KeyPEM,
|
||||
CACertPEM: result.CACertPEM,
|
||||
})
|
||||
}
|
||||
|
||||
// Heartbeat handles POST /v1/hosts/{id}/heartbeat (host-token-authenticated).
|
||||
func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
hostIDStr := chi.URLParam(r, "id")
|
||||
hc := auth.MustHostFromContext(r.Context())
|
||||
|
||||
hostID, err := id.ParseHostID(hostIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Prevent a host from heartbeating for a different host.
|
||||
if hostID != hc.HostID {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "host ID mismatch")
|
||||
return
|
||||
}
|
||||
|
||||
// Capture pre-heartbeat status to detect unreachable → online transition.
|
||||
prevHost, _ := h.queries.GetHost(r.Context(), hc.HostID)
|
||||
|
||||
if err := h.svc.Heartbeat(r.Context(), hc.HostID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to update heartbeat")
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
// Log marked_up if the host just recovered from unreachable.
|
||||
if prevHost.Status == "unreachable" {
|
||||
h.audit.LogHostMarkedUp(r.Context(), prevHost.TeamID, hc.HostID)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// AddTag handles POST /v1/hosts/{id}/tags.
|
||||
func (h *hostHandler) AddTag(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
hostIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
admin := h.isAdmin(r, ac.UserID)
|
||||
|
||||
hostID, err := id.ParseHostID(hostIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req addTagRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
@ -298,10 +467,16 @@ func (h *hostHandler) AddTag(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// RemoveTag handles DELETE /v1/hosts/{id}/tags/{tag}.
|
||||
func (h *hostHandler) RemoveTag(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
hostIDStr := chi.URLParam(r, "id")
|
||||
tag := chi.URLParam(r, "tag")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
hostID, err := id.ParseHostID(hostIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.RemoveTag(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID), tag); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
@ -311,11 +486,47 @@ func (h *hostHandler) RemoveTag(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// RefreshToken handles POST /v1/hosts/auth/refresh (unauthenticated).
|
||||
// The host agent sends its refresh token to receive a new JWT and rotated refresh token.
|
||||
func (h *hostHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
|
||||
var req refreshTokenRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
if req.RefreshToken == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "refresh_token is required")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.svc.Refresh(r.Context(), req.RefreshToken)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, refreshTokenResponse{
|
||||
Host: hostToResponse(result.Host),
|
||||
Token: result.JWT,
|
||||
RefreshToken: result.RefreshToken,
|
||||
CertPEM: result.CertPEM,
|
||||
KeyPEM: result.KeyPEM,
|
||||
CACertPEM: result.CACertPEM,
|
||||
})
|
||||
}
|
||||
|
||||
// ListTags handles GET /v1/hosts/{id}/tags.
|
||||
func (h *hostHandler) ListTags(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
hostIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
hostID, err := id.ParseHostID(hostIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
|
||||
return
|
||||
}
|
||||
|
||||
tags, err := h.svc.ListTags(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
|
||||
156
internal/api/handlers_metrics.go
Normal file
156
internal/api/handlers_metrics.go
Normal file
@ -0,0 +1,156 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
type sandboxMetricsHandler struct {
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
}
|
||||
|
||||
func newSandboxMetricsHandler(db *db.Queries, pool *lifecycle.HostClientPool) *sandboxMetricsHandler {
|
||||
return &sandboxMetricsHandler{db: db, pool: pool}
|
||||
}
|
||||
|
||||
type metricPointResponse struct {
|
||||
TimestampUnix int64 `json:"timestamp_unix"`
|
||||
CPUPct float64 `json:"cpu_pct"`
|
||||
MemBytes int64 `json:"mem_bytes"`
|
||||
DiskBytes int64 `json:"disk_bytes"`
|
||||
}
|
||||
|
||||
type metricsResponse struct {
|
||||
SandboxID string `json:"sandbox_id"`
|
||||
Range string `json:"range"`
|
||||
Points []metricPointResponse `json:"points"`
|
||||
}
|
||||
|
||||
// GetMetrics handles GET /v1/sandboxes/{id}/metrics?range=10m|2h|24h.
|
||||
func (h *sandboxMetricsHandler) GetMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
rangeTier := r.URL.Query().Get("range")
|
||||
if rangeTier == "" {
|
||||
rangeTier = "10m"
|
||||
}
|
||||
validRanges := map[string]bool{"5m": true, "10m": true, "1h": true, "2h": true, "6h": true, "12h": true, "24h": true}
|
||||
if !validRanges[rangeTier] {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "range must be one of: 5m, 10m, 1h, 2h, 6h, 12h, 24h")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
|
||||
switch sb.Status {
|
||||
case "running":
|
||||
h.getFromAgent(w, r, sandboxIDStr, rangeTier, sb.HostID)
|
||||
case "paused":
|
||||
h.getFromDB(ctx, w, sandboxIDStr, sandboxID, rangeTier)
|
||||
default:
|
||||
writeError(w, http.StatusNotFound, "not_found", "metrics not available for sandbox in state: "+sb.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sandboxMetricsHandler) getFromAgent(w http.ResponseWriter, r *http.Request, sandboxIDStr, rangeTier string, hostID pgtype.UUID) {
|
||||
ctx := r.Context()
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := agent.GetSandboxMetrics(ctx, connect.NewRequest(&pb.GetSandboxMetricsRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Range: rangeTier,
|
||||
}))
|
||||
if err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
points := make([]metricPointResponse, len(resp.Msg.Points))
|
||||
for i, p := range resp.Msg.Points {
|
||||
points[i] = metricPointResponse{
|
||||
TimestampUnix: p.TimestampUnix,
|
||||
CPUPct: p.CpuPct,
|
||||
MemBytes: p.MemBytes,
|
||||
DiskBytes: p.DiskBytes,
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, metricsResponse{
|
||||
SandboxID: sandboxIDStr,
|
||||
Range: rangeTier,
|
||||
Points: points,
|
||||
})
|
||||
}
|
||||
|
||||
// rangeToDB maps a user-facing range filter to the DB tier and cutoff duration.
|
||||
var rangeToDB = map[string]struct {
|
||||
tier string
|
||||
cutoff time.Duration
|
||||
}{
|
||||
"5m": {"10m", 5 * time.Minute},
|
||||
"10m": {"10m", 10 * time.Minute},
|
||||
"1h": {"2h", 1 * time.Hour},
|
||||
"2h": {"2h", 2 * time.Hour},
|
||||
"6h": {"24h", 6 * time.Hour},
|
||||
"12h": {"24h", 12 * time.Hour},
|
||||
"24h": {"24h", 24 * time.Hour},
|
||||
}
|
||||
|
||||
func (h *sandboxMetricsHandler) getFromDB(ctx context.Context, w http.ResponseWriter, sandboxIDStr string, sandboxID pgtype.UUID, rangeTier string) {
|
||||
mapping := rangeToDB[rangeTier]
|
||||
rows, err := h.db.GetSandboxMetricPoints(ctx, db.GetSandboxMetricPointsParams{
|
||||
SandboxID: sandboxID,
|
||||
Tier: mapping.tier,
|
||||
Ts: time.Now().Add(-mapping.cutoff).Unix(),
|
||||
})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to read metrics")
|
||||
return
|
||||
}
|
||||
|
||||
points := make([]metricPointResponse, len(rows))
|
||||
for i, row := range rows {
|
||||
points[i] = metricPointResponse{
|
||||
TimestampUnix: row.Ts,
|
||||
CPUPct: row.CpuPct,
|
||||
MemBytes: row.MemBytes,
|
||||
DiskBytes: row.DiskBytes,
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, metricsResponse{
|
||||
SandboxID: sandboxIDStr,
|
||||
Range: rangeTier,
|
||||
Points: points,
|
||||
})
|
||||
}
|
||||
@ -150,19 +150,19 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
team, err := h.db.GetDefaultTeamForUser(ctx, user.ID)
|
||||
team, role, err := loginTeam(ctx, h.db, user.ID)
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to get team", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email)
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
return
|
||||
}
|
||||
redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email)
|
||||
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name)
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, pgx.ErrNoRows) {
|
||||
@ -199,6 +199,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
_, err = qtx.InsertUserOAuth(ctx, db.InsertUserOAuthParams{
|
||||
ID: userID,
|
||||
Email: email,
|
||||
Name: profile.Name,
|
||||
})
|
||||
if err != nil {
|
||||
var pgErr *pgconn.PgError
|
||||
@ -219,6 +220,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
if _, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
|
||||
ID: teamID,
|
||||
Name: teamName,
|
||||
Slug: id.NewTeamSlug(),
|
||||
}); err != nil {
|
||||
slog.Error("oauth: failed to create team", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
@ -253,14 +255,14 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email)
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email, profile.Name, "owner", false)
|
||||
if err != nil {
|
||||
slog.Error("oauth: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
return
|
||||
}
|
||||
|
||||
redirectWithToken(w, r, redirectBase, token, userID, teamID, email)
|
||||
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(userID), id.FormatTeamID(teamID), email, profile.Name)
|
||||
}
|
||||
|
||||
// retryAsLogin handles the race where a concurrent request already created the user.
|
||||
@ -282,29 +284,39 @@ func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, prov
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
team, err := h.db.GetDefaultTeamForUser(ctx, user.ID)
|
||||
team, role, err := loginTeam(ctx, h.db, user.ID)
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to get team", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email)
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
return
|
||||
}
|
||||
redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email)
|
||||
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name)
|
||||
}
|
||||
|
||||
func redirectWithToken(w http.ResponseWriter, r *http.Request, base, token, userID, teamID, email string) {
|
||||
u := base + "?" + url.Values{
|
||||
"token": {token},
|
||||
"user_id": {userID},
|
||||
"team_id": {teamID},
|
||||
"email": {email},
|
||||
}.Encode()
|
||||
http.Redirect(w, r, u, http.StatusFound)
|
||||
func redirectWithToken(w http.ResponseWriter, r *http.Request, base, token, userID, teamID, email, name string) {
|
||||
// Set auth data as short-lived cookies instead of URL query parameters.
|
||||
// This prevents token leakage via server access logs, Referer headers, and browser history.
|
||||
for _, c := range []http.Cookie{
|
||||
{Name: "wrenn_oauth_token", Value: token},
|
||||
{Name: "wrenn_oauth_user_id", Value: userID},
|
||||
{Name: "wrenn_oauth_team_id", Value: teamID},
|
||||
{Name: "wrenn_oauth_email", Value: email},
|
||||
{Name: "wrenn_oauth_name", Value: name},
|
||||
} {
|
||||
c.Path = "/auth/"
|
||||
c.MaxAge = 60
|
||||
c.HttpOnly = false // frontend JS must read these
|
||||
c.SameSite = http.SameSiteLaxMode
|
||||
c.Secure = isSecure(r)
|
||||
http.SetCookie(w, &c)
|
||||
}
|
||||
http.Redirect(w, r, base, http.StatusFound)
|
||||
}
|
||||
|
||||
func redirectWithError(w http.ResponseWriter, r *http.Request, base, code string) {
|
||||
|
||||
@ -7,17 +7,20 @@ import (
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/audit"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
)
|
||||
|
||||
type sandboxHandler struct {
|
||||
svc *service.SandboxService
|
||||
svc *service.SandboxService
|
||||
audit *audit.AuditLogger
|
||||
}
|
||||
|
||||
func newSandboxHandler(svc *service.SandboxService) *sandboxHandler {
|
||||
return &sandboxHandler{svc: svc}
|
||||
func newSandboxHandler(svc *service.SandboxService, al *audit.AuditLogger) *sandboxHandler {
|
||||
return &sandboxHandler{svc: svc, audit: al}
|
||||
}
|
||||
|
||||
type createSandboxRequest struct {
|
||||
@ -44,7 +47,7 @@ type sandboxResponse struct {
|
||||
|
||||
func sandboxToResponse(sb db.Sandbox) sandboxResponse {
|
||||
resp := sandboxResponse{
|
||||
ID: sb.ID,
|
||||
ID: id.FormatSandboxID(sb.ID),
|
||||
Status: sb.Status,
|
||||
Template: sb.Template,
|
||||
VCPUs: sb.Vcpus,
|
||||
@ -79,6 +82,10 @@ func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
if !ac.TeamID.Valid {
|
||||
writeError(w, http.StatusForbidden, "no_team", "no active team context; re-authenticate")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.svc.Create(r.Context(), service.SandboxCreateParams{
|
||||
TeamID: ac.TeamID,
|
||||
@ -93,6 +100,7 @@ func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSandboxCreate(r.Context(), ac, sb.ID, sb.Template)
|
||||
writeJSON(w, http.StatusCreated, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
@ -115,9 +123,15 @@ func (h *sandboxHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Get handles GET /v1/sandboxes/{id}.
|
||||
func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.svc.Get(r.Context(), sandboxID, ac.TeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
@ -129,9 +143,15 @@ func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Pause handles POST /v1/sandboxes/{id}/pause.
|
||||
func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.svc.Pause(r.Context(), sandboxID, ac.TeamID)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
@ -139,14 +159,21 @@ func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSandboxPause(r.Context(), ac, sandboxID)
|
||||
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// Resume handles POST /v1/sandboxes/{id}/resume.
|
||||
func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.svc.Resume(r.Context(), sandboxID, ac.TeamID)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
@ -154,14 +181,21 @@ func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSandboxResume(r.Context(), ac, sandboxID)
|
||||
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// Ping handles POST /v1/sandboxes/{id}/ping.
|
||||
func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.Ping(r.Context(), sandboxID, ac.TeamID); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
@ -173,14 +207,21 @@ func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Destroy handles DELETE /v1/sandboxes/{id}.
|
||||
func (h *sandboxHandler) Destroy(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.Destroy(r.Context(), sandboxID, ac.TeamID); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSandboxDestroy(r.Context(), ac, sandboxID)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
@ -9,25 +10,57 @@ import (
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/audit"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/layout"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/validate"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
type snapshotHandler struct {
|
||||
svc *service.TemplateService
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
pool *lifecycle.HostClientPool
|
||||
audit *audit.AuditLogger
|
||||
}
|
||||
|
||||
func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *snapshotHandler {
|
||||
return &snapshotHandler{svc: svc, db: db, agent: agent}
|
||||
func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, pool *lifecycle.HostClientPool, al *audit.AuditLogger) *snapshotHandler {
|
||||
return &snapshotHandler{svc: svc, db: db, pool: pool, audit: al}
|
||||
}
|
||||
|
||||
// deleteSnapshotBroadcast attempts to delete snapshot files on all online hosts.
|
||||
// Snapshots aren't currently host-tracked in the DB, so we broadcast to all hosts
|
||||
// and ignore NotFound errors.
|
||||
func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, teamID, templateID pgtype.UUID) error {
|
||||
hosts, err := h.db.ListActiveHosts(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list hosts: %w", err)
|
||||
}
|
||||
for _, host := range hosts {
|
||||
if host.Status != "online" {
|
||||
continue
|
||||
}
|
||||
agent, err := h.pool.GetForHost(host)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
|
||||
TeamId: formatUUIDForRPC(teamID),
|
||||
TemplateId: formatUUIDForRPC(templateID),
|
||||
})); err != nil {
|
||||
if connect.CodeOf(err) != connect.CodeNotFound {
|
||||
slog.Warn("snapshot: failed to delete on host", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type createSnapshotRequest struct {
|
||||
@ -42,6 +75,7 @@ type snapshotResponse struct {
|
||||
MemoryMB *int32 `json:"memory_mb,omitempty"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Platform bool `json:"platform"`
|
||||
}
|
||||
|
||||
func templateToResponse(t db.Template) snapshotResponse {
|
||||
@ -49,12 +83,13 @@ func templateToResponse(t db.Template) snapshotResponse {
|
||||
Name: t.Name,
|
||||
Type: t.Type,
|
||||
SizeBytes: t.SizeBytes,
|
||||
Platform: t.TeamID == id.PlatformTeamID,
|
||||
}
|
||||
if t.Vcpus.Valid {
|
||||
resp.VCPUs = &t.Vcpus.Int32
|
||||
if t.Vcpus != 0 {
|
||||
resp.VCPUs = &t.Vcpus
|
||||
}
|
||||
if t.MemoryMb.Valid {
|
||||
resp.MemoryMB = &t.MemoryMb.Int32
|
||||
if t.MemoryMb != 0 {
|
||||
resp.MemoryMB = &t.MemoryMb
|
||||
}
|
||||
if t.CreatedAt.Valid {
|
||||
resp.CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
|
||||
@ -75,6 +110,12 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(req.SandboxID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox_id")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
req.Name = id.NewSnapshotName()
|
||||
}
|
||||
@ -87,16 +128,21 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(ctx)
|
||||
overwrite := r.URL.Query().Get("overwrite") == "true"
|
||||
|
||||
// Check for global name collision.
|
||||
if _, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
|
||||
writeError(w, http.StatusConflict, "name_reserved", "template name is reserved by a global template")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if name already exists for this team.
|
||||
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
|
||||
if existing, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
|
||||
if !overwrite {
|
||||
writeError(w, http.StatusConflict, "already_exists", "snapshot name already exists; use ?overwrite=true to replace")
|
||||
return
|
||||
}
|
||||
// Delete old files from the agent before removing the DB record.
|
||||
if _, err := h.agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{Name: req.Name})); err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, "failed to delete existing snapshot files: "+msg)
|
||||
// Delete old snapshot files from all hosts before removing the DB record.
|
||||
if err := h.deleteSnapshotBroadcast(ctx, existing.TeamID, existing.ID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files")
|
||||
return
|
||||
}
|
||||
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err != nil {
|
||||
@ -106,7 +152,7 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Verify sandbox exists, belongs to team, and is running or paused.
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: req.SandboxID, TeamID: ac.TeamID})
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
@ -116,30 +162,59 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{
|
||||
SandboxId: req.SandboxID,
|
||||
Name: req.Name,
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
// Pre-mark sandbox as "paused" in DB BEFORE issuing the snapshot RPC.
|
||||
// The host agent's CreateSnapshot removes the sandbox from its in-memory
|
||||
// map immediately; if the reconciler fires during the flatten window and
|
||||
// the DB still says "running", it will mark the sandbox "stopped".
|
||||
if sb.Status == "running" {
|
||||
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "paused",
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to update sandbox status")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Use a detached context with a generous timeout so the snapshot completes
|
||||
// even if the client disconnects (the flatten step can take 10-20s).
|
||||
snapCtx, snapCancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer snapCancel()
|
||||
|
||||
// Generate the new template ID upfront so the host agent knows where to store files.
|
||||
newTemplateID := id.NewTemplateID()
|
||||
|
||||
resp, err := agent.CreateSnapshot(snapCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
|
||||
SandboxId: req.SandboxID,
|
||||
Name: req.Name,
|
||||
TeamId: formatUUIDForRPC(ac.TeamID),
|
||||
TemplateId: formatUUIDForRPC(newTemplateID),
|
||||
}))
|
||||
if err != nil {
|
||||
// Snapshot failed — revert status back to what it was.
|
||||
if sb.Status == "running" {
|
||||
if _, dbErr := h.db.UpdateSandboxStatus(snapCtx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "running",
|
||||
}); dbErr != nil {
|
||||
slog.Error("failed to revert sandbox status after snapshot error", "sandbox_id", req.SandboxID, "error", dbErr)
|
||||
}
|
||||
}
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
// Mark sandbox as paused (if it was running, it got paused by the snapshot).
|
||||
if sb.Status != "paused" {
|
||||
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: req.SandboxID, Status: "paused",
|
||||
}); err != nil {
|
||||
slog.Error("failed to update sandbox status after snapshot", "sandbox_id", req.SandboxID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
tmpl, err := h.db.InsertTemplate(ctx, db.InsertTemplateParams{
|
||||
tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{
|
||||
ID: newTemplateID,
|
||||
Name: req.Name,
|
||||
Type: "snapshot",
|
||||
Vcpus: pgtype.Int4{Int32: sb.Vcpus, Valid: true},
|
||||
MemoryMb: pgtype.Int4{Int32: sb.MemoryMb, Valid: true},
|
||||
Vcpus: sb.Vcpus,
|
||||
MemoryMb: sb.MemoryMb,
|
||||
SizeBytes: resp.Msg.SizeBytes,
|
||||
TeamID: ac.TeamID,
|
||||
})
|
||||
@ -149,6 +224,12 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSnapshotCreate(snapCtx, ac, req.Name)
|
||||
|
||||
if ctx.Err() != nil {
|
||||
slog.Info("snapshot created but client disconnected before response", "name", req.Name)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, templateToResponse(tmpl))
|
||||
}
|
||||
|
||||
@ -181,16 +262,23 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: name, TeamID: ac.TeamID}); err != nil {
|
||||
tmpl, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: name, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "template not found")
|
||||
return
|
||||
}
|
||||
// Platform templates can only be deleted by admins via /v1/admin/templates.
|
||||
if tmpl.TeamID == id.PlatformTeamID {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "platform templates cannot be deleted here")
|
||||
return
|
||||
}
|
||||
if layout.IsMinimal(tmpl.TeamID, tmpl.ID) {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := h.agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
|
||||
Name: name,
|
||||
})); err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, "failed to delete snapshot files: "+msg)
|
||||
if err := h.deleteSnapshotBroadcast(ctx, tmpl.TeamID, tmpl.ID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete snapshot files")
|
||||
return
|
||||
}
|
||||
|
||||
@ -199,5 +287,6 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSnapshotDelete(r.Context(), ac, name)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
95
internal/api/handlers_stats.go
Normal file
95
internal/api/handlers_stats.go
Normal file
@ -0,0 +1,95 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
)
|
||||
|
||||
type statsHandler struct {
|
||||
svc *service.StatsService
|
||||
}
|
||||
|
||||
func newStatsHandler(svc *service.StatsService) *statsHandler {
|
||||
return &statsHandler{svc: svc}
|
||||
}
|
||||
|
||||
type statsCurrentResponse struct {
|
||||
RunningCount int32 `json:"running_count"`
|
||||
VCPUsReserved int32 `json:"vcpus_reserved"`
|
||||
MemoryMBReserved int32 `json:"memory_mb_reserved"`
|
||||
}
|
||||
|
||||
type statsPeaksResponse struct {
|
||||
RunningCount int32 `json:"running_count"`
|
||||
VCPUs int32 `json:"vcpus"`
|
||||
MemoryMB int32 `json:"memory_mb"`
|
||||
}
|
||||
|
||||
type statsSeriesResponse struct {
|
||||
Labels []string `json:"labels"`
|
||||
Running []int32 `json:"running"`
|
||||
VCPUs []int32 `json:"vcpus"`
|
||||
MemoryMB []int32 `json:"memory_mb"`
|
||||
}
|
||||
|
||||
type statsResponse struct {
|
||||
Range string `json:"range"`
|
||||
Current statsCurrentResponse `json:"current"`
|
||||
Peaks statsPeaksResponse `json:"peaks"`
|
||||
Series statsSeriesResponse `json:"series"`
|
||||
}
|
||||
|
||||
// GetStats handles GET /v1/sandboxes/stats?range=5m|1h|6h|24h|30d
|
||||
func (h *statsHandler) GetStats(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
rangeParam := r.URL.Query().Get("range")
|
||||
if rangeParam == "" {
|
||||
rangeParam = string(service.Range1h)
|
||||
}
|
||||
tr := service.TimeRange(rangeParam)
|
||||
if !service.ValidRange(tr) {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "range must be one of: 5m, 1h, 6h, 24h, 30d")
|
||||
return
|
||||
}
|
||||
|
||||
current, peaks, series, err := h.svc.GetStats(r.Context(), ac.TeamID, tr)
|
||||
if err != nil {
|
||||
slog.Error("stats handler: get stats failed", "team_id", ac.TeamID, "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to retrieve stats")
|
||||
return
|
||||
}
|
||||
|
||||
resp := statsResponse{
|
||||
Range: rangeParam,
|
||||
Current: statsCurrentResponse{
|
||||
RunningCount: current.RunningCount,
|
||||
VCPUsReserved: current.VCPUsReserved,
|
||||
MemoryMBReserved: current.MemoryMBReserved,
|
||||
},
|
||||
Peaks: statsPeaksResponse{
|
||||
RunningCount: peaks.RunningCount,
|
||||
VCPUs: peaks.VCPUs,
|
||||
MemoryMB: peaks.MemoryMB,
|
||||
},
|
||||
Series: statsSeriesResponse{
|
||||
Labels: make([]string, len(series)),
|
||||
Running: make([]int32, len(series)),
|
||||
VCPUs: make([]int32, len(series)),
|
||||
MemoryMB: make([]int32, len(series)),
|
||||
},
|
||||
}
|
||||
|
||||
for i, pt := range series {
|
||||
resp.Series.Labels[i] = pt.Bucket.UTC().Format(time.RFC3339)
|
||||
resp.Series.Running[i] = pt.RunningCount
|
||||
resp.Series.VCPUs[i] = pt.VCPUsReserved
|
||||
resp.Series.MemoryMB[i] = pt.MemoryMBReserved
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
390
internal/api/handlers_team.go
Normal file
390
internal/api/handlers_team.go
Normal file
@ -0,0 +1,390 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/audit"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
)
|
||||
|
||||
type teamHandler struct {
|
||||
svc *service.TeamService
|
||||
audit *audit.AuditLogger
|
||||
}
|
||||
|
||||
func newTeamHandler(svc *service.TeamService, al *audit.AuditLogger) *teamHandler {
|
||||
return &teamHandler{svc: svc, audit: al}
|
||||
}
|
||||
|
||||
// teamResponse is the JSON shape for a team.
|
||||
type teamResponse struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
IsByoc bool `json:"is_byoc"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// teamWithRoleResponse includes the calling user's role.
|
||||
type teamWithRoleResponse struct {
|
||||
teamResponse
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
type memberResponse struct {
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
JoinedAt string `json:"joined_at,omitempty"`
|
||||
}
|
||||
|
||||
func teamToResponse(t db.Team) teamResponse {
|
||||
resp := teamResponse{
|
||||
ID: id.FormatTeamID(t.ID),
|
||||
Name: t.Name,
|
||||
Slug: t.Slug,
|
||||
IsByoc: t.IsByoc,
|
||||
}
|
||||
if t.CreatedAt.Valid {
|
||||
resp.CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func memberInfoToResponse(m service.MemberInfo) memberResponse {
|
||||
return memberResponse{
|
||||
UserID: m.UserID,
|
||||
Name: m.Name,
|
||||
Email: m.Email,
|
||||
Role: m.Role,
|
||||
JoinedAt: m.JoinedAt.Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
// requireTeamAccess is an inline check used by every team-scoped handler:
|
||||
// the JWT team_id must match the URL {id} before any DB call is made.
|
||||
// Returns false and writes 403 if they don't match.
|
||||
func requireTeamAccess(w http.ResponseWriter, r *http.Request, ac auth.AuthContext) (pgtype.UUID, bool) {
|
||||
teamIDStr := chi.URLParam(r, "id")
|
||||
teamID, err := id.ParseTeamID(teamIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team ID")
|
||||
return pgtype.UUID{}, false
|
||||
}
|
||||
if ac.TeamID != teamID {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "JWT team does not match requested team; use switch-team first")
|
||||
return pgtype.UUID{}, false
|
||||
}
|
||||
return teamID, true
|
||||
}
|
||||
|
||||
// List handles GET /v1/teams
|
||||
// Returns all teams the authenticated user belongs to.
|
||||
func (h *teamHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
teams, err := h.svc.ListTeamsForUser(r.Context(), ac.UserID)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]teamWithRoleResponse, len(teams))
|
||||
for i, t := range teams {
|
||||
resp[i] = teamWithRoleResponse{
|
||||
teamResponse: teamToResponse(t.Team),
|
||||
Role: t.Role,
|
||||
}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Create handles POST /v1/teams
|
||||
// Creates a new team owned by the authenticated user.
|
||||
func (h *teamHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
req.Name = strings.TrimSpace(req.Name)
|
||||
|
||||
team, err := h.svc.CreateTeam(r.Context(), ac.UserID, req.Name)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, teamWithRoleResponse{
|
||||
teamResponse: teamToResponse(team.Team),
|
||||
Role: team.Role,
|
||||
})
|
||||
}
|
||||
|
||||
// Get handles GET /v1/teams/{id}
|
||||
// Returns team info and member list.
|
||||
func (h *teamHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
teamID, ok := requireTeamAccess(w, r, ac)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
team, err := h.svc.GetTeam(r.Context(), teamID)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
members, err := h.svc.GetMembers(r.Context(), teamID)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
memberResp := make([]memberResponse, len(members))
|
||||
for i, m := range members {
|
||||
memberResp[i] = memberInfoToResponse(m)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"team": teamToResponse(team),
|
||||
"members": memberResp,
|
||||
})
|
||||
}
|
||||
|
||||
// Rename handles PATCH /v1/teams/{id}
|
||||
// Renames the team. Requires admin or owner role (verified from DB).
|
||||
func (h *teamHandler) Rename(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
teamID, ok := requireTeamAccess(w, r, ac)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
req.Name = strings.TrimSpace(req.Name)
|
||||
|
||||
// Fetch old name for audit log before renaming.
|
||||
oldTeam, err := h.svc.GetTeam(r.Context(), teamID)
|
||||
if err != nil {
|
||||
slog.Warn("audit: could not fetch old team name for rename log", "team_id", id.FormatTeamID(teamID), "error", err)
|
||||
}
|
||||
|
||||
if err := h.svc.RenameTeam(r.Context(), teamID, ac.UserID, req.Name); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogTeamRename(r.Context(), ac, teamID, oldTeam.Name, req.Name)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// Delete handles DELETE /v1/teams/{id}
|
||||
// Soft-deletes the team and destroys active sandboxes. Owner only.
|
||||
func (h *teamHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
teamID, ok := requireTeamAccess(w, r, ac)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.DeleteTeam(r.Context(), teamID, ac.UserID); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// ListMembers handles GET /v1/teams/{id}/members
|
||||
func (h *teamHandler) ListMembers(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
teamID, ok := requireTeamAccess(w, r, ac)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
members, err := h.svc.GetMembers(r.Context(), teamID)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]memberResponse, len(members))
|
||||
for i, m := range members {
|
||||
resp[i] = memberInfoToResponse(m)
|
||||
}
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// AddMember handles POST /v1/teams/{id}/members
|
||||
// Adds a user by email. Requires admin or owner (verified from DB).
|
||||
func (h *teamHandler) AddMember(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
teamID, ok := requireTeamAccess(w, r, ac)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
req.Email = strings.TrimSpace(strings.ToLower(req.Email))
|
||||
if req.Email == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "email is required")
|
||||
return
|
||||
}
|
||||
|
||||
member, err := h.svc.AddMember(r.Context(), teamID, ac.UserID, req.Email)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
// member.UserID is already formatted with prefix; parse it back for the audit logger.
|
||||
targetUserID, parseErr := id.ParseUserID(member.UserID)
|
||||
if parseErr == nil {
|
||||
h.audit.LogMemberAdd(r.Context(), ac, targetUserID, member.Email, member.Role)
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, memberInfoToResponse(member))
|
||||
}
|
||||
|
||||
// RemoveMember handles DELETE /v1/teams/{id}/members/{uid}
|
||||
// Removes a member. Requires admin or owner (verified from DB). Owner cannot be removed.
|
||||
func (h *teamHandler) RemoveMember(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
teamID, ok := requireTeamAccess(w, r, ac)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
targetUserIDStr := chi.URLParam(r, "uid")
|
||||
|
||||
targetUserID, err := id.ParseUserID(targetUserIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.RemoveMember(r.Context(), teamID, ac.UserID, targetUserID); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogMemberRemove(r.Context(), ac, targetUserID)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// UpdateMemberRole handles PATCH /v1/teams/{id}/members/{uid}
|
||||
// Changes a member's role (admin or member). Owner's role cannot be changed.
|
||||
func (h *teamHandler) UpdateMemberRole(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
teamID, ok := requireTeamAccess(w, r, ac)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
targetUserIDStr := chi.URLParam(r, "uid")
|
||||
|
||||
targetUserID, err := id.ParseUserID(targetUserIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Role string `json:"role"`
|
||||
}
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.UpdateMemberRole(r.Context(), teamID, ac.UserID, targetUserID, req.Role); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogMemberRoleUpdate(r.Context(), ac, targetUserID, req.Role)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// Leave handles POST /v1/teams/{id}/leave
|
||||
// Removes the calling user from the team. Owner cannot leave.
|
||||
func (h *teamHandler) Leave(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
teamID, ok := requireTeamAccess(w, r, ac)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.LeaveTeam(r.Context(), teamID, ac.UserID); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogMemberLeave(r.Context(), ac)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// SetBYOC handles PUT /v1/admin/teams/{id}/byoc (admin only).
|
||||
// Enables or disables the BYOC feature flag for a team.
|
||||
func (h *teamHandler) SetBYOC(w http.ResponseWriter, r *http.Request) {
|
||||
teamIDStr := chi.URLParam(r, "id")
|
||||
|
||||
teamID, err := id.ParseTeamID(teamIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.SetBYOC(r.Context(), teamID, req.Enabled); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
52
internal/api/handlers_users.go
Normal file
52
internal/api/handlers_users.go
Normal file
@ -0,0 +1,52 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
type usersHandler struct {
|
||||
db *db.Queries
|
||||
}
|
||||
|
||||
func newUsersHandler(db *db.Queries) *usersHandler {
|
||||
return &usersHandler{db: db}
|
||||
}
|
||||
|
||||
// Search handles GET /v1/users/search?email=<prefix>
|
||||
// Returns up to 10 users whose email starts with the given prefix.
|
||||
// The prefix must be at least 3 characters long and contain "@".
|
||||
func (h *usersHandler) Search(w http.ResponseWriter, r *http.Request) {
|
||||
auth.MustFromContext(r.Context()) // ensure authenticated
|
||||
|
||||
prefix := strings.TrimSpace(r.URL.Query().Get("email"))
|
||||
if len(prefix) < 3 || !strings.Contains(prefix, "@") {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "email prefix must be at least 3 characters and contain '@'")
|
||||
return
|
||||
}
|
||||
|
||||
// Escape LIKE metacharacters to prevent pattern injection.
|
||||
escaped := strings.NewReplacer("%", "\\%", "_", "\\_").Replace(prefix)
|
||||
|
||||
results, err := h.db.SearchUsersByEmailPrefix(r.Context(), pgtype.Text{String: escaped, Valid: true})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal", "search failed")
|
||||
return
|
||||
}
|
||||
|
||||
type userResult struct {
|
||||
UserID string `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
resp := make([]userResult, len(results))
|
||||
for i, u := range results {
|
||||
resp[i] = userResult{UserID: id.FormatUserID(u.ID), Email: u.Email}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
216
internal/api/host_monitor.go
Normal file
216
internal/api/host_monitor.go
Normal file
@ -0,0 +1,216 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/audit"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
// unreachableThreshold is how long a host can go without a heartbeat before
|
||||
// it is considered unreachable (3 missed 30-second heartbeats).
|
||||
const unreachableThreshold = 90 * time.Second
|
||||
|
||||
// HostMonitor runs on a fixed interval and performs two duties:
|
||||
//
|
||||
// 1. Passive check: marks hosts whose last_heartbeat_at is stale as
|
||||
// "unreachable" and marks their active sandboxes as "missing".
|
||||
//
|
||||
// 2. Active reconciliation: for each online host, calls ListSandboxes and
|
||||
// reconciles DB state against live host state — restoring "missing"
|
||||
// sandboxes that are actually alive, and stopping orphaned ones.
|
||||
type HostMonitor struct {
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
audit *audit.AuditLogger
|
||||
interval time.Duration
|
||||
}
|
||||
|
||||
// NewHostMonitor creates a HostMonitor.
|
||||
func NewHostMonitor(queries *db.Queries, pool *lifecycle.HostClientPool, al *audit.AuditLogger, interval time.Duration) *HostMonitor {
|
||||
return &HostMonitor{
|
||||
db: queries,
|
||||
pool: pool,
|
||||
audit: al,
|
||||
interval: interval,
|
||||
}
|
||||
}
|
||||
|
||||
// Start runs the monitor loop until the context is cancelled.
|
||||
func (m *HostMonitor) Start(ctx context.Context) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(m.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Run immediately on startup so the CP doesn't wait one full interval
|
||||
// before reconciling host and sandbox state.
|
||||
m.run(ctx)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.run(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (m *HostMonitor) run(ctx context.Context) {
|
||||
hosts, err := m.db.ListActiveHosts(ctx)
|
||||
if err != nil {
|
||||
slog.Warn("host monitor: failed to list hosts", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, host := range hosts {
|
||||
m.checkHost(ctx, host)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
|
||||
// --- Passive phase: check heartbeat staleness ---
|
||||
|
||||
stale := !host.LastHeartbeatAt.Valid ||
|
||||
time.Since(host.LastHeartbeatAt.Time) > unreachableThreshold
|
||||
|
||||
if stale && host.Status != "unreachable" {
|
||||
slog.Info("host monitor: marking host unreachable", "host_id", id.FormatHostID(host.ID),
|
||||
"last_heartbeat", host.LastHeartbeatAt.Time)
|
||||
if err := m.db.MarkHostUnreachable(ctx, host.ID); err != nil {
|
||||
slog.Warn("host monitor: failed to mark host unreachable", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
}
|
||||
if err := m.db.MarkSandboxesMissingByHost(ctx, host.ID); err != nil {
|
||||
slog.Warn("host monitor: failed to mark sandboxes missing", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
}
|
||||
m.audit.LogHostMarkedDown(ctx, host.TeamID, host.ID)
|
||||
return
|
||||
}
|
||||
|
||||
// --- Active reconciliation: only for online hosts ---
|
||||
|
||||
if host.Status != "online" {
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := m.pool.GetForHost(host)
|
||||
if err != nil {
|
||||
// Host has no address yet (e.g., just registered) — skip.
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := agent.ListSandboxes(ctx, connect.NewRequest(&pb.ListSandboxesRequest{}))
|
||||
if err != nil {
|
||||
// RPC failure is a transient condition; the passive phase will catch it
|
||||
// if heartbeats stop arriving.
|
||||
slog.Debug("host monitor: ListSandboxes failed (transient)", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Build set of sandbox IDs alive on the host.
|
||||
// The host agent returns sandbox IDs as strings (formatted with prefix).
|
||||
alive := make(map[string]struct{}, len(resp.Msg.Sandboxes))
|
||||
for _, sb := range resp.Msg.Sandboxes {
|
||||
alive[sb.SandboxId] = struct{}{}
|
||||
}
|
||||
|
||||
autoPaused := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds))
|
||||
for _, apID := range resp.Msg.AutoPausedSandboxIds {
|
||||
autoPaused[apID] = struct{}{}
|
||||
}
|
||||
|
||||
// --- Restore sandboxes that are "missing" in DB but alive on host ---
|
||||
// This handles the case where CP marked them missing due to a transient
|
||||
// heartbeat gap, but the host was actually fine.
|
||||
|
||||
missingSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||
HostID: host.ID,
|
||||
Column2: []string{"missing"},
|
||||
})
|
||||
if err != nil {
|
||||
slog.Warn("host monitor: failed to list missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
} else {
|
||||
var toRestore []pgtype.UUID
|
||||
var toStop []pgtype.UUID
|
||||
for _, sb := range missingSandboxes {
|
||||
sbIDStr := id.FormatSandboxID(sb.ID)
|
||||
if _, ok := alive[sbIDStr]; ok {
|
||||
toRestore = append(toRestore, sb.ID)
|
||||
} else {
|
||||
toStop = append(toStop, sb.ID)
|
||||
}
|
||||
}
|
||||
if len(toRestore) > 0 {
|
||||
slog.Info("host monitor: restoring missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toRestore))
|
||||
if err := m.db.BulkRestoreRunning(ctx, toRestore); err != nil {
|
||||
slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
}
|
||||
}
|
||||
if len(toStop) > 0 {
|
||||
slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toStop))
|
||||
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: toStop,
|
||||
Status: "stopped",
|
||||
}); err != nil {
|
||||
slog.Warn("host monitor: failed to stop missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Find running sandboxes in DB that are no longer alive on the host ---
|
||||
|
||||
runningSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||
HostID: host.ID,
|
||||
Column2: []string{"running"},
|
||||
})
|
||||
if err != nil {
|
||||
slog.Warn("host monitor: failed to list running sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
var toPause, toStop []pgtype.UUID
|
||||
sbTeamID := make(map[pgtype.UUID]pgtype.UUID, len(runningSandboxes))
|
||||
for _, sb := range runningSandboxes {
|
||||
sbIDStr := id.FormatSandboxID(sb.ID)
|
||||
sbTeamID[sb.ID] = sb.TeamID
|
||||
if _, ok := alive[sbIDStr]; ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := autoPaused[sbIDStr]; ok {
|
||||
toPause = append(toPause, sb.ID)
|
||||
} else {
|
||||
toStop = append(toStop, sb.ID)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toPause) > 0 {
|
||||
slog.Info("host monitor: marking auto-paused sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toPause))
|
||||
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: toPause,
|
||||
Status: "paused",
|
||||
}); err != nil {
|
||||
slog.Warn("host monitor: failed to mark paused", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
}
|
||||
for _, sbID := range toPause {
|
||||
m.audit.LogSandboxAutoPause(ctx, sbTeamID[sbID], sbID)
|
||||
}
|
||||
}
|
||||
if len(toStop) > 0 {
|
||||
slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", id.FormatHostID(host.ID), "count", len(toStop))
|
||||
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: toStop,
|
||||
Status: "stopped",
|
||||
}); err != nil {
|
||||
slog.Warn("host monitor: failed to mark stopped", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
68
internal/api/metrics_sampler.go
Normal file
68
internal/api/metrics_sampler.go
Normal file
@ -0,0 +1,68 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
)
|
||||
|
||||
// MetricsSampler records per-team sandbox resource usage to
|
||||
// sandbox_metrics_snapshots every interval. It also prunes rows older than
|
||||
// 60 days on each tick to keep the table bounded.
|
||||
type MetricsSampler struct {
|
||||
db *db.Queries
|
||||
interval time.Duration
|
||||
}
|
||||
|
||||
// NewMetricsSampler creates a MetricsSampler.
|
||||
func NewMetricsSampler(queries *db.Queries, interval time.Duration) *MetricsSampler {
|
||||
return &MetricsSampler{db: queries, interval: interval}
|
||||
}
|
||||
|
||||
// Start runs the sampler loop until the context is cancelled.
|
||||
func (s *MetricsSampler) Start(ctx context.Context) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(s.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Sample immediately on startup.
|
||||
s.run(ctx)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.run(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *MetricsSampler) run(ctx context.Context) {
|
||||
s.prune(ctx)
|
||||
if err := s.sample(ctx); err != nil {
|
||||
slog.Warn("metrics sampler: sample failed", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MetricsSampler) sample(ctx context.Context) error {
|
||||
rows, err := s.db.SampleSandboxMetrics(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, row := range rows {
|
||||
if err := s.db.InsertMetricsSnapshot(ctx, db.InsertMetricsSnapshotParams(row)); err != nil {
|
||||
slog.Warn("metrics sampler: insert snapshot failed", "team_id", row.TeamID, "error", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MetricsSampler) prune(ctx context.Context) {
|
||||
if err := s.db.PruneOldMetrics(ctx); err != nil {
|
||||
slog.Warn("metrics sampler: prune failed", "error", err)
|
||||
}
|
||||
}
|
||||
@ -12,6 +12,9 @@ import (
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
type errorResponse struct {
|
||||
@ -35,6 +38,11 @@ func writeError(w http.ResponseWriter, status int, code, message string) {
|
||||
})
|
||||
}
|
||||
|
||||
// formatUUIDForRPC converts a pgtype.UUID to a hex string for RPC messages.
|
||||
func formatUUIDForRPC(u pgtype.UUID) string {
|
||||
return id.UUIDString(u)
|
||||
}
|
||||
|
||||
// agentErrToHTTP maps a Connect RPC error to an HTTP status, error code, and message.
|
||||
func agentErrToHTTP(err error) (int, string, string) {
|
||||
switch connect.CodeOf(err) {
|
||||
@ -87,8 +95,12 @@ func serviceErrToHTTP(err error) (int, string, string) {
|
||||
return http.StatusNotFound, "not_found", msg
|
||||
case strings.Contains(msg, "not running"), strings.Contains(msg, "not paused"):
|
||||
return http.StatusConflict, "invalid_state", msg
|
||||
case strings.Contains(msg, "conflict:"):
|
||||
return http.StatusConflict, "conflict", msg
|
||||
case strings.Contains(msg, "forbidden"):
|
||||
return http.StatusForbidden, "forbidden", msg
|
||||
case strings.Contains(msg, "invalid or expired"):
|
||||
return http.StatusUnauthorized, "unauthorized", msg
|
||||
case strings.Contains(msg, "invalid"):
|
||||
return http.StatusBadRequest, "invalid_request", msg
|
||||
default:
|
||||
|
||||
30
internal/api/middleware_admin.go
Normal file
30
internal/api/middleware_admin.go
Normal file
@ -0,0 +1,30 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
)
|
||||
|
||||
// requireAdmin validates that the authenticated user is a platform admin.
|
||||
// Must run after requireJWT (depends on AuthContext being present).
|
||||
// Re-validates against the DB — the JWT is_admin claim is for UI only;
|
||||
// the DB is the source of truth for admin access.
|
||||
func requireAdmin(queries *db.Queries) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ac, ok := auth.FromContext(r.Context())
|
||||
if !ok {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "authentication required")
|
||||
return
|
||||
}
|
||||
user, err := queries.GetUserByID(r.Context(), ac.UserID)
|
||||
if err != nil || !user.IsAdmin {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "admin access required")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1,38 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
)
|
||||
|
||||
// requireAPIKey validates the X-API-Key header, looks up the SHA-256 hash in DB,
|
||||
// and stamps TeamID into the request context.
|
||||
func requireAPIKey(queries *db.Queries) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
key := r.Header.Get("X-API-Key")
|
||||
if key == "" {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key header required")
|
||||
return
|
||||
}
|
||||
|
||||
hash := auth.HashAPIKey(key)
|
||||
row, err := queries.GetAPIKeyByHash(r.Context(), hash)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
// Best-effort update of last_used timestamp.
|
||||
if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil {
|
||||
slog.Warn("failed to update api key last_used", "key_id", row.ID, "error", err)
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{TeamID: row.TeamID})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -7,6 +7,7 @@ import (
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
// requireAPIKeyOrJWT accepts either X-API-Key header or Authorization: Bearer JWT.
|
||||
@ -19,15 +20,20 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
|
||||
hash := auth.HashAPIKey(key)
|
||||
row, err := queries.GetAPIKeyByHash(r.Context(), hash)
|
||||
if err != nil {
|
||||
slog.Warn("api key auth failed", "prefix", auth.APIKeyPrefix(key), "ip", r.RemoteAddr)
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil {
|
||||
slog.Warn("failed to update api key last_used", "key_id", row.ID, "error", err)
|
||||
slog.Warn("failed to update api key last_used", "key_id", id.FormatAPIKeyID(row.ID), "error", err)
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{TeamID: row.TeamID})
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: row.TeamID,
|
||||
APIKeyID: row.ID,
|
||||
APIKeyName: row.Name,
|
||||
})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
@ -37,14 +43,28 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
|
||||
tokenStr := strings.TrimPrefix(header, "Bearer ")
|
||||
claims, err := auth.VerifyJWT(jwtSecret, tokenStr)
|
||||
if err != nil {
|
||||
slog.Warn("jwt auth failed", "error", err, "ip", r.RemoteAddr)
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired token")
|
||||
return
|
||||
}
|
||||
|
||||
teamID, err := id.ParseTeamID(claims.TeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token")
|
||||
return
|
||||
}
|
||||
userID, err := id.ParseUserID(claims.Subject)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: claims.TeamID,
|
||||
UserID: claims.Subject,
|
||||
TeamID: teamID,
|
||||
UserID: userID,
|
||||
Email: claims.Email,
|
||||
Name: claims.Name,
|
||||
Role: claims.Role,
|
||||
})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
// requireHostToken validates the X-Host-Token header containing a host JWT,
|
||||
@ -23,7 +24,13 @@ func requireHostToken(secret []byte) func(http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := auth.WithHostContext(r.Context(), auth.HostContext{HostID: claims.HostID})
|
||||
hostID, err := id.ParseHostID(claims.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid host ID in token")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := auth.WithHostContext(r.Context(), auth.HostContext{HostID: hostID})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
// requireJWT validates the Authorization: Bearer <token> header, verifies the JWT
|
||||
@ -25,11 +26,26 @@ func requireJWT(secret []byte) func(http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
teamID, err := id.ParseTeamID(claims.TeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token")
|
||||
return
|
||||
}
|
||||
userID, err := id.ParseUserID(claims.Subject)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: claims.TeamID,
|
||||
UserID: claims.Subject,
|
||||
Email: claims.Email,
|
||||
TeamID: teamID,
|
||||
UserID: userID,
|
||||
Email: claims.Email,
|
||||
Name: claims.Name,
|
||||
Role: claims.Role,
|
||||
IsAdmin: claims.IsAdmin,
|
||||
})
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,126 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
// Reconciler periodically compares the host agent's sandbox list with the DB
|
||||
// and marks sandboxes that no longer exist on the host as stopped.
|
||||
type Reconciler struct {
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
hostID string
|
||||
interval time.Duration
|
||||
}
|
||||
|
||||
// NewReconciler creates a new reconciler.
|
||||
func NewReconciler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient, hostID string, interval time.Duration) *Reconciler {
|
||||
return &Reconciler{
|
||||
db: db,
|
||||
agent: agent,
|
||||
hostID: hostID,
|
||||
interval: interval,
|
||||
}
|
||||
}
|
||||
|
||||
// Start runs the reconciliation loop until the context is cancelled.
|
||||
func (rc *Reconciler) Start(ctx context.Context) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(rc.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
rc.reconcile(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (rc *Reconciler) reconcile(ctx context.Context) {
|
||||
// Single RPC returns both the running sandbox list and any IDs that
|
||||
// were auto-paused by the TTL reaper since the last call.
|
||||
resp, err := rc.agent.ListSandboxes(ctx, connect.NewRequest(&pb.ListSandboxesRequest{}))
|
||||
if err != nil {
|
||||
slog.Warn("reconciler: failed to list sandboxes from host agent", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Build a set of sandbox IDs that are alive on the host.
|
||||
alive := make(map[string]struct{}, len(resp.Msg.Sandboxes))
|
||||
for _, sb := range resp.Msg.Sandboxes {
|
||||
alive[sb.SandboxId] = struct{}{}
|
||||
}
|
||||
|
||||
// Build auto-paused set from the same response.
|
||||
autoPausedSet := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds))
|
||||
for _, id := range resp.Msg.AutoPausedSandboxIds {
|
||||
autoPausedSet[id] = struct{}{}
|
||||
}
|
||||
|
||||
// Get all DB sandboxes for this host that are running.
|
||||
// Paused sandboxes are excluded: they are expected to not exist on the
|
||||
// host agent because pause = snapshot + destroy resources.
|
||||
dbSandboxes, err := rc.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||
HostID: rc.hostID,
|
||||
Column2: []string{"running"},
|
||||
})
|
||||
if err != nil {
|
||||
slog.Warn("reconciler: failed to list DB sandboxes", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Find sandboxes in DB that are no longer on the host.
|
||||
var stale []string
|
||||
for _, sb := range dbSandboxes {
|
||||
if _, ok := alive[sb.ID]; !ok {
|
||||
stale = append(stale, sb.ID)
|
||||
}
|
||||
}
|
||||
|
||||
if len(stale) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Split stale sandboxes into those auto-paused by the TTL reaper vs
|
||||
// those that crashed/were orphaned.
|
||||
var toPause, toStop []string
|
||||
for _, id := range stale {
|
||||
if _, ok := autoPausedSet[id]; ok {
|
||||
toPause = append(toPause, id)
|
||||
} else {
|
||||
toStop = append(toStop, id)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toPause) > 0 {
|
||||
slog.Info("reconciler: marking auto-paused sandboxes", "count", len(toPause), "ids", toPause)
|
||||
if err := rc.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: toPause,
|
||||
Status: "paused",
|
||||
}); err != nil {
|
||||
slog.Warn("reconciler: failed to mark auto-paused sandboxes", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toStop) > 0 {
|
||||
slog.Info("reconciler: marking stale sandboxes as stopped", "count", len(toStop), "ids", toStop)
|
||||
if err := rc.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: toStop,
|
||||
Status: "stopped",
|
||||
}); err != nil {
|
||||
slog.Warn("reconciler: failed to update stale sandboxes", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -9,10 +9,14 @@ import (
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/audit"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/channels"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/scheduler"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
//go:embed openapi.yaml
|
||||
@ -20,30 +24,54 @@ var openapiYAML []byte
|
||||
|
||||
// Server is the control plane HTTP server.
|
||||
type Server struct {
|
||||
router chi.Router
|
||||
router chi.Router
|
||||
BuildSvc *service.BuildService
|
||||
}
|
||||
|
||||
// New constructs the chi router and registers all routes.
|
||||
func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, pool *pgxpool.Pool, rdb *redis.Client, jwtSecret []byte, oauthRegistry *oauth.Registry, oauthRedirectURL string) *Server {
|
||||
func New(
|
||||
queries *db.Queries,
|
||||
pool *lifecycle.HostClientPool,
|
||||
sched scheduler.HostScheduler,
|
||||
pgPool *pgxpool.Pool,
|
||||
rdb *redis.Client,
|
||||
jwtSecret []byte,
|
||||
oauthRegistry *oauth.Registry,
|
||||
oauthRedirectURL string,
|
||||
ca *auth.CA,
|
||||
al *audit.AuditLogger,
|
||||
channelSvc *channels.Service,
|
||||
) *Server {
|
||||
r := chi.NewRouter()
|
||||
r.Use(requestLogger())
|
||||
|
||||
// Shared service layer.
|
||||
sandboxSvc := &service.SandboxService{DB: queries, Agent: agent}
|
||||
sandboxSvc := &service.SandboxService{DB: queries, Pool: pool, Scheduler: sched}
|
||||
apiKeySvc := &service.APIKeyService{DB: queries}
|
||||
templateSvc := &service.TemplateService{DB: queries}
|
||||
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret}
|
||||
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool, CA: ca}
|
||||
teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool}
|
||||
auditSvc := &service.AuditService{DB: queries}
|
||||
statsSvc := &service.StatsService{DB: queries, Pool: pgPool}
|
||||
buildSvc := &service.BuildService{DB: queries, Redis: rdb, Pool: pool, Scheduler: sched}
|
||||
|
||||
sandbox := newSandboxHandler(sandboxSvc)
|
||||
exec := newExecHandler(queries, agent)
|
||||
execStream := newExecStreamHandler(queries, agent)
|
||||
files := newFilesHandler(queries, agent)
|
||||
filesStream := newFilesStreamHandler(queries, agent)
|
||||
snapshots := newSnapshotHandler(templateSvc, queries, agent)
|
||||
authH := newAuthHandler(queries, pool, jwtSecret)
|
||||
oauthH := newOAuthHandler(queries, pool, jwtSecret, oauthRegistry, oauthRedirectURL)
|
||||
apiKeys := newAPIKeyHandler(apiKeySvc)
|
||||
hostH := newHostHandler(hostSvc, queries)
|
||||
sandbox := newSandboxHandler(sandboxSvc, al)
|
||||
exec := newExecHandler(queries, pool)
|
||||
execStream := newExecStreamHandler(queries, pool)
|
||||
files := newFilesHandler(queries, pool)
|
||||
filesStream := newFilesStreamHandler(queries, pool)
|
||||
snapshots := newSnapshotHandler(templateSvc, queries, pool, al)
|
||||
authH := newAuthHandler(queries, pgPool, jwtSecret)
|
||||
oauthH := newOAuthHandler(queries, pgPool, jwtSecret, oauthRegistry, oauthRedirectURL)
|
||||
apiKeys := newAPIKeyHandler(apiKeySvc, al)
|
||||
hostH := newHostHandler(hostSvc, queries, al)
|
||||
teamH := newTeamHandler(teamSvc, al)
|
||||
usersH := newUsersHandler(queries)
|
||||
auditH := newAuditHandler(auditSvc)
|
||||
statsH := newStatsHandler(statsSvc)
|
||||
metricsH := newSandboxMetricsHandler(queries, pool)
|
||||
buildH := newBuildHandler(buildSvc, queries, pool)
|
||||
channelH := newChannelHandler(channelSvc, al)
|
||||
|
||||
// OpenAPI spec and docs.
|
||||
r.Get("/openapi.yaml", serveOpenAPI)
|
||||
@ -55,6 +83,9 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
|
||||
r.Get("/auth/oauth/{provider}", oauthH.Redirect)
|
||||
r.Get("/auth/oauth/{provider}/callback", oauthH.Callback)
|
||||
|
||||
// JWT-authenticated: switch active team.
|
||||
r.With(requireJWT(jwtSecret)).Post("/v1/auth/switch-team", authH.SwitchTeam)
|
||||
|
||||
// JWT-authenticated: API key management.
|
||||
r.Route("/v1/api-keys", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret))
|
||||
@ -63,11 +94,32 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
|
||||
r.Delete("/{id}", apiKeys.Delete)
|
||||
})
|
||||
|
||||
// JWT-authenticated: team management.
|
||||
r.Route("/v1/teams", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret))
|
||||
r.Get("/", teamH.List)
|
||||
r.Post("/", teamH.Create)
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
r.Get("/", teamH.Get)
|
||||
r.Patch("/", teamH.Rename)
|
||||
r.Delete("/", teamH.Delete)
|
||||
r.Get("/members", teamH.ListMembers)
|
||||
r.Post("/members", teamH.AddMember)
|
||||
r.Patch("/members/{uid}", teamH.UpdateMemberRole)
|
||||
r.Delete("/members/{uid}", teamH.RemoveMember)
|
||||
r.Post("/leave", teamH.Leave)
|
||||
})
|
||||
})
|
||||
|
||||
// JWT-authenticated: user search (for add-member UI).
|
||||
r.With(requireJWT(jwtSecret)).Get("/v1/users/search", usersH.Search)
|
||||
|
||||
// Sandbox lifecycle: accepts API key or JWT bearer token.
|
||||
r.Route("/v1/sandboxes", func(r chi.Router) {
|
||||
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
|
||||
r.Post("/", sandbox.Create)
|
||||
r.Get("/", sandbox.List)
|
||||
r.Get("/stats", statsH.GetStats)
|
||||
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
r.Get("/", sandbox.Get)
|
||||
@ -81,6 +133,7 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
|
||||
r.Post("/files/read", files.Download)
|
||||
r.Post("/files/stream/write", filesStream.StreamUpload)
|
||||
r.Post("/files/stream/read", filesStream.StreamDownload)
|
||||
r.Get("/metrics", metricsH.GetMetrics)
|
||||
})
|
||||
})
|
||||
|
||||
@ -97,6 +150,9 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
|
||||
// Unauthenticated: one-time registration token.
|
||||
r.Post("/register", hostH.Register)
|
||||
|
||||
// Unauthenticated: refresh token exchange.
|
||||
r.Post("/auth/refresh", hostH.RefreshToken)
|
||||
|
||||
// Host-token-authenticated: heartbeat.
|
||||
r.With(requireHostToken(jwtSecret)).Post("/{id}/heartbeat", hostH.Heartbeat)
|
||||
|
||||
@ -108,6 +164,7 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
r.Get("/", hostH.Get)
|
||||
r.Delete("/", hostH.Delete)
|
||||
r.Get("/delete-preview", hostH.DeletePreview)
|
||||
r.Post("/token", hostH.RegenerateToken)
|
||||
r.Get("/tags", hostH.ListTags)
|
||||
r.Post("/tags", hostH.AddTag)
|
||||
@ -116,7 +173,37 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
|
||||
})
|
||||
})
|
||||
|
||||
return &Server{router: r}
|
||||
// JWT-authenticated: notification channels.
|
||||
r.Route("/v1/channels", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret))
|
||||
r.Post("/", channelH.Create)
|
||||
r.Get("/", channelH.List)
|
||||
r.Post("/test", channelH.Test)
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
r.Get("/", channelH.Get)
|
||||
r.Patch("/", channelH.Update)
|
||||
r.Delete("/", channelH.Delete)
|
||||
r.Put("/config", channelH.RotateConfig)
|
||||
})
|
||||
})
|
||||
|
||||
// JWT-authenticated: audit log.
|
||||
r.With(requireJWT(jwtSecret)).Get("/v1/audit-logs", auditH.List)
|
||||
|
||||
// Platform admin routes — require JWT + DB-validated admin status.
|
||||
r.Route("/v1/admin", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret))
|
||||
r.Use(requireAdmin(queries))
|
||||
r.Put("/teams/{id}/byoc", teamH.SetBYOC)
|
||||
r.Get("/templates", buildH.ListTemplates)
|
||||
r.Delete("/templates/{name}", buildH.DeleteTemplate)
|
||||
r.Post("/builds", buildH.Create)
|
||||
r.Get("/builds", buildH.List)
|
||||
r.Get("/builds/{id}", buildH.Get)
|
||||
r.Post("/builds/{id}/cancel", buildH.Cancel)
|
||||
})
|
||||
|
||||
return &Server{router: r, BuildSvc: buildSvc}
|
||||
}
|
||||
|
||||
// Handler returns the HTTP handler.
|
||||
@ -137,7 +224,7 @@ func serveDocs(w http.ResponseWriter, r *http.Request) {
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Wrenn Sandbox API</title>
|
||||
<link rel="stylesheet" href="https://unpkg.com/swagger-ui-dist@5/swagger-ui.css">
|
||||
<link rel="stylesheet" href="https://unpkg.com/swagger-ui-dist@5.18.2/swagger-ui.css" integrity="sha384-rcbEi6xgdPk0iWkAQzT2F3FeBJXdG+ydrawGlfHAFIZG7wU6aKbQaRewysYpmrlW" crossorigin="anonymous">
|
||||
<style>
|
||||
body { margin: 0; background: #fafafa; }
|
||||
.swagger-ui .topbar { display: none; }
|
||||
@ -145,7 +232,7 @@ func serveDocs(w http.ResponseWriter, r *http.Request) {
|
||||
</head>
|
||||
<body>
|
||||
<div id="swagger-ui"></div>
|
||||
<script src="https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js"></script>
|
||||
<script src="https://unpkg.com/swagger-ui-dist@5.18.2/swagger-ui-bundle.js" integrity="sha384-NXtFPpN61oWCuN4D42K6Zd5Rt2+uxeIT36R7kpXBuY9tLnZorzrJ4ykpqwJfgjpZ" crossorigin="anonymous"></script>
|
||||
<script>
|
||||
SwaggerUIBundle({
|
||||
url: "/openapi.yaml",
|
||||
|
||||
569
internal/audit/logger.go
Normal file
569
internal/audit/logger.go
Normal file
@ -0,0 +1,569 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/events"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
// AuditLogger writes audit log entries for user-initiated and system events.
|
||||
// All methods are fire-and-forget: failures are logged via slog and never
|
||||
// propagated to the caller.
|
||||
type AuditLogger struct {
|
||||
db *db.Queries
|
||||
pub events.EventPublisher // optional — nil disables event publishing
|
||||
}
|
||||
|
||||
// New constructs an AuditLogger without event publishing.
|
||||
func New(queries *db.Queries) *AuditLogger {
|
||||
return &AuditLogger{db: queries}
|
||||
}
|
||||
|
||||
// NewWithPublisher constructs an AuditLogger that also publishes channel events.
|
||||
func NewWithPublisher(queries *db.Queries, pub events.EventPublisher) *AuditLogger {
|
||||
return &AuditLogger{db: queries, pub: pub}
|
||||
}
|
||||
|
||||
// publish sends an event to the notification stream if a publisher is configured.
|
||||
func (l *AuditLogger) publish(ctx context.Context, e events.Event) {
|
||||
if l.pub != nil {
|
||||
l.pub.Publish(ctx, e)
|
||||
}
|
||||
}
|
||||
|
||||
// actorToEvent converts auth context fields to an events.Actor.
|
||||
func actorToEvent(ac auth.AuthContext) events.Actor {
|
||||
at, aid, aname := actorFields(ac)
|
||||
return events.Actor{Type: events.ActorKind(at), ID: aid, Name: aname}
|
||||
}
|
||||
|
||||
// systemActor returns an events.Actor for system-initiated events.
|
||||
func systemActor() events.Actor {
|
||||
return events.Actor{Type: events.ActorSystem}
|
||||
}
|
||||
|
||||
// actorFields extracts actor_type, actor_id, and actor_name from an AuthContext.
|
||||
// actor_id is stored as a prefixed string in the TEXT column.
|
||||
func actorFields(ac auth.AuthContext) (actorType, actorID, actorName string) {
|
||||
if ac.UserID.Valid {
|
||||
return "user", id.FormatUserID(ac.UserID), ac.Name
|
||||
}
|
||||
if ac.APIKeyID.Valid {
|
||||
return "api_key", id.FormatAPIKeyID(ac.APIKeyID), ac.APIKeyName
|
||||
}
|
||||
return "system", "", ""
|
||||
}
|
||||
|
||||
func (l *AuditLogger) write(ctx context.Context, p db.InsertAuditLogParams) {
|
||||
if err := l.db.InsertAuditLog(ctx, p); err != nil {
|
||||
slog.Warn("audit: failed to write log entry",
|
||||
"action", p.Action,
|
||||
"resource_type", p.ResourceType,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func marshalMeta(meta map[string]any) []byte {
|
||||
if len(meta) == 0 {
|
||||
return []byte("{}")
|
||||
}
|
||||
b, err := json.Marshal(meta)
|
||||
if err != nil {
|
||||
return []byte("{}")
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// optText returns a valid pgtype.Text if s is non-empty, otherwise an invalid (NULL) one.
|
||||
func optText(s string) pgtype.Text {
|
||||
if s == "" {
|
||||
return pgtype.Text{}
|
||||
}
|
||||
return pgtype.Text{String: s, Valid: true}
|
||||
}
|
||||
|
||||
// --- Sandbox events (scope: team) ---
|
||||
|
||||
func (l *AuditLogger) LogSandboxCreate(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID, template string) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "sandbox",
|
||||
ResourceID: optText(id.FormatSandboxID(sandboxID)),
|
||||
Action: "create",
|
||||
Scope: "team",
|
||||
Status: "success",
|
||||
Metadata: marshalMeta(map[string]any{"template": template}),
|
||||
})
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.CapsuleCreated,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(ac.TeamID),
|
||||
Actor: actorToEvent(ac),
|
||||
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogSandboxPause(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "sandbox",
|
||||
ResourceID: optText(id.FormatSandboxID(sandboxID)),
|
||||
Action: "pause",
|
||||
Scope: "team",
|
||||
Status: "success",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.CapsulePaused,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(ac.TeamID),
|
||||
Actor: actorToEvent(ac),
|
||||
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
|
||||
})
|
||||
}
|
||||
|
||||
// LogSandboxAutoPause records a system-initiated auto-pause (TTL or host reconciler).
|
||||
func (l *AuditLogger) LogSandboxAutoPause(ctx context.Context, teamID, sandboxID pgtype.UUID) {
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: teamID,
|
||||
ActorType: "system",
|
||||
ActorID: pgtype.Text{},
|
||||
ActorName: "",
|
||||
ResourceType: "sandbox",
|
||||
ResourceID: optText(id.FormatSandboxID(sandboxID)),
|
||||
Action: "pause",
|
||||
Scope: "team",
|
||||
Status: "info",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.CapsulePaused,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Actor: systemActor(),
|
||||
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogSandboxResume(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "sandbox",
|
||||
ResourceID: optText(id.FormatSandboxID(sandboxID)),
|
||||
Action: "resume",
|
||||
Scope: "team",
|
||||
Status: "success",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.CapsuleRunning,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(ac.TeamID),
|
||||
Actor: actorToEvent(ac),
|
||||
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogSandboxDestroy(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "sandbox",
|
||||
ResourceID: optText(id.FormatSandboxID(sandboxID)),
|
||||
Action: "destroy",
|
||||
Scope: "team",
|
||||
Status: "warning",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.CapsuleDestroyed,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(ac.TeamID),
|
||||
Actor: actorToEvent(ac),
|
||||
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
|
||||
})
|
||||
}
|
||||
|
||||
// --- Snapshot events (scope: team) ---
|
||||
|
||||
func (l *AuditLogger) LogSnapshotCreate(ctx context.Context, ac auth.AuthContext, name string) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "snapshot",
|
||||
ResourceID: optText(name),
|
||||
Action: "create",
|
||||
Scope: "team",
|
||||
Status: "success",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.SnapshotCreated,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(ac.TeamID),
|
||||
Actor: actorToEvent(ac),
|
||||
Resource: events.Resource{ID: name, Type: "snapshot"},
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogSnapshotDelete(ctx context.Context, ac auth.AuthContext, name string) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "snapshot",
|
||||
ResourceID: optText(name),
|
||||
Action: "delete",
|
||||
Scope: "team",
|
||||
Status: "warning",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.SnapshotDeleted,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(ac.TeamID),
|
||||
Actor: actorToEvent(ac),
|
||||
Resource: events.Resource{ID: name, Type: "snapshot"},
|
||||
})
|
||||
}
|
||||
|
||||
// --- Team events (scope: team) ---
|
||||
|
||||
func (l *AuditLogger) LogTeamRename(ctx context.Context, ac auth.AuthContext, teamID pgtype.UUID, oldName, newName string) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "team",
|
||||
ResourceID: optText(id.FormatTeamID(teamID)),
|
||||
Action: "rename",
|
||||
Scope: "team",
|
||||
Status: "info",
|
||||
Metadata: marshalMeta(map[string]any{"old_name": oldName, "new_name": newName}),
|
||||
})
|
||||
}
|
||||
|
||||
// --- Channel events (scope: team) ---
|
||||
|
||||
func (l *AuditLogger) LogChannelCreate(ctx context.Context, ac auth.AuthContext, channelID pgtype.UUID, name, provider string) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "channel",
|
||||
ResourceID: optText(id.FormatChannelID(channelID)),
|
||||
Action: "create",
|
||||
Scope: "team",
|
||||
Status: "success",
|
||||
Metadata: marshalMeta(map[string]any{"name": name, "provider": provider}),
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogChannelUpdate(ctx context.Context, ac auth.AuthContext, channelID pgtype.UUID) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "channel",
|
||||
ResourceID: optText(id.FormatChannelID(channelID)),
|
||||
Action: "update",
|
||||
Scope: "team",
|
||||
Status: "info",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogChannelRotateConfig(ctx context.Context, ac auth.AuthContext, channelID pgtype.UUID) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "channel",
|
||||
ResourceID: optText(id.FormatChannelID(channelID)),
|
||||
Action: "rotate_config",
|
||||
Scope: "team",
|
||||
Status: "info",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogChannelDelete(ctx context.Context, ac auth.AuthContext, channelID pgtype.UUID) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "channel",
|
||||
ResourceID: optText(id.FormatChannelID(channelID)),
|
||||
Action: "delete",
|
||||
Scope: "team",
|
||||
Status: "warning",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
}
|
||||
|
||||
// --- API key events (scope: team) ---
|
||||
|
||||
func (l *AuditLogger) LogAPIKeyCreate(ctx context.Context, ac auth.AuthContext, keyID pgtype.UUID, keyName string) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "api_key",
|
||||
ResourceID: optText(id.FormatAPIKeyID(keyID)),
|
||||
Action: "create",
|
||||
Scope: "team",
|
||||
Status: "success",
|
||||
Metadata: marshalMeta(map[string]any{"name": keyName}),
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogAPIKeyRevoke(ctx context.Context, ac auth.AuthContext, keyID pgtype.UUID) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "api_key",
|
||||
ResourceID: optText(id.FormatAPIKeyID(keyID)),
|
||||
Action: "revoke",
|
||||
Scope: "team",
|
||||
Status: "warning",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
}
|
||||
|
||||
// --- Member events (scope: admin) ---
|
||||
|
||||
func (l *AuditLogger) LogMemberAdd(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID, targetEmail, role string) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "member",
|
||||
ResourceID: optText(id.FormatUserID(targetUserID)),
|
||||
Action: "add",
|
||||
Scope: "admin",
|
||||
Status: "success",
|
||||
Metadata: marshalMeta(map[string]any{"email": targetEmail, "role": role}),
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogMemberRemove(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "member",
|
||||
ResourceID: optText(id.FormatUserID(targetUserID)),
|
||||
Action: "remove",
|
||||
Scope: "admin",
|
||||
Status: "warning",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogMemberLeave(ctx context.Context, ac auth.AuthContext) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
resourceID := ""
|
||||
if ac.UserID.Valid {
|
||||
resourceID = id.FormatUserID(ac.UserID)
|
||||
}
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "member",
|
||||
ResourceID: optText(resourceID),
|
||||
Action: "leave",
|
||||
Scope: "admin",
|
||||
Status: "info",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogMemberRoleUpdate(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID, newRole string) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "member",
|
||||
ResourceID: optText(id.FormatUserID(targetUserID)),
|
||||
Action: "role_update",
|
||||
Scope: "admin",
|
||||
Status: "info",
|
||||
Metadata: marshalMeta(map[string]any{"new_role": newRole}),
|
||||
})
|
||||
}
|
||||
|
||||
// --- Host events (scope: admin) ---
|
||||
|
||||
func (l *AuditLogger) LogHostCreate(ctx context.Context, ac auth.AuthContext, hostID, teamID pgtype.UUID) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
// For shared hosts with no owning team, use the caller's team.
|
||||
logTeamID := teamID
|
||||
if !logTeamID.Valid {
|
||||
logTeamID = ac.TeamID
|
||||
}
|
||||
if !logTeamID.Valid {
|
||||
return
|
||||
}
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: logTeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "host",
|
||||
ResourceID: optText(id.FormatHostID(hostID)),
|
||||
Action: "create",
|
||||
Scope: "admin",
|
||||
Status: "success",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogHostDelete(ctx context.Context, ac auth.AuthContext, hostID, teamID pgtype.UUID) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
logTeamID := teamID
|
||||
if !logTeamID.Valid {
|
||||
logTeamID = ac.TeamID
|
||||
}
|
||||
if !logTeamID.Valid {
|
||||
return
|
||||
}
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: logTeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "host",
|
||||
ResourceID: optText(id.FormatHostID(hostID)),
|
||||
Action: "delete",
|
||||
Scope: "admin",
|
||||
Status: "warning",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
}
|
||||
|
||||
// LogHostMarkedDown records a system-initiated host status transition to unreachable.
|
||||
// Scoped to "team" so BYOC team members can see when their hosts go down.
|
||||
func (l *AuditLogger) LogHostMarkedDown(ctx context.Context, teamID, hostID pgtype.UUID) {
|
||||
if !teamID.Valid {
|
||||
return
|
||||
}
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: teamID,
|
||||
ActorType: "system",
|
||||
ActorID: pgtype.Text{},
|
||||
ActorName: "",
|
||||
ResourceType: "host",
|
||||
ResourceID: optText(id.FormatHostID(hostID)),
|
||||
Action: "marked_down",
|
||||
Scope: "team",
|
||||
Status: "error",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.HostDown,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Actor: systemActor(),
|
||||
Resource: events.Resource{ID: id.FormatHostID(hostID), Type: "host"},
|
||||
})
|
||||
}
|
||||
|
||||
// LogHostMarkedUp records a system-initiated host status transition back to online.
|
||||
// Scoped to "team" so BYOC team members can see when their hosts recover.
|
||||
func (l *AuditLogger) LogHostMarkedUp(ctx context.Context, teamID, hostID pgtype.UUID) {
|
||||
if !teamID.Valid {
|
||||
return
|
||||
}
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: teamID,
|
||||
ActorType: "system",
|
||||
ActorID: pgtype.Text{},
|
||||
ActorName: "",
|
||||
ResourceType: "host",
|
||||
ResourceID: optText(id.FormatHostID(hostID)),
|
||||
Action: "marked_up",
|
||||
Scope: "team",
|
||||
Status: "success",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.HostUp,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Actor: systemActor(),
|
||||
Resource: events.Resource{ID: id.FormatHostID(hostID), Type: "host"},
|
||||
})
|
||||
}
|
||||
251
internal/auth/cert.go
Normal file
251
internal/auth/cert.go
Normal file
@ -0,0 +1,251 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CPCertRenewInterval is how often the control plane should renew its client
|
||||
// certificate. It is set to half the cert TTL so there is always a wide safety
|
||||
// margin before expiry.
|
||||
const CPCertRenewInterval = cpCertTTL / 2
|
||||
|
||||
const (
|
||||
hostCertTTL = 7 * 24 * time.Hour
|
||||
cpCertTTL = 24 * time.Hour
|
||||
)
|
||||
|
||||
// CA holds a parsed certificate authority ready to issue leaf certificates.
|
||||
type CA struct {
|
||||
Cert *x509.Certificate
|
||||
Key *ecdsa.PrivateKey
|
||||
PEM string // PEM-encoded certificate for embedding in register/refresh responses
|
||||
}
|
||||
|
||||
// ParseCA parses PEM-encoded CA certificate and private key strings.
|
||||
// The cert and key are expected to be ECDSA P-256.
|
||||
func ParseCA(certPEM, keyPEM string) (*CA, error) {
|
||||
certBlock, _ := pem.Decode([]byte(certPEM))
|
||||
if certBlock == nil {
|
||||
return nil, fmt.Errorf("failed to decode CA certificate PEM")
|
||||
}
|
||||
cert, err := x509.ParseCertificate(certBlock.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse CA certificate: %w", err)
|
||||
}
|
||||
|
||||
keyBlock, _ := pem.Decode([]byte(keyPEM))
|
||||
if keyBlock == nil {
|
||||
return nil, fmt.Errorf("failed to decode CA key PEM")
|
||||
}
|
||||
keyIface, err := x509.ParseECPrivateKey(keyBlock.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse CA private key: %w", err)
|
||||
}
|
||||
|
||||
return &CA{Cert: cert, Key: keyIface, PEM: certPEM}, nil
|
||||
}
|
||||
|
||||
// HostCert holds all material returned when issuing a leaf cert for a host agent.
|
||||
type HostCert struct {
|
||||
CertPEM string
|
||||
KeyPEM string
|
||||
Fingerprint string // hex-encoded SHA-256 of DER bytes, stored in hosts.cert_fingerprint
|
||||
ExpiresAt time.Time // stored in hosts.cert_expires_at
|
||||
TLSCert tls.Certificate
|
||||
}
|
||||
|
||||
// IssueHostCert generates an ECDSA P-256 key pair and issues a 7-day server
|
||||
// certificate for the host agent. hostID becomes the common name; the host's
|
||||
// IP address (parsed from hostAddr) is added as an IP SAN so Go's TLS
|
||||
// stack can verify the connection without disabling hostname checking.
|
||||
func IssueHostCert(ca *CA, hostID, hostAddr string) (HostCert, error) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return HostCert{}, fmt.Errorf("generate host key: %w", err)
|
||||
}
|
||||
|
||||
serial, err := randomSerial()
|
||||
if err != nil {
|
||||
return HostCert{}, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
expires := now.Add(hostCertTTL)
|
||||
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
Subject: pkix.Name{CommonName: hostID},
|
||||
NotBefore: now.Add(-time.Minute), // small clock-skew tolerance
|
||||
NotAfter: expires,
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
}
|
||||
|
||||
// Extract IP from "ip:port" address; fall back to DNS SAN if not parseable.
|
||||
host, _, err := net.SplitHostPort(hostAddr)
|
||||
if err != nil {
|
||||
host = hostAddr
|
||||
}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
tmpl.IPAddresses = []net.IP{ip}
|
||||
} else {
|
||||
tmpl.DNSNames = []string{host}
|
||||
}
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key)
|
||||
if err != nil {
|
||||
return HostCert{}, fmt.Errorf("create host certificate: %w", err)
|
||||
}
|
||||
|
||||
certPEM := string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}))
|
||||
keyDER, err := x509.MarshalECPrivateKey(key)
|
||||
if err != nil {
|
||||
return HostCert{}, fmt.Errorf("marshal host key: %w", err)
|
||||
}
|
||||
keyPEM := string(pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}))
|
||||
|
||||
tlsCert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
|
||||
if err != nil {
|
||||
return HostCert{}, fmt.Errorf("build TLS certificate: %w", err)
|
||||
}
|
||||
|
||||
fp := fmt.Sprintf("%x", sha256.Sum256(derBytes))
|
||||
|
||||
return HostCert{
|
||||
CertPEM: certPEM,
|
||||
KeyPEM: keyPEM,
|
||||
Fingerprint: fp,
|
||||
ExpiresAt: expires,
|
||||
TLSCert: tlsCert,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// IssueCPClientCert generates a short-lived (24h) ECDSA client certificate for
|
||||
// the control plane to present during mTLS handshakes with host agents.
|
||||
// Called once at CP startup; the result is embedded into the shared HTTP client.
|
||||
func IssueCPClientCert(ca *CA) (tls.Certificate, error) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("generate CP client key: %w", err)
|
||||
}
|
||||
|
||||
serial, err := randomSerial()
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
Subject: pkix.Name{CommonName: "wrenn-cp"},
|
||||
NotBefore: now.Add(-time.Minute),
|
||||
NotAfter: now.Add(cpCertTTL),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||
}
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("create CP client certificate: %w", err)
|
||||
}
|
||||
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
keyDER, err := x509.MarshalECPrivateKey(key)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("marshal CP client key: %w", err)
|
||||
}
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||
|
||||
return tls.X509KeyPair(certPEM, keyPEM)
|
||||
}
|
||||
|
||||
// AgentTLSConfigFromPEM returns a tls.Config for the host agent using the
|
||||
// PEM-encoded CA certificate. This is used on the agent side where only the
|
||||
// CA certificate (not the private key) is available.
|
||||
func AgentTLSConfigFromPEM(caCertPEM string, getCert func(*tls.ClientHelloInfo) (*tls.Certificate, error)) *tls.Config {
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM([]byte(caCertPEM)) {
|
||||
return nil
|
||||
}
|
||||
return &tls.Config{
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
ClientCAs: pool,
|
||||
GetCertificate: getCert,
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}
|
||||
}
|
||||
|
||||
// CPCertStore provides lock-free read/write access to the control plane's
|
||||
// current client TLS certificate. It is used with tls.Config.GetClientCertificate
|
||||
// to enable hot-swap without restarting the HTTP client.
|
||||
//
|
||||
// The zero value is not usable; use NewCPCertStore to create one.
|
||||
type CPCertStore struct {
|
||||
ptr atomic.Pointer[tls.Certificate]
|
||||
ca *CA
|
||||
}
|
||||
|
||||
// NewCPCertStore issues an initial CP client certificate from ca and returns a
|
||||
// store that can renew it in place. Returns an error if the initial issuance fails.
|
||||
func NewCPCertStore(ca *CA) (*CPCertStore, error) {
|
||||
s := &CPCertStore{ca: ca}
|
||||
if err := s.Refresh(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Refresh issues a fresh CP client certificate and atomically stores it.
|
||||
// If issuance fails the existing cert is unchanged.
|
||||
func (s *CPCertStore) Refresh() error {
|
||||
cert, err := IssueCPClientCert(s.ca)
|
||||
if err != nil {
|
||||
return fmt.Errorf("renew CP client certificate: %w", err)
|
||||
}
|
||||
s.ptr.Store(&cert)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClientCertificate satisfies tls.Config.GetClientCertificate. It is called
|
||||
// per-handshake and always returns the most recently stored certificate.
|
||||
func (s *CPCertStore) GetClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
cert := s.ptr.Load()
|
||||
if cert == nil {
|
||||
return nil, fmt.Errorf("no CP client certificate available")
|
||||
}
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// CPClientTLSConfig returns a tls.Config for the CP's outbound HTTP client.
|
||||
// It uses certStore.GetClientCertificate so the certificate can be renewed
|
||||
// without replacing the config or transport.
|
||||
func CPClientTLSConfig(ca *CA, certStore *CPCertStore) *tls.Config {
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(ca.Cert)
|
||||
return &tls.Config{
|
||||
RootCAs: pool,
|
||||
GetClientCertificate: certStore.GetClientCertificate,
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}
|
||||
}
|
||||
|
||||
// randomSerial returns a random 128-bit certificate serial number.
|
||||
func randomSerial() (*big.Int, error) {
|
||||
serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate serial number: %w", err)
|
||||
}
|
||||
return serial, nil
|
||||
}
|
||||
@ -1,6 +1,10 @@
|
||||
package auth
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
type contextKey int
|
||||
|
||||
@ -8,9 +12,14 @@ const authCtxKey contextKey = 0
|
||||
|
||||
// AuthContext is stamped into request context by auth middleware.
|
||||
type AuthContext struct {
|
||||
TeamID string
|
||||
UserID string // empty when authenticated via API key
|
||||
Email string // empty when authenticated via API key
|
||||
TeamID pgtype.UUID
|
||||
UserID pgtype.UUID // zero value (Valid=false) when authenticated via API key
|
||||
Email string // empty when authenticated via API key
|
||||
Name string // empty when authenticated via API key
|
||||
Role string // owner, admin, or member; empty when authenticated via API key
|
||||
IsAdmin bool // platform-level admin; always false when authenticated via API key
|
||||
APIKeyID pgtype.UUID // populated when authenticated via API key; zero value for JWT auth
|
||||
APIKeyName string // display name of the key, snapshotted at auth time; empty for JWT auth
|
||||
}
|
||||
|
||||
// WithAuthContext returns a new context with the given AuthContext.
|
||||
@ -38,7 +47,7 @@ const hostCtxKey contextKey = 1
|
||||
|
||||
// HostContext is stamped into request context by host token middleware.
|
||||
type HostContext struct {
|
||||
HostID string
|
||||
HostID pgtype.UUID
|
||||
}
|
||||
|
||||
// WithHostContext returns a new context with the given HostContext.
|
||||
|
||||
@ -5,27 +5,37 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
const jwtExpiry = 6 * time.Hour
|
||||
const hostJWTExpiry = 8760 * time.Hour // 1 year
|
||||
const hostJWTExpiry = 7 * 24 * time.Hour // 7 days; host refreshes via refresh token
|
||||
const HostRefreshTokenExpiry = 60 * 24 * time.Hour // 60 days; exported for service layer
|
||||
|
||||
// Claims are the JWT payload for user tokens.
|
||||
type Claims struct {
|
||||
Type string `json:"typ,omitempty"` // empty for user tokens; used to reject host tokens
|
||||
TeamID string `json:"team_id"`
|
||||
Email string `json:"email"`
|
||||
Type string `json:"typ,omitempty"` // empty for user tokens; used to reject host tokens
|
||||
TeamID string `json:"team_id"`
|
||||
Role string `json:"role"` // owner, admin, or member within TeamID
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
IsAdmin bool `json:"is_admin,omitempty"` // platform-level admin flag
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// SignJWT signs a new 6-hour JWT for the given user.
|
||||
func SignJWT(secret []byte, userID, teamID, email string) (string, error) {
|
||||
func SignJWT(secret []byte, userID, teamID pgtype.UUID, email, name, role string, isAdmin bool) (string, error) {
|
||||
now := time.Now()
|
||||
claims := Claims{
|
||||
TeamID: teamID,
|
||||
Email: email,
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Role: role,
|
||||
Email: email,
|
||||
Name: name,
|
||||
IsAdmin: isAdmin,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID,
|
||||
Subject: id.FormatUserID(userID),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(jwtExpiry)),
|
||||
},
|
||||
@ -63,14 +73,15 @@ type HostClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// SignHostJWT signs a long-lived (1 year) JWT for a registered host agent.
|
||||
func SignHostJWT(secret []byte, hostID string) (string, error) {
|
||||
// SignHostJWT signs a long-lived (7-day) JWT for a registered host agent.
|
||||
func SignHostJWT(secret []byte, hostID pgtype.UUID) (string, error) {
|
||||
formatted := id.FormatHostID(hostID)
|
||||
now := time.Now()
|
||||
claims := HostClaims{
|
||||
Type: "host",
|
||||
HostID: hostID,
|
||||
HostID: formatted,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: hostID,
|
||||
Subject: formatted,
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(hostJWTExpiry)),
|
||||
},
|
||||
|
||||
63
internal/channels/crypto.go
Normal file
63
internal/channels/crypto.go
Normal file
@ -0,0 +1,63 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// EncryptSecret encrypts plaintext using AES-256-GCM with a random nonce.
|
||||
// Returns base64(nonce || ciphertext).
|
||||
func EncryptSecret(key [32]byte, plaintext string) (string, error) {
|
||||
block, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("aes cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("gcm: %w", err)
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", fmt.Errorf("nonce: %w", err)
|
||||
}
|
||||
|
||||
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// DecryptSecret decrypts a value produced by EncryptSecret.
|
||||
func DecryptSecret(key [32]byte, encoded string) (string, error) {
|
||||
data, err := base64.StdEncoding.DecodeString(encoded)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("base64 decode: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("aes cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("gcm: %w", err)
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return "", fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt: %w", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
36
internal/channels/deliver.go
Normal file
36
internal/channels/deliver.go
Normal file
@ -0,0 +1,36 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/containrrr/shoutrrr"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/events"
|
||||
)
|
||||
|
||||
// Deliver sends a notification to a single provider with the given config.
|
||||
// For webhooks it uses HMAC-signed HTTP POST; for all others it uses shoutrrr.
|
||||
func Deliver(ctx context.Context, provider string, config map[string]string, e events.Event) error {
|
||||
payload, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal event: %w", err)
|
||||
}
|
||||
|
||||
if provider == "webhook" {
|
||||
wh := NewWebhookDelivery()
|
||||
return wh.Deliver(ctx, config["url"], config["secret"], payload)
|
||||
}
|
||||
|
||||
shoutrrrURL, err := ShoutrrrURL(provider, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build shoutrrr URL: %w", err)
|
||||
}
|
||||
|
||||
msg := FormatMessage(e)
|
||||
if err := shoutrrr.Send(shoutrrrURL, msg); err != nil {
|
||||
return fmt.Errorf("shoutrrr send: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
183
internal/channels/dispatcher.go
Normal file
183
internal/channels/dispatcher.go
Normal file
@ -0,0 +1,183 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/events"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
const (
|
||||
groupName = "wrenn-channels-v1"
|
||||
consumerName = "cp-0"
|
||||
)
|
||||
|
||||
// Dispatcher consumes events from the Redis stream and delivers them
|
||||
// to matching notification channels.
|
||||
type Dispatcher struct {
|
||||
rdb *redis.Client
|
||||
db *db.Queries
|
||||
encKey [32]byte
|
||||
webhook *WebhookDelivery
|
||||
}
|
||||
|
||||
// NewDispatcher constructs an event dispatcher.
|
||||
func NewDispatcher(rdb *redis.Client, queries *db.Queries, encKey [32]byte) *Dispatcher {
|
||||
return &Dispatcher{
|
||||
rdb: rdb,
|
||||
db: queries,
|
||||
encKey: encKey,
|
||||
webhook: NewWebhookDelivery(),
|
||||
}
|
||||
}
|
||||
|
||||
// Start launches the consumer goroutine. Returns when ctx is cancelled.
|
||||
func (d *Dispatcher) Start(ctx context.Context) {
|
||||
go d.run(ctx)
|
||||
}
|
||||
|
||||
func (d *Dispatcher) run(ctx context.Context) {
|
||||
// Create consumer group idempotently. "$" means only new messages.
|
||||
err := d.rdb.XGroupCreateMkStream(ctx, streamKey, groupName, "$").Err()
|
||||
if err != nil && !isGroupExistsError(err) {
|
||||
slog.Error("channels: failed to create consumer group", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
streams, err := d.rdb.XReadGroup(ctx, &redis.XReadGroupArgs{
|
||||
Group: groupName,
|
||||
Consumer: consumerName,
|
||||
Streams: []string{streamKey, ">"},
|
||||
Count: 10,
|
||||
Block: 5 * time.Second,
|
||||
}).Result()
|
||||
|
||||
if err != nil {
|
||||
if err == redis.Nil || ctx.Err() != nil {
|
||||
continue
|
||||
}
|
||||
slog.Warn("channels: xreadgroup error", "error", err)
|
||||
time.Sleep(1 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, stream := range streams {
|
||||
for _, msg := range stream.Messages {
|
||||
d.handleMessage(ctx, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dispatcher) handleMessage(ctx context.Context, msg redis.XMessage) {
|
||||
defer func() {
|
||||
if err := d.rdb.XAck(ctx, streamKey, groupName, msg.ID).Err(); err != nil {
|
||||
slog.Warn("channels: xack failed", "id", msg.ID, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
payload, ok := msg.Values["payload"].(string)
|
||||
if !ok {
|
||||
slog.Warn("channels: message missing payload", "id", msg.ID)
|
||||
return
|
||||
}
|
||||
|
||||
var event events.Event
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
slog.Warn("channels: failed to unmarshal event", "id", msg.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
teamID, err := id.ParseTeamID(event.TeamID)
|
||||
if err != nil {
|
||||
slog.Warn("channels: invalid team ID in event", "team_id", event.TeamID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
channels, err := d.db.ListChannelsForEvent(ctx, db.ListChannelsForEventParams{
|
||||
TeamID: teamID,
|
||||
EventType: event.Event,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Warn("channels: failed to list channels for event", "event", event.Event, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, ch := range channels {
|
||||
d.dispatch(ctx, ch, event)
|
||||
}
|
||||
}
|
||||
|
||||
// retryDelays defines the wait durations before each retry attempt.
|
||||
var retryDelays = []time.Duration{10 * time.Second, 30 * time.Second}
|
||||
|
||||
func (d *Dispatcher) dispatch(ctx context.Context, ch db.Channel, e events.Event) {
|
||||
config, err := d.decryptConfig(ch.Config)
|
||||
if err != nil {
|
||||
slog.Warn("channels: failed to decrypt config",
|
||||
"channel_id", id.FormatChannelID(ch.ID), "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
chID := id.FormatChannelID(ch.ID)
|
||||
|
||||
if err := Deliver(ctx, ch.Provider, config, e); err != nil {
|
||||
slog.Warn("channels: delivery failed, scheduling retries",
|
||||
"channel_id", chID, "provider", ch.Provider, "error", err)
|
||||
go d.retryDeliver(ctx, ch.Provider, config, e, chID)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dispatcher) retryDeliver(ctx context.Context, provider string, config map[string]string, e events.Event, chID string) {
|
||||
for i, delay := range retryDelays {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(delay):
|
||||
}
|
||||
|
||||
if err := Deliver(ctx, provider, config, e); err != nil {
|
||||
slog.Warn("channels: retry delivery failed",
|
||||
"channel_id", chID, "provider", provider,
|
||||
"attempt", i+2, "error", err)
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
slog.Error("channels: delivery failed after all retries",
|
||||
"channel_id", chID, "provider", provider, "event", e.Event)
|
||||
}
|
||||
|
||||
func (d *Dispatcher) decryptConfig(configJSON []byte) (map[string]string, error) {
|
||||
var encrypted map[string]string
|
||||
if err := json.Unmarshal(configJSON, &encrypted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
decrypted := make(map[string]string, len(encrypted))
|
||||
for k, v := range encrypted {
|
||||
plaintext, err := DecryptSecret(d.encKey, v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decrypted[k] = plaintext
|
||||
}
|
||||
return decrypted, nil
|
||||
}
|
||||
|
||||
func isGroupExistsError(err error) bool {
|
||||
return err != nil && err.Error() == "BUSYGROUP Consumer Group name already exists"
|
||||
}
|
||||
65
internal/channels/message.go
Normal file
65
internal/channels/message.go
Normal file
@ -0,0 +1,65 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/events"
|
||||
)
|
||||
|
||||
// FormatMessage produces a human-readable notification string containing
|
||||
// the event summary, resource details, actor, and timestamp.
|
||||
func FormatMessage(e events.Event) string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString(formatSummary(e))
|
||||
fmt.Fprintf(&b, "\n\nEvent: %s", e.Event)
|
||||
fmt.Fprintf(&b, "\nResource: %s %s", e.Resource.Type, e.Resource.ID)
|
||||
fmt.Fprintf(&b, "\nActor: %s", formatActor(e.Actor))
|
||||
fmt.Fprintf(&b, "\nTeam: %s", e.TeamID)
|
||||
fmt.Fprintf(&b, "\nTime: %s", e.Timestamp)
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func formatSummary(e events.Event) string {
|
||||
switch e.Event {
|
||||
case events.CapsuleCreated:
|
||||
return fmt.Sprintf("Capsule %s created", e.Resource.ID)
|
||||
case events.CapsuleRunning:
|
||||
return fmt.Sprintf("Capsule %s is running", e.Resource.ID)
|
||||
case events.CapsulePaused:
|
||||
return fmt.Sprintf("Capsule %s paused", e.Resource.ID)
|
||||
case events.CapsuleDestroyed:
|
||||
return fmt.Sprintf("Capsule %s destroyed", e.Resource.ID)
|
||||
case events.SnapshotCreated:
|
||||
return fmt.Sprintf("Template snapshot %s created", e.Resource.ID)
|
||||
case events.SnapshotDeleted:
|
||||
return fmt.Sprintf("Template snapshot %s deleted", e.Resource.ID)
|
||||
case events.HostUp:
|
||||
return fmt.Sprintf("Host %s is up", e.Resource.ID)
|
||||
case events.HostDown:
|
||||
return fmt.Sprintf("Host %s is down", e.Resource.ID)
|
||||
default:
|
||||
return fmt.Sprintf("%s %s", e.Resource.Type, e.Resource.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func formatActor(a events.Actor) string {
|
||||
switch a.Type {
|
||||
case events.ActorSystem:
|
||||
return "system"
|
||||
case events.ActorUser:
|
||||
if a.Name != "" {
|
||||
return fmt.Sprintf("%s (%s)", a.Name, a.ID)
|
||||
}
|
||||
return a.ID
|
||||
case events.ActorAPIKey:
|
||||
if a.Name != "" {
|
||||
return fmt.Sprintf("api_key %s (%s)", a.Name, a.ID)
|
||||
}
|
||||
return fmt.Sprintf("api_key %s", a.ID)
|
||||
default:
|
||||
return string(a.Type)
|
||||
}
|
||||
}
|
||||
44
internal/channels/publisher.go
Normal file
44
internal/channels/publisher.go
Normal file
@ -0,0 +1,44 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/events"
|
||||
)
|
||||
|
||||
const streamKey = "wrenn:events"
|
||||
|
||||
// Publisher pushes events onto the Redis stream for the dispatcher to consume.
|
||||
type Publisher struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewPublisher constructs an event publisher.
|
||||
func NewPublisher(rdb *redis.Client) *Publisher {
|
||||
return &Publisher{rdb: rdb}
|
||||
}
|
||||
|
||||
// Publish serializes the event and appends it to the global stream.
|
||||
// Fire-and-forget: failures are logged, never propagated.
|
||||
func (p *Publisher) Publish(ctx context.Context, e events.Event) {
|
||||
payload, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
slog.Warn("channels: failed to marshal event", "event", e.Event, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := p.rdb.XAdd(ctx, &redis.XAddArgs{
|
||||
Stream: streamKey,
|
||||
MaxLen: 10000,
|
||||
Approx: true,
|
||||
Values: map[string]interface{}{
|
||||
"payload": string(payload),
|
||||
},
|
||||
}).Err(); err != nil {
|
||||
slog.Warn("channels: failed to publish event", "event", e.Event, "error", err)
|
||||
}
|
||||
}
|
||||
298
internal/channels/service.go
Normal file
298
internal/channels/service.go
Normal file
@ -0,0 +1,298 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/events"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/validate"
|
||||
)
|
||||
|
||||
// Valid providers.
|
||||
var validProviders = map[string]bool{
|
||||
"discord": true,
|
||||
"slack": true,
|
||||
"teams": true,
|
||||
"googlechat": true,
|
||||
"telegram": true,
|
||||
"matrix": true,
|
||||
"webhook": true,
|
||||
}
|
||||
|
||||
// Required config fields per provider.
|
||||
var requiredFields = map[string][]string{
|
||||
"discord": {"webhook_url"},
|
||||
"slack": {"webhook_url"},
|
||||
"teams": {"webhook_url"},
|
||||
"googlechat": {"webhook_url"},
|
||||
"telegram": {"bot_token", "chat_id"},
|
||||
"matrix": {"homeserver_url", "access_token", "room_id"},
|
||||
"webhook": {"url"},
|
||||
}
|
||||
|
||||
// validEvents maps event type strings to true for validation.
|
||||
var validEvents map[string]bool
|
||||
|
||||
func init() {
|
||||
validEvents = make(map[string]bool, len(events.AllEventTypes))
|
||||
for _, et := range events.AllEventTypes {
|
||||
validEvents[et] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Service handles channel CRUD operations.
|
||||
type Service struct {
|
||||
DB *db.Queries
|
||||
EncKey [32]byte
|
||||
}
|
||||
|
||||
// CreateParams holds the parameters for creating a channel.
|
||||
type CreateParams struct {
|
||||
TeamID pgtype.UUID
|
||||
Name string
|
||||
Provider string
|
||||
Config map[string]string
|
||||
Events []string
|
||||
}
|
||||
|
||||
// CreateResult holds the result of creating a channel.
|
||||
type CreateResult struct {
|
||||
Channel db.Channel
|
||||
PlaintextSecret string // non-empty only for webhook provider
|
||||
}
|
||||
|
||||
// Create creates a new notification channel.
|
||||
func (s *Service) Create(ctx context.Context, p CreateParams) (CreateResult, error) {
|
||||
clean, err := cleanName(p.Name)
|
||||
if err != nil {
|
||||
return CreateResult{}, err
|
||||
}
|
||||
p.Name = clean
|
||||
|
||||
if !validProviders[p.Provider] {
|
||||
return CreateResult{}, fmt.Errorf("invalid: unsupported provider %q", p.Provider)
|
||||
}
|
||||
|
||||
if len(p.Events) == 0 {
|
||||
return CreateResult{}, fmt.Errorf("invalid: at least one event type is required")
|
||||
}
|
||||
for _, et := range p.Events {
|
||||
if !validEvents[et] {
|
||||
return CreateResult{}, fmt.Errorf("invalid: unknown event type %q", et)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate required config fields.
|
||||
for _, field := range requiredFields[p.Provider] {
|
||||
if p.Config[field] == "" {
|
||||
return CreateResult{}, fmt.Errorf("invalid: %s is required for %s", field, p.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
// For webhooks, auto-generate secret if not provided.
|
||||
var plaintextSecret string
|
||||
if p.Provider == "webhook" {
|
||||
if p.Config["secret"] == "" {
|
||||
secret := generateSecret()
|
||||
p.Config["secret"] = secret
|
||||
plaintextSecret = secret
|
||||
} else {
|
||||
plaintextSecret = p.Config["secret"]
|
||||
}
|
||||
}
|
||||
|
||||
// Encrypt config fields.
|
||||
encrypted := make(map[string]string, len(p.Config))
|
||||
for k, v := range p.Config {
|
||||
enc, err := EncryptSecret(s.EncKey, v)
|
||||
if err != nil {
|
||||
return CreateResult{}, fmt.Errorf("encrypt config field %s: %w", k, err)
|
||||
}
|
||||
encrypted[k] = enc
|
||||
}
|
||||
|
||||
configJSON, err := json.Marshal(encrypted)
|
||||
if err != nil {
|
||||
return CreateResult{}, fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
|
||||
ch, err := s.DB.InsertChannel(ctx, db.InsertChannelParams{
|
||||
ID: id.NewChannelID(),
|
||||
TeamID: p.TeamID,
|
||||
Name: p.Name,
|
||||
Provider: p.Provider,
|
||||
Config: configJSON,
|
||||
EventTypes: p.Events,
|
||||
})
|
||||
if err != nil {
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
|
||||
return CreateResult{}, fmt.Errorf("conflict: channel name %q already exists", p.Name)
|
||||
}
|
||||
return CreateResult{}, fmt.Errorf("insert channel: %w", err)
|
||||
}
|
||||
|
||||
return CreateResult{Channel: ch, PlaintextSecret: plaintextSecret}, nil
|
||||
}
|
||||
|
||||
// List returns all channels belonging to the given team.
|
||||
func (s *Service) List(ctx context.Context, teamID pgtype.UUID) ([]db.Channel, error) {
|
||||
return s.DB.ListChannelsByTeam(ctx, teamID)
|
||||
}
|
||||
|
||||
// Get returns a single channel by ID, scoped to the given team.
|
||||
func (s *Service) Get(ctx context.Context, channelID, teamID pgtype.UUID) (db.Channel, error) {
|
||||
return s.DB.GetChannelByTeam(ctx, db.GetChannelByTeamParams{ID: channelID, TeamID: teamID})
|
||||
}
|
||||
|
||||
// Update updates a channel's name and event types.
|
||||
func (s *Service) Update(ctx context.Context, channelID, teamID pgtype.UUID, name string, eventTypes []string) (db.Channel, error) {
|
||||
clean, err := cleanName(name)
|
||||
if err != nil {
|
||||
return db.Channel{}, err
|
||||
}
|
||||
name = clean
|
||||
|
||||
if len(eventTypes) == 0 {
|
||||
return db.Channel{}, fmt.Errorf("invalid: at least one event type is required")
|
||||
}
|
||||
for _, et := range eventTypes {
|
||||
if !validEvents[et] {
|
||||
return db.Channel{}, fmt.Errorf("invalid: unknown event type %q", et)
|
||||
}
|
||||
}
|
||||
|
||||
ch, err := s.DB.UpdateChannel(ctx, db.UpdateChannelParams{
|
||||
ID: channelID,
|
||||
TeamID: teamID,
|
||||
Name: name,
|
||||
EventTypes: eventTypes,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return db.Channel{}, fmt.Errorf("channel not found")
|
||||
}
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
|
||||
return db.Channel{}, fmt.Errorf("conflict: channel name %q already exists", name)
|
||||
}
|
||||
return db.Channel{}, fmt.Errorf("update channel: %w", err)
|
||||
}
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// RotateConfig replaces a channel's config with new provider secrets.
|
||||
func (s *Service) RotateConfig(ctx context.Context, channelID, teamID pgtype.UUID, config map[string]string) (db.Channel, error) {
|
||||
// Look up the existing channel to get its provider for validation.
|
||||
ch, err := s.DB.GetChannelByTeam(ctx, db.GetChannelByTeamParams{ID: channelID, TeamID: teamID})
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return db.Channel{}, fmt.Errorf("channel not found")
|
||||
}
|
||||
return db.Channel{}, fmt.Errorf("get channel: %w", err)
|
||||
}
|
||||
|
||||
// Validate required config fields for this provider.
|
||||
for _, field := range requiredFields[ch.Provider] {
|
||||
if config[field] == "" {
|
||||
return db.Channel{}, fmt.Errorf("invalid: %s is required for %s", field, ch.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
// For webhooks, auto-generate secret if not provided.
|
||||
if ch.Provider == "webhook" && config["secret"] == "" {
|
||||
config["secret"] = generateSecret()
|
||||
}
|
||||
|
||||
// Encrypt all config fields.
|
||||
encrypted := make(map[string]string, len(config))
|
||||
for k, v := range config {
|
||||
enc, err := EncryptSecret(s.EncKey, v)
|
||||
if err != nil {
|
||||
return db.Channel{}, fmt.Errorf("encrypt config field %s: %w", k, err)
|
||||
}
|
||||
encrypted[k] = enc
|
||||
}
|
||||
|
||||
configJSON, err := json.Marshal(encrypted)
|
||||
if err != nil {
|
||||
return db.Channel{}, fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
|
||||
updated, err := s.DB.UpdateChannelConfig(ctx, db.UpdateChannelConfigParams{
|
||||
ID: channelID,
|
||||
TeamID: teamID,
|
||||
Config: configJSON,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return db.Channel{}, fmt.Errorf("channel not found")
|
||||
}
|
||||
return db.Channel{}, fmt.Errorf("update channel config: %w", err)
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// Test validates config and sends a test notification without persisting anything.
|
||||
func (s *Service) Test(ctx context.Context, provider string, config map[string]string) error {
|
||||
if !validProviders[provider] {
|
||||
return fmt.Errorf("invalid: unsupported provider %q", provider)
|
||||
}
|
||||
|
||||
for _, field := range requiredFields[provider] {
|
||||
if config[field] == "" {
|
||||
return fmt.Errorf("invalid: %s is required for %s", field, provider)
|
||||
}
|
||||
}
|
||||
|
||||
// For webhooks, auto-generate a temporary secret if not provided.
|
||||
if provider == "webhook" && config["secret"] == "" {
|
||||
config["secret"] = generateSecret()
|
||||
}
|
||||
|
||||
testEvent := events.Event{
|
||||
Event: "channel.test",
|
||||
Timestamp: events.Now(),
|
||||
TeamID: "test",
|
||||
Actor: events.Actor{Type: events.ActorSystem},
|
||||
Resource: events.Resource{ID: "test", Type: "channel"},
|
||||
}
|
||||
|
||||
return Deliver(ctx, provider, config, testEvent)
|
||||
}
|
||||
|
||||
// Delete removes a channel by ID, scoped to the given team.
|
||||
func (s *Service) Delete(ctx context.Context, channelID, teamID pgtype.UUID) error {
|
||||
return s.DB.DeleteChannelByTeam(ctx, db.DeleteChannelByTeamParams{ID: channelID, TeamID: teamID})
|
||||
}
|
||||
|
||||
// cleanName normalises a channel name: trim whitespace, lowercase, replace
|
||||
// spaces with hyphens, then validate against SafeName rules.
|
||||
func cleanName(name string) (string, error) {
|
||||
name = strings.TrimSpace(name)
|
||||
name = strings.ToLower(name)
|
||||
name = strings.ReplaceAll(name, " ", "-")
|
||||
if err := validate.SafeName(name); err != nil {
|
||||
return "", fmt.Errorf("invalid: %w", err)
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func generateSecret() string {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand failed: %v", err))
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
119
internal/channels/shoutrrr.go
Normal file
119
internal/channels/shoutrrr.go
Normal file
@ -0,0 +1,119 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ShoutrrrURL builds a shoutrrr-compatible URL from structured provider config.
|
||||
func ShoutrrrURL(provider string, config map[string]string) (string, error) {
|
||||
switch provider {
|
||||
case "discord":
|
||||
return discordURL(config)
|
||||
case "slack":
|
||||
return slackURL(config)
|
||||
case "teams":
|
||||
return teamsURL(config)
|
||||
case "googlechat":
|
||||
return googlechatURL(config)
|
||||
case "telegram":
|
||||
return telegramURL(config)
|
||||
case "matrix":
|
||||
return matrixURL(config)
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported shoutrrr provider: %s", provider)
|
||||
}
|
||||
}
|
||||
|
||||
// discordURL converts https://discord.com/api/webhooks/{id}/{token} → discord://{token}@{id}
|
||||
func discordURL(config map[string]string) (string, error) {
|
||||
u, err := url.Parse(config["webhook_url"])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid discord webhook URL: %w", err)
|
||||
}
|
||||
// Path: /api/webhooks/{id}/{token}
|
||||
parts := strings.Split(strings.TrimPrefix(u.Path, "/"), "/")
|
||||
if len(parts) < 4 || parts[0] != "api" || parts[1] != "webhooks" {
|
||||
return "", fmt.Errorf("unexpected discord webhook URL format")
|
||||
}
|
||||
webhookID, token := parts[2], parts[3]
|
||||
return fmt.Sprintf("discord://%s@%s?splitLines=No", token, webhookID), nil
|
||||
}
|
||||
|
||||
// slackURL converts https://hooks.slack.com/services/T.../B.../XXX → slack://T.../B.../XXX
|
||||
func slackURL(config map[string]string) (string, error) {
|
||||
u, err := url.Parse(config["webhook_url"])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid slack webhook URL: %w", err)
|
||||
}
|
||||
// Path: /services/TXXXXX/BXXXXX/XXXXXXXX
|
||||
parts := strings.Split(strings.TrimPrefix(u.Path, "/"), "/")
|
||||
if len(parts) < 4 || parts[0] != "services" {
|
||||
return "", fmt.Errorf("unexpected slack webhook URL format")
|
||||
}
|
||||
return fmt.Sprintf("slack://hook:%s-%s-%s@webhook", parts[1], parts[2], parts[3]), nil
|
||||
}
|
||||
|
||||
// teamsWebhookRe extracts the 4 components from a Teams webhook URL.
|
||||
// Format: https://<host>/<path>/{group}@{tenant}/IncomingWebhook/{altID}/{groupOwner}
|
||||
var teamsWebhookRe = regexp.MustCompile(`([0-9a-f-]{36})@([0-9a-f-]{36})/[^/]+/([0-9a-f]{32})/([0-9a-f-]{36})`)
|
||||
|
||||
// teamsURL converts a Teams webhook URL → teams://Group@Tenant/AltID/GroupOwner
|
||||
func teamsURL(config map[string]string) (string, error) {
|
||||
webhookURL := config["webhook_url"]
|
||||
if webhookURL == "" {
|
||||
return "", fmt.Errorf("teams webhook_url is required")
|
||||
}
|
||||
groups := teamsWebhookRe.FindStringSubmatch(webhookURL)
|
||||
if len(groups) != 5 {
|
||||
return "", fmt.Errorf("unexpected teams webhook URL format")
|
||||
}
|
||||
group, tenant, altID, groupOwner := groups[1], groups[2], groups[3], groups[4]
|
||||
return fmt.Sprintf("teams://%s@%s/%s/%s", group, tenant, altID, groupOwner), nil
|
||||
}
|
||||
|
||||
// googlechatURL converts a Google Chat webhook URL to shoutrrr format.
|
||||
// Input: https://chat.googleapis.com/v1/spaces/SPACE/messages?key=KEY&token=TOKEN
|
||||
// Output: googlechat://chat.googleapis.com/v1/spaces/SPACE/messages?key=KEY&token=TOKEN
|
||||
func googlechatURL(config map[string]string) (string, error) {
|
||||
webhookURL := config["webhook_url"]
|
||||
if webhookURL == "" {
|
||||
return "", fmt.Errorf("googlechat webhook_url is required")
|
||||
}
|
||||
u, err := url.Parse(webhookURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid googlechat webhook URL: %w", err)
|
||||
}
|
||||
if u.Host != "chat.googleapis.com" {
|
||||
return "", fmt.Errorf("unexpected googlechat webhook URL host: %s", u.Host)
|
||||
}
|
||||
// Rebuild as googlechat:// scheme with same host, path, and query.
|
||||
u.Scheme = "googlechat"
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
// telegramURL builds telegram://token@telegram/?chats=chatID
|
||||
func telegramURL(config map[string]string) (string, error) {
|
||||
token := config["bot_token"]
|
||||
chatID := config["chat_id"]
|
||||
if token == "" || chatID == "" {
|
||||
return "", fmt.Errorf("telegram bot_token and chat_id are required")
|
||||
}
|
||||
return fmt.Sprintf("telegram://%s@telegram/?chats=%s", token, chatID), nil
|
||||
}
|
||||
|
||||
// matrixURL builds matrix://user:token@homeserver/room
|
||||
func matrixURL(config map[string]string) (string, error) {
|
||||
homeserver := config["homeserver_url"]
|
||||
token := config["access_token"]
|
||||
roomID := config["room_id"]
|
||||
if homeserver == "" || token == "" || roomID == "" {
|
||||
return "", fmt.Errorf("matrix homeserver_url, access_token, and room_id are required")
|
||||
}
|
||||
// Strip protocol from homeserver URL.
|
||||
host := strings.TrimPrefix(strings.TrimPrefix(homeserver, "https://"), "http://")
|
||||
// Room ID often starts with ! — URL-encode it.
|
||||
return fmt.Sprintf("matrix://:%s@%s/%s", url.PathEscape(token), host, url.PathEscape(roomID)), nil
|
||||
}
|
||||
62
internal/channels/webhook.go
Normal file
62
internal/channels/webhook.go
Normal file
@ -0,0 +1,62 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// WebhookDelivery delivers events to webhook URLs with HMAC signing.
|
||||
type WebhookDelivery struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewWebhookDelivery constructs a webhook delivery client.
|
||||
func NewWebhookDelivery() *WebhookDelivery {
|
||||
return &WebhookDelivery{
|
||||
client: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
CheckRedirect: func(*http.Request, []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Deliver signs and POSTs the event payload to the configured URL.
|
||||
func (d *WebhookDelivery) Deliver(ctx context.Context, targetURL, secret string, payload []byte) error {
|
||||
timestamp := time.Now().UTC().Format(time.RFC3339)
|
||||
deliveryID := uuid.New().String()
|
||||
|
||||
// Compute HMAC-SHA256: sign over "timestamp.body".
|
||||
mac := hmac.New(sha256.New, []byte(secret))
|
||||
mac.Write([]byte(timestamp + "." + string(payload)))
|
||||
signature := "sha256=" + hex.EncodeToString(mac.Sum(nil))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, strings.NewReader(string(payload)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-WRENN-SIGNATURE", signature)
|
||||
req.Header.Set("X-Wrenn-Delivery", deliveryID)
|
||||
req.Header.Set("X-Wrenn-Timestamp", timestamp)
|
||||
|
||||
resp, err := d.client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("http post: %w", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return fmt.Errorf("webhook returned %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -1,24 +1,32 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
||||
// Config holds the control plane configuration.
|
||||
type Config struct {
|
||||
DatabaseURL string
|
||||
RedisURL string
|
||||
ListenAddr string
|
||||
HostAgentAddr string
|
||||
JWTSecret string
|
||||
DatabaseURL string
|
||||
RedisURL string
|
||||
ListenAddr string
|
||||
JWTSecret string
|
||||
|
||||
// mTLS — CP→Agent channel. Both must be set to enable mTLS; omitting either
|
||||
// disables cert issuance and leaves agent connections on plain HTTP (dev mode).
|
||||
CACert string // WRENN_CA_CERT — PEM-encoded internal CA certificate
|
||||
CAKey string // WRENN_CA_KEY — PEM-encoded internal CA private key
|
||||
|
||||
OAuthGitHubClientID string
|
||||
OAuthGitHubClientSecret string
|
||||
OAuthRedirectURL string
|
||||
CPPublicURL string
|
||||
|
||||
// Channels — encryption for channel secrets (AES-256-GCM).
|
||||
EncryptionKeyHex string // WRENN_ENCRYPTION_KEY raw hex string (for validation)
|
||||
EncryptionKey [32]byte // parsed 32-byte key
|
||||
}
|
||||
|
||||
// Load reads configuration from a .env file (if present) and environment variables.
|
||||
@ -28,21 +36,27 @@ func Load() Config {
|
||||
_ = godotenv.Load()
|
||||
|
||||
cfg := Config{
|
||||
DatabaseURL: envOrDefault("DATABASE_URL", "postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable"),
|
||||
RedisURL: envOrDefault("REDIS_URL", "redis://localhost:6379/0"),
|
||||
ListenAddr: envOrDefault("CP_LISTEN_ADDR", ":8080"),
|
||||
HostAgentAddr: envOrDefault("CP_HOST_AGENT_ADDR", "http://localhost:50051"),
|
||||
JWTSecret: os.Getenv("JWT_SECRET"),
|
||||
DatabaseURL: envOrDefault("DATABASE_URL", "postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable"),
|
||||
RedisURL: envOrDefault("REDIS_URL", "redis://localhost:6379/0"),
|
||||
ListenAddr: envOrDefault("WRENN_CP_LISTEN_ADDR", ":8080"),
|
||||
JWTSecret: os.Getenv("JWT_SECRET"),
|
||||
|
||||
CACert: os.Getenv("WRENN_CA_CERT"),
|
||||
CAKey: os.Getenv("WRENN_CA_KEY"),
|
||||
|
||||
OAuthGitHubClientID: os.Getenv("OAUTH_GITHUB_CLIENT_ID"),
|
||||
OAuthGitHubClientSecret: os.Getenv("OAUTH_GITHUB_CLIENT_SECRET"),
|
||||
OAuthRedirectURL: envOrDefault("OAUTH_REDIRECT_URL", "https://app.wrenn.dev"),
|
||||
CPPublicURL: os.Getenv("CP_PUBLIC_URL"),
|
||||
|
||||
EncryptionKeyHex: os.Getenv("WRENN_ENCRYPTION_KEY"),
|
||||
}
|
||||
|
||||
// Ensure the host agent address has a scheme.
|
||||
if !strings.HasPrefix(cfg.HostAgentAddr, "http://") && !strings.HasPrefix(cfg.HostAgentAddr, "https://") {
|
||||
cfg.HostAgentAddr = "http://" + cfg.HostAgentAddr
|
||||
if cfg.EncryptionKeyHex != "" {
|
||||
b, err := hex.DecodeString(cfg.EncryptionKeyHex)
|
||||
if err == nil && len(b) == 32 {
|
||||
copy(cfg.EncryptionKey[:], b)
|
||||
}
|
||||
}
|
||||
|
||||
return cfg
|
||||
|
||||
@ -16,8 +16,8 @@ DELETE FROM team_api_keys WHERE id = $1 AND team_id = $2
|
||||
`
|
||||
|
||||
type DeleteAPIKeyParams struct {
|
||||
ID string `json:"id"`
|
||||
TeamID string `json:"team_id"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) DeleteAPIKey(ctx context.Context, arg DeleteAPIKeyParams) error {
|
||||
@ -52,12 +52,12 @@ RETURNING id, team_id, name, key_hash, key_prefix, created_by, created_at, last_
|
||||
`
|
||||
|
||||
type InsertAPIKeyParams struct {
|
||||
ID string `json:"id"`
|
||||
TeamID string `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
KeyHash string `json:"key_hash"`
|
||||
KeyPrefix string `json:"key_prefix"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
KeyHash string `json:"key_hash"`
|
||||
KeyPrefix string `json:"key_prefix"`
|
||||
CreatedBy pgtype.UUID `json:"created_by"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (TeamApiKey, error) {
|
||||
@ -87,7 +87,7 @@ const listAPIKeysByTeam = `-- name: ListAPIKeysByTeam :many
|
||||
SELECT id, team_id, name, key_hash, key_prefix, created_by, created_at, last_used FROM team_api_keys WHERE team_id = $1 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListAPIKeysByTeam(ctx context.Context, teamID string) ([]TeamApiKey, error) {
|
||||
func (q *Queries) ListAPIKeysByTeam(ctx context.Context, teamID pgtype.UUID) ([]TeamApiKey, error) {
|
||||
rows, err := q.db.Query(ctx, listAPIKeysByTeam, teamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -126,18 +126,18 @@ ORDER BY k.created_at DESC
|
||||
`
|
||||
|
||||
type ListAPIKeysByTeamWithCreatorRow struct {
|
||||
ID string `json:"id"`
|
||||
TeamID string `json:"team_id"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
KeyHash string `json:"key_hash"`
|
||||
KeyPrefix string `json:"key_prefix"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedBy pgtype.UUID `json:"created_by"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
LastUsed pgtype.Timestamptz `json:"last_used"`
|
||||
CreatorEmail string `json:"creator_email"`
|
||||
}
|
||||
|
||||
func (q *Queries) ListAPIKeysByTeamWithCreator(ctx context.Context, teamID string) ([]ListAPIKeysByTeamWithCreatorRow, error) {
|
||||
func (q *Queries) ListAPIKeysByTeamWithCreator(ctx context.Context, teamID pgtype.UUID) ([]ListAPIKeysByTeamWithCreatorRow, error) {
|
||||
rows, err := q.db.Query(ctx, listAPIKeysByTeamWithCreator, teamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -171,7 +171,7 @@ const updateAPIKeyLastUsed = `-- name: UpdateAPIKeyLastUsed :exec
|
||||
UPDATE team_api_keys SET last_used = NOW() WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) UpdateAPIKeyLastUsed(ctx context.Context, id string) error {
|
||||
func (q *Queries) UpdateAPIKeyLastUsed(ctx context.Context, id pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, updateAPIKeyLastUsed, id)
|
||||
return err
|
||||
}
|
||||
|
||||
111
internal/db/audit.sql.go
Normal file
111
internal/db/audit.sql.go
Normal file
@ -0,0 +1,111 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: audit.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const insertAuditLog = `-- name: InsertAuditLog :exec
|
||||
INSERT INTO audit_logs (id, team_id, actor_type, actor_id, actor_name, resource_type, resource_id, action, scope, status, metadata)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
`
|
||||
|
||||
type InsertAuditLogParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
ActorType string `json:"actor_type"`
|
||||
ActorID pgtype.Text `json:"actor_id"`
|
||||
ActorName string `json:"actor_name"`
|
||||
ResourceType string `json:"resource_type"`
|
||||
ResourceID pgtype.Text `json:"resource_id"`
|
||||
Action string `json:"action"`
|
||||
Scope string `json:"scope"`
|
||||
Status string `json:"status"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) error {
|
||||
_, err := q.db.Exec(ctx, insertAuditLog,
|
||||
arg.ID,
|
||||
arg.TeamID,
|
||||
arg.ActorType,
|
||||
arg.ActorID,
|
||||
arg.ActorName,
|
||||
arg.ResourceType,
|
||||
arg.ResourceID,
|
||||
arg.Action,
|
||||
arg.Scope,
|
||||
arg.Status,
|
||||
arg.Metadata,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
const listAuditLogs = `-- name: ListAuditLogs :many
|
||||
SELECT id, team_id, actor_type, actor_id, actor_name, resource_type, resource_id, action, scope, status, metadata, created_at FROM audit_logs
|
||||
WHERE team_id = $1
|
||||
AND scope = ANY($2::text[])
|
||||
AND (cardinality($3::text[]) = 0 OR resource_type = ANY($3::text[]))
|
||||
AND (cardinality($4::text[]) = 0 OR action = ANY($4::text[]))
|
||||
AND ($5::timestamptz IS NULL OR created_at < $5
|
||||
OR (created_at = $5 AND id < $6))
|
||||
ORDER BY created_at DESC, id DESC
|
||||
LIMIT $7
|
||||
`
|
||||
|
||||
type ListAuditLogsParams struct {
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Column2 []string `json:"column_2"`
|
||||
Column3 []string `json:"column_3"`
|
||||
Column4 []string `json:"column_4"`
|
||||
Column5 pgtype.Timestamptz `json:"column_5"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Limit int32 `json:"limit"`
|
||||
}
|
||||
|
||||
func (q *Queries) ListAuditLogs(ctx context.Context, arg ListAuditLogsParams) ([]AuditLog, error) {
|
||||
rows, err := q.db.Query(ctx, listAuditLogs,
|
||||
arg.TeamID,
|
||||
arg.Column2,
|
||||
arg.Column3,
|
||||
arg.Column4,
|
||||
arg.Column5,
|
||||
arg.ID,
|
||||
arg.Limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []AuditLog
|
||||
for rows.Next() {
|
||||
var i AuditLog
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.ActorType,
|
||||
&i.ActorID,
|
||||
&i.ActorName,
|
||||
&i.ResourceType,
|
||||
&i.ResourceID,
|
||||
&i.Action,
|
||||
&i.Scope,
|
||||
&i.Status,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
225
internal/db/channels.sql.go
Normal file
225
internal/db/channels.sql.go
Normal file
@ -0,0 +1,225 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: channels.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const deleteChannelByTeam = `-- name: DeleteChannelByTeam :exec
|
||||
DELETE FROM channels WHERE id = $1 AND team_id = $2
|
||||
`
|
||||
|
||||
type DeleteChannelByTeamParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) DeleteChannelByTeam(ctx context.Context, arg DeleteChannelByTeamParams) error {
|
||||
_, err := q.db.Exec(ctx, deleteChannelByTeam, arg.ID, arg.TeamID)
|
||||
return err
|
||||
}
|
||||
|
||||
const getChannelByTeam = `-- name: GetChannelByTeam :one
|
||||
SELECT id, team_id, name, provider, config, event_types, created_at, updated_at FROM channels WHERE id = $1 AND team_id = $2
|
||||
`
|
||||
|
||||
type GetChannelByTeamParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetChannelByTeam(ctx context.Context, arg GetChannelByTeamParams) (Channel, error) {
|
||||
row := q.db.QueryRow(ctx, getChannelByTeam, arg.ID, arg.TeamID)
|
||||
var i Channel
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.Config,
|
||||
&i.EventTypes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertChannel = `-- name: InsertChannel :one
|
||||
INSERT INTO channels (id, team_id, name, provider, config, event_types)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id, team_id, name, provider, config, event_types, created_at, updated_at
|
||||
`
|
||||
|
||||
type InsertChannelParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Config []byte `json:"config"`
|
||||
EventTypes []string `json:"event_types"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertChannel(ctx context.Context, arg InsertChannelParams) (Channel, error) {
|
||||
row := q.db.QueryRow(ctx, insertChannel,
|
||||
arg.ID,
|
||||
arg.TeamID,
|
||||
arg.Name,
|
||||
arg.Provider,
|
||||
arg.Config,
|
||||
arg.EventTypes,
|
||||
)
|
||||
var i Channel
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.Config,
|
||||
&i.EventTypes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listChannelsByTeam = `-- name: ListChannelsByTeam :many
|
||||
SELECT id, team_id, name, provider, config, event_types, created_at, updated_at FROM channels WHERE team_id = $1 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListChannelsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Channel, error) {
|
||||
rows, err := q.db.Query(ctx, listChannelsByTeam, teamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Channel
|
||||
for rows.Next() {
|
||||
var i Channel
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.Config,
|
||||
&i.EventTypes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listChannelsForEvent = `-- name: ListChannelsForEvent :many
|
||||
SELECT id, team_id, name, provider, config, event_types, created_at, updated_at FROM channels
|
||||
WHERE team_id = $1
|
||||
AND $2::text = ANY(event_types)
|
||||
ORDER BY created_at
|
||||
`
|
||||
|
||||
type ListChannelsForEventParams struct {
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
EventType string `json:"event_type"`
|
||||
}
|
||||
|
||||
func (q *Queries) ListChannelsForEvent(ctx context.Context, arg ListChannelsForEventParams) ([]Channel, error) {
|
||||
rows, err := q.db.Query(ctx, listChannelsForEvent, arg.TeamID, arg.EventType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Channel
|
||||
for rows.Next() {
|
||||
var i Channel
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.Config,
|
||||
&i.EventTypes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateChannel = `-- name: UpdateChannel :one
|
||||
UPDATE channels SET name = $3, event_types = $4, updated_at = NOW()
|
||||
WHERE id = $1 AND team_id = $2
|
||||
RETURNING id, team_id, name, provider, config, event_types, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateChannelParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
EventTypes []string `json:"event_types"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateChannel(ctx context.Context, arg UpdateChannelParams) (Channel, error) {
|
||||
row := q.db.QueryRow(ctx, updateChannel,
|
||||
arg.ID,
|
||||
arg.TeamID,
|
||||
arg.Name,
|
||||
arg.EventTypes,
|
||||
)
|
||||
var i Channel
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.Config,
|
||||
&i.EventTypes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateChannelConfig = `-- name: UpdateChannelConfig :one
|
||||
UPDATE channels SET config = $3, updated_at = NOW()
|
||||
WHERE id = $1 AND team_id = $2
|
||||
RETURNING id, team_id, name, provider, config, event_types, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateChannelConfigParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Config []byte `json:"config"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateChannelConfig(ctx context.Context, arg UpdateChannelConfigParams) (Channel, error) {
|
||||
row := q.db.QueryRow(ctx, updateChannelConfig, arg.ID, arg.TeamID, arg.Config)
|
||||
var i Channel
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.Config,
|
||||
&i.EventTypes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
92
internal/db/host_refresh_tokens.sql.go
Normal file
92
internal/db/host_refresh_tokens.sql.go
Normal file
@ -0,0 +1,92 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: host_refresh_tokens.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const deleteExpiredHostRefreshTokens = `-- name: DeleteExpiredHostRefreshTokens :exec
|
||||
DELETE FROM host_refresh_tokens
|
||||
WHERE expires_at < NOW() OR revoked_at IS NOT NULL
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteExpiredHostRefreshTokens(ctx context.Context) error {
|
||||
_, err := q.db.Exec(ctx, deleteExpiredHostRefreshTokens)
|
||||
return err
|
||||
}
|
||||
|
||||
const getHostRefreshTokenByHash = `-- name: GetHostRefreshTokenByHash :one
|
||||
SELECT id, host_id, token_hash, expires_at, created_at, revoked_at FROM host_refresh_tokens
|
||||
WHERE token_hash = $1 AND revoked_at IS NULL AND expires_at > NOW()
|
||||
`
|
||||
|
||||
func (q *Queries) GetHostRefreshTokenByHash(ctx context.Context, tokenHash string) (HostRefreshToken, error) {
|
||||
row := q.db.QueryRow(ctx, getHostRefreshTokenByHash, tokenHash)
|
||||
var i HostRefreshToken
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.HostID,
|
||||
&i.TokenHash,
|
||||
&i.ExpiresAt,
|
||||
&i.CreatedAt,
|
||||
&i.RevokedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertHostRefreshToken = `-- name: InsertHostRefreshToken :one
|
||||
INSERT INTO host_refresh_tokens (id, host_id, token_hash, expires_at)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING id, host_id, token_hash, expires_at, created_at, revoked_at
|
||||
`
|
||||
|
||||
type InsertHostRefreshTokenParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
HostID pgtype.UUID `json:"host_id"`
|
||||
TokenHash string `json:"token_hash"`
|
||||
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertHostRefreshToken(ctx context.Context, arg InsertHostRefreshTokenParams) (HostRefreshToken, error) {
|
||||
row := q.db.QueryRow(ctx, insertHostRefreshToken,
|
||||
arg.ID,
|
||||
arg.HostID,
|
||||
arg.TokenHash,
|
||||
arg.ExpiresAt,
|
||||
)
|
||||
var i HostRefreshToken
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.HostID,
|
||||
&i.TokenHash,
|
||||
&i.ExpiresAt,
|
||||
&i.CreatedAt,
|
||||
&i.RevokedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const revokeHostRefreshToken = `-- name: RevokeHostRefreshToken :exec
|
||||
UPDATE host_refresh_tokens SET revoked_at = NOW() WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) RevokeHostRefreshToken(ctx context.Context, id pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, revokeHostRefreshToken, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const revokeHostRefreshTokensByHost = `-- name: RevokeHostRefreshTokensByHost :exec
|
||||
UPDATE host_refresh_tokens SET revoked_at = NOW()
|
||||
WHERE host_id = $1 AND revoked_at IS NULL
|
||||
`
|
||||
|
||||
func (q *Queries) RevokeHostRefreshTokensByHost(ctx context.Context, hostID pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, revokeHostRefreshTokensByHost, hostID)
|
||||
return err
|
||||
}
|
||||
@ -16,8 +16,8 @@ INSERT INTO host_tags (host_id, tag) VALUES ($1, $2) ON CONFLICT DO NOTHING
|
||||
`
|
||||
|
||||
type AddHostTagParams struct {
|
||||
HostID string `json:"host_id"`
|
||||
Tag string `json:"tag"`
|
||||
HostID pgtype.UUID `json:"host_id"`
|
||||
Tag string `json:"tag"`
|
||||
}
|
||||
|
||||
func (q *Queries) AddHostTag(ctx context.Context, arg AddHostTagParams) error {
|
||||
@ -29,16 +29,16 @@ const deleteHost = `-- name: DeleteHost :exec
|
||||
DELETE FROM hosts WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteHost(ctx context.Context, id string) error {
|
||||
func (q *Queries) DeleteHost(ctx context.Context, id pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, deleteHost, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const getHost = `-- name: GetHost :one
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetHost(ctx context.Context, id string) (Host, error) {
|
||||
func (q *Queries) GetHost(ctx context.Context, id pgtype.UUID) (Host, error) {
|
||||
row := q.db.QueryRow(ctx, getHost, id)
|
||||
var i Host
|
||||
err := row.Scan(
|
||||
@ -59,18 +59,18 @@ func (q *Queries) GetHost(ctx context.Context, id string) (Host, error) {
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.MtlsEnabled,
|
||||
&i.CertExpiresAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getHostByTeam = `-- name: GetHostByTeam :one
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1 AND team_id = $2
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE id = $1 AND team_id = $2
|
||||
`
|
||||
|
||||
type GetHostByTeamParams struct {
|
||||
ID string `json:"id"`
|
||||
TeamID pgtype.Text `json:"team_id"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetHostByTeam(ctx context.Context, arg GetHostByTeamParams) (Host, error) {
|
||||
@ -94,7 +94,7 @@ func (q *Queries) GetHostByTeam(ctx context.Context, arg GetHostByTeamParams) (H
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.MtlsEnabled,
|
||||
&i.CertExpiresAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@ -103,7 +103,7 @@ const getHostTags = `-- name: GetHostTags :many
|
||||
SELECT tag FROM host_tags WHERE host_id = $1 ORDER BY tag
|
||||
`
|
||||
|
||||
func (q *Queries) GetHostTags(ctx context.Context, hostID string) ([]string, error) {
|
||||
func (q *Queries) GetHostTags(ctx context.Context, hostID pgtype.UUID) ([]string, error) {
|
||||
rows, err := q.db.Query(ctx, getHostTags, hostID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -127,7 +127,7 @@ const getHostTokensByHost = `-- name: GetHostTokensByHost :many
|
||||
SELECT id, host_id, created_by, created_at, expires_at, used_at FROM host_tokens WHERE host_id = $1 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID string) ([]HostToken, error) {
|
||||
func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID pgtype.UUID) ([]HostToken, error) {
|
||||
rows, err := q.db.Query(ctx, getHostTokensByHost, hostID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -157,16 +157,16 @@ func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID string) ([]Hos
|
||||
const insertHost = `-- name: InsertHost :one
|
||||
INSERT INTO hosts (id, type, team_id, provider, availability_zone, created_by)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled
|
||||
RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at
|
||||
`
|
||||
|
||||
type InsertHostParams struct {
|
||||
ID string `json:"id"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Type string `json:"type"`
|
||||
TeamID pgtype.Text `json:"team_id"`
|
||||
Provider pgtype.Text `json:"provider"`
|
||||
AvailabilityZone pgtype.Text `json:"availability_zone"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Provider string `json:"provider"`
|
||||
AvailabilityZone string `json:"availability_zone"`
|
||||
CreatedBy pgtype.UUID `json:"created_by"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertHost(ctx context.Context, arg InsertHostParams) (Host, error) {
|
||||
@ -197,7 +197,7 @@ func (q *Queries) InsertHost(ctx context.Context, arg InsertHostParams) (Host, e
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.MtlsEnabled,
|
||||
&i.CertExpiresAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@ -209,9 +209,9 @@ RETURNING id, host_id, created_by, created_at, expires_at, used_at
|
||||
`
|
||||
|
||||
type InsertHostTokenParams struct {
|
||||
ID string `json:"id"`
|
||||
HostID string `json:"host_id"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
HostID pgtype.UUID `json:"host_id"`
|
||||
CreatedBy pgtype.UUID `json:"created_by"`
|
||||
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
|
||||
}
|
||||
|
||||
@ -234,8 +234,52 @@ func (q *Queries) InsertHostToken(ctx context.Context, arg InsertHostTokenParams
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listActiveHosts = `-- name: ListActiveHosts :many
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at
|
||||
`
|
||||
|
||||
// Returns all hosts that have completed registration (not pending/offline).
|
||||
func (q *Queries) ListActiveHosts(ctx context.Context) ([]Host, error) {
|
||||
rows, err := q.db.Query(ctx, listActiveHosts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Host
|
||||
for rows.Next() {
|
||||
var i Host
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Type,
|
||||
&i.TeamID,
|
||||
&i.Provider,
|
||||
&i.AvailabilityZone,
|
||||
&i.Arch,
|
||||
&i.CpuCores,
|
||||
&i.MemoryMb,
|
||||
&i.DiskGb,
|
||||
&i.Address,
|
||||
&i.Status,
|
||||
&i.LastHeartbeatAt,
|
||||
&i.Metadata,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.CertExpiresAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listHosts = `-- name: ListHosts :many
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts ORDER BY created_at DESC
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
|
||||
@ -265,7 +309,7 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.MtlsEnabled,
|
||||
&i.CertExpiresAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -278,7 +322,7 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
|
||||
}
|
||||
|
||||
const listHostsByStatus = `-- name: ListHostsByStatus :many
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE status = $1 ORDER BY created_at DESC
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status = $1 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, error) {
|
||||
@ -308,7 +352,7 @@ func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.MtlsEnabled,
|
||||
&i.CertExpiresAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -321,7 +365,7 @@ func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host,
|
||||
}
|
||||
|
||||
const listHostsByTag = `-- name: ListHostsByTag :many
|
||||
SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.mtls_enabled FROM hosts h
|
||||
SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.cert_expires_at FROM hosts h
|
||||
JOIN host_tags ht ON ht.host_id = h.id
|
||||
WHERE ht.tag = $1
|
||||
ORDER BY h.created_at DESC
|
||||
@ -354,7 +398,7 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.MtlsEnabled,
|
||||
&i.CertExpiresAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -367,10 +411,10 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error
|
||||
}
|
||||
|
||||
const listHostsByTeam = `-- name: ListHostsByTeam :many
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Host, error) {
|
||||
func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Host, error) {
|
||||
rows, err := q.db.Query(ctx, listHostsByTeam, teamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -397,7 +441,7 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Ho
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.MtlsEnabled,
|
||||
&i.CertExpiresAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -410,7 +454,7 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Ho
|
||||
}
|
||||
|
||||
const listHostsByType = `-- name: ListHostsByType :many
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE type = $1 ORDER BY created_at DESC
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE type = $1 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, error) {
|
||||
@ -440,7 +484,7 @@ func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, er
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.MtlsEnabled,
|
||||
&i.CertExpiresAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -456,31 +500,44 @@ const markHostTokenUsed = `-- name: MarkHostTokenUsed :exec
|
||||
UPDATE host_tokens SET used_at = NOW() WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) MarkHostTokenUsed(ctx context.Context, id string) error {
|
||||
func (q *Queries) MarkHostTokenUsed(ctx context.Context, id pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, markHostTokenUsed, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const markHostUnreachable = `-- name: MarkHostUnreachable :exec
|
||||
UPDATE hosts SET status = 'unreachable', updated_at = NOW() WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) MarkHostUnreachable(ctx context.Context, id pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, markHostUnreachable, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const registerHost = `-- name: RegisterHost :execrows
|
||||
UPDATE hosts
|
||||
SET arch = $2,
|
||||
cpu_cores = $3,
|
||||
memory_mb = $4,
|
||||
disk_gb = $5,
|
||||
address = $6,
|
||||
status = 'online',
|
||||
SET arch = $2,
|
||||
cpu_cores = $3,
|
||||
memory_mb = $4,
|
||||
disk_gb = $5,
|
||||
address = $6,
|
||||
cert_fingerprint = $7,
|
||||
cert_expires_at = $8,
|
||||
status = 'online',
|
||||
last_heartbeat_at = NOW(),
|
||||
updated_at = NOW()
|
||||
updated_at = NOW()
|
||||
WHERE id = $1 AND status = 'pending'
|
||||
`
|
||||
|
||||
type RegisterHostParams struct {
|
||||
ID string `json:"id"`
|
||||
Arch pgtype.Text `json:"arch"`
|
||||
CpuCores pgtype.Int4 `json:"cpu_cores"`
|
||||
MemoryMb pgtype.Int4 `json:"memory_mb"`
|
||||
DiskGb pgtype.Int4 `json:"disk_gb"`
|
||||
Address pgtype.Text `json:"address"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Arch string `json:"arch"`
|
||||
CpuCores int32 `json:"cpu_cores"`
|
||||
MemoryMb int32 `json:"memory_mb"`
|
||||
DiskGb int32 `json:"disk_gb"`
|
||||
Address string `json:"address"`
|
||||
CertFingerprint string `json:"cert_fingerprint"`
|
||||
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int64, error) {
|
||||
@ -491,6 +548,8 @@ func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int
|
||||
arg.MemoryMb,
|
||||
arg.DiskGb,
|
||||
arg.Address,
|
||||
arg.CertFingerprint,
|
||||
arg.CertExpiresAt,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@ -503,8 +562,8 @@ DELETE FROM host_tags WHERE host_id = $1 AND tag = $2
|
||||
`
|
||||
|
||||
type RemoveHostTagParams struct {
|
||||
HostID string `json:"host_id"`
|
||||
Tag string `json:"tag"`
|
||||
HostID pgtype.UUID `json:"host_id"`
|
||||
Tag string `json:"tag"`
|
||||
}
|
||||
|
||||
func (q *Queries) RemoveHostTag(ctx context.Context, arg RemoveHostTagParams) error {
|
||||
@ -512,22 +571,59 @@ func (q *Queries) RemoveHostTag(ctx context.Context, arg RemoveHostTagParams) er
|
||||
return err
|
||||
}
|
||||
|
||||
const updateHostCert = `-- name: UpdateHostCert :exec
|
||||
UPDATE hosts
|
||||
SET cert_fingerprint = $2,
|
||||
cert_expires_at = $3,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
type UpdateHostCertParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
CertFingerprint string `json:"cert_fingerprint"`
|
||||
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateHostCert(ctx context.Context, arg UpdateHostCertParams) error {
|
||||
_, err := q.db.Exec(ctx, updateHostCert, arg.ID, arg.CertFingerprint, arg.CertExpiresAt)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateHostHeartbeat = `-- name: UpdateHostHeartbeat :exec
|
||||
UPDATE hosts SET last_heartbeat_at = NOW(), updated_at = NOW() WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) UpdateHostHeartbeat(ctx context.Context, id string) error {
|
||||
func (q *Queries) UpdateHostHeartbeat(ctx context.Context, id pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, updateHostHeartbeat, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateHostHeartbeatAndStatus = `-- name: UpdateHostHeartbeatAndStatus :execrows
|
||||
UPDATE hosts
|
||||
SET last_heartbeat_at = NOW(),
|
||||
status = CASE WHEN status = 'unreachable' THEN 'online' ELSE status END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
// Updates last_heartbeat_at and transitions unreachable hosts back to online.
|
||||
// Returns 0 if no host was found (deleted), which the caller treats as 404.
|
||||
func (q *Queries) UpdateHostHeartbeatAndStatus(ctx context.Context, id pgtype.UUID) (int64, error) {
|
||||
result, err := q.db.Exec(ctx, updateHostHeartbeatAndStatus, id)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected(), nil
|
||||
}
|
||||
|
||||
const updateHostStatus = `-- name: UpdateHostStatus :exec
|
||||
UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1
|
||||
`
|
||||
|
||||
type UpdateHostStatusParams struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateHostStatus(ctx context.Context, arg UpdateHostStatusParams) error {
|
||||
|
||||
250
internal/db/metrics.sql.go
Normal file
250
internal/db/metrics.sql.go
Normal file
@ -0,0 +1,250 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: metrics.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const deleteSandboxMetricPoints = `-- name: DeleteSandboxMetricPoints :exec
|
||||
DELETE FROM sandbox_metric_points
|
||||
WHERE sandbox_id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteSandboxMetricPoints(ctx context.Context, sandboxID pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, deleteSandboxMetricPoints, sandboxID)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteSandboxMetricPointsByTier = `-- name: DeleteSandboxMetricPointsByTier :exec
|
||||
DELETE FROM sandbox_metric_points
|
||||
WHERE sandbox_id = $1 AND tier = $2
|
||||
`
|
||||
|
||||
type DeleteSandboxMetricPointsByTierParams struct {
|
||||
SandboxID pgtype.UUID `json:"sandbox_id"`
|
||||
Tier string `json:"tier"`
|
||||
}
|
||||
|
||||
func (q *Queries) DeleteSandboxMetricPointsByTier(ctx context.Context, arg DeleteSandboxMetricPointsByTierParams) error {
|
||||
_, err := q.db.Exec(ctx, deleteSandboxMetricPointsByTier, arg.SandboxID, arg.Tier)
|
||||
return err
|
||||
}
|
||||
|
||||
const getLiveMetrics = `-- name: GetLiveMetrics :one
|
||||
SELECT
|
||||
(COUNT(*) FILTER (WHERE status IN ('running', 'starting')))::INTEGER AS running_count,
|
||||
(COALESCE(SUM(vcpus) FILTER (WHERE status IN ('running', 'starting')), 0))::INTEGER AS vcpus_reserved,
|
||||
(COALESCE(SUM(memory_mb) FILTER (WHERE status IN ('running', 'starting')), 0)
|
||||
+ COALESCE(SUM(CEIL(memory_mb::NUMERIC / 2)) FILTER (WHERE status = 'paused'), 0))::INTEGER AS memory_mb_reserved
|
||||
FROM sandboxes
|
||||
WHERE team_id = $1
|
||||
`
|
||||
|
||||
type GetLiveMetricsRow struct {
|
||||
RunningCount int32 `json:"running_count"`
|
||||
VcpusReserved int32 `json:"vcpus_reserved"`
|
||||
MemoryMbReserved int32 `json:"memory_mb_reserved"`
|
||||
}
|
||||
|
||||
// Reads directly from sandboxes for accurate real-time current values.
|
||||
// CPU reserved = running + starting only (paused VMs release CPU).
|
||||
// RAM reserved = running + starting + sum(ceil(each_paused/2)) (per-VM ceiling).
|
||||
func (q *Queries) GetLiveMetrics(ctx context.Context, teamID pgtype.UUID) (GetLiveMetricsRow, error) {
|
||||
row := q.db.QueryRow(ctx, getLiveMetrics, teamID)
|
||||
var i GetLiveMetricsRow
|
||||
err := row.Scan(&i.RunningCount, &i.VcpusReserved, &i.MemoryMbReserved)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getPeakMetrics = `-- name: GetPeakMetrics :one
|
||||
SELECT
|
||||
COALESCE(MAX(running_count), 0)::INTEGER AS peak_running_count,
|
||||
COALESCE(MAX(vcpus_reserved), 0)::INTEGER AS peak_vcpus,
|
||||
COALESCE(MAX(memory_mb_reserved), 0)::INTEGER AS peak_memory_mb
|
||||
FROM sandbox_metrics_snapshots
|
||||
WHERE team_id = $1
|
||||
AND sampled_at > NOW() - INTERVAL '30 days'
|
||||
`
|
||||
|
||||
type GetPeakMetricsRow struct {
|
||||
PeakRunningCount int32 `json:"peak_running_count"`
|
||||
PeakVcpus int32 `json:"peak_vcpus"`
|
||||
PeakMemoryMb int32 `json:"peak_memory_mb"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetPeakMetrics(ctx context.Context, teamID pgtype.UUID) (GetPeakMetricsRow, error) {
|
||||
row := q.db.QueryRow(ctx, getPeakMetrics, teamID)
|
||||
var i GetPeakMetricsRow
|
||||
err := row.Scan(&i.PeakRunningCount, &i.PeakVcpus, &i.PeakMemoryMb)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getSandboxMetricPoints = `-- name: GetSandboxMetricPoints :many
|
||||
SELECT ts, cpu_pct, mem_bytes, disk_bytes
|
||||
FROM sandbox_metric_points
|
||||
WHERE sandbox_id = $1 AND tier = $2 AND ts >= $3
|
||||
ORDER BY ts ASC
|
||||
`
|
||||
|
||||
type GetSandboxMetricPointsParams struct {
|
||||
SandboxID pgtype.UUID `json:"sandbox_id"`
|
||||
Tier string `json:"tier"`
|
||||
Ts int64 `json:"ts"`
|
||||
}
|
||||
|
||||
type GetSandboxMetricPointsRow struct {
|
||||
Ts int64 `json:"ts"`
|
||||
CpuPct float64 `json:"cpu_pct"`
|
||||
MemBytes int64 `json:"mem_bytes"`
|
||||
DiskBytes int64 `json:"disk_bytes"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetSandboxMetricPoints(ctx context.Context, arg GetSandboxMetricPointsParams) ([]GetSandboxMetricPointsRow, error) {
|
||||
rows, err := q.db.Query(ctx, getSandboxMetricPoints, arg.SandboxID, arg.Tier, arg.Ts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []GetSandboxMetricPointsRow
|
||||
for rows.Next() {
|
||||
var i GetSandboxMetricPointsRow
|
||||
if err := rows.Scan(
|
||||
&i.Ts,
|
||||
&i.CpuPct,
|
||||
&i.MemBytes,
|
||||
&i.DiskBytes,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const insertMetricsSnapshot = `-- name: InsertMetricsSnapshot :exec
|
||||
INSERT INTO sandbox_metrics_snapshots (team_id, running_count, vcpus_reserved, memory_mb_reserved)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
`
|
||||
|
||||
type InsertMetricsSnapshotParams struct {
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
RunningCount int32 `json:"running_count"`
|
||||
VcpusReserved int32 `json:"vcpus_reserved"`
|
||||
MemoryMbReserved int32 `json:"memory_mb_reserved"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertMetricsSnapshot(ctx context.Context, arg InsertMetricsSnapshotParams) error {
|
||||
_, err := q.db.Exec(ctx, insertMetricsSnapshot,
|
||||
arg.TeamID,
|
||||
arg.RunningCount,
|
||||
arg.VcpusReserved,
|
||||
arg.MemoryMbReserved,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
const insertSandboxMetricPoint = `-- name: InsertSandboxMetricPoint :exec
|
||||
INSERT INTO sandbox_metric_points (sandbox_id, tier, ts, cpu_pct, mem_bytes, disk_bytes)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (sandbox_id, tier, ts) DO NOTHING
|
||||
`
|
||||
|
||||
type InsertSandboxMetricPointParams struct {
|
||||
SandboxID pgtype.UUID `json:"sandbox_id"`
|
||||
Tier string `json:"tier"`
|
||||
Ts int64 `json:"ts"`
|
||||
CpuPct float64 `json:"cpu_pct"`
|
||||
MemBytes int64 `json:"mem_bytes"`
|
||||
DiskBytes int64 `json:"disk_bytes"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertSandboxMetricPoint(ctx context.Context, arg InsertSandboxMetricPointParams) error {
|
||||
_, err := q.db.Exec(ctx, insertSandboxMetricPoint,
|
||||
arg.SandboxID,
|
||||
arg.Tier,
|
||||
arg.Ts,
|
||||
arg.CpuPct,
|
||||
arg.MemBytes,
|
||||
arg.DiskBytes,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
const pruneOldMetrics = `-- name: PruneOldMetrics :exec
|
||||
DELETE FROM sandbox_metrics_snapshots
|
||||
WHERE sampled_at < NOW() - INTERVAL '60 days'
|
||||
`
|
||||
|
||||
func (q *Queries) PruneOldMetrics(ctx context.Context) error {
|
||||
_, err := q.db.Exec(ctx, pruneOldMetrics)
|
||||
return err
|
||||
}
|
||||
|
||||
const pruneSandboxMetricPoints = `-- name: PruneSandboxMetricPoints :exec
|
||||
DELETE FROM sandbox_metric_points
|
||||
WHERE ts < EXTRACT(EPOCH FROM NOW() - INTERVAL '30 days')::BIGINT
|
||||
`
|
||||
|
||||
// Remove metric points older than 30 days for destroyed sandboxes.
|
||||
func (q *Queries) PruneSandboxMetricPoints(ctx context.Context) error {
|
||||
_, err := q.db.Exec(ctx, pruneSandboxMetricPoints)
|
||||
return err
|
||||
}
|
||||
|
||||
const sampleSandboxMetrics = `-- name: SampleSandboxMetrics :many
|
||||
SELECT
|
||||
team_id,
|
||||
(COUNT(*) FILTER (WHERE status IN ('running', 'starting')))::INTEGER AS running_count,
|
||||
(COALESCE(SUM(vcpus) FILTER (WHERE status IN ('running', 'starting')), 0))::INTEGER AS vcpus_reserved,
|
||||
(COALESCE(SUM(memory_mb) FILTER (WHERE status IN ('running', 'starting')), 0)
|
||||
+ COALESCE(SUM(CEIL(memory_mb::NUMERIC / 2)) FILTER (WHERE status = 'paused'), 0))::INTEGER AS memory_mb_reserved
|
||||
FROM sandboxes
|
||||
GROUP BY team_id
|
||||
`
|
||||
|
||||
type SampleSandboxMetricsRow struct {
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
RunningCount int32 `json:"running_count"`
|
||||
VcpusReserved int32 `json:"vcpus_reserved"`
|
||||
MemoryMbReserved int32 `json:"memory_mb_reserved"`
|
||||
}
|
||||
|
||||
// Aggregates per-team resource usage from the live sandboxes table.
|
||||
// Groups by all teams that have any sandbox row (including stopped) so that
|
||||
// zero-value snapshots are recorded when all capsules are stopped, keeping the
|
||||
// time-series charts continuous rather than trailing off into empty space.
|
||||
// CPU reserved = running + starting only (paused VMs release CPU).
|
||||
// RAM reserved = running + starting + sum(ceil(each_paused/2)) (per-VM ceiling).
|
||||
func (q *Queries) SampleSandboxMetrics(ctx context.Context) ([]SampleSandboxMetricsRow, error) {
|
||||
rows, err := q.db.Query(ctx, sampleSandboxMetrics)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []SampleSandboxMetricsRow
|
||||
for rows.Next() {
|
||||
var i SampleSandboxMetricsRow
|
||||
if err := rows.Scan(
|
||||
&i.TeamID,
|
||||
&i.RunningCount,
|
||||
&i.VcpusReserved,
|
||||
&i.MemoryMbReserved,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
@ -9,42 +9,77 @@ import (
|
||||
)
|
||||
|
||||
type AdminPermission struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
Permission string `json:"permission"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
}
|
||||
|
||||
type AuditLog struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
ActorType string `json:"actor_type"`
|
||||
ActorID pgtype.Text `json:"actor_id"`
|
||||
ActorName string `json:"actor_name"`
|
||||
ResourceType string `json:"resource_type"`
|
||||
ResourceID pgtype.Text `json:"resource_id"`
|
||||
Action string `json:"action"`
|
||||
Scope string `json:"scope"`
|
||||
Status string `json:"status"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
}
|
||||
|
||||
type Channel struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Config []byte `json:"config"`
|
||||
EventTypes []string `json:"event_types"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
}
|
||||
|
||||
type Host struct {
|
||||
ID string `json:"id"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Type string `json:"type"`
|
||||
TeamID pgtype.Text `json:"team_id"`
|
||||
Provider pgtype.Text `json:"provider"`
|
||||
AvailabilityZone pgtype.Text `json:"availability_zone"`
|
||||
Arch pgtype.Text `json:"arch"`
|
||||
CpuCores pgtype.Int4 `json:"cpu_cores"`
|
||||
MemoryMb pgtype.Int4 `json:"memory_mb"`
|
||||
DiskGb pgtype.Int4 `json:"disk_gb"`
|
||||
Address pgtype.Text `json:"address"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Provider string `json:"provider"`
|
||||
AvailabilityZone string `json:"availability_zone"`
|
||||
Arch string `json:"arch"`
|
||||
CpuCores int32 `json:"cpu_cores"`
|
||||
MemoryMb int32 `json:"memory_mb"`
|
||||
DiskGb int32 `json:"disk_gb"`
|
||||
Address string `json:"address"`
|
||||
Status string `json:"status"`
|
||||
LastHeartbeatAt pgtype.Timestamptz `json:"last_heartbeat_at"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedBy pgtype.UUID `json:"created_by"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
CertFingerprint pgtype.Text `json:"cert_fingerprint"`
|
||||
MtlsEnabled bool `json:"mtls_enabled"`
|
||||
CertFingerprint string `json:"cert_fingerprint"`
|
||||
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
|
||||
}
|
||||
|
||||
type HostRefreshToken struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
HostID pgtype.UUID `json:"host_id"`
|
||||
TokenHash string `json:"token_hash"`
|
||||
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
RevokedAt pgtype.Timestamptz `json:"revoked_at"`
|
||||
}
|
||||
|
||||
type HostTag struct {
|
||||
HostID string `json:"host_id"`
|
||||
Tag string `json:"tag"`
|
||||
HostID pgtype.UUID `json:"host_id"`
|
||||
Tag string `json:"tag"`
|
||||
}
|
||||
|
||||
type HostToken struct {
|
||||
ID string `json:"id"`
|
||||
HostID string `json:"host_id"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
HostID pgtype.UUID `json:"host_id"`
|
||||
CreatedBy pgtype.UUID `json:"created_by"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
|
||||
UsedAt pgtype.Timestamptz `json:"used_at"`
|
||||
@ -53,42 +88,65 @@ type HostToken struct {
|
||||
type OauthProvider struct {
|
||||
Provider string `json:"provider"`
|
||||
ProviderID string `json:"provider_id"`
|
||||
UserID string `json:"user_id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
}
|
||||
|
||||
type Sandbox struct {
|
||||
ID string `json:"id"`
|
||||
HostID string `json:"host_id"`
|
||||
Template string `json:"template"`
|
||||
Status string `json:"status"`
|
||||
Vcpus int32 `json:"vcpus"`
|
||||
MemoryMb int32 `json:"memory_mb"`
|
||||
TimeoutSec int32 `json:"timeout_sec"`
|
||||
GuestIp string `json:"guest_ip"`
|
||||
HostIp string `json:"host_ip"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
StartedAt pgtype.Timestamptz `json:"started_at"`
|
||||
LastActiveAt pgtype.Timestamptz `json:"last_active_at"`
|
||||
LastUpdated pgtype.Timestamptz `json:"last_updated"`
|
||||
TeamID string `json:"team_id"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
HostID pgtype.UUID `json:"host_id"`
|
||||
Template string `json:"template"`
|
||||
Status string `json:"status"`
|
||||
Vcpus int32 `json:"vcpus"`
|
||||
MemoryMb int32 `json:"memory_mb"`
|
||||
TimeoutSec int32 `json:"timeout_sec"`
|
||||
DiskSizeMb int32 `json:"disk_size_mb"`
|
||||
GuestIp string `json:"guest_ip"`
|
||||
HostIp string `json:"host_ip"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
StartedAt pgtype.Timestamptz `json:"started_at"`
|
||||
LastActiveAt pgtype.Timestamptz `json:"last_active_at"`
|
||||
LastUpdated pgtype.Timestamptz `json:"last_updated"`
|
||||
TemplateID pgtype.UUID `json:"template_id"`
|
||||
TemplateTeamID pgtype.UUID `json:"template_team_id"`
|
||||
}
|
||||
|
||||
type SandboxMetricPoint struct {
|
||||
SandboxID pgtype.UUID `json:"sandbox_id"`
|
||||
Tier string `json:"tier"`
|
||||
Ts int64 `json:"ts"`
|
||||
CpuPct float64 `json:"cpu_pct"`
|
||||
MemBytes int64 `json:"mem_bytes"`
|
||||
DiskBytes int64 `json:"disk_bytes"`
|
||||
}
|
||||
|
||||
type SandboxMetricsSnapshot struct {
|
||||
ID int64 `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
SampledAt pgtype.Timestamptz `json:"sampled_at"`
|
||||
RunningCount int32 `json:"running_count"`
|
||||
VcpusReserved int32 `json:"vcpus_reserved"`
|
||||
MemoryMbReserved int32 `json:"memory_mb_reserved"`
|
||||
}
|
||||
|
||||
type Team struct {
|
||||
ID string `json:"id"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
Slug string `json:"slug"`
|
||||
IsByoc bool `json:"is_byoc"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
DeletedAt pgtype.Timestamptz `json:"deleted_at"`
|
||||
}
|
||||
|
||||
type TeamApiKey struct {
|
||||
ID string `json:"id"`
|
||||
TeamID string `json:"team_id"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
KeyHash string `json:"key_hash"`
|
||||
KeyPrefix string `json:"key_prefix"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedBy pgtype.UUID `json:"created_by"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
LastUsed pgtype.Timestamptz `json:"last_used"`
|
||||
}
|
||||
@ -96,25 +154,50 @@ type TeamApiKey struct {
|
||||
type Template struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Vcpus pgtype.Int4 `json:"vcpus"`
|
||||
MemoryMb pgtype.Int4 `json:"memory_mb"`
|
||||
Vcpus int32 `json:"vcpus"`
|
||||
MemoryMb int32 `json:"memory_mb"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
TeamID string `json:"team_id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
}
|
||||
|
||||
type TemplateBuild struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
BaseTemplate string `json:"base_template"`
|
||||
Recipe []byte `json:"recipe"`
|
||||
Healthcheck string `json:"healthcheck"`
|
||||
Vcpus int32 `json:"vcpus"`
|
||||
MemoryMb int32 `json:"memory_mb"`
|
||||
Status string `json:"status"`
|
||||
CurrentStep int32 `json:"current_step"`
|
||||
TotalSteps int32 `json:"total_steps"`
|
||||
Logs []byte `json:"logs"`
|
||||
Error string `json:"error"`
|
||||
SandboxID pgtype.UUID `json:"sandbox_id"`
|
||||
HostID pgtype.UUID `json:"host_id"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
StartedAt pgtype.Timestamptz `json:"started_at"`
|
||||
CompletedAt pgtype.Timestamptz `json:"completed_at"`
|
||||
TemplateID pgtype.UUID `json:"template_id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
SkipPrePost bool `json:"skip_pre_post"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID string `json:"id"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
PasswordHash pgtype.Text `json:"password_hash"`
|
||||
Name string `json:"name"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
}
|
||||
|
||||
type UsersTeam struct {
|
||||
UserID string `json:"user_id"`
|
||||
TeamID string `json:"team_id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
Role string `json:"role"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
|
||||
@ -7,6 +7,8 @@ package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const getOAuthProvider = `-- name: GetOAuthProvider :one
|
||||
@ -38,10 +40,10 @@ VALUES ($1, $2, $3, $4)
|
||||
`
|
||||
|
||||
type InsertOAuthProviderParams struct {
|
||||
Provider string `json:"provider"`
|
||||
ProviderID string `json:"provider_id"`
|
||||
UserID string `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Provider string `json:"provider"`
|
||||
ProviderID string `json:"provider_id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertOAuthProvider(ctx context.Context, arg InsertOAuthProviderParams) error {
|
||||
|
||||
@ -11,16 +11,30 @@ import (
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const bulkRestoreRunning = `-- name: BulkRestoreRunning :exec
|
||||
UPDATE sandboxes
|
||||
SET status = 'running',
|
||||
last_updated = NOW()
|
||||
WHERE id = ANY($1::uuid[]) AND status = 'missing'
|
||||
`
|
||||
|
||||
// Called by the reconciler when a host comes back online and its sandboxes are
|
||||
// confirmed alive. Restores only sandboxes that are in 'missing' state.
|
||||
func (q *Queries) BulkRestoreRunning(ctx context.Context, dollar_1 []pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, bulkRestoreRunning, dollar_1)
|
||||
return err
|
||||
}
|
||||
|
||||
const bulkUpdateStatusByIDs = `-- name: BulkUpdateStatusByIDs :exec
|
||||
UPDATE sandboxes
|
||||
SET status = $2,
|
||||
last_updated = NOW()
|
||||
WHERE id = ANY($1::text[])
|
||||
WHERE id = ANY($1::uuid[])
|
||||
`
|
||||
|
||||
type BulkUpdateStatusByIDsParams struct {
|
||||
Column1 []string `json:"column_1"`
|
||||
Status string `json:"status"`
|
||||
Column1 []pgtype.UUID `json:"column_1"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func (q *Queries) BulkUpdateStatusByIDs(ctx context.Context, arg BulkUpdateStatusByIDsParams) error {
|
||||
@ -29,38 +43,41 @@ func (q *Queries) BulkUpdateStatusByIDs(ctx context.Context, arg BulkUpdateStatu
|
||||
}
|
||||
|
||||
const getSandbox = `-- name: GetSandbox :one
|
||||
SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes WHERE id = $1
|
||||
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetSandbox(ctx context.Context, id string) (Sandbox, error) {
|
||||
func (q *Queries) GetSandbox(ctx context.Context, id pgtype.UUID) (Sandbox, error) {
|
||||
row := q.db.QueryRow(ctx, getSandbox, id)
|
||||
var i Sandbox
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.DiskSizeMb,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TeamID,
|
||||
&i.TemplateID,
|
||||
&i.TemplateTeamID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getSandboxByTeam = `-- name: GetSandboxByTeam :one
|
||||
SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes WHERE id = $1 AND team_id = $2
|
||||
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes WHERE id = $1 AND team_id = $2
|
||||
`
|
||||
|
||||
type GetSandboxByTeamParams struct {
|
||||
ID string `json:"id"`
|
||||
TeamID string `json:"team_id"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetSandboxByTeam(ctx context.Context, arg GetSandboxByTeamParams) (Sandbox, error) {
|
||||
@ -68,38 +85,70 @@ func (q *Queries) GetSandboxByTeam(ctx context.Context, arg GetSandboxByTeamPara
|
||||
var i Sandbox
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.DiskSizeMb,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TeamID,
|
||||
&i.TemplateID,
|
||||
&i.TemplateTeamID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getSandboxProxyTarget = `-- name: GetSandboxProxyTarget :one
|
||||
SELECT s.status, h.address AS host_address
|
||||
FROM sandboxes s
|
||||
JOIN hosts h ON h.id = s.host_id
|
||||
WHERE s.id = $1 AND s.team_id = $2
|
||||
`
|
||||
|
||||
type GetSandboxProxyTargetParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
}
|
||||
|
||||
type GetSandboxProxyTargetRow struct {
|
||||
Status string `json:"status"`
|
||||
HostAddress string `json:"host_address"`
|
||||
}
|
||||
|
||||
// Returns the sandbox status and its host's address in one query.
|
||||
// Used by SandboxProxyWrapper to avoid two round-trips.
|
||||
func (q *Queries) GetSandboxProxyTarget(ctx context.Context, arg GetSandboxProxyTargetParams) (GetSandboxProxyTargetRow, error) {
|
||||
row := q.db.QueryRow(ctx, getSandboxProxyTarget, arg.ID, arg.TeamID)
|
||||
var i GetSandboxProxyTargetRow
|
||||
err := row.Scan(&i.Status, &i.HostAddress)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertSandbox = `-- name: InsertSandbox :one
|
||||
INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id
|
||||
INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, template_id, template_team_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id
|
||||
`
|
||||
|
||||
type InsertSandboxParams struct {
|
||||
ID string `json:"id"`
|
||||
TeamID string `json:"team_id"`
|
||||
HostID string `json:"host_id"`
|
||||
Template string `json:"template"`
|
||||
Status string `json:"status"`
|
||||
Vcpus int32 `json:"vcpus"`
|
||||
MemoryMb int32 `json:"memory_mb"`
|
||||
TimeoutSec int32 `json:"timeout_sec"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
HostID pgtype.UUID `json:"host_id"`
|
||||
Template string `json:"template"`
|
||||
Status string `json:"status"`
|
||||
Vcpus int32 `json:"vcpus"`
|
||||
MemoryMb int32 `json:"memory_mb"`
|
||||
TimeoutSec int32 `json:"timeout_sec"`
|
||||
DiskSizeMb int32 `json:"disk_size_mb"`
|
||||
TemplateID pgtype.UUID `json:"template_id"`
|
||||
TemplateTeamID pgtype.UUID `json:"template_team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (Sandbox, error) {
|
||||
@ -112,29 +161,79 @@ func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (S
|
||||
arg.Vcpus,
|
||||
arg.MemoryMb,
|
||||
arg.TimeoutSec,
|
||||
arg.DiskSizeMb,
|
||||
arg.TemplateID,
|
||||
arg.TemplateTeamID,
|
||||
)
|
||||
var i Sandbox
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.DiskSizeMb,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TeamID,
|
||||
&i.TemplateID,
|
||||
&i.TemplateTeamID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listActiveSandboxesByTeam = `-- name: ListActiveSandboxesByTeam :many
|
||||
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes
|
||||
WHERE team_id = $1 AND status IN ('running', 'paused', 'starting')
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Sandbox, error) {
|
||||
rows, err := q.db.Query(ctx, listActiveSandboxesByTeam, teamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Sandbox
|
||||
for rows.Next() {
|
||||
var i Sandbox
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.DiskSizeMb,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TemplateID,
|
||||
&i.TemplateTeamID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listSandboxes = `-- name: ListSandboxes :many
|
||||
SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes ORDER BY created_at DESC
|
||||
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) {
|
||||
@ -148,19 +247,22 @@ func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) {
|
||||
var i Sandbox
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.DiskSizeMb,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TeamID,
|
||||
&i.TemplateID,
|
||||
&i.TemplateTeamID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -173,14 +275,14 @@ func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) {
|
||||
}
|
||||
|
||||
const listSandboxesByHostAndStatus = `-- name: ListSandboxesByHostAndStatus :many
|
||||
SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes
|
||||
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes
|
||||
WHERE host_id = $1 AND status = ANY($2::text[])
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
type ListSandboxesByHostAndStatusParams struct {
|
||||
HostID string `json:"host_id"`
|
||||
Column2 []string `json:"column_2"`
|
||||
HostID pgtype.UUID `json:"host_id"`
|
||||
Column2 []string `json:"column_2"`
|
||||
}
|
||||
|
||||
func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSandboxesByHostAndStatusParams) ([]Sandbox, error) {
|
||||
@ -194,19 +296,22 @@ func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSand
|
||||
var i Sandbox
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.DiskSizeMb,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TeamID,
|
||||
&i.TemplateID,
|
||||
&i.TemplateTeamID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -219,12 +324,12 @@ func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSand
|
||||
}
|
||||
|
||||
const listSandboxesByTeam = `-- name: ListSandboxesByTeam :many
|
||||
SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes
|
||||
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes
|
||||
WHERE team_id = $1 AND status NOT IN ('stopped', 'error')
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID string) ([]Sandbox, error) {
|
||||
func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Sandbox, error) {
|
||||
rows, err := q.db.Query(ctx, listSandboxesByTeam, teamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -235,19 +340,22 @@ func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID string) ([]San
|
||||
var i Sandbox
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.DiskSizeMb,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TeamID,
|
||||
&i.TemplateID,
|
||||
&i.TemplateTeamID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -259,6 +367,21 @@ func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID string) ([]San
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const markSandboxesMissingByHost = `-- name: MarkSandboxesMissingByHost :exec
|
||||
UPDATE sandboxes
|
||||
SET status = 'missing',
|
||||
last_updated = NOW()
|
||||
WHERE host_id = $1 AND status IN ('running', 'starting', 'pending')
|
||||
`
|
||||
|
||||
// Called when the host monitor marks a host unreachable.
|
||||
// Marks running/starting/pending sandboxes on that host as 'missing' so users see
|
||||
// the sandbox is not currently reachable, without permanently losing the record.
|
||||
func (q *Queries) MarkSandboxesMissingByHost(ctx context.Context, hostID pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, markSandboxesMissingByHost, hostID)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateLastActive = `-- name: UpdateLastActive :exec
|
||||
UPDATE sandboxes
|
||||
SET last_active_at = $2,
|
||||
@ -267,7 +390,7 @@ WHERE id = $1
|
||||
`
|
||||
|
||||
type UpdateLastActiveParams struct {
|
||||
ID string `json:"id"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
LastActiveAt pgtype.Timestamptz `json:"last_active_at"`
|
||||
}
|
||||
|
||||
@ -285,11 +408,11 @@ SET status = 'running',
|
||||
last_active_at = $4,
|
||||
last_updated = NOW()
|
||||
WHERE id = $1
|
||||
RETURNING id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id
|
||||
RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id
|
||||
`
|
||||
|
||||
type UpdateSandboxRunningParams struct {
|
||||
ID string `json:"id"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
HostIp string `json:"host_ip"`
|
||||
GuestIp string `json:"guest_ip"`
|
||||
StartedAt pgtype.Timestamptz `json:"started_at"`
|
||||
@ -305,19 +428,22 @@ func (q *Queries) UpdateSandboxRunning(ctx context.Context, arg UpdateSandboxRun
|
||||
var i Sandbox
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.DiskSizeMb,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TeamID,
|
||||
&i.TemplateID,
|
||||
&i.TemplateTeamID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@ -327,12 +453,12 @@ UPDATE sandboxes
|
||||
SET status = $2,
|
||||
last_updated = NOW()
|
||||
WHERE id = $1
|
||||
RETURNING id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id
|
||||
RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id
|
||||
`
|
||||
|
||||
type UpdateSandboxStatusParams struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateSandboxStatus(ctx context.Context, arg UpdateSandboxStatusParams) (Sandbox, error) {
|
||||
@ -340,19 +466,22 @@ func (q *Queries) UpdateSandboxStatus(ctx context.Context, arg UpdateSandboxStat
|
||||
var i Sandbox
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.DiskSizeMb,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TeamID,
|
||||
&i.TemplateID,
|
||||
&i.TemplateTeamID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
@ -7,10 +7,26 @@ package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const deleteTeamMember = `-- name: DeleteTeamMember :exec
|
||||
DELETE FROM users_teams WHERE team_id = $1 AND user_id = $2
|
||||
`
|
||||
|
||||
type DeleteTeamMemberParams struct {
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) DeleteTeamMember(ctx context.Context, arg DeleteTeamMemberParams) error {
|
||||
_, err := q.db.Exec(ctx, deleteTeamMember, arg.TeamID, arg.UserID)
|
||||
return err
|
||||
}
|
||||
|
||||
const getBYOCTeams = `-- name: GetBYOCTeams :many
|
||||
SELECT id, name, created_at, is_byoc FROM teams WHERE is_byoc = TRUE ORDER BY created_at
|
||||
SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE is_byoc = TRUE AND deleted_at IS NULL ORDER BY created_at
|
||||
`
|
||||
|
||||
func (q *Queries) GetBYOCTeams(ctx context.Context) ([]Team, error) {
|
||||
@ -25,8 +41,10 @@ func (q *Queries) GetBYOCTeams(ctx context.Context) ([]Team, error) {
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.CreatedAt,
|
||||
&i.Slug,
|
||||
&i.IsByoc,
|
||||
&i.CreatedAt,
|
||||
&i.DeletedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -39,47 +57,111 @@ func (q *Queries) GetBYOCTeams(ctx context.Context) ([]Team, error) {
|
||||
}
|
||||
|
||||
const getDefaultTeamForUser = `-- name: GetDefaultTeamForUser :one
|
||||
SELECT t.id, t.name, t.created_at, t.is_byoc FROM teams t
|
||||
SELECT t.id, t.name, t.slug, t.is_byoc, t.created_at, t.deleted_at FROM teams t
|
||||
JOIN users_teams ut ON ut.team_id = t.id
|
||||
WHERE ut.user_id = $1 AND ut.is_default = TRUE
|
||||
WHERE ut.user_id = $1 AND ut.is_default = TRUE AND t.deleted_at IS NULL
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
func (q *Queries) GetDefaultTeamForUser(ctx context.Context, userID string) (Team, error) {
|
||||
func (q *Queries) GetDefaultTeamForUser(ctx context.Context, userID pgtype.UUID) (Team, error) {
|
||||
row := q.db.QueryRow(ctx, getDefaultTeamForUser, userID)
|
||||
var i Team
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.CreatedAt,
|
||||
&i.Slug,
|
||||
&i.IsByoc,
|
||||
&i.CreatedAt,
|
||||
&i.DeletedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getTeam = `-- name: GetTeam :one
|
||||
SELECT id, name, created_at, is_byoc FROM teams WHERE id = $1
|
||||
SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetTeam(ctx context.Context, id string) (Team, error) {
|
||||
func (q *Queries) GetTeam(ctx context.Context, id pgtype.UUID) (Team, error) {
|
||||
row := q.db.QueryRow(ctx, getTeam, id)
|
||||
var i Team
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.CreatedAt,
|
||||
&i.Slug,
|
||||
&i.IsByoc,
|
||||
&i.CreatedAt,
|
||||
&i.DeletedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getTeamBySlug = `-- name: GetTeamBySlug :one
|
||||
SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE slug = $1 AND deleted_at IS NULL
|
||||
`
|
||||
|
||||
func (q *Queries) GetTeamBySlug(ctx context.Context, slug string) (Team, error) {
|
||||
row := q.db.QueryRow(ctx, getTeamBySlug, slug)
|
||||
var i Team
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.Slug,
|
||||
&i.IsByoc,
|
||||
&i.CreatedAt,
|
||||
&i.DeletedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getTeamMembers = `-- name: GetTeamMembers :many
|
||||
SELECT u.id, u.name, u.email, ut.role, ut.created_at AS joined_at
|
||||
FROM users_teams ut
|
||||
JOIN users u ON u.id = ut.user_id
|
||||
WHERE ut.team_id = $1
|
||||
ORDER BY ut.created_at
|
||||
`
|
||||
|
||||
type GetTeamMembersRow struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
JoinedAt pgtype.Timestamptz `json:"joined_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetTeamMembers(ctx context.Context, teamID pgtype.UUID) ([]GetTeamMembersRow, error) {
|
||||
rows, err := q.db.Query(ctx, getTeamMembers, teamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []GetTeamMembersRow
|
||||
for rows.Next() {
|
||||
var i GetTeamMembersRow
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.Email,
|
||||
&i.Role,
|
||||
&i.JoinedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getTeamMembership = `-- name: GetTeamMembership :one
|
||||
SELECT user_id, team_id, is_default, role, created_at FROM users_teams WHERE user_id = $1 AND team_id = $2
|
||||
`
|
||||
|
||||
type GetTeamMembershipParams struct {
|
||||
UserID string `json:"user_id"`
|
||||
TeamID string `json:"team_id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetTeamMembership(ctx context.Context, arg GetTeamMembershipParams) (UsersTeam, error) {
|
||||
@ -95,25 +177,74 @@ func (q *Queries) GetTeamMembership(ctx context.Context, arg GetTeamMembershipPa
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getTeamsForUser = `-- name: GetTeamsForUser :many
|
||||
SELECT t.id, t.name, t.slug, t.is_byoc, t.created_at, t.deleted_at, ut.role
|
||||
FROM teams t
|
||||
JOIN users_teams ut ON ut.team_id = t.id
|
||||
WHERE ut.user_id = $1 AND t.deleted_at IS NULL
|
||||
ORDER BY ut.created_at
|
||||
`
|
||||
|
||||
type GetTeamsForUserRow struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
IsByoc bool `json:"is_byoc"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
DeletedAt pgtype.Timestamptz `json:"deleted_at"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetTeamsForUser(ctx context.Context, userID pgtype.UUID) ([]GetTeamsForUserRow, error) {
|
||||
rows, err := q.db.Query(ctx, getTeamsForUser, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []GetTeamsForUserRow
|
||||
for rows.Next() {
|
||||
var i GetTeamsForUserRow
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.Slug,
|
||||
&i.IsByoc,
|
||||
&i.CreatedAt,
|
||||
&i.DeletedAt,
|
||||
&i.Role,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const insertTeam = `-- name: InsertTeam :one
|
||||
INSERT INTO teams (id, name)
|
||||
VALUES ($1, $2)
|
||||
RETURNING id, name, created_at, is_byoc
|
||||
INSERT INTO teams (id, name, slug)
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING id, name, slug, is_byoc, created_at, deleted_at
|
||||
`
|
||||
|
||||
type InsertTeamParams struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertTeam(ctx context.Context, arg InsertTeamParams) (Team, error) {
|
||||
row := q.db.QueryRow(ctx, insertTeam, arg.ID, arg.Name)
|
||||
row := q.db.QueryRow(ctx, insertTeam, arg.ID, arg.Name, arg.Slug)
|
||||
var i Team
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.CreatedAt,
|
||||
&i.Slug,
|
||||
&i.IsByoc,
|
||||
&i.CreatedAt,
|
||||
&i.DeletedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@ -124,10 +255,10 @@ VALUES ($1, $2, $3, $4)
|
||||
`
|
||||
|
||||
type InsertTeamMemberParams struct {
|
||||
UserID string `json:"user_id"`
|
||||
TeamID string `json:"team_id"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
Role string `json:"role"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertTeamMember(ctx context.Context, arg InsertTeamMemberParams) error {
|
||||
@ -145,11 +276,49 @@ UPDATE teams SET is_byoc = $2 WHERE id = $1
|
||||
`
|
||||
|
||||
type SetTeamBYOCParams struct {
|
||||
ID string `json:"id"`
|
||||
IsByoc bool `json:"is_byoc"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
IsByoc bool `json:"is_byoc"`
|
||||
}
|
||||
|
||||
func (q *Queries) SetTeamBYOC(ctx context.Context, arg SetTeamBYOCParams) error {
|
||||
_, err := q.db.Exec(ctx, setTeamBYOC, arg.ID, arg.IsByoc)
|
||||
return err
|
||||
}
|
||||
|
||||
const softDeleteTeam = `-- name: SoftDeleteTeam :exec
|
||||
UPDATE teams SET deleted_at = NOW() WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) SoftDeleteTeam(ctx context.Context, id pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, softDeleteTeam, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateMemberRole = `-- name: UpdateMemberRole :exec
|
||||
UPDATE users_teams SET role = $3 WHERE team_id = $1 AND user_id = $2
|
||||
`
|
||||
|
||||
type UpdateMemberRoleParams struct {
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateMemberRole(ctx context.Context, arg UpdateMemberRoleParams) error {
|
||||
_, err := q.db.Exec(ctx, updateMemberRole, arg.TeamID, arg.UserID, arg.Role)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateTeamName = `-- name: UpdateTeamName :exec
|
||||
UPDATE teams SET name = $2 WHERE id = $1 AND deleted_at IS NULL
|
||||
`
|
||||
|
||||
type UpdateTeamNameParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateTeamName(ctx context.Context, arg UpdateTeamNameParams) error {
|
||||
_, err := q.db.Exec(ctx, updateTeamName, arg.ID, arg.Name)
|
||||
return err
|
||||
}
|
||||
|
||||
241
internal/db/template_builds.sql.go
Normal file
241
internal/db/template_builds.sql.go
Normal file
@ -0,0 +1,241 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: template_builds.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const getTemplateBuild = `-- name: GetTemplateBuild :one
|
||||
SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post FROM template_builds WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetTemplateBuild(ctx context.Context, id pgtype.UUID) (TemplateBuild, error) {
|
||||
row := q.db.QueryRow(ctx, getTemplateBuild, id)
|
||||
var i TemplateBuild
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.BaseTemplate,
|
||||
&i.Recipe,
|
||||
&i.Healthcheck,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.Status,
|
||||
&i.CurrentStep,
|
||||
&i.TotalSteps,
|
||||
&i.Logs,
|
||||
&i.Error,
|
||||
&i.SandboxID,
|
||||
&i.HostID,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.CompletedAt,
|
||||
&i.TemplateID,
|
||||
&i.TeamID,
|
||||
&i.SkipPrePost,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertTemplateBuild = `-- name: InsertTemplateBuild :one
|
||||
INSERT INTO template_builds (id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, total_steps, template_id, team_id, skip_pre_post)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8, $9, $10, $11)
|
||||
RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post
|
||||
`
|
||||
|
||||
type InsertTemplateBuildParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
BaseTemplate string `json:"base_template"`
|
||||
Recipe []byte `json:"recipe"`
|
||||
Healthcheck string `json:"healthcheck"`
|
||||
Vcpus int32 `json:"vcpus"`
|
||||
MemoryMb int32 `json:"memory_mb"`
|
||||
TotalSteps int32 `json:"total_steps"`
|
||||
TemplateID pgtype.UUID `json:"template_id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
SkipPrePost bool `json:"skip_pre_post"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertTemplateBuild(ctx context.Context, arg InsertTemplateBuildParams) (TemplateBuild, error) {
|
||||
row := q.db.QueryRow(ctx, insertTemplateBuild,
|
||||
arg.ID,
|
||||
arg.Name,
|
||||
arg.BaseTemplate,
|
||||
arg.Recipe,
|
||||
arg.Healthcheck,
|
||||
arg.Vcpus,
|
||||
arg.MemoryMb,
|
||||
arg.TotalSteps,
|
||||
arg.TemplateID,
|
||||
arg.TeamID,
|
||||
arg.SkipPrePost,
|
||||
)
|
||||
var i TemplateBuild
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.BaseTemplate,
|
||||
&i.Recipe,
|
||||
&i.Healthcheck,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.Status,
|
||||
&i.CurrentStep,
|
||||
&i.TotalSteps,
|
||||
&i.Logs,
|
||||
&i.Error,
|
||||
&i.SandboxID,
|
||||
&i.HostID,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.CompletedAt,
|
||||
&i.TemplateID,
|
||||
&i.TeamID,
|
||||
&i.SkipPrePost,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listTemplateBuilds = `-- name: ListTemplateBuilds :many
|
||||
SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post FROM template_builds ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListTemplateBuilds(ctx context.Context) ([]TemplateBuild, error) {
|
||||
rows, err := q.db.Query(ctx, listTemplateBuilds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []TemplateBuild
|
||||
for rows.Next() {
|
||||
var i TemplateBuild
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.BaseTemplate,
|
||||
&i.Recipe,
|
||||
&i.Healthcheck,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.Status,
|
||||
&i.CurrentStep,
|
||||
&i.TotalSteps,
|
||||
&i.Logs,
|
||||
&i.Error,
|
||||
&i.SandboxID,
|
||||
&i.HostID,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.CompletedAt,
|
||||
&i.TemplateID,
|
||||
&i.TeamID,
|
||||
&i.SkipPrePost,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateBuildError = `-- name: UpdateBuildError :exec
|
||||
UPDATE template_builds
|
||||
SET error = $2, status = 'failed', completed_at = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
type UpdateBuildErrorParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateBuildError(ctx context.Context, arg UpdateBuildErrorParams) error {
|
||||
_, err := q.db.Exec(ctx, updateBuildError, arg.ID, arg.Error)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateBuildProgress = `-- name: UpdateBuildProgress :exec
|
||||
UPDATE template_builds
|
||||
SET current_step = $2, logs = $3
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
type UpdateBuildProgressParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
CurrentStep int32 `json:"current_step"`
|
||||
Logs []byte `json:"logs"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateBuildProgress(ctx context.Context, arg UpdateBuildProgressParams) error {
|
||||
_, err := q.db.Exec(ctx, updateBuildProgress, arg.ID, arg.CurrentStep, arg.Logs)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateBuildSandbox = `-- name: UpdateBuildSandbox :exec
|
||||
UPDATE template_builds
|
||||
SET sandbox_id = $2, host_id = $3
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
type UpdateBuildSandboxParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
SandboxID pgtype.UUID `json:"sandbox_id"`
|
||||
HostID pgtype.UUID `json:"host_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateBuildSandbox(ctx context.Context, arg UpdateBuildSandboxParams) error {
|
||||
_, err := q.db.Exec(ctx, updateBuildSandbox, arg.ID, arg.SandboxID, arg.HostID)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateBuildStatus = `-- name: UpdateBuildStatus :one
|
||||
UPDATE template_builds
|
||||
SET status = $2,
|
||||
started_at = CASE WHEN $2 = 'running' AND started_at IS NULL THEN NOW() ELSE started_at END,
|
||||
completed_at = CASE WHEN $2 IN ('success', 'failed', 'cancelled') THEN NOW() ELSE completed_at END
|
||||
WHERE id = $1
|
||||
RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post
|
||||
`
|
||||
|
||||
type UpdateBuildStatusParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateBuildStatus(ctx context.Context, arg UpdateBuildStatusParams) (TemplateBuild, error) {
|
||||
row := q.db.QueryRow(ctx, updateBuildStatus, arg.ID, arg.Status)
|
||||
var i TemplateBuild
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.BaseTemplate,
|
||||
&i.Recipe,
|
||||
&i.Healthcheck,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.Status,
|
||||
&i.CurrentStep,
|
||||
&i.TotalSteps,
|
||||
&i.Logs,
|
||||
&i.Error,
|
||||
&i.SandboxID,
|
||||
&i.HostID,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.CompletedAt,
|
||||
&i.TemplateID,
|
||||
&i.TeamID,
|
||||
&i.SkipPrePost,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@ -12,11 +12,11 @@ import (
|
||||
)
|
||||
|
||||
const deleteTemplate = `-- name: DeleteTemplate :exec
|
||||
DELETE FROM templates WHERE name = $1
|
||||
DELETE FROM templates WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteTemplate(ctx context.Context, name string) error {
|
||||
_, err := q.db.Exec(ctx, deleteTemplate, name)
|
||||
func (q *Queries) DeleteTemplate(ctx context.Context, id pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, deleteTemplate, id)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -25,8 +25,8 @@ DELETE FROM templates WHERE name = $1 AND team_id = $2
|
||||
`
|
||||
|
||||
type DeleteTemplateByTeamParams struct {
|
||||
Name string `json:"name"`
|
||||
TeamID string `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) DeleteTemplateByTeam(ctx context.Context, arg DeleteTemplateByTeamParams) error {
|
||||
@ -34,12 +34,23 @@ func (q *Queries) DeleteTemplateByTeam(ctx context.Context, arg DeleteTemplateBy
|
||||
return err
|
||||
}
|
||||
|
||||
const getTemplate = `-- name: GetTemplate :one
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE name = $1
|
||||
const deleteTemplatesByTeam = `-- name: DeleteTemplatesByTeam :exec
|
||||
DELETE FROM templates WHERE team_id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetTemplate(ctx context.Context, name string) (Template, error) {
|
||||
row := q.db.QueryRow(ctx, getTemplate, name)
|
||||
// Bulk delete all templates owned by a team (for team soft-delete cleanup).
|
||||
func (q *Queries) DeleteTemplatesByTeam(ctx context.Context, teamID pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, deleteTemplatesByTeam, teamID)
|
||||
return err
|
||||
}
|
||||
|
||||
const getPlatformTemplateByName = `-- name: GetPlatformTemplateByName :one
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE team_id = '00000000-0000-0000-0000-000000000000' AND name = $1
|
||||
`
|
||||
|
||||
// Check if a global (platform) template exists with the given name.
|
||||
func (q *Queries) GetPlatformTemplateByName(ctx context.Context, name string) (Template, error) {
|
||||
row := q.db.QueryRow(ctx, getPlatformTemplateByName, name)
|
||||
var i Template
|
||||
err := row.Scan(
|
||||
&i.Name,
|
||||
@ -49,19 +60,67 @@ func (q *Queries) GetTemplate(ctx context.Context, name string) (Template, error
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
&i.ID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getTemplate = `-- name: GetTemplate :one
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetTemplate(ctx context.Context, id pgtype.UUID) (Template, error) {
|
||||
row := q.db.QueryRow(ctx, getTemplate, id)
|
||||
var i Template
|
||||
err := row.Scan(
|
||||
&i.Name,
|
||||
&i.Type,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
&i.ID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getTemplateByName = `-- name: GetTemplateByName :one
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE team_id = $1 AND name = $2
|
||||
`
|
||||
|
||||
type GetTemplateByNameParams struct {
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Look up a template by team_id and name (exact team match, no global fallback).
|
||||
func (q *Queries) GetTemplateByName(ctx context.Context, arg GetTemplateByNameParams) (Template, error) {
|
||||
row := q.db.QueryRow(ctx, getTemplateByName, arg.TeamID, arg.Name)
|
||||
var i Template
|
||||
err := row.Scan(
|
||||
&i.Name,
|
||||
&i.Type,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
&i.ID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getTemplateByTeam = `-- name: GetTemplateByTeam :one
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE name = $1 AND team_id = $2
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE name = $1 AND (team_id = $2 OR team_id = '00000000-0000-0000-0000-000000000000')
|
||||
`
|
||||
|
||||
type GetTemplateByTeamParams struct {
|
||||
Name string `json:"name"`
|
||||
TeamID string `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
}
|
||||
|
||||
// Platform templates (team_id = 00000000-...) are visible to all teams.
|
||||
func (q *Queries) GetTemplateByTeam(ctx context.Context, arg GetTemplateByTeamParams) (Template, error) {
|
||||
row := q.db.QueryRow(ctx, getTemplateByTeam, arg.Name, arg.TeamID)
|
||||
var i Template
|
||||
@ -73,27 +132,30 @@ func (q *Queries) GetTemplateByTeam(ctx context.Context, arg GetTemplateByTeamPa
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
&i.ID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertTemplate = `-- name: InsertTemplate :one
|
||||
INSERT INTO templates (name, type, vcpus, memory_mb, size_bytes, team_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING name, type, vcpus, memory_mb, size_bytes, created_at, team_id
|
||||
INSERT INTO templates (id, name, type, vcpus, memory_mb, size_bytes, team_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
RETURNING name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id
|
||||
`
|
||||
|
||||
type InsertTemplateParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Vcpus pgtype.Int4 `json:"vcpus"`
|
||||
MemoryMb pgtype.Int4 `json:"memory_mb"`
|
||||
Vcpus int32 `json:"vcpus"`
|
||||
MemoryMb int32 `json:"memory_mb"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
TeamID string `json:"team_id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertTemplate(ctx context.Context, arg InsertTemplateParams) (Template, error) {
|
||||
row := q.db.QueryRow(ctx, insertTemplate,
|
||||
arg.ID,
|
||||
arg.Name,
|
||||
arg.Type,
|
||||
arg.Vcpus,
|
||||
@ -110,12 +172,13 @@ func (q *Queries) InsertTemplate(ctx context.Context, arg InsertTemplateParams)
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
&i.ID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listTemplates = `-- name: ListTemplates :many
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates ORDER BY created_at DESC
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) {
|
||||
@ -135,6 +198,7 @@ func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) {
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
&i.ID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -147,10 +211,11 @@ func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) {
|
||||
}
|
||||
|
||||
const listTemplatesByTeam = `-- name: ListTemplatesByTeam :many
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE team_id = $1 ORDER BY created_at DESC
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID string) ([]Template, error) {
|
||||
// Platform templates are visible to all teams.
|
||||
func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Template, error) {
|
||||
rows, err := q.db.Query(ctx, listTemplatesByTeam, teamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -167,6 +232,7 @@ func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID string) ([]Tem
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
&i.ID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -179,14 +245,15 @@ func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID string) ([]Tem
|
||||
}
|
||||
|
||||
const listTemplatesByTeamAndType = `-- name: ListTemplatesByTeamAndType :many
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE team_id = $1 AND type = $2 ORDER BY created_at DESC
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') AND type = $2 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
type ListTemplatesByTeamAndTypeParams struct {
|
||||
TeamID string `json:"team_id"`
|
||||
Type string `json:"type"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// Platform templates are visible to all teams.
|
||||
func (q *Queries) ListTemplatesByTeamAndType(ctx context.Context, arg ListTemplatesByTeamAndTypeParams) ([]Template, error) {
|
||||
rows, err := q.db.Query(ctx, listTemplatesByTeamAndType, arg.TeamID, arg.Type)
|
||||
if err != nil {
|
||||
@ -204,6 +271,41 @@ func (q *Queries) ListTemplatesByTeamAndType(ctx context.Context, arg ListTempla
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
&i.ID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listTemplatesByTeamOnly = `-- name: ListTemplatesByTeamOnly :many
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE team_id = $1 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
// List templates owned by a specific team (NOT including platform templates).
|
||||
func (q *Queries) ListTemplatesByTeamOnly(ctx context.Context, teamID pgtype.UUID) ([]Template, error) {
|
||||
rows, err := q.db.Query(ctx, listTemplatesByTeamOnly, teamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Template
|
||||
for rows.Next() {
|
||||
var i Template
|
||||
if err := rows.Scan(
|
||||
&i.Name,
|
||||
&i.Type,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
&i.ID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -216,7 +318,7 @@ func (q *Queries) ListTemplatesByTeamAndType(ctx context.Context, arg ListTempla
|
||||
}
|
||||
|
||||
const listTemplatesByType = `-- name: ListTemplatesByType :many
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE type = $1 ORDER BY created_at DESC
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE type = $1 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListTemplatesByType(ctx context.Context, type_ string) ([]Template, error) {
|
||||
@ -236,6 +338,7 @@ func (q *Queries) ListTemplatesByType(ctx context.Context, type_ string) ([]Temp
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
&i.ID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -16,8 +16,8 @@ DELETE FROM admin_permissions WHERE user_id = $1 AND permission = $2
|
||||
`
|
||||
|
||||
type DeleteAdminPermissionParams struct {
|
||||
UserID string `json:"user_id"`
|
||||
Permission string `json:"permission"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
Permission string `json:"permission"`
|
||||
}
|
||||
|
||||
func (q *Queries) DeleteAdminPermission(ctx context.Context, arg DeleteAdminPermissionParams) error {
|
||||
@ -29,7 +29,7 @@ const getAdminPermissions = `-- name: GetAdminPermissions :many
|
||||
SELECT id, user_id, permission, created_at FROM admin_permissions WHERE user_id = $1 ORDER BY permission
|
||||
`
|
||||
|
||||
func (q *Queries) GetAdminPermissions(ctx context.Context, userID string) ([]AdminPermission, error) {
|
||||
func (q *Queries) GetAdminPermissions(ctx context.Context, userID pgtype.UUID) ([]AdminPermission, error) {
|
||||
rows, err := q.db.Query(ctx, getAdminPermissions, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -55,7 +55,7 @@ func (q *Queries) GetAdminPermissions(ctx context.Context, userID string) ([]Adm
|
||||
}
|
||||
|
||||
const getAdminUsers = `-- name: GetAdminUsers :many
|
||||
SELECT id, email, password_hash, created_at, updated_at, is_admin FROM users WHERE is_admin = TRUE ORDER BY created_at
|
||||
SELECT id, email, password_hash, name, is_admin, created_at, updated_at FROM users WHERE is_admin = TRUE ORDER BY created_at
|
||||
`
|
||||
|
||||
func (q *Queries) GetAdminUsers(ctx context.Context) ([]User, error) {
|
||||
@ -71,9 +71,10 @@ func (q *Queries) GetAdminUsers(ctx context.Context) ([]User, error) {
|
||||
&i.ID,
|
||||
&i.Email,
|
||||
&i.PasswordHash,
|
||||
&i.Name,
|
||||
&i.IsAdmin,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.IsAdmin,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -86,7 +87,7 @@ func (q *Queries) GetAdminUsers(ctx context.Context) ([]User, error) {
|
||||
}
|
||||
|
||||
const getUserByEmail = `-- name: GetUserByEmail :one
|
||||
SELECT id, email, password_hash, created_at, updated_at, is_admin FROM users WHERE email = $1
|
||||
SELECT id, email, password_hash, name, is_admin, created_at, updated_at FROM users WHERE email = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error) {
|
||||
@ -96,27 +97,29 @@ func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error
|
||||
&i.ID,
|
||||
&i.Email,
|
||||
&i.PasswordHash,
|
||||
&i.Name,
|
||||
&i.IsAdmin,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.IsAdmin,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserByID = `-- name: GetUserByID :one
|
||||
SELECT id, email, password_hash, created_at, updated_at, is_admin FROM users WHERE id = $1
|
||||
SELECT id, email, password_hash, name, is_admin, created_at, updated_at FROM users WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetUserByID(ctx context.Context, id string) (User, error) {
|
||||
func (q *Queries) GetUserByID(ctx context.Context, id pgtype.UUID) (User, error) {
|
||||
row := q.db.QueryRow(ctx, getUserByID, id)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Email,
|
||||
&i.PasswordHash,
|
||||
&i.Name,
|
||||
&i.IsAdmin,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.IsAdmin,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@ -128,8 +131,8 @@ SELECT EXISTS(
|
||||
`
|
||||
|
||||
type HasAdminPermissionParams struct {
|
||||
UserID string `json:"user_id"`
|
||||
Permission string `json:"permission"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
Permission string `json:"permission"`
|
||||
}
|
||||
|
||||
func (q *Queries) HasAdminPermission(ctx context.Context, arg HasAdminPermissionParams) (bool, error) {
|
||||
@ -145,9 +148,9 @@ VALUES ($1, $2, $3)
|
||||
`
|
||||
|
||||
type InsertAdminPermissionParams struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Permission string `json:"permission"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
Permission string `json:"permission"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertAdminPermission(ctx context.Context, arg InsertAdminPermissionParams) error {
|
||||
@ -156,66 +159,118 @@ func (q *Queries) InsertAdminPermission(ctx context.Context, arg InsertAdminPerm
|
||||
}
|
||||
|
||||
const insertUser = `-- name: InsertUser :one
|
||||
INSERT INTO users (id, email, password_hash)
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING id, email, password_hash, created_at, updated_at, is_admin
|
||||
INSERT INTO users (id, email, password_hash, name)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING id, email, password_hash, name, is_admin, created_at, updated_at
|
||||
`
|
||||
|
||||
type InsertUserParams struct {
|
||||
ID string `json:"id"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
PasswordHash pgtype.Text `json:"password_hash"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) {
|
||||
row := q.db.QueryRow(ctx, insertUser, arg.ID, arg.Email, arg.PasswordHash)
|
||||
row := q.db.QueryRow(ctx, insertUser,
|
||||
arg.ID,
|
||||
arg.Email,
|
||||
arg.PasswordHash,
|
||||
arg.Name,
|
||||
)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Email,
|
||||
&i.PasswordHash,
|
||||
&i.Name,
|
||||
&i.IsAdmin,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.IsAdmin,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertUserOAuth = `-- name: InsertUserOAuth :one
|
||||
INSERT INTO users (id, email)
|
||||
VALUES ($1, $2)
|
||||
RETURNING id, email, password_hash, created_at, updated_at, is_admin
|
||||
INSERT INTO users (id, email, name)
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING id, email, password_hash, name, is_admin, created_at, updated_at
|
||||
`
|
||||
|
||||
type InsertUserOAuthParams struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertUserOAuth(ctx context.Context, arg InsertUserOAuthParams) (User, error) {
|
||||
row := q.db.QueryRow(ctx, insertUserOAuth, arg.ID, arg.Email)
|
||||
row := q.db.QueryRow(ctx, insertUserOAuth, arg.ID, arg.Email, arg.Name)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Email,
|
||||
&i.PasswordHash,
|
||||
&i.Name,
|
||||
&i.IsAdmin,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.IsAdmin,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const searchUsersByEmailPrefix = `-- name: SearchUsersByEmailPrefix :many
|
||||
SELECT id, email FROM users WHERE email LIKE $1 || '%' ORDER BY email LIMIT 10
|
||||
`
|
||||
|
||||
type SearchUsersByEmailPrefixRow struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func (q *Queries) SearchUsersByEmailPrefix(ctx context.Context, dollar_1 pgtype.Text) ([]SearchUsersByEmailPrefixRow, error) {
|
||||
rows, err := q.db.Query(ctx, searchUsersByEmailPrefix, dollar_1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []SearchUsersByEmailPrefixRow
|
||||
for rows.Next() {
|
||||
var i SearchUsersByEmailPrefixRow
|
||||
if err := rows.Scan(&i.ID, &i.Email); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const setUserAdmin = `-- name: SetUserAdmin :exec
|
||||
UPDATE users SET is_admin = $2, updated_at = NOW() WHERE id = $1
|
||||
`
|
||||
|
||||
type SetUserAdminParams struct {
|
||||
ID string `json:"id"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
}
|
||||
|
||||
func (q *Queries) SetUserAdmin(ctx context.Context, arg SetUserAdminParams) error {
|
||||
_, err := q.db.Exec(ctx, setUserAdmin, arg.ID, arg.IsAdmin)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateUserName = `-- name: UpdateUserName :exec
|
||||
UPDATE users SET name = $2, updated_at = NOW() WHERE id = $1
|
||||
`
|
||||
|
||||
type UpdateUserNameParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateUserName(ctx context.Context, arg UpdateUserNameParams) error {
|
||||
_, err := q.db.Exec(ctx, updateUserName, arg.ID, arg.Name)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -116,9 +116,10 @@ type SnapshotDevice struct {
|
||||
// writable CoW layer.
|
||||
//
|
||||
// The origin loop device must already exist (from LoopRegistry.Acquire).
|
||||
func CreateSnapshot(name, originLoopDev, cowPath string, originSizeBytes int64) (*SnapshotDevice, error) {
|
||||
// Create sparse CoW file sized to match the origin.
|
||||
if err := createSparseFile(cowPath, originSizeBytes); err != nil {
|
||||
func CreateSnapshot(name, originLoopDev, cowPath string, originSizeBytes, cowSizeBytes int64) (*SnapshotDevice, error) {
|
||||
// Create sparse CoW file. The logical size limits how many blocks can be
|
||||
// modified; because the file is sparse, only written blocks use real disk.
|
||||
if err := createSparseFile(cowPath, cowSizeBytes); err != nil {
|
||||
return nil, fmt.Errorf("create cow file: %w", err)
|
||||
}
|
||||
|
||||
@ -128,6 +129,9 @@ func CreateSnapshot(name, originLoopDev, cowPath string, originSizeBytes int64)
|
||||
return nil, fmt.Errorf("losetup cow: %w", err)
|
||||
}
|
||||
|
||||
// The dm-snapshot virtual device size must match the origin — the snapshot
|
||||
// target maps 1:1 onto origin sectors. The CoW file just needs enough
|
||||
// space to store all modified blocks (it's sparse, so 20GB costs nothing).
|
||||
sectors := originSizeBytes / 512
|
||||
if err := dmsetupCreate(name, originLoopDev, cowLoopDev, sectors); err != nil {
|
||||
if detachErr := losetupDetach(cowLoopDev); detachErr != nil {
|
||||
@ -220,6 +224,7 @@ func FlattenSnapshot(dmDevPath, outputPath string) error {
|
||||
"if="+dmDevPath,
|
||||
"of="+outputPath,
|
||||
"bs=4M",
|
||||
"conv=sparse",
|
||||
"status=none",
|
||||
)
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
|
||||
@ -3,14 +3,12 @@ package envdclient
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
@ -49,35 +47,6 @@ func (c *Client) BaseURL() string {
|
||||
return c.base
|
||||
}
|
||||
|
||||
// Init calls POST /init on envd to sync the guest clock with the host.
|
||||
// This is important after snapshot resume where the guest clock is frozen.
|
||||
func (c *Client) Init(ctx context.Context) error {
|
||||
now := time.Now().UTC()
|
||||
body, err := json.Marshal(map[string]any{"timestamp": now})
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal init body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base+"/init", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create init request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("init: status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExecResult holds the output of a command execution.
|
||||
type ExecResult struct {
|
||||
Stdout []byte
|
||||
|
||||
73
internal/events/event.go
Normal file
73
internal/events/event.go
Normal file
@ -0,0 +1,73 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// EventPublisher pushes events onto the notification stream.
|
||||
// Satisfied by *channels.Publisher.
|
||||
type EventPublisher interface {
|
||||
Publish(ctx context.Context, e Event)
|
||||
}
|
||||
|
||||
// ActorKind identifies what initiated an event.
|
||||
type ActorKind string
|
||||
|
||||
const (
|
||||
ActorUser ActorKind = "user"
|
||||
ActorAPIKey ActorKind = "api_key"
|
||||
ActorSystem ActorKind = "system"
|
||||
)
|
||||
|
||||
// Actor describes who triggered an event.
|
||||
type Actor struct {
|
||||
Type ActorKind `json:"type"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
// Resource identifies the object the event relates to.
|
||||
type Resource struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// Event is the canonical notification payload published to the Redis stream
|
||||
// and delivered to channel subscribers.
|
||||
type Event struct {
|
||||
Event string `json:"event"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
TeamID string `json:"team_id"`
|
||||
Actor Actor `json:"actor"`
|
||||
Resource Resource `json:"resource"`
|
||||
}
|
||||
|
||||
// Event type constants.
|
||||
const (
|
||||
CapsuleCreated = "capsule.created"
|
||||
CapsuleRunning = "capsule.running"
|
||||
CapsulePaused = "capsule.paused"
|
||||
CapsuleDestroyed = "capsule.destroyed"
|
||||
SnapshotCreated = "template.snapshot.created"
|
||||
SnapshotDeleted = "template.snapshot.deleted"
|
||||
HostUp = "host.up"
|
||||
HostDown = "host.down"
|
||||
)
|
||||
|
||||
// AllEventTypes is the complete set of valid event type strings.
|
||||
var AllEventTypes = []string{
|
||||
CapsuleCreated,
|
||||
CapsuleRunning,
|
||||
CapsulePaused,
|
||||
CapsuleDestroyed,
|
||||
SnapshotCreated,
|
||||
SnapshotDeleted,
|
||||
HostUp,
|
||||
HostDown,
|
||||
}
|
||||
|
||||
// Now returns the current time formatted for event timestamps.
|
||||
func Now() string {
|
||||
return time.Now().UTC().Format(time.RFC3339)
|
||||
}
|
||||
42
internal/hostagent/certstore.go
Normal file
42
internal/hostagent/certstore.go
Normal file
@ -0,0 +1,42 @@
|
||||
package hostagent
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// CertStore provides lock-free read/write access to the agent's current TLS
|
||||
// certificate. It is used with tls.Config.GetCertificate to enable hot-swap
|
||||
// of the agent's cert on JWT refresh without restarting the server.
|
||||
//
|
||||
// The zero value is usable; GetCert returns an error until a cert is stored.
|
||||
type CertStore struct {
|
||||
ptr atomic.Pointer[tls.Certificate]
|
||||
}
|
||||
|
||||
// Store atomically replaces the current certificate.
|
||||
func (s *CertStore) Store(cert *tls.Certificate) {
|
||||
s.ptr.Store(cert)
|
||||
}
|
||||
|
||||
// ParseAndStore parses certPEM+keyPEM and atomically replaces the stored cert.
|
||||
// If parsing fails the existing cert is unchanged.
|
||||
func (s *CertStore) ParseAndStore(certPEM, keyPEM string) error {
|
||||
cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse TLS key pair: %w", err)
|
||||
}
|
||||
s.ptr.Store(&cert)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCert satisfies tls.Config.GetCertificate. Returns an error if no cert has
|
||||
// been stored yet.
|
||||
func (s *CertStore) GetCert(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
cert := s.ptr.Load()
|
||||
if cert == nil {
|
||||
return nil, fmt.Errorf("no TLS certificate available")
|
||||
}
|
||||
return cert, nil
|
||||
}
|
||||
89
internal/hostagent/proxy.go
Normal file
89
internal/hostagent/proxy.go
Normal file
@ -0,0 +1,89 @@
|
||||
package hostagent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/sandbox"
|
||||
)
|
||||
|
||||
// ProxyHandler reverse-proxies HTTP requests to services running inside
|
||||
// sandboxes. It handles requests of the form:
|
||||
//
|
||||
// /proxy/{sandbox_id}/{port}/{path...}
|
||||
//
|
||||
// The sandbox's HostIP (routable on this machine) is used as the upstream.
|
||||
// This supports any protocol that rides on HTTP, including WebSocket upgrades.
|
||||
type ProxyHandler struct {
|
||||
mgr *sandbox.Manager
|
||||
transport http.RoundTripper
|
||||
}
|
||||
|
||||
// NewProxyHandler creates a new sandbox proxy handler.
|
||||
func NewProxyHandler(mgr *sandbox.Manager) *ProxyHandler {
|
||||
return &ProxyHandler{
|
||||
mgr: mgr,
|
||||
transport: http.DefaultTransport,
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler.
|
||||
func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Expected path: /proxy/{sandbox_id}/{port}/...
|
||||
// After trimming "/proxy/", we get "{sandbox_id}/{port}/..."
|
||||
trimmed := strings.TrimPrefix(r.URL.Path, "/proxy/")
|
||||
if trimmed == r.URL.Path {
|
||||
http.Error(w, "invalid proxy path", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.SplitN(trimmed, "/", 3)
|
||||
if len(parts) < 2 {
|
||||
http.Error(w, "expected /proxy/{sandbox_id}/{port}/...", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
sandboxID := parts[0]
|
||||
port := parts[1]
|
||||
remainder := ""
|
||||
if len(parts) == 3 {
|
||||
remainder = parts[2]
|
||||
}
|
||||
|
||||
// Validate port is a number in the valid range.
|
||||
portNum, err := strconv.Atoi(port)
|
||||
if err != nil || portNum < 1 || portNum > 65535 {
|
||||
http.Error(w, "invalid port", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
hostIP, tracker, ok := h.mgr.AcquireProxyConn(sandboxID)
|
||||
if !ok {
|
||||
http.Error(w, "sandbox is not available", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
defer tracker.Release()
|
||||
|
||||
targetHost := fmt.Sprintf("%s:%d", hostIP, portNum)
|
||||
|
||||
proxy := &httputil.ReverseProxy{
|
||||
Transport: h.transport,
|
||||
Director: func(req *http.Request) {
|
||||
req.URL.Scheme = "http"
|
||||
req.URL.Host = targetHost
|
||||
req.URL.Path = "/" + remainder
|
||||
req.URL.RawQuery = r.URL.RawQuery
|
||||
req.Host = targetHost
|
||||
},
|
||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
slog.Debug("proxy error", "sandbox_id", sandboxID, "port", port, "error", err)
|
||||
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
|
||||
},
|
||||
}
|
||||
|
||||
proxy.ServeHTTP(w, r)
|
||||
}
|
||||
@ -17,11 +17,24 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// TokenFile is the JSON format persisted to WRENN_DIR/host-credentials.json.
|
||||
// It holds all credentials the agent needs: the host JWT, refresh token, and
|
||||
// (when mTLS is enabled) the TLS certificate material for the agent's server.
|
||||
type TokenFile struct {
|
||||
HostID string `json:"host_id"`
|
||||
JWT string `json:"jwt"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
// mTLS fields — empty when the CP has no CA configured.
|
||||
CertPEM string `json:"cert_pem,omitempty"`
|
||||
KeyPEM string `json:"key_pem,omitempty"`
|
||||
CACertPEM string `json:"ca_cert_pem,omitempty"`
|
||||
}
|
||||
|
||||
// RegistrationConfig holds the configuration for host registration.
|
||||
type RegistrationConfig struct {
|
||||
CPURL string // Control plane base URL (e.g., http://localhost:8000)
|
||||
RegistrationToken string // One-time registration token from the control plane
|
||||
TokenFile string // Path to persist the host JWT after registration
|
||||
TokenFile string // Path to persist the credentials after registration
|
||||
Address string // Externally-reachable address (ip:port) for this host
|
||||
}
|
||||
|
||||
@ -34,9 +47,18 @@ type registerRequest struct {
|
||||
Address string `json:"address"`
|
||||
}
|
||||
|
||||
type registerResponse struct {
|
||||
Host json.RawMessage `json:"host"`
|
||||
Token string `json:"token"`
|
||||
// authResponse is the shared JSON shape for both register and refresh responses.
|
||||
type authResponse struct {
|
||||
Host json.RawMessage `json:"host"`
|
||||
Token string `json:"token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
CertPEM string `json:"cert_pem,omitempty"`
|
||||
KeyPEM string `json:"key_pem,omitempty"`
|
||||
CACertPEM string `json:"ca_cert_pem,omitempty"`
|
||||
}
|
||||
|
||||
type refreshRequest struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
type errorResponse struct {
|
||||
@ -46,20 +68,47 @@ type errorResponse struct {
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
// Register calls the control plane to register this host agent and persists
|
||||
// the returned JWT to disk. Returns the host JWT token string.
|
||||
func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
|
||||
// Check if we already have a saved token.
|
||||
if data, err := os.ReadFile(cfg.TokenFile); err == nil {
|
||||
token := strings.TrimSpace(string(data))
|
||||
if token != "" {
|
||||
slog.Info("loaded existing host token", "file", cfg.TokenFile)
|
||||
return token, nil
|
||||
}
|
||||
// LoadTokenFile reads and parses the persisted credentials file.
|
||||
func LoadTokenFile(path string) (*TokenFile, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Support legacy format (raw JWT string) for backwards compatibility.
|
||||
trimmed := strings.TrimSpace(string(data))
|
||||
if !strings.HasPrefix(trimmed, "{") {
|
||||
// Old format: just the JWT, no refresh token.
|
||||
hostID, _ := hostIDFromJWT(trimmed)
|
||||
return &TokenFile{HostID: hostID, JWT: trimmed}, nil
|
||||
}
|
||||
var tf TokenFile
|
||||
if err := json.Unmarshal(data, &tf); err != nil {
|
||||
return nil, fmt.Errorf("parse credentials file: %w", err)
|
||||
}
|
||||
return &tf, nil
|
||||
}
|
||||
|
||||
// saveTokenFile writes the credentials file as JSON with 0600 permissions.
|
||||
func saveTokenFile(path string, tf TokenFile) error {
|
||||
data, err := json.MarshalIndent(tf, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal credentials file: %w", err)
|
||||
}
|
||||
return os.WriteFile(path, data, 0600)
|
||||
}
|
||||
|
||||
// Register calls the control plane to register this host agent and persists
|
||||
// the returned credentials to disk. Returns the full TokenFile on success.
|
||||
func Register(ctx context.Context, cfg RegistrationConfig) (*TokenFile, error) {
|
||||
// If no explicit registration token was given, reuse the saved credentials.
|
||||
// A --register flag always overrides the local file so operators can
|
||||
// force re-registration without manually deleting the credentials file.
|
||||
if cfg.RegistrationToken == "" {
|
||||
return "", fmt.Errorf("no saved host token and no registration token provided")
|
||||
if tf, err := LoadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" {
|
||||
slog.Info("loaded existing host credentials", "file", cfg.TokenFile, "host_id", tf.HostID)
|
||||
return tf, nil
|
||||
}
|
||||
return nil, fmt.Errorf("no saved host credentials and no registration token provided (use --register flag)")
|
||||
}
|
||||
|
||||
arch := runtime.GOARCH
|
||||
@ -78,85 +127,238 @@ func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
|
||||
|
||||
body, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal registration request: %w", err)
|
||||
return nil, fmt.Errorf("marshal registration request: %w", err)
|
||||
}
|
||||
|
||||
url := strings.TrimRight(cfg.CPURL, "/") + "/v1/hosts/register"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create registration request: %w", err)
|
||||
return nil, fmt.Errorf("create registration request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("registration request failed: %w", err)
|
||||
return nil, fmt.Errorf("registration request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read registration response: %w", err)
|
||||
return nil, fmt.Errorf("read registration response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
var errResp errorResponse
|
||||
if err := json.Unmarshal(respBody, &errResp); err == nil {
|
||||
return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message)
|
||||
return nil, fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message)
|
||||
}
|
||||
return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody))
|
||||
return nil, fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var regResp registerResponse
|
||||
var regResp authResponse
|
||||
if err := json.Unmarshal(respBody, ®Resp); err != nil {
|
||||
return "", fmt.Errorf("parse registration response: %w", err)
|
||||
return nil, fmt.Errorf("parse registration response: %w", err)
|
||||
}
|
||||
|
||||
if regResp.Token == "" {
|
||||
return "", fmt.Errorf("registration response missing token")
|
||||
return nil, fmt.Errorf("registration response missing token")
|
||||
}
|
||||
|
||||
// Persist the token to disk for subsequent startups.
|
||||
if err := os.WriteFile(cfg.TokenFile, []byte(regResp.Token), 0600); err != nil {
|
||||
return "", fmt.Errorf("save host token: %w", err)
|
||||
hostID, err := hostIDFromJWT(regResp.Token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("extract host ID from JWT: %w", err)
|
||||
}
|
||||
slog.Info("host registered and token saved", "file", cfg.TokenFile)
|
||||
|
||||
return regResp.Token, nil
|
||||
tf := TokenFile{
|
||||
HostID: hostID,
|
||||
JWT: regResp.Token,
|
||||
RefreshToken: regResp.RefreshToken,
|
||||
CertPEM: regResp.CertPEM,
|
||||
KeyPEM: regResp.KeyPEM,
|
||||
CACertPEM: regResp.CACertPEM,
|
||||
}
|
||||
if err := saveTokenFile(cfg.TokenFile, tf); err != nil {
|
||||
return nil, fmt.Errorf("save host credentials: %w", err)
|
||||
}
|
||||
slog.Info("host registered and credentials saved", "file", cfg.TokenFile, "host_id", hostID)
|
||||
|
||||
return &tf, nil
|
||||
}
|
||||
|
||||
// RefreshCredentials exchanges the refresh token for a new JWT, rotated refresh
|
||||
// token, and (when mTLS is enabled) a new TLS certificate. The credentials file
|
||||
// is updated in place. Returns the updated TokenFile.
|
||||
func RefreshCredentials(ctx context.Context, cpURL, credentialsFilePath string) (*TokenFile, error) {
|
||||
tf, err := LoadTokenFile(credentialsFilePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load credentials file: %w", err)
|
||||
}
|
||||
if tf.RefreshToken == "" {
|
||||
return nil, fmt.Errorf("no refresh token available; host must re-register")
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(refreshRequest{RefreshToken: tf.RefreshToken})
|
||||
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/auth/refresh"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create refresh request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refresh request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read refresh response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errResp errorResponse
|
||||
if json.Unmarshal(respBody, &errResp) == nil {
|
||||
return nil, fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, errResp.Error.Message)
|
||||
}
|
||||
return nil, fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var refResp authResponse
|
||||
if err := json.Unmarshal(respBody, &refResp); err != nil {
|
||||
return nil, fmt.Errorf("parse refresh response: %w", err)
|
||||
}
|
||||
|
||||
tf.JWT = refResp.Token
|
||||
tf.RefreshToken = refResp.RefreshToken
|
||||
if refResp.CertPEM != "" {
|
||||
tf.CertPEM = refResp.CertPEM
|
||||
tf.KeyPEM = refResp.KeyPEM
|
||||
tf.CACertPEM = refResp.CACertPEM
|
||||
}
|
||||
if err := saveTokenFile(credentialsFilePath, *tf); err != nil {
|
||||
return nil, fmt.Errorf("save refreshed credentials: %w", err)
|
||||
}
|
||||
|
||||
slog.Info("host credentials refreshed", "host_id", tf.HostID)
|
||||
return tf, nil
|
||||
}
|
||||
|
||||
// StartHeartbeat launches a background goroutine that sends periodic heartbeats
|
||||
// to the control plane. It runs until the context is cancelled.
|
||||
func StartHeartbeat(ctx context.Context, cpURL, hostID, hostToken string, interval time.Duration) {
|
||||
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/" + hostID + "/heartbeat"
|
||||
//
|
||||
// On 401/403: the heartbeat loop attempts to refresh credentials. If the refresh
|
||||
// also fails (expired refresh token), it calls pauseAll and stops.
|
||||
//
|
||||
// On repeated network failures (3 consecutive), it calls pauseAll but keeps
|
||||
// retrying — the connection may recover and the host should resume heartbeating.
|
||||
//
|
||||
// onDeleted is called when CP returns 404, meaning this host record was deleted.
|
||||
// The credentials file is removed before calling onDeleted so subsequent starts
|
||||
// prompt for a new registration token.
|
||||
//
|
||||
// onCredsRefreshed is called after a successful credential refresh (JWT + cert).
|
||||
// It may be nil. The caller uses it to hot-swap the agent's TLS certificate.
|
||||
func StartHeartbeat(ctx context.Context, cpURL, credentialsFilePath, hostID string, interval time.Duration, pauseAll func(), onDeleted func(), onCredsRefreshed func(*TokenFile)) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
consecutiveFailures := 0
|
||||
pausedDueToFailure := false
|
||||
currentJWT := ""
|
||||
|
||||
// Load the current JWT from the credentials file.
|
||||
if tf, err := LoadTokenFile(credentialsFilePath); err == nil {
|
||||
currentJWT = tf.JWT
|
||||
}
|
||||
|
||||
// beat sends one heartbeat. Returns true if the loop should stop.
|
||||
beat := func() (stop bool) {
|
||||
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/" + hostID + "/heartbeat"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
|
||||
if err != nil {
|
||||
slog.Warn("heartbeat: failed to create request", "error", err)
|
||||
return false
|
||||
}
|
||||
req.Header.Set("X-Host-Token", currentJWT)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
consecutiveFailures++
|
||||
slog.Warn("heartbeat: request failed", "error", err, "consecutive_failures", consecutiveFailures)
|
||||
if consecutiveFailures >= 3 && !pausedDueToFailure {
|
||||
slog.Error("heartbeat: CP unreachable after 3 failures — pausing all sandboxes")
|
||||
if pauseAll != nil {
|
||||
pauseAll()
|
||||
}
|
||||
pausedDueToFailure = true
|
||||
}
|
||||
return false
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusNoContent:
|
||||
if consecutiveFailures > 0 || pausedDueToFailure {
|
||||
slog.Info("heartbeat: CP connection restored")
|
||||
}
|
||||
consecutiveFailures = 0
|
||||
pausedDueToFailure = false
|
||||
|
||||
case http.StatusUnauthorized, http.StatusForbidden:
|
||||
slog.Warn("heartbeat: JWT rejected — attempting credentials refresh")
|
||||
newCreds, refreshErr := RefreshCredentials(ctx, cpURL, credentialsFilePath)
|
||||
if refreshErr != nil {
|
||||
slog.Error("heartbeat: credentials refresh failed — pausing all sandboxes; manual re-registration required",
|
||||
"error", refreshErr)
|
||||
if pauseAll != nil && !pausedDueToFailure {
|
||||
pauseAll()
|
||||
pausedDueToFailure = true
|
||||
}
|
||||
// Stop the heartbeat loop — operator must re-register.
|
||||
return true
|
||||
}
|
||||
currentJWT = newCreds.JWT
|
||||
slog.Info("heartbeat: credentials refreshed successfully")
|
||||
if onCredsRefreshed != nil {
|
||||
onCredsRefreshed(newCreds)
|
||||
}
|
||||
|
||||
case http.StatusNotFound:
|
||||
slog.Error("heartbeat: host no longer exists in CP — host was deleted; removing credentials file and exiting")
|
||||
if err := os.Remove(credentialsFilePath); err != nil && !os.IsNotExist(err) {
|
||||
slog.Warn("heartbeat: failed to remove credentials file", "error", err)
|
||||
}
|
||||
if onDeleted != nil {
|
||||
onDeleted()
|
||||
}
|
||||
return true
|
||||
|
||||
default:
|
||||
slog.Warn("heartbeat: unexpected status", "status", resp.StatusCode)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Send an immediate heartbeat on startup so the CP sees the host as
|
||||
// online without waiting for the first ticker tick.
|
||||
if beat() {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
|
||||
if err != nil {
|
||||
slog.Warn("heartbeat: failed to create request", "error", err)
|
||||
continue
|
||||
}
|
||||
req.Header.Set("X-Host-Token", hostToken)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
slog.Warn("heartbeat: request failed", "error", err)
|
||||
continue
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
slog.Warn("heartbeat: unexpected status", "status", resp.StatusCode)
|
||||
if beat() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -166,6 +368,12 @@ func StartHeartbeat(ctx context.Context, cpURL, hostID, hostToken string, interv
|
||||
// HostIDFromToken extracts the host_id claim from a host JWT without
|
||||
// verifying the signature (the agent doesn't have the signing secret).
|
||||
func HostIDFromToken(token string) (string, error) {
|
||||
return hostIDFromJWT(token)
|
||||
}
|
||||
|
||||
// hostIDFromJWT is the internal implementation used by both HostIDFromToken and
|
||||
// the credentials file loader.
|
||||
func hostIDFromJWT(token string) (string, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return "", fmt.Errorf("invalid JWT format")
|
||||
|
||||
@ -12,6 +12,8 @@ import (
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
@ -22,12 +24,28 @@ import (
|
||||
// Server implements the HostAgentService Connect RPC handler.
|
||||
type Server struct {
|
||||
hostagentv1connect.UnimplementedHostAgentServiceHandler
|
||||
mgr *sandbox.Manager
|
||||
mgr *sandbox.Manager
|
||||
terminate func() // called when the CP requests agent termination
|
||||
}
|
||||
|
||||
// NewServer creates a new host agent RPC server.
|
||||
func NewServer(mgr *sandbox.Manager) *Server {
|
||||
return &Server{mgr: mgr}
|
||||
// terminate is invoked (in a goroutine) when the CP calls the Terminate RPC,
|
||||
// allowing main to perform a clean shutdown.
|
||||
func NewServer(mgr *sandbox.Manager, terminate func()) *Server {
|
||||
return &Server{mgr: mgr, terminate: terminate}
|
||||
}
|
||||
|
||||
// parseUUIDString parses a UUID hex string into a pgtype.UUID.
|
||||
// An empty string yields an all-zeros UUID (valid).
|
||||
func parseUUIDString(s string) (pgtype.UUID, error) {
|
||||
if s == "" {
|
||||
return pgtype.UUID{Bytes: [16]byte{}, Valid: true}, nil
|
||||
}
|
||||
parsed, err := uuid.Parse(s)
|
||||
if err != nil {
|
||||
return pgtype.UUID{}, fmt.Errorf("invalid UUID %q: %w", s, err)
|
||||
}
|
||||
return pgtype.UUID{Bytes: parsed, Valid: true}, nil
|
||||
}
|
||||
|
||||
func (s *Server) CreateSandbox(
|
||||
@ -36,7 +54,16 @@ func (s *Server) CreateSandbox(
|
||||
) (*connect.Response[pb.CreateSandboxResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
sb, err := s.mgr.Create(ctx, msg.SandboxId, msg.Template, int(msg.Vcpus), int(msg.MemoryMb), int(msg.TimeoutSec))
|
||||
teamID, err := parseUUIDString(msg.TeamId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
templateID, err := parseUUIDString(msg.TemplateId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
sb, err := s.mgr.Create(ctx, msg.SandboxId, teamID, templateID, int(msg.Vcpus), int(msg.MemoryMb), int(msg.TimeoutSec), int(msg.DiskSizeMb))
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create sandbox: %w", err))
|
||||
}
|
||||
@ -87,12 +114,21 @@ func (s *Server) CreateSnapshot(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.CreateSnapshotRequest],
|
||||
) (*connect.Response[pb.CreateSnapshotResponse], error) {
|
||||
sizeBytes, err := s.mgr.CreateSnapshot(ctx, req.Msg.SandboxId, req.Msg.Name)
|
||||
msg := req.Msg
|
||||
teamID, err := parseUUIDString(msg.TeamId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
templateID, err := parseUUIDString(msg.TemplateId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
sizeBytes, err := s.mgr.CreateSnapshot(ctx, msg.SandboxId, teamID, templateID)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create snapshot: %w", err))
|
||||
}
|
||||
return connect.NewResponse(&pb.CreateSnapshotResponse{
|
||||
Name: req.Msg.Name,
|
||||
SizeBytes: sizeBytes,
|
||||
}), nil
|
||||
}
|
||||
@ -101,12 +137,45 @@ func (s *Server) DeleteSnapshot(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.DeleteSnapshotRequest],
|
||||
) (*connect.Response[pb.DeleteSnapshotResponse], error) {
|
||||
if err := s.mgr.DeleteSnapshot(req.Msg.Name); err != nil {
|
||||
msg := req.Msg
|
||||
teamID, err := parseUUIDString(msg.TeamId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
templateID, err := parseUUIDString(msg.TemplateId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
if err := s.mgr.DeleteSnapshot(teamID, templateID); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("delete snapshot: %w", err))
|
||||
}
|
||||
return connect.NewResponse(&pb.DeleteSnapshotResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Server) FlattenRootfs(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.FlattenRootfsRequest],
|
||||
) (*connect.Response[pb.FlattenRootfsResponse], error) {
|
||||
msg := req.Msg
|
||||
teamID, err := parseUUIDString(msg.TeamId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
templateID, err := parseUUIDString(msg.TemplateId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
sizeBytes, err := s.mgr.FlattenRootfs(ctx, msg.SandboxId, teamID, templateID)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("flatten rootfs: %w", err))
|
||||
}
|
||||
return connect.NewResponse(&pb.FlattenRootfsResponse{
|
||||
SizeBytes: sizeBytes,
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s *Server) PingSandbox(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.PingSandboxRequest],
|
||||
@ -397,7 +466,8 @@ func (s *Server) ListSandboxes(
|
||||
infos[i] = &pb.SandboxInfo{
|
||||
SandboxId: sb.ID,
|
||||
Status: string(sb.Status),
|
||||
Template: sb.Template,
|
||||
TeamId: uuid.UUID(sb.TemplateTeamID).String(),
|
||||
TemplateId: uuid.UUID(sb.TemplateID).String(),
|
||||
Vcpus: int32(sb.VCPUs),
|
||||
MemoryMb: int32(sb.MemoryMB),
|
||||
HostIp: sb.HostIP.String(),
|
||||
@ -412,3 +482,66 @@ func (s *Server) ListSandboxes(
|
||||
AutoPausedSandboxIds: s.mgr.DrainAutoPausedIDs(),
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s *Server) Terminate(
|
||||
_ context.Context,
|
||||
_ *connect.Request[pb.TerminateRequest],
|
||||
) (*connect.Response[pb.TerminateResponse], error) {
|
||||
slog.Info("terminate RPC received — scheduling shutdown")
|
||||
if s.terminate != nil {
|
||||
go s.terminate()
|
||||
}
|
||||
return connect.NewResponse(&pb.TerminateResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Server) GetSandboxMetrics(
|
||||
_ context.Context,
|
||||
req *connect.Request[pb.GetSandboxMetricsRequest],
|
||||
) (*connect.Response[pb.GetSandboxMetricsResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
points, err := s.mgr.GetMetrics(msg.SandboxId, msg.Range)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
if strings.Contains(err.Error(), "invalid range") {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeInternal, err)
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.GetSandboxMetricsResponse{Points: metricPointsToPB(points)}), nil
|
||||
}
|
||||
|
||||
func (s *Server) FlushSandboxMetrics(
|
||||
_ context.Context,
|
||||
req *connect.Request[pb.FlushSandboxMetricsRequest],
|
||||
) (*connect.Response[pb.FlushSandboxMetricsResponse], error) {
|
||||
pts10m, pts2h, pts24h, err := s.mgr.FlushMetrics(req.Msg.SandboxId)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeInternal, err)
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.FlushSandboxMetricsResponse{
|
||||
Points_10M: metricPointsToPB(pts10m),
|
||||
Points_2H: metricPointsToPB(pts2h),
|
||||
Points_24H: metricPointsToPB(pts24h),
|
||||
}), nil
|
||||
}
|
||||
|
||||
func metricPointsToPB(pts []sandbox.MetricPoint) []*pb.MetricPoint {
|
||||
out := make([]*pb.MetricPoint, len(pts))
|
||||
for i, p := range pts {
|
||||
out[i] = &pb.MetricPoint{
|
||||
TimestampUnix: p.Timestamp.Unix(),
|
||||
CpuPct: p.CPUPct,
|
||||
MemBytes: p.MemBytes,
|
||||
DiskBytes: p.DiskBytes,
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@ -4,8 +4,171 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const (
|
||||
base36Alphabet = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
base36IDLen = 25 // ceil(128 * log2 / log36) = 25 chars for a full UUID
|
||||
)
|
||||
|
||||
var base36Base = big.NewInt(36)
|
||||
|
||||
// --- Generation ---
|
||||
|
||||
// newUUID returns a new random (v4) UUID wrapped in pgtype.UUID for direct DB use.
|
||||
func newUUID() pgtype.UUID {
|
||||
return pgtype.UUID{Bytes: uuid.New(), Valid: true}
|
||||
}
|
||||
|
||||
func NewSandboxID() pgtype.UUID { return newUUID() }
|
||||
func NewUserID() pgtype.UUID { return newUUID() }
|
||||
func NewTeamID() pgtype.UUID { return newUUID() }
|
||||
func NewAPIKeyID() pgtype.UUID { return newUUID() }
|
||||
func NewHostID() pgtype.UUID { return newUUID() }
|
||||
func NewHostTokenID() pgtype.UUID { return newUUID() }
|
||||
func NewRefreshTokenID() pgtype.UUID { return newUUID() }
|
||||
func NewAuditLogID() pgtype.UUID { return newUUID() }
|
||||
func NewBuildID() pgtype.UUID { return newUUID() }
|
||||
func NewAdminPermissionID() pgtype.UUID { return newUUID() }
|
||||
func NewChannelID() pgtype.UUID { return newUUID() }
|
||||
|
||||
func NewTemplateID() pgtype.UUID { return newUUID() }
|
||||
|
||||
// NewSnapshotName generates a snapshot name: "template-" + 8 hex chars.
|
||||
func NewSnapshotName() string {
|
||||
return "template-" + hex8()
|
||||
}
|
||||
|
||||
// NewTeamSlug generates a unique team slug in the format "xxxxxx-yyyyyy".
|
||||
func NewTeamSlug() string {
|
||||
b := make([]byte, 6)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand failed: %v", err))
|
||||
}
|
||||
return hex.EncodeToString(b[:3]) + "-" + hex.EncodeToString(b[3:])
|
||||
}
|
||||
|
||||
// NewRegistrationToken generates a 64-char hex token (32 bytes of entropy).
|
||||
func NewRegistrationToken() string {
|
||||
return hexToken(32)
|
||||
}
|
||||
|
||||
// NewRefreshToken generates a 64-char hex token (32 bytes of entropy).
|
||||
func NewRefreshToken() string {
|
||||
return hexToken(32)
|
||||
}
|
||||
|
||||
// --- Formatting (pgtype.UUID → prefixed string for API/RPC output) ---
|
||||
|
||||
const (
|
||||
PrefixSandbox = "cl-"
|
||||
PrefixUser = "usr-"
|
||||
PrefixTeam = "team-"
|
||||
PrefixAPIKey = "key-"
|
||||
PrefixHost = "host-"
|
||||
PrefixHostToken = "htok-"
|
||||
PrefixRefreshToken = "hrt-"
|
||||
PrefixAuditLog = "log-"
|
||||
PrefixBuild = "bld-"
|
||||
PrefixAdminPermission = "perm-"
|
||||
PrefixChannel = "ch-"
|
||||
)
|
||||
|
||||
// UUIDToBase36 encodes 16 UUID bytes as a 25-char base36 string (0-9a-z).
|
||||
func UUIDToBase36(b [16]byte) string {
|
||||
n := new(big.Int).SetBytes(b[:])
|
||||
buf := make([]byte, base36IDLen)
|
||||
mod := new(big.Int)
|
||||
for i := base36IDLen - 1; i >= 0; i-- {
|
||||
n.DivMod(n, base36Base, mod)
|
||||
buf[i] = base36Alphabet[mod.Int64()]
|
||||
}
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
// base36ToUUID decodes a 25-char base36 string back to 16 UUID bytes.
|
||||
func base36ToUUID(s string) ([16]byte, error) {
|
||||
if len(s) != base36IDLen {
|
||||
return [16]byte{}, fmt.Errorf("expected %d-char base36 ID, got %d", base36IDLen, len(s))
|
||||
}
|
||||
n := new(big.Int)
|
||||
for _, c := range s {
|
||||
idx := strings.IndexRune(base36Alphabet, c)
|
||||
if idx < 0 {
|
||||
return [16]byte{}, fmt.Errorf("invalid base36 character: %c", c)
|
||||
}
|
||||
n.Mul(n, base36Base)
|
||||
n.Add(n, big.NewInt(int64(idx)))
|
||||
}
|
||||
b := n.Bytes()
|
||||
var out [16]byte
|
||||
// big.Int.Bytes() strips leading zeros; right-align into 16-byte array.
|
||||
copy(out[16-len(b):], b)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func formatUUID(prefix string, id pgtype.UUID) string {
|
||||
return prefix + UUIDToBase36(id.Bytes)
|
||||
}
|
||||
|
||||
func FormatSandboxID(id pgtype.UUID) string { return formatUUID(PrefixSandbox, id) }
|
||||
func FormatUserID(id pgtype.UUID) string { return formatUUID(PrefixUser, id) }
|
||||
func FormatTeamID(id pgtype.UUID) string { return formatUUID(PrefixTeam, id) }
|
||||
func FormatAPIKeyID(id pgtype.UUID) string { return formatUUID(PrefixAPIKey, id) }
|
||||
func FormatHostID(id pgtype.UUID) string { return formatUUID(PrefixHost, id) }
|
||||
func FormatHostTokenID(id pgtype.UUID) string { return formatUUID(PrefixHostToken, id) }
|
||||
func FormatRefreshTokenID(id pgtype.UUID) string { return formatUUID(PrefixRefreshToken, id) }
|
||||
func FormatAuditLogID(id pgtype.UUID) string { return formatUUID(PrefixAuditLog, id) }
|
||||
func FormatBuildID(id pgtype.UUID) string { return formatUUID(PrefixBuild, id) }
|
||||
func FormatChannelID(id pgtype.UUID) string { return formatUUID(PrefixChannel, id) }
|
||||
|
||||
// --- Parsing (prefixed string from API/RPC input → pgtype.UUID) ---
|
||||
|
||||
func parseUUID(prefix, s string) (pgtype.UUID, error) {
|
||||
if !strings.HasPrefix(s, prefix) {
|
||||
return pgtype.UUID{}, fmt.Errorf("invalid ID: expected %q prefix, got %q", prefix, s)
|
||||
}
|
||||
b, err := base36ToUUID(strings.TrimPrefix(s, prefix))
|
||||
if err != nil {
|
||||
return pgtype.UUID{}, fmt.Errorf("invalid ID %q: %w", s, err)
|
||||
}
|
||||
return pgtype.UUID{Bytes: b, Valid: true}, nil
|
||||
}
|
||||
|
||||
func ParseSandboxID(s string) (pgtype.UUID, error) { return parseUUID(PrefixSandbox, s) }
|
||||
func ParseUserID(s string) (pgtype.UUID, error) { return parseUUID(PrefixUser, s) }
|
||||
func ParseTeamID(s string) (pgtype.UUID, error) { return parseUUID(PrefixTeam, s) }
|
||||
func ParseAPIKeyID(s string) (pgtype.UUID, error) { return parseUUID(PrefixAPIKey, s) }
|
||||
func ParseHostID(s string) (pgtype.UUID, error) { return parseUUID(PrefixHost, s) }
|
||||
func ParseHostTokenID(s string) (pgtype.UUID, error) { return parseUUID(PrefixHostToken, s) }
|
||||
func ParseAuditLogID(s string) (pgtype.UUID, error) { return parseUUID(PrefixAuditLog, s) }
|
||||
func ParseBuildID(s string) (pgtype.UUID, error) { return parseUUID(PrefixBuild, s) }
|
||||
func ParseChannelID(s string) (pgtype.UUID, error) { return parseUUID(PrefixChannel, s) }
|
||||
|
||||
// --- Well-known IDs ---
|
||||
|
||||
// PlatformTeamID is the all-zeros UUID reserved for platform-owned resources
|
||||
// (e.g. base templates, shared infrastructure).
|
||||
var PlatformTeamID = pgtype.UUID{Bytes: [16]byte{}, Valid: true}
|
||||
|
||||
// MinimalTemplateID is the all-zeros UUID sentinel for the built-in "minimal"
|
||||
// template. When both team_id and template_id are zero, the host agent uses
|
||||
// the minimal rootfs at WRENN_DIR/images/minimal/.
|
||||
var MinimalTemplateID = pgtype.UUID{Bytes: [16]byte{}, Valid: true}
|
||||
|
||||
// UUIDString converts a pgtype.UUID to a standard hyphenated UUID string
|
||||
// (e.g., "6ba7b810-9dad-11d1-80b4-00c04fd430c8"). Used for RPC wire format.
|
||||
func UUIDString(id pgtype.UUID) string {
|
||||
return uuid.UUID(id.Bytes).String()
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func hex8() string {
|
||||
b := make([]byte, 4)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
@ -14,44 +177,8 @@ func hex8() string {
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// NewSandboxID generates a new sandbox ID in the format "sb-" + 8 hex chars.
|
||||
func NewSandboxID() string {
|
||||
return "sb-" + hex8()
|
||||
}
|
||||
|
||||
// NewSnapshotName generates a snapshot name in the format "template-" + 8 hex chars.
|
||||
func NewSnapshotName() string {
|
||||
return "template-" + hex8()
|
||||
}
|
||||
|
||||
// NewUserID generates a new user ID in the format "usr-" + 8 hex chars.
|
||||
func NewUserID() string {
|
||||
return "usr-" + hex8()
|
||||
}
|
||||
|
||||
// NewTeamID generates a new team ID in the format "team-" + 8 hex chars.
|
||||
func NewTeamID() string {
|
||||
return "team-" + hex8()
|
||||
}
|
||||
|
||||
// NewAPIKeyID generates a new API key ID in the format "key-" + 8 hex chars.
|
||||
func NewAPIKeyID() string {
|
||||
return "key-" + hex8()
|
||||
}
|
||||
|
||||
// NewHostID generates a new host ID in the format "host-" + 8 hex chars.
|
||||
func NewHostID() string {
|
||||
return "host-" + hex8()
|
||||
}
|
||||
|
||||
// NewHostTokenID generates a new host token audit ID in the format "htok-" + 8 hex chars.
|
||||
func NewHostTokenID() string {
|
||||
return "htok-" + hex8()
|
||||
}
|
||||
|
||||
// NewRegistrationToken generates a 64-char hex token (32 bytes of entropy).
|
||||
func NewRegistrationToken() string {
|
||||
b := make([]byte, 32)
|
||||
func hexToken(nBytes int) string {
|
||||
b := make([]byte, nBytes)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand failed: %v", err))
|
||||
}
|
||||
|
||||
118
internal/id/id_test.go
Normal file
118
internal/id/id_test.go
Normal file
@ -0,0 +1,118 @@
|
||||
package id
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
func TestBase36RoundTrip(t *testing.T) {
|
||||
for i := 0; i < 1000; i++ {
|
||||
orig := uuid.New()
|
||||
encoded := UUIDToBase36(orig)
|
||||
|
||||
if len(encoded) != base36IDLen {
|
||||
t.Fatalf("expected %d chars, got %d: %s", base36IDLen, len(encoded), encoded)
|
||||
}
|
||||
|
||||
decoded, err := base36ToUUID(encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("decode failed: %v", err)
|
||||
}
|
||||
|
||||
if decoded != orig {
|
||||
t.Fatalf("round-trip failed: %v → %s → %v", orig, encoded, decoded)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBase36ZeroUUID(t *testing.T) {
|
||||
var zero [16]byte
|
||||
encoded := UUIDToBase36(zero)
|
||||
if encoded != "0000000000000000000000000" {
|
||||
t.Fatalf("zero UUID should encode to all zeros, got %s", encoded)
|
||||
}
|
||||
decoded, err := base36ToUUID(encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("decode failed: %v", err)
|
||||
}
|
||||
if decoded != zero {
|
||||
t.Fatalf("round-trip failed for zero UUID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatParseRoundTrip(t *testing.T) {
|
||||
id := NewSandboxID()
|
||||
formatted := FormatSandboxID(id)
|
||||
|
||||
if formatted[:3] != "cl-" {
|
||||
t.Fatalf("expected cl- prefix, got %s", formatted)
|
||||
}
|
||||
if len(formatted) != 3+base36IDLen {
|
||||
t.Fatalf("expected %d chars total, got %d: %s", 3+base36IDLen, len(formatted), formatted)
|
||||
}
|
||||
|
||||
parsed, err := ParseSandboxID(formatted)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
if parsed != id {
|
||||
t.Fatalf("round-trip failed: %v → %s → %v", id, formatted, parsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBase36InvalidInput(t *testing.T) {
|
||||
// Wrong length.
|
||||
if _, err := base36ToUUID("abc"); err == nil {
|
||||
t.Fatal("expected error for short input")
|
||||
}
|
||||
// Invalid character.
|
||||
if _, err := base36ToUUID("000000000000000000000000!"); err == nil {
|
||||
t.Fatal("expected error for invalid character")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlatformTeamIDFormats(t *testing.T) {
|
||||
formatted := FormatTeamID(PlatformTeamID)
|
||||
parsed, err := ParseTeamID(formatted)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
if parsed != PlatformTeamID {
|
||||
t.Fatalf("platform team ID round-trip failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxUUID(t *testing.T) {
|
||||
max := [16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
|
||||
encoded := UUIDToBase36(max)
|
||||
if len(encoded) != base36IDLen {
|
||||
t.Fatalf("max UUID encoding wrong length: %d", len(encoded))
|
||||
}
|
||||
decoded, err := base36ToUUID(encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("decode failed: %v", err)
|
||||
}
|
||||
if decoded != max {
|
||||
t.Fatalf("round-trip failed for max UUID")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFormatSandboxID(b *testing.B) {
|
||||
id := pgtype.UUID{Bytes: uuid.New(), Valid: true}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
FormatSandboxID(id)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseSandboxID(b *testing.B) {
|
||||
id := pgtype.UUID{Bytes: uuid.New(), Valid: true}
|
||||
s := FormatSandboxID(id)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = ParseSandboxID(s)
|
||||
}
|
||||
}
|
||||
58
internal/layout/layout.go
Normal file
58
internal/layout/layout.go
Normal file
@ -0,0 +1,58 @@
|
||||
package layout
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
// IsMinimal reports whether the given team and template IDs represent the
|
||||
// built-in "minimal" template (both all-zeros).
|
||||
func IsMinimal(teamID, templateID pgtype.UUID) bool {
|
||||
return teamID.Bytes == id.PlatformTeamID.Bytes && templateID.Bytes == id.MinimalTemplateID.Bytes
|
||||
}
|
||||
|
||||
// TemplateDir returns the on-disk directory for a template.
|
||||
//
|
||||
// minimal (zeros, zeros): {wrennDir}/images/minimal
|
||||
// all others: {wrennDir}/images/teams/{base36(teamID)}/{base36(templateID)}
|
||||
func TemplateDir(wrennDir string, teamID, templateID pgtype.UUID) string {
|
||||
if IsMinimal(teamID, templateID) {
|
||||
return filepath.Join(wrennDir, "images", "minimal")
|
||||
}
|
||||
return filepath.Join(wrennDir, "images", "teams",
|
||||
id.UUIDToBase36(teamID.Bytes),
|
||||
id.UUIDToBase36(templateID.Bytes))
|
||||
}
|
||||
|
||||
// TemplateRootfs returns the path to a template's rootfs.ext4.
|
||||
func TemplateRootfs(wrennDir string, teamID, templateID pgtype.UUID) string {
|
||||
return filepath.Join(TemplateDir(wrennDir, teamID, templateID), "rootfs.ext4")
|
||||
}
|
||||
|
||||
// PauseSnapshotDir returns the directory for a paused sandbox's snapshot files.
|
||||
func PauseSnapshotDir(wrennDir, sandboxID string) string {
|
||||
return filepath.Join(wrennDir, "snapshots", sandboxID)
|
||||
}
|
||||
|
||||
// SandboxesDir returns the directory for running sandbox CoW files.
|
||||
func SandboxesDir(wrennDir string) string {
|
||||
return filepath.Join(wrennDir, "sandboxes")
|
||||
}
|
||||
|
||||
// KernelPath returns the path to the Firecracker kernel.
|
||||
func KernelPath(wrennDir string) string {
|
||||
return filepath.Join(wrennDir, "kernels", "vmlinux")
|
||||
}
|
||||
|
||||
// ImagesRoot returns the root images directory.
|
||||
func ImagesRoot(wrennDir string) string {
|
||||
return filepath.Join(wrennDir, "images")
|
||||
}
|
||||
|
||||
// TeamsDir returns the directory containing all team template subdirectories.
|
||||
func TeamsDir(wrennDir string) string {
|
||||
return filepath.Join(wrennDir, "images", "teams")
|
||||
}
|
||||
120
internal/layout/layout_test.go
Normal file
120
internal/layout/layout_test.go
Normal file
@ -0,0 +1,120 @@
|
||||
package layout
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
func TestIsMinimal(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
teamID pgtype.UUID
|
||||
templateID pgtype.UUID
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "both zeros",
|
||||
teamID: id.PlatformTeamID,
|
||||
templateID: id.MinimalTemplateID,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "non-zero team",
|
||||
teamID: pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Valid: true},
|
||||
templateID: id.MinimalTemplateID,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "non-zero template",
|
||||
teamID: id.PlatformTeamID,
|
||||
templateID: pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Valid: true},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "both non-zero",
|
||||
teamID: pgtype.UUID{Bytes: [16]byte{1}, Valid: true},
|
||||
templateID: pgtype.UUID{Bytes: [16]byte{2}, Valid: true},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := IsMinimal(tt.teamID, tt.templateID); got != tt.want {
|
||||
t.Errorf("IsMinimal() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateDir(t *testing.T) {
|
||||
wrennDir := "/var/lib/wrenn"
|
||||
|
||||
t.Run("minimal", func(t *testing.T) {
|
||||
got := TemplateDir(wrennDir, id.PlatformTeamID, id.MinimalTemplateID)
|
||||
want := filepath.Join(wrennDir, "images", "minimal")
|
||||
if got != want {
|
||||
t.Errorf("TemplateDir() = %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("team template", func(t *testing.T) {
|
||||
teamID := pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Valid: true}
|
||||
tmplID := pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}, Valid: true}
|
||||
got := TemplateDir(wrennDir, teamID, tmplID)
|
||||
want := filepath.Join(wrennDir, "images", "teams",
|
||||
id.UUIDToBase36(teamID.Bytes),
|
||||
id.UUIDToBase36(tmplID.Bytes))
|
||||
if got != want {
|
||||
t.Errorf("TemplateDir() = %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("global template (platform team, non-zero template)", func(t *testing.T) {
|
||||
tmplID := pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5}, Valid: true}
|
||||
got := TemplateDir(wrennDir, id.PlatformTeamID, tmplID)
|
||||
want := filepath.Join(wrennDir, "images", "teams",
|
||||
id.UUIDToBase36(id.PlatformTeamID.Bytes),
|
||||
id.UUIDToBase36(tmplID.Bytes))
|
||||
if got != want {
|
||||
t.Errorf("TemplateDir() = %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestTemplateRootfs(t *testing.T) {
|
||||
wrennDir := "/var/lib/wrenn"
|
||||
got := TemplateRootfs(wrennDir, id.PlatformTeamID, id.MinimalTemplateID)
|
||||
want := filepath.Join(wrennDir, "images", "minimal", "rootfs.ext4")
|
||||
if got != want {
|
||||
t.Errorf("TemplateRootfs() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPauseSnapshotDir(t *testing.T) {
|
||||
got := PauseSnapshotDir("/var/lib/wrenn", "cl-abc123")
|
||||
want := "/var/lib/wrenn/snapshots/cl-abc123"
|
||||
if got != want {
|
||||
t.Errorf("PauseSnapshotDir() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSandboxesDir(t *testing.T) {
|
||||
got := SandboxesDir("/var/lib/wrenn")
|
||||
want := "/var/lib/wrenn/sandboxes"
|
||||
if got != want {
|
||||
t.Errorf("SandboxesDir() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKernelPath(t *testing.T) {
|
||||
got := KernelPath("/var/lib/wrenn")
|
||||
want := "/var/lib/wrenn/kernels/vmlinux"
|
||||
if got != want {
|
||||
t.Errorf("KernelPath() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
125
internal/lifecycle/hostpool.go
Normal file
125
internal/lifecycle/hostpool.go
Normal file
@ -0,0 +1,125 @@
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
// HostClientPool maintains a cache of Connect RPC clients keyed by host ID.
|
||||
// Clients are created lazily on first access and evicted when a host is removed
|
||||
// or goes unreachable. The pool is safe for concurrent use.
|
||||
type HostClientPool struct {
|
||||
mu sync.RWMutex
|
||||
clients map[string]hostagentv1connect.HostAgentServiceClient
|
||||
httpClient *http.Client
|
||||
scheme string // "http://" or "https://"
|
||||
}
|
||||
|
||||
// NewHostClientPool creates a pool that connects to agents over plain HTTP.
|
||||
// Use NewHostClientPoolTLS when mTLS is required.
|
||||
func NewHostClientPool() *HostClientPool {
|
||||
return &HostClientPool{
|
||||
clients: make(map[string]hostagentv1connect.HostAgentServiceClient),
|
||||
httpClient: &http.Client{Timeout: 10 * time.Minute},
|
||||
scheme: "http://",
|
||||
}
|
||||
}
|
||||
|
||||
// NewHostClientPoolTLS creates a pool that connects to agents over mTLS.
|
||||
// tlsCfg should already carry the CP client cert and CA trust anchor
|
||||
// (use auth.CPClientTLSConfig to construct it).
|
||||
func NewHostClientPoolTLS(tlsCfg *tls.Config) *HostClientPool {
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: tlsCfg,
|
||||
}
|
||||
return &HostClientPool{
|
||||
clients: make(map[string]hostagentv1connect.HostAgentServiceClient),
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Minute,
|
||||
Transport: transport,
|
||||
},
|
||||
scheme: "https://",
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns a Connect RPC client for the given host, creating one if necessary.
|
||||
// address is the host agent address (ip:port or full URL). The scheme is added if absent.
|
||||
func (p *HostClientPool) Get(hostID, address string) hostagentv1connect.HostAgentServiceClient {
|
||||
p.mu.RLock()
|
||||
c, ok := p.clients[hostID]
|
||||
p.mu.RUnlock()
|
||||
if ok {
|
||||
return c
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
// Double-check after acquiring write lock.
|
||||
if c, ok = p.clients[hostID]; ok {
|
||||
return c
|
||||
}
|
||||
c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, p.ensureScheme(address))
|
||||
p.clients[hostID] = c
|
||||
return c
|
||||
}
|
||||
|
||||
// GetForHost is a convenience wrapper that extracts the address from a db.Host
|
||||
// and returns an error if the host has no address recorded yet.
|
||||
func (p *HostClientPool) GetForHost(h db.Host) (hostagentv1connect.HostAgentServiceClient, error) {
|
||||
if h.Address == "" {
|
||||
return nil, fmt.Errorf("host %s has no address", id.FormatHostID(h.ID))
|
||||
}
|
||||
return p.Get(id.FormatHostID(h.ID), h.Address), nil
|
||||
}
|
||||
|
||||
// Evict removes the cached client for the given host, forcing a new client to be
|
||||
// created on the next call to Get. Call this when a host's address changes or when
|
||||
// a host is deleted.
|
||||
func (p *HostClientPool) Evict(hostID string) {
|
||||
p.mu.Lock()
|
||||
delete(p.clients, hostID)
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
// ensureScheme prepends the pool's configured scheme if the address has none.
|
||||
func (p *HostClientPool) ensureScheme(addr string) string {
|
||||
if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
|
||||
return addr
|
||||
}
|
||||
return p.scheme + addr
|
||||
}
|
||||
|
||||
// Transport returns the http.RoundTripper used by this pool. Use this when you
|
||||
// need to make raw HTTP requests to agent addresses with the same TLS settings
|
||||
// as the pool's Connect RPC clients (e.g., the sandbox reverse proxy).
|
||||
func (p *HostClientPool) Transport() http.RoundTripper {
|
||||
if p.httpClient.Transport != nil {
|
||||
return p.httpClient.Transport
|
||||
}
|
||||
return http.DefaultTransport
|
||||
}
|
||||
|
||||
// ResolveAddr prepends the pool's configured scheme to addr if it has none.
|
||||
// Use this when constructing URLs that must use the same transport as the pool
|
||||
// (e.g., the sandbox proxy handler). Calling Get/GetForHost internally does
|
||||
// the same thing, but ResolveAddr exposes it for callers that only need the URL.
|
||||
func (p *HostClientPool) ResolveAddr(addr string) string {
|
||||
return p.ensureScheme(addr)
|
||||
}
|
||||
|
||||
// EnsureScheme adds "http://" if the address has no scheme.
|
||||
// Deprecated: use pool.ResolveAddr which respects the pool's TLS setting.
|
||||
func EnsureScheme(addr string) string {
|
||||
if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
|
||||
return addr
|
||||
}
|
||||
return "http://" + addr
|
||||
}
|
||||
@ -18,15 +18,16 @@ const (
|
||||
|
||||
// Sandbox holds all state for a running sandbox on this host.
|
||||
type Sandbox struct {
|
||||
ID string
|
||||
Status SandboxStatus
|
||||
Template string
|
||||
VCPUs int
|
||||
MemoryMB int
|
||||
TimeoutSec int
|
||||
SlotIndex int
|
||||
HostIP net.IP
|
||||
RootfsPath string
|
||||
CreatedAt time.Time
|
||||
LastActiveAt time.Time
|
||||
ID string
|
||||
Status SandboxStatus
|
||||
TemplateTeamID [16]byte
|
||||
TemplateID [16]byte
|
||||
VCPUs int
|
||||
MemoryMB int
|
||||
TimeoutSec int
|
||||
SlotIndex int
|
||||
HostIP net.IP
|
||||
RootfsPath string
|
||||
CreatedAt time.Time
|
||||
LastActiveAt time.Time
|
||||
}
|
||||
|
||||
@ -5,13 +5,91 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
"github.com/vishvananda/netns"
|
||||
)
|
||||
|
||||
const nsPrefix = "wrenn-ns-"
|
||||
|
||||
// CleanupStaleNamespaces removes leftover wrenn network namespaces from a
|
||||
// previous crash. Called once at agent startup.
|
||||
func CleanupStaleNamespaces() {
|
||||
entries, err := os.ReadDir("/run/netns")
|
||||
if err != nil {
|
||||
return // no /run/netns or unreadable — nothing to clean
|
||||
}
|
||||
for _, e := range entries {
|
||||
name := e.Name()
|
||||
if !strings.HasPrefix(name, nsPrefix) {
|
||||
continue
|
||||
}
|
||||
// Also remove the associated veth from the host side.
|
||||
vethName := "wrenn-veth-" + strings.TrimPrefix(name, nsPrefix)
|
||||
if link, err := netlink.LinkByName(vethName); err == nil {
|
||||
_ = netlink.LinkDel(link)
|
||||
}
|
||||
if err := netns.DeleteNamed(name); err != nil {
|
||||
slog.Warn("failed to remove stale namespace", "ns", name, "error", err)
|
||||
} else {
|
||||
slog.Info("removed stale namespace", "ns", name)
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up any stale wrenn iptables rules referencing old veth interfaces.
|
||||
cleanupStaleIptablesRules()
|
||||
}
|
||||
|
||||
// cleanupStaleIptablesRules removes host iptables rules that reference
|
||||
// wrenn-veth interfaces no longer present on the system.
|
||||
func cleanupStaleIptablesRules() {
|
||||
for _, table := range []string{"filter", "nat"} {
|
||||
cmd := exec.Command("iptables-save", "-t", table)
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, line := range strings.Split(string(out), "\n") {
|
||||
if !strings.Contains(line, "wrenn-veth-") {
|
||||
continue
|
||||
}
|
||||
// Lines look like "-A FORWARD -i wrenn-veth-1 -o wlo1 -j ACCEPT"
|
||||
// Convert -A to -D to delete the rule.
|
||||
if !strings.HasPrefix(line, "-A ") {
|
||||
continue
|
||||
}
|
||||
delRule := "-D " + line[3:]
|
||||
args := strings.Fields(delRule)
|
||||
delCmd := exec.Command("iptables", append([]string{"-t", table}, args...)...)
|
||||
if err := delCmd.Run(); err != nil {
|
||||
slog.Debug("failed to remove stale iptables rule", "rule", line, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also remove stale host routes to 10.11.0.x via wrenn-veth interfaces.
|
||||
routes, err := netlink.RouteList(nil, netlink.FAMILY_V4)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, r := range routes {
|
||||
if r.LinkIndex == 0 {
|
||||
continue
|
||||
}
|
||||
link, err := netlink.LinkByIndex(r.LinkIndex)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(link.Attrs().Name, "wrenn-veth-") {
|
||||
_ = netlink.RouteDel(&r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
// Fixed addresses inside each network namespace (safe because each
|
||||
// sandbox gets its own netns).
|
||||
@ -59,19 +137,20 @@ func NewSlot(index int) *Slot {
|
||||
|
||||
hostIP := make(net.IP, 4)
|
||||
copy(hostIP, hostBaseIP)
|
||||
hostIP[2] += byte(index / 256)
|
||||
hostIP[3] += byte(index % 256)
|
||||
hostIP[2] += byte(index >> 8)
|
||||
hostIP[3] += byte(index & 0xFF)
|
||||
|
||||
vethOffset := index * vrtAddressesPerSlot
|
||||
vethIP := make(net.IP, 4)
|
||||
copy(vethIP, vrtBaseIP)
|
||||
vethIP[2] += byte(vethOffset / 256)
|
||||
vethIP[3] += byte(vethOffset % 256)
|
||||
vethIP[2] += byte(vethOffset >> 8)
|
||||
vethIP[3] += byte(vethOffset & 0xFF)
|
||||
|
||||
vpeerOffset := vethOffset + 1
|
||||
vpeerIP := make(net.IP, 4)
|
||||
copy(vpeerIP, vrtBaseIP)
|
||||
vpeerIP[2] += byte((vethOffset + 1) / 256)
|
||||
vpeerIP[3] += byte((vethOffset + 1) % 256)
|
||||
vpeerIP[2] += byte(vpeerOffset >> 8)
|
||||
vpeerIP[3] += byte(vpeerOffset & 0xFF)
|
||||
|
||||
return &Slot{
|
||||
Index: index,
|
||||
@ -84,8 +163,8 @@ func NewSlot(index int) *Slot {
|
||||
GuestIP: guestIP,
|
||||
GuestNetMask: guestNetMask,
|
||||
TapName: tapName,
|
||||
NamespaceID: fmt.Sprintf("ns-%d", index),
|
||||
VethName: fmt.Sprintf("veth-%d", index),
|
||||
NamespaceID: fmt.Sprintf("wrenn-ns-%d", index),
|
||||
VethName: fmt.Sprintf("wrenn-veth-%d", index),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
104
internal/recipe/context.go
Normal file
104
internal/recipe/context.go
Normal file
@ -0,0 +1,104 @@
|
||||
package recipe
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ExecContext holds mutable state that persists across recipe steps.
|
||||
// It is initialized empty and updated by ENV and WORKDIR steps.
|
||||
type ExecContext struct {
|
||||
WorkDir string
|
||||
EnvVars map[string]string
|
||||
}
|
||||
|
||||
// This regex matches:
|
||||
// 1. $$ (escaped dollar)
|
||||
// 2. ${VAR} or ${} (braced variable, possibly empty)
|
||||
// 3. $VAR (bare variable)
|
||||
var envRegex = regexp.MustCompile(`\$\$|\$\{([a-zA-Z0-9_]*)\}|\$([a-zA-Z0-9_]+)`)
|
||||
|
||||
// WrappedCommand returns the full shell command for a RUN step with context
|
||||
// applied. The result is passed as the argument to /bin/sh -c.
|
||||
//
|
||||
// If WORKDIR and/or ENV are set, they are prepended as a shell preamble:
|
||||
//
|
||||
// cd '/the/dir' && KEY='val' /bin/sh -c 'original command'
|
||||
func (c *ExecContext) WrappedCommand(cmd string) string {
|
||||
prefix := c.shellPrefix()
|
||||
if prefix == "" {
|
||||
return cmd
|
||||
}
|
||||
return prefix + "/bin/sh -c " + shellescape(cmd)
|
||||
}
|
||||
|
||||
// StartCommand returns the shell command for a START step. The process is
|
||||
// launched in the background via nohup so that the outer shell exits
|
||||
// immediately, allowing the build to continue. stdout/stderr of the
|
||||
// background process are discarded (the process keeps running in the VM).
|
||||
//
|
||||
// Multiple START steps can be issued to run several background processes
|
||||
// simultaneously before a healthcheck is evaluated.
|
||||
func (c *ExecContext) StartCommand(cmd string) string {
|
||||
prefix := c.shellPrefix()
|
||||
return prefix + "nohup /bin/sh -c " + shellescape(cmd) + " >/dev/null 2>&1 &"
|
||||
}
|
||||
|
||||
// shellPrefix builds the "cd ... && KEY=val " preamble for a shell command.
|
||||
// Returns an empty string when no context is set.
|
||||
func (c *ExecContext) shellPrefix() string {
|
||||
if c.WorkDir == "" && len(c.EnvVars) == 0 {
|
||||
return ""
|
||||
}
|
||||
var sb strings.Builder
|
||||
if c.WorkDir != "" {
|
||||
sb.WriteString("cd ")
|
||||
sb.WriteString(shellescape(c.WorkDir))
|
||||
sb.WriteString(" && ")
|
||||
}
|
||||
keys := make([]string, 0, len(c.EnvVars))
|
||||
for k := range c.EnvVars {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
slices.Sort(keys)
|
||||
for _, k := range keys {
|
||||
sb.WriteString(k)
|
||||
sb.WriteByte('=')
|
||||
sb.WriteString(shellescape(c.EnvVars[k]))
|
||||
sb.WriteByte(' ')
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// expandEnv replaces $var and ${var} placeholders in the string s with their
|
||||
// corresponding values from the vars map.
|
||||
// It supports escaping with $$, which is replaced by a single $.
|
||||
// If a variable is not found in the vars map, it is replaced with an empty
|
||||
// string.
|
||||
func expandEnv(s string, vars map[string]string) string {
|
||||
return envRegex.ReplaceAllStringFunc(s, func(match string) string {
|
||||
if match == "$$" {
|
||||
return "$"
|
||||
}
|
||||
|
||||
var name string
|
||||
if len(match) > 1 && match[1] == '{' {
|
||||
name = match[2 : len(match)-1]
|
||||
} else {
|
||||
name = match[1:]
|
||||
}
|
||||
|
||||
if v, ok := vars[name]; ok {
|
||||
return v
|
||||
}
|
||||
|
||||
return ""
|
||||
})
|
||||
}
|
||||
|
||||
// shellescape wraps s in single quotes, escaping any embedded single quotes.
|
||||
// This is POSIX-safe for paths, env values, and shell commands.
|
||||
func shellescape(s string) string {
|
||||
return "'" + strings.ReplaceAll(s, "'", `'\''`) + "'"
|
||||
}
|
||||
237
internal/recipe/context_test.go
Normal file
237
internal/recipe/context_test.go
Normal file
@ -0,0 +1,237 @@
|
||||
package recipe
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestExecContext_WrappedCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx ExecContext
|
||||
cmd string
|
||||
want string
|
||||
wantOneOf []string
|
||||
}{
|
||||
{
|
||||
name: "no context",
|
||||
ctx: ExecContext{},
|
||||
cmd: "apt install -y curl",
|
||||
want: "apt install -y curl",
|
||||
},
|
||||
{
|
||||
name: "workdir only",
|
||||
ctx: ExecContext{WorkDir: "/app"},
|
||||
cmd: "npm install",
|
||||
want: "cd '/app' && /bin/sh -c 'npm install'",
|
||||
},
|
||||
{
|
||||
name: "env only",
|
||||
ctx: ExecContext{EnvVars: map[string]string{"PORT": "8080"}},
|
||||
cmd: "node server.js",
|
||||
want: "PORT='8080' /bin/sh -c 'node server.js'",
|
||||
},
|
||||
{
|
||||
name: "workdir with space",
|
||||
ctx: ExecContext{WorkDir: "/my project"},
|
||||
cmd: "make build",
|
||||
want: "cd '/my project' && /bin/sh -c 'make build'",
|
||||
},
|
||||
{
|
||||
name: "command with single quotes",
|
||||
ctx: ExecContext{WorkDir: "/app"},
|
||||
cmd: "echo 'hello'",
|
||||
want: "cd '/app' && /bin/sh -c 'echo '\\''hello'\\'''",
|
||||
},
|
||||
{
|
||||
name: "env value with single quotes",
|
||||
ctx: ExecContext{EnvVars: map[string]string{"MSG": "it's fine"}},
|
||||
cmd: "echo $MSG",
|
||||
want: "MSG='it'\\''s fine' /bin/sh -c 'echo $MSG'",
|
||||
},
|
||||
{
|
||||
name: "env expansion with pre-expanded PATH",
|
||||
ctx: ExecContext{
|
||||
EnvVars: map[string]string{"PATH": "/usr/bin", "FOO": "/opt/venv/bin:/usr/bin"},
|
||||
},
|
||||
cmd: "make build",
|
||||
want: "FOO='/opt/venv/bin:/usr/bin' PATH='/usr/bin' /bin/sh -c 'make build'",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := tc.ctx.WrappedCommand(tc.cmd)
|
||||
if len(tc.wantOneOf) > 0 {
|
||||
matched := false
|
||||
for _, w := range tc.wantOneOf {
|
||||
if got == w {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !matched {
|
||||
t.Errorf("WrappedCommand(%q)\n got %q\n want one of %q", tc.cmd, got, tc.wantOneOf)
|
||||
}
|
||||
} else if got != tc.want {
|
||||
t.Errorf("WrappedCommand(%q)\n got %q\n want %q", tc.cmd, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecContext_StartCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx ExecContext
|
||||
cmd string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no context",
|
||||
ctx: ExecContext{},
|
||||
cmd: "python3 app.py",
|
||||
want: "nohup /bin/sh -c 'python3 app.py' >/dev/null 2>&1 &",
|
||||
},
|
||||
{
|
||||
name: "with workdir",
|
||||
ctx: ExecContext{WorkDir: "/app"},
|
||||
cmd: "python3 server.py",
|
||||
want: "cd '/app' && nohup /bin/sh -c 'python3 server.py' >/dev/null 2>&1 &",
|
||||
},
|
||||
{
|
||||
name: "with env",
|
||||
ctx: ExecContext{EnvVars: map[string]string{"PORT": "9000"}},
|
||||
cmd: "node index.js",
|
||||
want: "PORT='9000' nohup /bin/sh -c 'node index.js' >/dev/null 2>&1 &",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := tc.ctx.StartCommand(tc.cmd)
|
||||
if got != tc.want {
|
||||
t.Errorf("StartCommand(%q)\n got %q\n want %q", tc.cmd, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
vars map[string]string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
s: "hello",
|
||||
vars: nil,
|
||||
want: "hello",
|
||||
},
|
||||
{
|
||||
s: "$PATH",
|
||||
vars: map[string]string{"PATH": "/usr/bin"},
|
||||
want: "/usr/bin",
|
||||
},
|
||||
{
|
||||
s: "${PATH}",
|
||||
vars: map[string]string{"PATH": "/usr/bin"},
|
||||
want: "/usr/bin",
|
||||
},
|
||||
{
|
||||
s: "/opt/venv/bin:$PATH",
|
||||
vars: map[string]string{"PATH": "/usr/bin"},
|
||||
want: "/opt/venv/bin:/usr/bin",
|
||||
},
|
||||
{
|
||||
s: "${HOME}/code",
|
||||
vars: map[string]string{"HOME": "/root"},
|
||||
want: "/root/code",
|
||||
},
|
||||
{
|
||||
s: "hello $USER",
|
||||
vars: map[string]string{"USER": "admin"},
|
||||
want: "hello admin",
|
||||
},
|
||||
{
|
||||
s: "$UNSET",
|
||||
vars: map[string]string{"PATH": "/usr/bin"},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
s: "${UNSET}",
|
||||
vars: map[string]string{"PATH": "/usr/bin"},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
s: "$$",
|
||||
vars: map[string]string{"PATH": "/usr/bin"},
|
||||
want: "$",
|
||||
},
|
||||
{
|
||||
s: "price is $$100",
|
||||
vars: nil,
|
||||
want: "price is $100",
|
||||
},
|
||||
{
|
||||
s: "$FOO:$BAR",
|
||||
vars: map[string]string{"FOO": "a", "BAR": "b"},
|
||||
want: "a:b",
|
||||
},
|
||||
{
|
||||
s: "${FOO}_${BAR}",
|
||||
vars: map[string]string{"FOO": "hello", "BAR": "world"},
|
||||
want: "hello_world",
|
||||
},
|
||||
{
|
||||
s: "no vars here",
|
||||
vars: nil,
|
||||
want: "no vars here",
|
||||
},
|
||||
{
|
||||
s: "$",
|
||||
vars: nil,
|
||||
want: "$",
|
||||
},
|
||||
{
|
||||
s: "${",
|
||||
vars: nil,
|
||||
want: "${",
|
||||
},
|
||||
{
|
||||
s: "${}",
|
||||
vars: nil,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
s: "$VAR1$VAR2",
|
||||
vars: map[string]string{"VAR1": "a", "VAR2": "b"},
|
||||
want: "ab",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.s, func(t *testing.T) {
|
||||
got := expandEnv(tc.s, tc.vars)
|
||||
if got != tc.want {
|
||||
t.Errorf("expandEnv(%q, %v)\n got %q\n want %q", tc.s, tc.vars, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShellescape(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"simple", "'simple'"},
|
||||
{"/path/to/dir", "'/path/to/dir'"},
|
||||
{"it's fine", "'it'\\''s fine'"},
|
||||
{"", "''"},
|
||||
{"a'b'c", "'a'\\''b'\\''c'"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
got := shellescape(tc.input)
|
||||
if got != tc.want {
|
||||
t.Errorf("shellescape(%q) = %q, want %q", tc.input, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
185
internal/recipe/executor.go
Normal file
185
internal/recipe/executor.go
Normal file
@ -0,0 +1,185 @@
|
||||
package recipe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
// DefaultStepTimeout is the fallback timeout for RUN steps that carry no
|
||||
// explicit --timeout flag.
|
||||
const DefaultStepTimeout = 30 * time.Second
|
||||
|
||||
// BuildLogEntry is the per-step record stored in template_builds.logs (JSONB).
|
||||
type BuildLogEntry struct {
|
||||
Step int `json:"step"`
|
||||
Phase string `json:"phase"`
|
||||
Cmd string `json:"cmd"`
|
||||
Stdout string `json:"stdout"`
|
||||
Stderr string `json:"stderr"`
|
||||
Exit int32 `json:"exit"`
|
||||
Ok bool `json:"ok"`
|
||||
Elapsed int64 `json:"elapsed_ms"`
|
||||
}
|
||||
|
||||
// ExecFunc is the agent.Exec call signature used by the executor. It matches
|
||||
// the method on the hostagent Connect RPC client.
|
||||
type ExecFunc func(ctx context.Context, req *connect.Request[pb.ExecRequest]) (*connect.Response[pb.ExecResponse], error)
|
||||
|
||||
// Execute runs steps sequentially against sandboxID using execFn.
|
||||
//
|
||||
// - phase labels the log entries (e.g., "pre-build", "recipe", "post-build").
|
||||
// - startStep is the 1-based offset so entries are globally numbered across phases.
|
||||
// - defaultTimeout applies to RUN steps with no per-step --timeout; 0 → 10 minutes.
|
||||
// - bctx is mutated in place as ENV/WORKDIR steps execute, and carries forward
|
||||
// into subsequent phases when the caller passes the same pointer.
|
||||
//
|
||||
// Returns all log entries appended during this call, the next step counter
|
||||
// value, and whether all steps succeeded. On false the last entry contains
|
||||
// failure details; the caller is responsible for destroying the sandbox and
|
||||
// recording the build error.
|
||||
func Execute(
|
||||
ctx context.Context,
|
||||
phase string,
|
||||
steps []Step,
|
||||
sandboxID string,
|
||||
startStep int,
|
||||
defaultTimeout time.Duration,
|
||||
bctx *ExecContext,
|
||||
execFn ExecFunc,
|
||||
) (entries []BuildLogEntry, nextStep int, ok bool) {
|
||||
if defaultTimeout <= 0 {
|
||||
defaultTimeout = 10 * time.Minute
|
||||
}
|
||||
|
||||
step := startStep
|
||||
for _, st := range steps {
|
||||
step++
|
||||
slog.Info("executing build step", "phase", phase, "step", step, "instruction", st.Raw)
|
||||
|
||||
switch st.Kind {
|
||||
case KindENV:
|
||||
if bctx.EnvVars == nil {
|
||||
bctx.EnvVars = make(map[string]string)
|
||||
}
|
||||
bctx.EnvVars[st.Key] = expandEnv(st.Value, bctx.EnvVars)
|
||||
entries = append(entries, BuildLogEntry{Step: step, Phase: phase, Cmd: st.Raw, Ok: true})
|
||||
|
||||
case KindWORKDIR:
|
||||
bctx.WorkDir = st.Path
|
||||
entries = append(entries, BuildLogEntry{Step: step, Phase: phase, Cmd: st.Raw, Ok: true})
|
||||
|
||||
case KindUSER, KindCOPY:
|
||||
verb := strings.ToUpper(strings.Fields(st.Raw)[0])
|
||||
entries = append(entries, BuildLogEntry{
|
||||
Step: step,
|
||||
Phase: phase,
|
||||
Cmd: st.Raw,
|
||||
Stderr: verb + " is not yet supported",
|
||||
Ok: false,
|
||||
})
|
||||
return entries, step, false
|
||||
|
||||
case KindSTART:
|
||||
entry, succeeded := execStart(ctx, st, sandboxID, phase, step, bctx, execFn)
|
||||
entries = append(entries, entry)
|
||||
if !succeeded {
|
||||
return entries, step, false
|
||||
}
|
||||
|
||||
case KindRUN:
|
||||
timeout := defaultTimeout
|
||||
if st.Timeout > 0 {
|
||||
timeout = st.Timeout
|
||||
}
|
||||
entry, succeeded := execRun(ctx, st, sandboxID, phase, step, timeout, bctx, execFn)
|
||||
entries = append(entries, entry)
|
||||
if !succeeded {
|
||||
return entries, step, false
|
||||
}
|
||||
}
|
||||
}
|
||||
return entries, step, true
|
||||
}
|
||||
|
||||
func execRun(
|
||||
ctx context.Context,
|
||||
st Step,
|
||||
sandboxID, phase string,
|
||||
step int,
|
||||
timeout time.Duration,
|
||||
bctx *ExecContext,
|
||||
execFn ExecFunc,
|
||||
) (BuildLogEntry, bool) {
|
||||
execCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
resp, err := execFn(execCtx, connect.NewRequest(&pb.ExecRequest{
|
||||
SandboxId: sandboxID,
|
||||
Cmd: "/bin/sh",
|
||||
Args: []string{"-c", bctx.WrappedCommand(st.Shell)},
|
||||
TimeoutSec: int32(timeout.Seconds()),
|
||||
}))
|
||||
|
||||
entry := BuildLogEntry{
|
||||
Step: step,
|
||||
Phase: phase,
|
||||
Cmd: st.Raw,
|
||||
Elapsed: time.Since(start).Milliseconds(),
|
||||
}
|
||||
if err != nil {
|
||||
entry.Stderr = fmt.Sprintf("exec error: %v", err)
|
||||
return entry, false
|
||||
}
|
||||
entry.Stdout = string(resp.Msg.Stdout)
|
||||
entry.Stderr = string(resp.Msg.Stderr)
|
||||
entry.Exit = resp.Msg.ExitCode
|
||||
entry.Ok = resp.Msg.ExitCode == 0
|
||||
return entry, entry.Ok
|
||||
}
|
||||
|
||||
func execStart(
|
||||
ctx context.Context,
|
||||
st Step,
|
||||
sandboxID, phase string,
|
||||
step int,
|
||||
bctx *ExecContext,
|
||||
execFn ExecFunc,
|
||||
) (BuildLogEntry, bool) {
|
||||
// START uses a short timeout: just long enough for the shell to fork and
|
||||
// return. The background process itself runs indefinitely inside the VM.
|
||||
execCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
resp, err := execFn(execCtx, connect.NewRequest(&pb.ExecRequest{
|
||||
SandboxId: sandboxID,
|
||||
Cmd: "/bin/sh",
|
||||
Args: []string{"-c", bctx.StartCommand(st.Shell)},
|
||||
TimeoutSec: 10,
|
||||
}))
|
||||
|
||||
entry := BuildLogEntry{
|
||||
Step: step,
|
||||
Phase: phase,
|
||||
Cmd: st.Raw,
|
||||
Elapsed: time.Since(start).Milliseconds(),
|
||||
}
|
||||
if err != nil {
|
||||
entry.Stderr = fmt.Sprintf("start error: %v", err)
|
||||
return entry, false
|
||||
}
|
||||
entry.Exit = resp.Msg.ExitCode
|
||||
entry.Ok = resp.Msg.ExitCode == 0
|
||||
if !entry.Ok {
|
||||
entry.Stderr = fmt.Sprintf("start failed with exit code %d: %s", resp.Msg.ExitCode, string(resp.Msg.Stderr))
|
||||
}
|
||||
return entry, entry.Ok
|
||||
}
|
||||
94
internal/recipe/healthcheck.go
Normal file
94
internal/recipe/healthcheck.go
Normal file
@ -0,0 +1,94 @@
|
||||
package recipe
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HealthcheckConfig holds the parsed configuration for a build healthcheck.
|
||||
// A healthcheck is a shell command that is executed repeatedly inside the
|
||||
// sandbox until it succeeds or the retry/timeout budget is exhausted.
|
||||
//
|
||||
// Retries of 0 means unlimited retries (bounded only by the overall deadline)
|
||||
type HealthcheckConfig struct {
|
||||
Cmd string
|
||||
Interval time.Duration
|
||||
Timeout time.Duration
|
||||
StartPeriod time.Duration
|
||||
Retries int // 0 = unlimited
|
||||
}
|
||||
|
||||
// ParseHealthcheck parses a healthcheck string with optional flag prefix into
|
||||
// a HealthcheckConfig. The syntax is:
|
||||
//
|
||||
// [--interval=<duration>] [--timeout=<duration>] [--start-period=<duration>]
|
||||
// [--retries=<n>] <command>
|
||||
//
|
||||
// Flags must use the form --flag=value. The first token that does not start
|
||||
// with "--" and everything after it is treated as the command. Defaults:
|
||||
// interval=3s, timeout=10s, start-period=0, retries=0 (unlimited)
|
||||
func ParseHealthcheck(s string) (HealthcheckConfig, error) {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return HealthcheckConfig{}, fmt.Errorf("empty healthcheck")
|
||||
}
|
||||
|
||||
hc := HealthcheckConfig{
|
||||
Interval: 3 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
tokens := strings.Fields(s)
|
||||
cmdIndex := -1
|
||||
|
||||
for i, token := range tokens {
|
||||
if !strings.HasPrefix(token, "--") {
|
||||
cmdIndex = i
|
||||
break
|
||||
}
|
||||
|
||||
parts := strings.SplitN(token, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
return HealthcheckConfig{}, fmt.Errorf("malformed flag (missing '='): %q", token)
|
||||
}
|
||||
|
||||
key, val := parts[0], parts[1]
|
||||
switch key {
|
||||
case "--interval":
|
||||
d, err := time.ParseDuration(val)
|
||||
if err != nil {
|
||||
return HealthcheckConfig{}, fmt.Errorf("parse interval: %w", err)
|
||||
}
|
||||
hc.Interval = d
|
||||
case "--timeout":
|
||||
d, err := time.ParseDuration(val)
|
||||
if err != nil {
|
||||
return HealthcheckConfig{}, fmt.Errorf("parse timeout: %w", err)
|
||||
}
|
||||
hc.Timeout = d
|
||||
case "--start-period":
|
||||
d, err := time.ParseDuration(val)
|
||||
if err != nil {
|
||||
return HealthcheckConfig{}, fmt.Errorf("parse start period: %w", err)
|
||||
}
|
||||
hc.StartPeriod = d
|
||||
case "--retries":
|
||||
r, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return HealthcheckConfig{}, fmt.Errorf("parse retries: %w", err)
|
||||
}
|
||||
hc.Retries = r
|
||||
default:
|
||||
return HealthcheckConfig{}, fmt.Errorf("unknown healthcheck flag: %q", token)
|
||||
}
|
||||
}
|
||||
|
||||
if cmdIndex == -1 {
|
||||
return HealthcheckConfig{}, fmt.Errorf("healthcheck has no command")
|
||||
}
|
||||
|
||||
hc.Cmd = strings.Join(tokens[cmdIndex:], " ")
|
||||
return hc, nil
|
||||
}
|
||||
126
internal/recipe/healthcheck_test.go
Normal file
126
internal/recipe/healthcheck_test.go
Normal file
@ -0,0 +1,126 @@
|
||||
package recipe
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseHealthcheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want HealthcheckConfig
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "plain command",
|
||||
input: "curl -f http://localhost:8080",
|
||||
want: HealthcheckConfig{
|
||||
Cmd: "curl -f http://localhost:8080",
|
||||
Interval: 3 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "all flags",
|
||||
input: "--interval=5s --timeout=2s --start-period=15s --retries=3 ping -c 1 8.8.8.8",
|
||||
want: HealthcheckConfig{
|
||||
Cmd: "ping -c 1 8.8.8.8",
|
||||
Interval: 5 * time.Second,
|
||||
Timeout: 2 * time.Second,
|
||||
StartPeriod: 15 * time.Second,
|
||||
Retries: 3,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "partial flags",
|
||||
input: "--timeout=5s my-custom-check --verbose",
|
||||
want: HealthcheckConfig{
|
||||
Cmd: "my-custom-check --verbose",
|
||||
Interval: 3 * time.Second,
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "retries only",
|
||||
input: "--retries=5 test.sh",
|
||||
want: HealthcheckConfig{
|
||||
Cmd: "test.sh",
|
||||
Interval: 3 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
Retries: 5,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "whitespace only",
|
||||
input: " \t \n ",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "flags but no command",
|
||||
input: "--interval=5s --retries=2",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "unknown flag",
|
||||
input: "--magic=true my-check",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid duration",
|
||||
input: "--interval=5smiles check.sh",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid retries",
|
||||
input: "--retries=five check.sh",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "command with dashes",
|
||||
input: "--interval=2s command-with-dash --flag=value",
|
||||
want: HealthcheckConfig{
|
||||
Cmd: "command-with-dash --flag=value",
|
||||
Interval: 2 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ParseHealthcheck(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseHealthcheck() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr {
|
||||
if got.Cmd != tt.want.Cmd {
|
||||
t.Errorf("Cmd got = %v, want %v", got.Cmd, tt.want.Cmd)
|
||||
}
|
||||
if got.Interval != tt.want.Interval {
|
||||
t.Errorf("Interval got = %v, want %v", got.Interval, tt.want.Interval)
|
||||
}
|
||||
if got.Timeout != tt.want.Timeout {
|
||||
t.Errorf("Timeout got = %v, want %v", got.Timeout, tt.want.Timeout)
|
||||
}
|
||||
if got.StartPeriod != tt.want.StartPeriod {
|
||||
t.Errorf("StartPeriod got = %v, want %v", got.StartPeriod, tt.want.StartPeriod)
|
||||
}
|
||||
if got.Retries != tt.want.Retries {
|
||||
t.Errorf("Retries got = %v, want %v", got.Retries, tt.want.Retries)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
129
internal/recipe/step.go
Normal file
129
internal/recipe/step.go
Normal file
@ -0,0 +1,129 @@
|
||||
package recipe
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Kind identifies the instruction type in a recipe line.
|
||||
type Kind int
|
||||
|
||||
const (
|
||||
KindRUN Kind = iota // Execute a command and wait for it to exit.
|
||||
KindSTART // Start a command in the background (non-blocking).
|
||||
KindENV // Set an environment variable for subsequent steps.
|
||||
KindWORKDIR // Set the working directory for subsequent steps.
|
||||
KindUSER // Switch the unix user for subsequent steps. (stub)
|
||||
KindCOPY // Copy files into the sandbox. (stub)
|
||||
)
|
||||
|
||||
// Step is the parsed representation of one recipe instruction.
|
||||
type Step struct {
|
||||
Kind Kind
|
||||
Raw string // original string, preserved for logging
|
||||
Shell string // KindRUN, KindSTART: the shell command text
|
||||
Timeout time.Duration // KindRUN: 0 means use caller's default
|
||||
Key string // KindENV: variable name
|
||||
Value string // KindENV: variable value
|
||||
Path string // KindWORKDIR: directory path
|
||||
}
|
||||
|
||||
// ParseStep parses a single recipe instruction string into a Step.
|
||||
// Instructions are Dockerfile-like: a keyword followed by arguments.
|
||||
//
|
||||
// Supported syntax:
|
||||
//
|
||||
// RUN <cmd> — run command, wait for exit
|
||||
// RUN --timeout=<d> <cmd> — run command with explicit timeout (e.g. --timeout=5m)
|
||||
// START <cmd> — start command in background, return immediately
|
||||
// ENV <key>=<value> — set environment variable
|
||||
// WORKDIR <path> — set working directory
|
||||
// USER <name> — not yet supported
|
||||
// COPY <src> <dst> — not yet supported
|
||||
func ParseStep(s string) (Step, error) {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return Step{}, fmt.Errorf("empty step")
|
||||
}
|
||||
|
||||
// Split on first space to get the keyword.
|
||||
keyword, rest, _ := strings.Cut(s, " ")
|
||||
rest = strings.TrimSpace(rest)
|
||||
|
||||
switch strings.ToUpper(keyword) {
|
||||
case "RUN":
|
||||
return parseRUN(s, rest)
|
||||
case "START":
|
||||
return parseSTART(s, rest)
|
||||
case "ENV":
|
||||
return parseENV(s, rest)
|
||||
case "WORKDIR":
|
||||
return parseWORKDIR(s, rest)
|
||||
case "USER":
|
||||
return Step{Kind: KindUSER, Raw: s}, nil
|
||||
case "COPY":
|
||||
return Step{Kind: KindCOPY, Raw: s}, nil
|
||||
default:
|
||||
return Step{}, fmt.Errorf("unknown instruction %q (expected RUN, START, ENV, WORKDIR, USER, or COPY)", keyword)
|
||||
}
|
||||
}
|
||||
|
||||
// ParseRecipe parses all recipe lines, returning on the first error.
|
||||
func ParseRecipe(lines []string) ([]Step, error) {
|
||||
steps := make([]Step, 0, len(lines))
|
||||
for i, line := range lines {
|
||||
st, err := ParseStep(line)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("recipe line %d: %w", i+1, err)
|
||||
}
|
||||
steps = append(steps, st)
|
||||
}
|
||||
return steps, nil
|
||||
}
|
||||
|
||||
func parseRUN(raw, rest string) (Step, error) {
|
||||
var timeout time.Duration
|
||||
if strings.HasPrefix(rest, "--timeout=") {
|
||||
rest = rest[len("--timeout="):]
|
||||
flag, cmd, found := strings.Cut(rest, " ")
|
||||
if !found || strings.TrimSpace(cmd) == "" {
|
||||
return Step{}, fmt.Errorf("RUN --timeout= flag has no command: %q", raw)
|
||||
}
|
||||
d, err := time.ParseDuration(flag)
|
||||
if err != nil {
|
||||
return Step{}, fmt.Errorf("RUN --timeout= invalid duration %q: %w", flag, err)
|
||||
}
|
||||
timeout = d
|
||||
rest = strings.TrimSpace(cmd)
|
||||
}
|
||||
if rest == "" {
|
||||
return Step{}, fmt.Errorf("RUN requires a command: %q", raw)
|
||||
}
|
||||
return Step{Kind: KindRUN, Raw: raw, Shell: rest, Timeout: timeout}, nil
|
||||
}
|
||||
|
||||
func parseSTART(raw, rest string) (Step, error) {
|
||||
if rest == "" {
|
||||
return Step{}, fmt.Errorf("START requires a command: %q", raw)
|
||||
}
|
||||
return Step{Kind: KindSTART, Raw: raw, Shell: rest}, nil
|
||||
}
|
||||
|
||||
func parseENV(raw, rest string) (Step, error) {
|
||||
key, value, found := strings.Cut(rest, "=")
|
||||
if !found {
|
||||
return Step{}, fmt.Errorf("ENV requires KEY=VALUE format: %q", raw)
|
||||
}
|
||||
if key == "" {
|
||||
return Step{}, fmt.Errorf("ENV key is empty: %q", raw)
|
||||
}
|
||||
return Step{Kind: KindENV, Raw: raw, Key: key, Value: value}, nil
|
||||
}
|
||||
|
||||
func parseWORKDIR(raw, path string) (Step, error) {
|
||||
if path == "" {
|
||||
return Step{}, fmt.Errorf("WORKDIR requires a path: %q", raw)
|
||||
}
|
||||
return Step{Kind: KindWORKDIR, Raw: raw, Path: path}, nil
|
||||
}
|
||||
208
internal/recipe/step_test.go
Normal file
208
internal/recipe/step_test.go
Normal file
@ -0,0 +1,208 @@
|
||||
package recipe
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseStep(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want Step
|
||||
wantErr bool
|
||||
}{
|
||||
// RUN
|
||||
{
|
||||
name: "RUN basic",
|
||||
input: "RUN apt install -y curl",
|
||||
want: Step{Kind: KindRUN, Raw: "RUN apt install -y curl", Shell: "apt install -y curl"},
|
||||
},
|
||||
{
|
||||
name: "RUN lowercase",
|
||||
input: "run echo hello",
|
||||
want: Step{Kind: KindRUN, Raw: "run echo hello", Shell: "echo hello"},
|
||||
},
|
||||
{
|
||||
name: "RUN with timeout",
|
||||
input: "RUN --timeout=5m npm install",
|
||||
want: Step{Kind: KindRUN, Raw: "RUN --timeout=5m npm install", Shell: "npm install", Timeout: 5 * time.Minute},
|
||||
},
|
||||
{
|
||||
name: "RUN with timeout seconds",
|
||||
input: "RUN --timeout=30s make build",
|
||||
want: Step{Kind: KindRUN, Raw: "RUN --timeout=30s make build", Shell: "make build", Timeout: 30 * time.Second},
|
||||
},
|
||||
{
|
||||
name: "RUN no command",
|
||||
input: "RUN",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "RUN timeout no command",
|
||||
input: "RUN --timeout=5m",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "RUN invalid timeout",
|
||||
input: "RUN --timeout=notaduration echo hi",
|
||||
wantErr: true,
|
||||
},
|
||||
// START
|
||||
{
|
||||
name: "START basic",
|
||||
input: "START python3 app.py",
|
||||
want: Step{Kind: KindSTART, Raw: "START python3 app.py", Shell: "python3 app.py"},
|
||||
},
|
||||
{
|
||||
name: "START uppercase",
|
||||
input: "START node server.js --port=8080",
|
||||
want: Step{Kind: KindSTART, Raw: "START node server.js --port=8080", Shell: "node server.js --port=8080"},
|
||||
},
|
||||
{
|
||||
name: "START no command",
|
||||
input: "START",
|
||||
wantErr: true,
|
||||
},
|
||||
// ENV
|
||||
{
|
||||
name: "ENV basic",
|
||||
input: "ENV FOO=bar",
|
||||
want: Step{Kind: KindENV, Raw: "ENV FOO=bar", Key: "FOO", Value: "bar"},
|
||||
},
|
||||
{
|
||||
name: "ENV value with spaces",
|
||||
input: "ENV GREETING=hello world",
|
||||
want: Step{Kind: KindENV, Raw: "ENV GREETING=hello world", Key: "GREETING", Value: "hello world"},
|
||||
},
|
||||
{
|
||||
name: "ENV value with equals sign",
|
||||
input: "ENV URL=http://example.com?a=1",
|
||||
want: Step{Kind: KindENV, Raw: "ENV URL=http://example.com?a=1", Key: "URL", Value: "http://example.com?a=1"},
|
||||
},
|
||||
{
|
||||
name: "ENV empty value",
|
||||
input: "ENV FOO=",
|
||||
want: Step{Kind: KindENV, Raw: "ENV FOO=", Key: "FOO", Value: ""},
|
||||
},
|
||||
{
|
||||
name: "ENV missing equals",
|
||||
input: "ENV FOO",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "ENV empty key",
|
||||
input: "ENV =value",
|
||||
wantErr: true,
|
||||
},
|
||||
// WORKDIR
|
||||
{
|
||||
name: "WORKDIR basic",
|
||||
input: "WORKDIR /app",
|
||||
want: Step{Kind: KindWORKDIR, Raw: "WORKDIR /app", Path: "/app"},
|
||||
},
|
||||
{
|
||||
name: "WORKDIR with spaces in path",
|
||||
input: "WORKDIR /my project",
|
||||
want: Step{Kind: KindWORKDIR, Raw: "WORKDIR /my project", Path: "/my project"},
|
||||
},
|
||||
{
|
||||
name: "WORKDIR empty",
|
||||
input: "WORKDIR",
|
||||
wantErr: true,
|
||||
},
|
||||
// USER and COPY stubs
|
||||
{
|
||||
name: "USER stub",
|
||||
input: "USER www-data",
|
||||
want: Step{Kind: KindUSER, Raw: "USER www-data"},
|
||||
},
|
||||
{
|
||||
name: "COPY stub",
|
||||
input: "COPY config.yaml /etc/app/config.yaml",
|
||||
want: Step{Kind: KindCOPY, Raw: "COPY config.yaml /etc/app/config.yaml"},
|
||||
},
|
||||
// Unknown keyword
|
||||
{
|
||||
name: "unknown keyword",
|
||||
input: "FROBNICATE something",
|
||||
wantErr: true,
|
||||
},
|
||||
// Empty input
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := ParseStep(tc.input)
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Fatalf("ParseStep(%q) expected error, got %+v", tc.input, got)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("ParseStep(%q) unexpected error: %v", tc.input, err)
|
||||
}
|
||||
if got != tc.want {
|
||||
t.Errorf("ParseStep(%q)\n got %+v\n want %+v", tc.input, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRecipe(t *testing.T) {
|
||||
t.Run("valid recipe", func(t *testing.T) {
|
||||
lines := []string{
|
||||
"RUN apt update",
|
||||
"WORKDIR /app",
|
||||
"ENV PORT=8080",
|
||||
"START python3 server.py",
|
||||
"RUN --timeout=2m pip install -r requirements.txt",
|
||||
}
|
||||
steps, err := ParseRecipe(lines)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(steps) != 5 {
|
||||
t.Fatalf("expected 5 steps, got %d", len(steps))
|
||||
}
|
||||
if steps[0].Kind != KindRUN {
|
||||
t.Errorf("step 0: want KindRUN, got %v", steps[0].Kind)
|
||||
}
|
||||
if steps[1].Kind != KindWORKDIR {
|
||||
t.Errorf("step 1: want KindWORKDIR, got %v", steps[1].Kind)
|
||||
}
|
||||
if steps[3].Kind != KindSTART {
|
||||
t.Errorf("step 3: want KindSTART, got %v", steps[3].Kind)
|
||||
}
|
||||
if steps[4].Timeout != 2*time.Minute {
|
||||
t.Errorf("step 4: want 2m timeout, got %v", steps[4].Timeout)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("error on invalid line", func(t *testing.T) {
|
||||
lines := []string{
|
||||
"RUN apt update",
|
||||
"BADCMD something",
|
||||
}
|
||||
_, err := ParseRecipe(lines)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid line, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty recipe", func(t *testing.T) {
|
||||
steps, err := ParseRecipe(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(steps) != 0 {
|
||||
t.Fatalf("expected 0 steps, got %d", len(steps))
|
||||
}
|
||||
})
|
||||
}
|
||||
85
internal/sandbox/conntracker.go
Normal file
85
internal/sandbox/conntracker.go
Normal file
@ -0,0 +1,85 @@
|
||||
package sandbox
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ConnTracker tracks active proxy connections for a single sandbox and
|
||||
// provides a drain mechanism for pre-pause graceful shutdown.
|
||||
// It is safe for concurrent use.
|
||||
type ConnTracker struct {
|
||||
draining atomic.Bool
|
||||
wg sync.WaitGroup
|
||||
|
||||
// cancelMu protects cancelDrain so Reset can signal a timed-out Drain
|
||||
// goroutine to exit, preventing goroutine leaks on repeated pause failures.
|
||||
cancelMu sync.Mutex
|
||||
cancelDrain chan struct{}
|
||||
}
|
||||
|
||||
// Acquire registers one in-flight connection. Returns false if the tracker
|
||||
// is already draining; the caller must not call Release in that case.
|
||||
func (t *ConnTracker) Acquire() bool {
|
||||
if t.draining.Load() {
|
||||
return false
|
||||
}
|
||||
t.wg.Add(1)
|
||||
// Re-check after Add: Drain may have set draining between our Load
|
||||
// and Add. If so, undo the Add and reject the connection.
|
||||
if t.draining.Load() {
|
||||
t.wg.Done()
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Release marks one connection as complete. Must be called exactly once
|
||||
// per successful Acquire.
|
||||
func (t *ConnTracker) Release() {
|
||||
t.wg.Done()
|
||||
}
|
||||
|
||||
// Drain marks the tracker as draining (all future Acquire calls return
|
||||
// false) and waits up to timeout for in-flight connections to finish.
|
||||
func (t *ConnTracker) Drain(timeout time.Duration) {
|
||||
t.draining.Store(true)
|
||||
|
||||
cancel := make(chan struct{})
|
||||
t.cancelMu.Lock()
|
||||
t.cancelDrain = cancel
|
||||
t.cancelMu.Unlock()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
t.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-cancel:
|
||||
// Reset was called; stop waiting.
|
||||
case <-time.After(timeout):
|
||||
}
|
||||
}
|
||||
|
||||
// Reset re-enables the tracker after a failed drain. This allows the
|
||||
// sandbox to accept proxy connections again if the pause operation fails
|
||||
// and the VM is resumed. It also cancels any lingering Drain goroutine.
|
||||
func (t *ConnTracker) Reset() {
|
||||
t.cancelMu.Lock()
|
||||
if t.cancelDrain != nil {
|
||||
select {
|
||||
case <-t.cancelDrain:
|
||||
// Already closed.
|
||||
default:
|
||||
close(t.cancelDrain)
|
||||
}
|
||||
t.cancelDrain = nil
|
||||
}
|
||||
t.cancelMu.Unlock()
|
||||
|
||||
t.draining.Store(false)
|
||||
}
|
||||
106
internal/sandbox/images.go
Normal file
106
internal/sandbox/images.go
Normal file
@ -0,0 +1,106 @@
|
||||
package sandbox
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/layout"
|
||||
)
|
||||
|
||||
// DefaultDiskSizeMB is the standard disk size for base images. Images smaller
|
||||
// than this are expanded at startup so that dm-snapshot sandboxes see the full
|
||||
// size without per-sandbox copies. The expansion is sparse — only metadata
|
||||
// changes; no physical disk is consumed beyond the original content.
|
||||
const DefaultDiskSizeMB = 5120 // 5 GB
|
||||
|
||||
// EnsureImageSizes walks template directories and expands any rootfs.ext4 that
|
||||
// is smaller than the target size. This is idempotent: images already at or
|
||||
// above the target size are left untouched. Should be called once at host agent
|
||||
// startup before any sandboxes are created.
|
||||
func EnsureImageSizes(wrennDir string, targetMB int) error {
|
||||
if targetMB <= 0 {
|
||||
targetMB = DefaultDiskSizeMB
|
||||
}
|
||||
targetBytes := int64(targetMB) * 1024 * 1024
|
||||
|
||||
// Expand the built-in minimal image.
|
||||
minimalRootfs := layout.TemplateRootfs(wrennDir, id.PlatformTeamID, id.MinimalTemplateID)
|
||||
if err := expandImage(minimalRootfs, targetBytes, targetMB); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Walk teams/{teamDir}/{templateDir}/rootfs.ext4 two levels deep.
|
||||
teamsDir := layout.TeamsDir(wrennDir)
|
||||
teamEntries, err := os.ReadDir(teamsDir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil // teams dir doesn't exist yet — nothing to expand
|
||||
}
|
||||
return fmt.Errorf("read teams dir: %w", err)
|
||||
}
|
||||
|
||||
for _, teamEntry := range teamEntries {
|
||||
if !teamEntry.IsDir() {
|
||||
continue
|
||||
}
|
||||
teamPath := filepath.Join(teamsDir, teamEntry.Name())
|
||||
templateEntries, err := os.ReadDir(teamPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, tmplEntry := range templateEntries {
|
||||
if !tmplEntry.IsDir() {
|
||||
continue
|
||||
}
|
||||
rootfs := filepath.Join(teamPath, tmplEntry.Name(), "rootfs.ext4")
|
||||
if err := expandImage(rootfs, targetBytes, targetMB); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// expandImage expands a single rootfs image if it is smaller than targetBytes.
|
||||
func expandImage(rootfs string, targetBytes int64, targetMB int) error {
|
||||
info, err := os.Stat(rootfs)
|
||||
if err != nil {
|
||||
return nil // not every template dir has a rootfs.ext4
|
||||
}
|
||||
|
||||
if info.Size() >= targetBytes {
|
||||
return nil // already large enough
|
||||
}
|
||||
|
||||
slog.Info("expanding base image",
|
||||
"path", rootfs,
|
||||
"from_mb", info.Size()/(1024*1024),
|
||||
"to_mb", targetMB,
|
||||
)
|
||||
|
||||
// Expand the file (sparse — instant, no physical disk used).
|
||||
if err := os.Truncate(rootfs, targetBytes); err != nil {
|
||||
return fmt.Errorf("truncate %s: %w", rootfs, err)
|
||||
}
|
||||
|
||||
// Check filesystem before resize.
|
||||
if out, err := exec.Command("e2fsck", "-fy", rootfs).CombinedOutput(); err != nil {
|
||||
// e2fsck returns 1 if it fixed errors, which is fine.
|
||||
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() > 1 {
|
||||
return fmt.Errorf("e2fsck %s: %s: %w", rootfs, string(out), err)
|
||||
}
|
||||
}
|
||||
|
||||
// Grow the ext4 filesystem to fill the new file size.
|
||||
if out, err := exec.Command("resize2fs", rootfs).CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("resize2fs %s: %s: %w", rootfs, string(out), err)
|
||||
}
|
||||
|
||||
slog.Info("base image expanded", "path", rootfs, "size_mb", targetMB)
|
||||
return nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
178
internal/sandbox/metrics.go
Normal file
178
internal/sandbox/metrics.go
Normal file
@ -0,0 +1,178 @@
|
||||
package sandbox
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MetricPoint holds one metrics sample.
|
||||
type MetricPoint struct {
|
||||
Timestamp time.Time
|
||||
CPUPct float64
|
||||
MemBytes int64
|
||||
DiskBytes int64
|
||||
}
|
||||
|
||||
// Ring buffer capacity constants.
|
||||
const (
|
||||
ring10mCap = 1200 // 500ms × 1200 = 10 min
|
||||
ring2hCap = 240 // 30s × 240 = 2 h
|
||||
ring24hCap = 288 // 5min × 288 = 24 h
|
||||
|
||||
downsample2hEvery = 60 // 60 × 500ms = 30s
|
||||
downsample24hEvery = 10 // 10 × 30s = 5min
|
||||
)
|
||||
|
||||
// metricsRing holds three tiered ring buffers with automatic downsampling
|
||||
// from the finest tier into coarser tiers.
|
||||
type metricsRing struct {
|
||||
mu sync.Mutex
|
||||
|
||||
// 10-minute tier: 500ms samples.
|
||||
buf10m [ring10mCap]MetricPoint
|
||||
idx10m int
|
||||
count10m int
|
||||
|
||||
// 2-hour tier: 30s averages.
|
||||
buf2h [ring2hCap]MetricPoint
|
||||
idx2h int
|
||||
count2h int
|
||||
|
||||
// 24-hour tier: 5min averages.
|
||||
buf24h [ring24hCap]MetricPoint
|
||||
idx24h int
|
||||
count24h int
|
||||
|
||||
// Accumulators for downsampling.
|
||||
acc500ms [downsample2hEvery]MetricPoint
|
||||
acc500msN int
|
||||
|
||||
acc30s [downsample24hEvery]MetricPoint
|
||||
acc30sN int
|
||||
}
|
||||
|
||||
// newMetricsRing creates an empty metrics ring buffer.
|
||||
func newMetricsRing() *metricsRing {
|
||||
return &metricsRing{}
|
||||
}
|
||||
|
||||
// Push adds a 500ms sample to the finest tier and triggers downsampling
|
||||
// into coarser tiers when enough samples have accumulated.
|
||||
func (r *metricsRing) Push(p MetricPoint) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// Write to 10m ring.
|
||||
r.buf10m[r.idx10m] = p
|
||||
r.idx10m = (r.idx10m + 1) % ring10mCap
|
||||
if r.count10m < ring10mCap {
|
||||
r.count10m++
|
||||
}
|
||||
|
||||
// Accumulate for 2h downsample.
|
||||
r.acc500ms[r.acc500msN] = p
|
||||
r.acc500msN++
|
||||
if r.acc500msN == downsample2hEvery {
|
||||
avg := averagePoints(r.acc500ms[:downsample2hEvery])
|
||||
r.push2h(avg)
|
||||
r.acc500msN = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (r *metricsRing) push2h(p MetricPoint) {
|
||||
r.buf2h[r.idx2h] = p
|
||||
r.idx2h = (r.idx2h + 1) % ring2hCap
|
||||
if r.count2h < ring2hCap {
|
||||
r.count2h++
|
||||
}
|
||||
|
||||
// Accumulate for 24h downsample.
|
||||
r.acc30s[r.acc30sN] = p
|
||||
r.acc30sN++
|
||||
if r.acc30sN == downsample24hEvery {
|
||||
avg := averagePoints(r.acc30s[:downsample24hEvery])
|
||||
r.push24h(avg)
|
||||
r.acc30sN = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (r *metricsRing) push24h(p MetricPoint) {
|
||||
r.buf24h[r.idx24h] = p
|
||||
r.idx24h = (r.idx24h + 1) % ring24hCap
|
||||
if r.count24h < ring24hCap {
|
||||
r.count24h++
|
||||
}
|
||||
}
|
||||
|
||||
// Get10m returns the 10-minute tier points in chronological order.
|
||||
func (r *metricsRing) Get10m() []MetricPoint {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.readRing(r.buf10m[:], r.idx10m, r.count10m)
|
||||
}
|
||||
|
||||
// Get2h returns the 2-hour tier points in chronological order.
|
||||
func (r *metricsRing) Get2h() []MetricPoint {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.readRing(r.buf2h[:], r.idx2h, r.count2h)
|
||||
}
|
||||
|
||||
// Get24h returns the 24-hour tier points in chronological order.
|
||||
func (r *metricsRing) Get24h() []MetricPoint {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.readRing(r.buf24h[:], r.idx24h, r.count24h)
|
||||
}
|
||||
|
||||
// Flush returns all three tiers and resets the ring buffer.
|
||||
func (r *metricsRing) Flush() (pts10m, pts2h, pts24h []MetricPoint) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
pts10m = r.readRing(r.buf10m[:], r.idx10m, r.count10m)
|
||||
pts2h = r.readRing(r.buf2h[:], r.idx2h, r.count2h)
|
||||
pts24h = r.readRing(r.buf24h[:], r.idx24h, r.count24h)
|
||||
|
||||
// Reset all state.
|
||||
r.idx10m, r.count10m = 0, 0
|
||||
r.idx2h, r.count2h = 0, 0
|
||||
r.idx24h, r.count24h = 0, 0
|
||||
r.acc500msN = 0
|
||||
r.acc30sN = 0
|
||||
|
||||
return pts10m, pts2h, pts24h
|
||||
}
|
||||
|
||||
// readRing extracts elements from a circular buffer in chronological order.
|
||||
func (r *metricsRing) readRing(buf []MetricPoint, nextIdx, count int) []MetricPoint {
|
||||
if count == 0 {
|
||||
return nil
|
||||
}
|
||||
result := make([]MetricPoint, count)
|
||||
bufLen := len(buf)
|
||||
start := (nextIdx - count + bufLen) % bufLen
|
||||
for i := range count {
|
||||
result[i] = buf[(start+i)%bufLen]
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// averagePoints computes the average of a slice of MetricPoints.
|
||||
// The timestamp is set to the last point's timestamp.
|
||||
func averagePoints(pts []MetricPoint) MetricPoint {
|
||||
n := float64(len(pts))
|
||||
var cpu float64
|
||||
var mem, disk int64
|
||||
for _, p := range pts {
|
||||
cpu += p.CPUPct
|
||||
mem += p.MemBytes
|
||||
disk += p.DiskBytes
|
||||
}
|
||||
return MetricPoint{
|
||||
Timestamp: pts[len(pts)-1].Timestamp,
|
||||
CPUPct: cpu / n,
|
||||
MemBytes: int64(float64(mem) / n),
|
||||
DiskBytes: int64(float64(disk) / n),
|
||||
}
|
||||
}
|
||||
83
internal/sandbox/proc.go
Normal file
83
internal/sandbox/proc.go
Normal file
@ -0,0 +1,83 @@
|
||||
package sandbox
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// cpuStat holds raw CPU jiffies read from /proc/{pid}/stat.
|
||||
type cpuStat struct {
|
||||
utime uint64
|
||||
stime uint64
|
||||
}
|
||||
|
||||
// readCPUStat reads user and system CPU jiffies from /proc/{pid}/stat.
|
||||
// Fields 14 (utime) and 15 (stime) are 1-indexed in the man page;
|
||||
// after splitting on space, they are at indices 13 and 14.
|
||||
func readCPUStat(pid int) (cpuStat, error) {
|
||||
path := fmt.Sprintf("/proc/%d/stat", pid)
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return cpuStat{}, fmt.Errorf("read stat: %w", err)
|
||||
}
|
||||
|
||||
// /proc/{pid}/stat format: pid (comm) state fields...
|
||||
// The comm field may contain spaces and parens, so find the last ')' first.
|
||||
content := string(data)
|
||||
idx := strings.LastIndex(content, ")")
|
||||
if idx < 0 {
|
||||
return cpuStat{}, fmt.Errorf("malformed /proc/%d/stat: no closing paren", pid)
|
||||
}
|
||||
// After ")" there is " state field3 field4 ... fieldN"
|
||||
// field1 after ')' is state (index 0), utime is field 11, stime is field 12
|
||||
// (0-indexed from after the closing paren).
|
||||
fields := strings.Fields(content[idx+2:])
|
||||
if len(fields) < 13 {
|
||||
return cpuStat{}, fmt.Errorf("malformed /proc/%d/stat: too few fields (%d)", pid, len(fields))
|
||||
}
|
||||
utime, err := strconv.ParseUint(fields[11], 10, 64)
|
||||
if err != nil {
|
||||
return cpuStat{}, fmt.Errorf("parse utime: %w", err)
|
||||
}
|
||||
stime, err := strconv.ParseUint(fields[12], 10, 64)
|
||||
if err != nil {
|
||||
return cpuStat{}, fmt.Errorf("parse stime: %w", err)
|
||||
}
|
||||
return cpuStat{utime: utime, stime: stime}, nil
|
||||
}
|
||||
|
||||
// readMemRSS reads VmRSS from /proc/{pid}/status and returns bytes.
|
||||
func readMemRSS(pid int) (int64, error) {
|
||||
path := fmt.Sprintf("/proc/%d/status", pid)
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("read status: %w", err)
|
||||
}
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
if strings.HasPrefix(line, "VmRSS:") {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 {
|
||||
return 0, fmt.Errorf("malformed VmRSS line")
|
||||
}
|
||||
kb, err := strconv.ParseInt(fields[1], 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parse VmRSS: %w", err)
|
||||
}
|
||||
return kb * 1024, nil
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("VmRSS not found in /proc/%d/status", pid)
|
||||
}
|
||||
|
||||
// readDiskAllocated returns the actual allocated bytes (not apparent size)
|
||||
// of the file at path. This uses stat's block count × 512.
|
||||
func readDiskAllocated(path string) (int64, error) {
|
||||
var stat syscall.Stat_t
|
||||
if err := syscall.Stat(path, &stat); err != nil {
|
||||
return 0, fmt.Errorf("stat %s: %w", path, err)
|
||||
}
|
||||
return stat.Blocks * 512, nil
|
||||
}
|
||||
71
internal/scheduler/round_robin.go
Normal file
71
internal/scheduler/round_robin.go
Normal file
@ -0,0 +1,71 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
)
|
||||
|
||||
// HostScheduler selects a host for a new sandbox. Implementations may use
|
||||
// different strategies (round-robin, least-loaded, tag-based, etc.).
|
||||
type HostScheduler interface {
|
||||
// SelectHost returns a host that can accept a new sandbox.
|
||||
// For BYOC teams (isByoc=true), only online BYOC hosts belonging to teamID
|
||||
// are considered. For non-BYOC teams, only online regular (platform) hosts
|
||||
// are considered. Returns an error if no suitable host is available.
|
||||
SelectHost(ctx context.Context, teamID pgtype.UUID, isByoc bool) (db.Host, error)
|
||||
}
|
||||
|
||||
// RoundRobinScheduler cycles through eligible online hosts in round-robin order.
|
||||
// It re-fetches the host list on every call so that newly registered or
|
||||
// recovered hosts are considered immediately.
|
||||
type RoundRobinScheduler struct {
|
||||
db *db.Queries
|
||||
counter atomic.Int64
|
||||
}
|
||||
|
||||
// NewRoundRobinScheduler creates a RoundRobinScheduler backed by the given DB.
|
||||
func NewRoundRobinScheduler(queries *db.Queries) *RoundRobinScheduler {
|
||||
return &RoundRobinScheduler{db: queries}
|
||||
}
|
||||
|
||||
// SelectHost returns the next eligible online host in round-robin order.
|
||||
func (s *RoundRobinScheduler) SelectHost(ctx context.Context, teamID pgtype.UUID, isByoc bool) (db.Host, error) {
|
||||
hosts, err := s.db.ListActiveHosts(ctx)
|
||||
if err != nil {
|
||||
return db.Host{}, fmt.Errorf("list hosts: %w", err)
|
||||
}
|
||||
|
||||
var eligible []db.Host
|
||||
for _, h := range hosts {
|
||||
if h.Status != "online" || h.Address == "" {
|
||||
continue
|
||||
}
|
||||
if isByoc {
|
||||
// BYOC team: only use hosts belonging to this team.
|
||||
if h.Type != "byoc" || !h.TeamID.Valid || h.TeamID != teamID {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
// Non-BYOC team: only use platform (regular) hosts.
|
||||
if h.Type != "regular" {
|
||||
continue
|
||||
}
|
||||
}
|
||||
eligible = append(eligible, h)
|
||||
}
|
||||
|
||||
if len(eligible) == 0 {
|
||||
if isByoc {
|
||||
return db.Host{}, fmt.Errorf("no online BYOC hosts available for team")
|
||||
}
|
||||
return db.Host{}, fmt.Errorf("no online platform hosts available")
|
||||
}
|
||||
|
||||
idx := s.counter.Add(1) - 1
|
||||
return eligible[int(idx%int64(len(eligible)))], nil
|
||||
}
|
||||
@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
@ -22,7 +24,7 @@ type APIKeyCreateResult struct {
|
||||
}
|
||||
|
||||
// Create generates a new API key for the given team.
|
||||
func (s *APIKeyService) Create(ctx context.Context, teamID, userID, name string) (APIKeyCreateResult, error) {
|
||||
func (s *APIKeyService) Create(ctx context.Context, teamID, userID pgtype.UUID, name string) (APIKeyCreateResult, error) {
|
||||
if name == "" {
|
||||
name = "Unnamed API Key"
|
||||
}
|
||||
@ -48,16 +50,16 @@ func (s *APIKeyService) Create(ctx context.Context, teamID, userID, name string)
|
||||
}
|
||||
|
||||
// List returns all API keys belonging to the given team.
|
||||
func (s *APIKeyService) List(ctx context.Context, teamID string) ([]db.TeamApiKey, error) {
|
||||
func (s *APIKeyService) List(ctx context.Context, teamID pgtype.UUID) ([]db.TeamApiKey, error) {
|
||||
return s.DB.ListAPIKeysByTeam(ctx, teamID)
|
||||
}
|
||||
|
||||
// ListWithCreator returns all API keys for the team, joined with the creator's email.
|
||||
func (s *APIKeyService) ListWithCreator(ctx context.Context, teamID string) ([]db.ListAPIKeysByTeamWithCreatorRow, error) {
|
||||
func (s *APIKeyService) ListWithCreator(ctx context.Context, teamID pgtype.UUID) ([]db.ListAPIKeysByTeamWithCreatorRow, error) {
|
||||
return s.DB.ListAPIKeysByTeamWithCreator(ctx, teamID)
|
||||
}
|
||||
|
||||
// Delete removes an API key by ID, scoped to the given team.
|
||||
func (s *APIKeyService) Delete(ctx context.Context, keyID, teamID string) error {
|
||||
func (s *APIKeyService) Delete(ctx context.Context, keyID, teamID pgtype.UUID) error {
|
||||
return s.DB.DeleteAPIKey(ctx, db.DeleteAPIKeyParams{ID: keyID, TeamID: teamID})
|
||||
}
|
||||
|
||||
113
internal/service/audit.go
Normal file
113
internal/service/audit.go
Normal file
@ -0,0 +1,113 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
const auditMaxLimit = 200
|
||||
|
||||
// AuditEntry is a single audit log record returned by List.
|
||||
type AuditEntry struct {
|
||||
ID string
|
||||
TeamID string
|
||||
ActorType string
|
||||
ActorID string // empty for system
|
||||
ActorName string // empty for system
|
||||
ResourceType string
|
||||
ResourceID string // empty when not applicable
|
||||
Action string
|
||||
Scope string
|
||||
Status string // 'success', 'info', 'warning', 'error'
|
||||
Metadata map[string]any
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// AuditListParams controls the ListAuditLogs query.
|
||||
type AuditListParams struct {
|
||||
TeamID pgtype.UUID
|
||||
AdminScoped bool // true → include admin-scoped events; false → team-scoped only
|
||||
ResourceTypes []string // empty = no filter; multiple values = OR match
|
||||
Actions []string // empty = no filter; multiple values = OR match
|
||||
Before time.Time // zero = no cursor (start from latest)
|
||||
BeforeID pgtype.UUID // tie-breaker: id of the last item at the Before timestamp; zero = no tie-break
|
||||
Limit int // clamped to auditMaxLimit by the handler
|
||||
}
|
||||
|
||||
// AuditService provides the read side of the audit log.
|
||||
type AuditService struct {
|
||||
DB *db.Queries
|
||||
}
|
||||
|
||||
// List returns a page of audit log entries for the given team.
|
||||
func (s *AuditService) List(ctx context.Context, p AuditListParams) ([]AuditEntry, error) {
|
||||
limit := p.Limit
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
if limit > auditMaxLimit {
|
||||
limit = auditMaxLimit
|
||||
}
|
||||
|
||||
scopes := []string{"team"}
|
||||
if p.AdminScoped {
|
||||
scopes = append(scopes, "admin")
|
||||
}
|
||||
|
||||
var before pgtype.Timestamptz
|
||||
if !p.Before.IsZero() {
|
||||
before = pgtype.Timestamptz{Time: p.Before, Valid: true}
|
||||
}
|
||||
|
||||
resourceTypes := p.ResourceTypes
|
||||
if resourceTypes == nil {
|
||||
resourceTypes = []string{}
|
||||
}
|
||||
actions := p.Actions
|
||||
if actions == nil {
|
||||
actions = []string{}
|
||||
}
|
||||
|
||||
rows, err := s.DB.ListAuditLogs(ctx, db.ListAuditLogsParams{
|
||||
TeamID: p.TeamID,
|
||||
Column2: scopes,
|
||||
Column3: resourceTypes,
|
||||
Column4: actions,
|
||||
Column5: before,
|
||||
ID: p.BeforeID,
|
||||
Limit: int32(limit),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list audit logs: %w", err)
|
||||
}
|
||||
|
||||
entries := make([]AuditEntry, len(rows))
|
||||
for i, row := range rows {
|
||||
var meta map[string]any
|
||||
if len(row.Metadata) > 0 {
|
||||
_ = json.Unmarshal(row.Metadata, &meta)
|
||||
}
|
||||
entries[i] = AuditEntry{
|
||||
ID: id.FormatAuditLogID(row.ID),
|
||||
TeamID: id.FormatTeamID(row.TeamID),
|
||||
ActorType: row.ActorType,
|
||||
ActorID: row.ActorID.String,
|
||||
ActorName: row.ActorName,
|
||||
ResourceType: row.ResourceType,
|
||||
ResourceID: row.ResourceID.String,
|
||||
Action: row.Action,
|
||||
Scope: row.Scope,
|
||||
Status: row.Status,
|
||||
Metadata: meta,
|
||||
CreatedAt: row.CreatedAt.Time,
|
||||
}
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
605
internal/service/build.go
Normal file
605
internal/service/build.go
Normal file
@ -0,0 +1,605 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/recipe"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/scheduler"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
const (
|
||||
buildQueueKey = "wrenn:build_queue"
|
||||
buildCommandTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// preBuildCmds run before the user recipe to prepare the build environment.
|
||||
var preBuildCmds = []string{
|
||||
"RUN apt update",
|
||||
}
|
||||
|
||||
// postBuildCmds run after the user recipe to clean up caches and reduce image size.
|
||||
var postBuildCmds = []string{
|
||||
"RUN apt clean",
|
||||
"RUN apt autoremove -y",
|
||||
"RUN rm -rf /var/lib/apt/lists/*",
|
||||
}
|
||||
|
||||
// buildAgentClient is the subset of the host agent client used by the build worker.
|
||||
type buildAgentClient interface {
|
||||
CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error)
|
||||
DestroySandbox(ctx context.Context, req *connect.Request[pb.DestroySandboxRequest]) (*connect.Response[pb.DestroySandboxResponse], error)
|
||||
Exec(ctx context.Context, req *connect.Request[pb.ExecRequest]) (*connect.Response[pb.ExecResponse], error)
|
||||
CreateSnapshot(ctx context.Context, req *connect.Request[pb.CreateSnapshotRequest]) (*connect.Response[pb.CreateSnapshotResponse], error)
|
||||
FlattenRootfs(ctx context.Context, req *connect.Request[pb.FlattenRootfsRequest]) (*connect.Response[pb.FlattenRootfsResponse], error)
|
||||
}
|
||||
|
||||
// BuildService handles template build orchestration.
|
||||
type BuildService struct {
|
||||
DB *db.Queries
|
||||
Redis *redis.Client
|
||||
Pool *lifecycle.HostClientPool
|
||||
Scheduler scheduler.HostScheduler
|
||||
|
||||
mu sync.Mutex
|
||||
cancelMap map[string]context.CancelFunc // buildID → per-build cancel func
|
||||
}
|
||||
|
||||
// BuildCreateParams holds the parameters for creating a template build.
|
||||
type BuildCreateParams struct {
|
||||
Name string
|
||||
BaseTemplate string
|
||||
Recipe []string
|
||||
Healthcheck string
|
||||
VCPUs int32
|
||||
MemoryMB int32
|
||||
SkipPrePost bool
|
||||
}
|
||||
|
||||
// Create inserts a new build record and enqueues it to Redis.
|
||||
func (s *BuildService) Create(ctx context.Context, p BuildCreateParams) (db.TemplateBuild, error) {
|
||||
if p.BaseTemplate == "" {
|
||||
p.BaseTemplate = "minimal"
|
||||
}
|
||||
if p.VCPUs <= 0 {
|
||||
p.VCPUs = 1
|
||||
}
|
||||
if p.MemoryMB <= 0 {
|
||||
p.MemoryMB = 512
|
||||
}
|
||||
|
||||
recipeJSON, err := json.Marshal(p.Recipe)
|
||||
if err != nil {
|
||||
return db.TemplateBuild{}, fmt.Errorf("marshal recipe: %w", err)
|
||||
}
|
||||
|
||||
buildID := id.NewBuildID()
|
||||
buildIDStr := id.FormatBuildID(buildID)
|
||||
newTemplateID := id.NewTemplateID()
|
||||
|
||||
defaultSteps := len(preBuildCmds) + len(postBuildCmds)
|
||||
if p.SkipPrePost {
|
||||
defaultSteps = 0
|
||||
}
|
||||
|
||||
build, err := s.DB.InsertTemplateBuild(ctx, db.InsertTemplateBuildParams{
|
||||
ID: buildID,
|
||||
Name: p.Name,
|
||||
BaseTemplate: p.BaseTemplate,
|
||||
Recipe: recipeJSON,
|
||||
Healthcheck: p.Healthcheck,
|
||||
Vcpus: p.VCPUs,
|
||||
MemoryMb: p.MemoryMB,
|
||||
TotalSteps: int32(len(p.Recipe) + defaultSteps),
|
||||
TemplateID: newTemplateID,
|
||||
TeamID: id.PlatformTeamID,
|
||||
SkipPrePost: p.SkipPrePost,
|
||||
})
|
||||
if err != nil {
|
||||
return db.TemplateBuild{}, fmt.Errorf("insert build: %w", err)
|
||||
}
|
||||
|
||||
// Enqueue build ID (as formatted string) to Redis for workers to pick up.
|
||||
if err := s.Redis.RPush(ctx, buildQueueKey, buildIDStr).Err(); err != nil {
|
||||
return db.TemplateBuild{}, fmt.Errorf("enqueue build: %w", err)
|
||||
}
|
||||
|
||||
return build, nil
|
||||
}
|
||||
|
||||
// Get returns a single build by ID.
|
||||
func (s *BuildService) Get(ctx context.Context, buildID pgtype.UUID) (db.TemplateBuild, error) {
|
||||
return s.DB.GetTemplateBuild(ctx, buildID)
|
||||
}
|
||||
|
||||
// List returns all builds ordered by creation time.
|
||||
func (s *BuildService) List(ctx context.Context) ([]db.TemplateBuild, error) {
|
||||
return s.DB.ListTemplateBuilds(ctx)
|
||||
}
|
||||
|
||||
// Cancel cancels a pending or running build. For pending builds the status is
|
||||
// updated in the DB and the worker skips it when dequeued. For running builds
|
||||
// the per-build context is cancelled, which causes the current exec step to
|
||||
// abort; executeBuild then detects the cancellation and records the status.
|
||||
func (s *BuildService) Cancel(ctx context.Context, buildID pgtype.UUID) error {
|
||||
build, err := s.DB.GetTemplateBuild(ctx, buildID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get build: %w", err)
|
||||
}
|
||||
switch build.Status {
|
||||
case "success", "failed", "cancelled":
|
||||
return fmt.Errorf("build is already %s", build.Status)
|
||||
}
|
||||
|
||||
// Mark cancelled in DB first. This handles both pending builds (which haven't
|
||||
// been picked up yet) and acts as a flag for executeBuild to check on start.
|
||||
if _, err := s.DB.UpdateBuildStatus(ctx, db.UpdateBuildStatusParams{
|
||||
ID: buildID, Status: "cancelled",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("update build status: %w", err)
|
||||
}
|
||||
|
||||
// If the build is currently running, signal its context.
|
||||
buildIDStr := id.FormatBuildID(buildID)
|
||||
s.mu.Lock()
|
||||
cancel, running := s.cancelMap[buildIDStr]
|
||||
s.mu.Unlock()
|
||||
if running {
|
||||
cancel()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartWorkers launches n goroutines that consume from the Redis build queue.
|
||||
// The returned cancel function stops all workers.
|
||||
func (s *BuildService) StartWorkers(ctx context.Context, n int) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
for i := range n {
|
||||
go s.worker(ctx, i)
|
||||
}
|
||||
slog.Info("build workers started", "count", n)
|
||||
return cancel
|
||||
}
|
||||
|
||||
func (s *BuildService) worker(ctx context.Context, workerID int) {
|
||||
log := slog.With("worker", workerID)
|
||||
for {
|
||||
// BLPOP blocks until a build ID is available or context is cancelled.
|
||||
result, err := s.Redis.BLPop(ctx, 0, buildQueueKey).Result()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
log.Info("build worker shutting down")
|
||||
return
|
||||
}
|
||||
log.Error("redis BLPOP error", "error", err)
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
// result[0] is the key, result[1] is the build ID (formatted string).
|
||||
buildIDStr := result[1]
|
||||
log.Info("picked up build", "build_id", buildIDStr)
|
||||
s.executeBuild(ctx, buildIDStr)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) {
|
||||
log := slog.With("build_id", buildIDStr)
|
||||
|
||||
buildID, err := id.ParseBuildID(buildIDStr)
|
||||
if err != nil {
|
||||
log.Error("invalid build ID from queue", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a per-build context so this build can be cancelled independently of
|
||||
// the worker. Register in cancelMap before fetching the build so that a
|
||||
// concurrent Cancel call can always find and signal it.
|
||||
buildCtx, buildCancel := context.WithCancel(ctx)
|
||||
defer buildCancel()
|
||||
|
||||
s.mu.Lock()
|
||||
if s.cancelMap == nil {
|
||||
s.cancelMap = make(map[string]context.CancelFunc)
|
||||
}
|
||||
s.cancelMap[buildIDStr] = buildCancel
|
||||
s.mu.Unlock()
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
delete(s.cancelMap, buildIDStr)
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
build, err := s.DB.GetTemplateBuild(buildCtx, buildID)
|
||||
if err != nil {
|
||||
log.Error("failed to fetch build", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Skip if already cancelled (Cancel was called before we dequeued).
|
||||
if build.Status == "cancelled" {
|
||||
log.Info("build already cancelled, skipping")
|
||||
return
|
||||
}
|
||||
|
||||
// Mark as running.
|
||||
if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{
|
||||
ID: buildID, Status: "running",
|
||||
}); err != nil {
|
||||
log.Error("failed to update build status", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse user recipe.
|
||||
var userRecipe []string
|
||||
if err := json.Unmarshal(build.Recipe, &userRecipe); err != nil {
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("invalid recipe JSON: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Pick a platform host and create a sandbox.
|
||||
host, err := s.Scheduler.SelectHost(buildCtx, id.PlatformTeamID, false)
|
||||
if err != nil {
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("no host available: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := s.Pool.GetForHost(host)
|
||||
if err != nil {
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("agent client error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
sandboxID := id.NewSandboxID()
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
log = log.With("sandbox_id", sandboxIDStr, "host_id", id.FormatHostID(host.ID))
|
||||
|
||||
// Resolve the base template to UUIDs. "minimal" is the zero sentinel.
|
||||
baseTeamID := id.PlatformTeamID
|
||||
baseTemplateID := id.MinimalTemplateID
|
||||
if build.BaseTemplate != "minimal" {
|
||||
baseTmpl, err := s.DB.GetPlatformTemplateByName(buildCtx, build.BaseTemplate)
|
||||
if err != nil {
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("base template %q not found: %v", build.BaseTemplate, err))
|
||||
return
|
||||
}
|
||||
baseTeamID = baseTmpl.TeamID
|
||||
baseTemplateID = baseTmpl.ID
|
||||
}
|
||||
|
||||
resp, err := agent.CreateSandbox(buildCtx, connect.NewRequest(&pb.CreateSandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Template: build.BaseTemplate,
|
||||
TeamId: id.UUIDString(baseTeamID),
|
||||
TemplateId: id.UUIDString(baseTemplateID),
|
||||
Vcpus: build.Vcpus,
|
||||
MemoryMb: build.MemoryMb,
|
||||
TimeoutSec: 0, // no auto-pause for builds
|
||||
DiskSizeMb: 5120, // 5 GB for template builds
|
||||
}))
|
||||
if err != nil {
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("create sandbox failed: %v", err))
|
||||
return
|
||||
}
|
||||
_ = resp
|
||||
|
||||
// Record sandbox/host association.
|
||||
_ = s.DB.UpdateBuildSandbox(buildCtx, db.UpdateBuildSandboxParams{
|
||||
ID: buildID,
|
||||
SandboxID: sandboxID,
|
||||
HostID: host.ID,
|
||||
})
|
||||
|
||||
// Parse recipe steps. preBuildCmds and postBuildCmds are hardcoded and always
|
||||
// valid; panic on error is appropriate here since it would be a programmer mistake.
|
||||
preBuildSteps, err := recipe.ParseRecipe(preBuildCmds)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("invalid pre-build recipe: %v", err))
|
||||
}
|
||||
userRecipeSteps, err := recipe.ParseRecipe(userRecipe)
|
||||
if err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("recipe parse error: %v", err))
|
||||
return
|
||||
}
|
||||
postBuildSteps, err := recipe.ParseRecipe(postBuildCmds)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("invalid post-build recipe: %v", err))
|
||||
}
|
||||
|
||||
var logs []recipe.BuildLogEntry
|
||||
step := 0
|
||||
|
||||
envVars, err := s.fetchSandboxEnv(buildCtx, agent, sandboxIDStr)
|
||||
if err != nil {
|
||||
log.Warn("failed to fetch sandbox env, using defaults", "error", err)
|
||||
envVars = map[string]string{
|
||||
"PATH": "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
|
||||
"HOME": "/root",
|
||||
}
|
||||
}
|
||||
bctx := &recipe.ExecContext{EnvVars: envVars}
|
||||
|
||||
runPhase := func(phase string, steps []recipe.Step, defaultTimeout time.Duration) bool {
|
||||
newEntries, nextStep, ok := recipe.Execute(buildCtx, phase, steps, sandboxIDStr, step, defaultTimeout, bctx, agent.Exec)
|
||||
logs = append(logs, newEntries...)
|
||||
step = nextStep
|
||||
s.updateLogs(buildCtx, buildID, step, logs)
|
||||
if !ok {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
// If the build was cancelled, status is already set — don't overwrite with "failed".
|
||||
if buildCtx.Err() != nil {
|
||||
return false
|
||||
}
|
||||
last := newEntries[len(newEntries)-1]
|
||||
reason := last.Stderr
|
||||
if reason == "" {
|
||||
reason = fmt.Sprintf("exit code %d", last.Exit)
|
||||
}
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("%s step %d failed: %s", phase, step, reason))
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
if !build.SkipPrePost {
|
||||
if !runPhase("pre-build", preBuildSteps, 0) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if !runPhase("recipe", userRecipeSteps, buildCommandTimeout) {
|
||||
return
|
||||
}
|
||||
if !build.SkipPrePost {
|
||||
if !runPhase("post-build", postBuildSteps, 0) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Healthcheck or direct snapshot.
|
||||
var sizeBytes int64
|
||||
if build.Healthcheck != "" {
|
||||
hc, err := recipe.ParseHealthcheck(build.Healthcheck)
|
||||
if err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("invalid healthcheck: %v", err))
|
||||
return
|
||||
}
|
||||
log.Info("running healthcheck", "cmd", hc.Cmd, "interval", hc.Interval, "timeout", hc.Timeout, "start_period", hc.StartPeriod, "retries", hc.Retries)
|
||||
if err := s.waitForHealthcheck(buildCtx, agent, sandboxIDStr, hc); err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
if buildCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("healthcheck failed: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Healthcheck passed → full snapshot (with memory/CPU state).
|
||||
log.Info("healthcheck passed, creating snapshot")
|
||||
snapResp, err := agent.CreateSnapshot(buildCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Name: build.Name,
|
||||
TeamId: id.UUIDString(build.TeamID),
|
||||
TemplateId: id.UUIDString(build.TemplateID),
|
||||
}))
|
||||
if err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
if buildCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("create snapshot failed: %v", err))
|
||||
return
|
||||
}
|
||||
sizeBytes = snapResp.Msg.SizeBytes
|
||||
} else {
|
||||
// No healthcheck → image-only template (rootfs only).
|
||||
log.Info("no healthcheck, flattening rootfs")
|
||||
flatResp, err := agent.FlattenRootfs(buildCtx, connect.NewRequest(&pb.FlattenRootfsRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Name: build.Name,
|
||||
TeamId: id.UUIDString(build.TeamID),
|
||||
TemplateId: id.UUIDString(build.TemplateID),
|
||||
}))
|
||||
if err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
if buildCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("flatten rootfs failed: %v", err))
|
||||
return
|
||||
}
|
||||
sizeBytes = flatResp.Msg.SizeBytes
|
||||
}
|
||||
|
||||
// Insert into templates table as a global (platform) template.
|
||||
templateType := "base"
|
||||
if build.Healthcheck != "" {
|
||||
templateType = "snapshot"
|
||||
}
|
||||
|
||||
if _, err := s.DB.InsertTemplate(buildCtx, db.InsertTemplateParams{
|
||||
ID: build.TemplateID,
|
||||
Name: build.Name,
|
||||
Type: templateType,
|
||||
Vcpus: build.Vcpus,
|
||||
MemoryMb: build.MemoryMb,
|
||||
SizeBytes: sizeBytes,
|
||||
TeamID: id.PlatformTeamID,
|
||||
}); err != nil {
|
||||
log.Error("failed to insert template record", "error", err)
|
||||
// Build succeeded on disk, just DB record failed — don't mark as failed.
|
||||
}
|
||||
|
||||
// For CreateSnapshot, the sandbox is already destroyed by the snapshot process.
|
||||
// For FlattenRootfs, the sandbox is already destroyed by the flatten process.
|
||||
// No additional destroy needed.
|
||||
|
||||
// Mark build as success.
|
||||
if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{
|
||||
ID: buildID, Status: "success",
|
||||
}); err != nil {
|
||||
log.Error("failed to mark build as success", "error", err)
|
||||
}
|
||||
|
||||
log.Info("template build completed successfully", "name", build.Name)
|
||||
}
|
||||
|
||||
// waitForHealthcheck repeatedly executes the healthcheck command inside the
|
||||
// sandbox according to the config's interval, timeout, start-period, and
|
||||
// retries.
|
||||
// During the start period, failures are not counted toward the retry budget.
|
||||
// Returns nil on the first successful check, or an error if retries are
|
||||
// exhausted, the deadline passes, or the context is cancelled.
|
||||
func (s *BuildService) waitForHealthcheck(ctx context.Context, agent buildAgentClient, sandboxIDStr string, hc recipe.HealthcheckConfig) error {
|
||||
ticker := time.NewTicker(hc.Interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// When retries > 0, set a deadline based on the retry budget.
|
||||
// When retries == 0 (unlimited), rely solely on the parent context deadline.
|
||||
var deadlineCh <-chan time.Time
|
||||
if hc.Retries > 0 {
|
||||
deadline := time.NewTimer(hc.StartPeriod + time.Duration(hc.Retries+1)*hc.Interval)
|
||||
defer deadline.Stop()
|
||||
deadlineCh = deadline.C
|
||||
}
|
||||
|
||||
startedAt := time.Now()
|
||||
failCount := 0
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-deadlineCh:
|
||||
return fmt.Errorf("healthcheck timed out: exceeded %d attempts over %s", failCount, time.Since(startedAt))
|
||||
case <-ticker.C:
|
||||
execCtx, cancel := context.WithTimeout(ctx, hc.Timeout)
|
||||
resp, err := agent.Exec(execCtx, connect.NewRequest(&pb.ExecRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Cmd: "/bin/sh",
|
||||
Args: []string{"-c", hc.Cmd},
|
||||
TimeoutSec: int32(hc.Timeout.Seconds()),
|
||||
}))
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
slog.Debug("healthcheck exec error (retrying)", "error", err)
|
||||
if time.Since(startedAt) >= hc.StartPeriod {
|
||||
failCount++
|
||||
if hc.Retries > 0 && failCount >= hc.Retries {
|
||||
return fmt.Errorf("healthcheck failed after %d retries: exec error: %w", failCount, err)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if resp.Msg.ExitCode == 0 {
|
||||
return nil
|
||||
}
|
||||
slog.Debug("healthcheck failed (retrying)", "exit_code", resp.Msg.ExitCode)
|
||||
if time.Since(startedAt) >= hc.StartPeriod {
|
||||
failCount++
|
||||
if hc.Retries > 0 && failCount >= hc.Retries {
|
||||
return fmt.Errorf("healthcheck failed after %d retries: exit code %d", failCount, resp.Msg.ExitCode)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BuildService) updateLogs(ctx context.Context, buildID pgtype.UUID, step int, logs []recipe.BuildLogEntry) {
|
||||
logsJSON, err := json.Marshal(logs)
|
||||
if err != nil {
|
||||
slog.Warn("failed to marshal build logs", "error", err)
|
||||
return
|
||||
}
|
||||
if err := s.DB.UpdateBuildProgress(ctx, db.UpdateBuildProgressParams{
|
||||
ID: buildID,
|
||||
CurrentStep: int32(step),
|
||||
Logs: logsJSON,
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update build progress", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BuildService) failBuild(_ context.Context, buildID pgtype.UUID, errMsg string) {
|
||||
slog.Error("build failed", "build_id", id.FormatBuildID(buildID), "error", errMsg)
|
||||
// Use a detached context so DB writes survive parent context cancellation (e.g. shutdown).
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := s.DB.UpdateBuildError(ctx, db.UpdateBuildErrorParams{
|
||||
ID: buildID,
|
||||
Error: errMsg,
|
||||
}); err != nil {
|
||||
slog.Error("failed to update build error", "build_id", id.FormatBuildID(buildID), "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BuildService) destroySandbox(_ context.Context, agent buildAgentClient, sandboxIDStr string) {
|
||||
// Use a detached context so cleanup succeeds even during shutdown.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
})); err != nil {
|
||||
slog.Warn("failed to destroy build sandbox", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// fetchSandboxEnv executes the 'env' command inside the specified sandbox via
|
||||
// the build agent and returns environment variables
|
||||
func (s *BuildService) fetchSandboxEnv(ctx context.Context,
|
||||
agent buildAgentClient, sandboxIDStr string) (map[string]string, error) {
|
||||
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Cmd: "/bin/sh",
|
||||
Args: []string{"-c", "env"},
|
||||
TimeoutSec: 10,
|
||||
}))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch env: %w", err)
|
||||
}
|
||||
|
||||
if resp.Msg.ExitCode != 0 {
|
||||
return nil, fmt.Errorf("fetch env: command exited with code %d",
|
||||
resp.Msg.ExitCode)
|
||||
}
|
||||
|
||||
return parseSandboxEnv(string(resp.Msg.Stdout)), nil
|
||||
}
|
||||
|
||||
// parseSandboxEnv converts the raw newline-separated output of an 'env'
|
||||
// command into a map.
|
||||
// It skips empty lines and malformed entries, and correctly handles values
|
||||
// containing '='.
|
||||
func parseSandboxEnv(raw string) map[string]string {
|
||||
envVars := make(map[string]string)
|
||||
|
||||
for line := range strings.SplitSeq(raw, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
envVars[parts[0]] = parts[1]
|
||||
}
|
||||
|
||||
return envVars
|
||||
}
|
||||
@ -2,12 +2,14 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/redis/go-redis/v9"
|
||||
@ -15,6 +17,8 @@ import (
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
// HostService provides host management operations.
|
||||
@ -22,15 +26,17 @@ type HostService struct {
|
||||
DB *db.Queries
|
||||
Redis *redis.Client
|
||||
JWT []byte
|
||||
Pool *lifecycle.HostClientPool
|
||||
CA *auth.CA // nil disables mTLS cert issuance (dev/test environments)
|
||||
}
|
||||
|
||||
// HostCreateParams holds the parameters for creating a host.
|
||||
type HostCreateParams struct {
|
||||
Type string
|
||||
TeamID string // required for BYOC, empty for regular
|
||||
TeamID pgtype.UUID // required for BYOC, zero value for regular
|
||||
Provider string
|
||||
AvailabilityZone string
|
||||
RequestingUserID string
|
||||
RequestingUserID pgtype.UUID
|
||||
IsRequestorAdmin bool
|
||||
}
|
||||
|
||||
@ -50,10 +56,34 @@ type HostRegisterParams struct {
|
||||
Address string
|
||||
}
|
||||
|
||||
// HostRegisterResult holds the registered host and its long-lived JWT.
|
||||
// HostRegisterResult holds the registered host, its short-lived JWT, a long-lived
|
||||
// refresh token, and optionally the host's mTLS certificate material.
|
||||
type HostRegisterResult struct {
|
||||
Host db.Host
|
||||
JWT string
|
||||
Host db.Host
|
||||
JWT string
|
||||
RefreshToken string
|
||||
// mTLS cert material — empty when CA is not configured.
|
||||
CertPEM string
|
||||
KeyPEM string
|
||||
CACertPEM string
|
||||
}
|
||||
|
||||
// HostRefreshResult holds a new JWT and rotated refresh token after a successful
|
||||
// refresh, plus refreshed mTLS certificate material when CA is configured.
|
||||
type HostRefreshResult struct {
|
||||
Host db.Host
|
||||
JWT string
|
||||
RefreshToken string
|
||||
// mTLS cert material — empty when CA is not configured.
|
||||
CertPEM string
|
||||
KeyPEM string
|
||||
CACertPEM string
|
||||
}
|
||||
|
||||
// HostDeletePreview describes what will be affected by deleting a host.
|
||||
type HostDeletePreview struct {
|
||||
Host db.Host
|
||||
SandboxIDs []string
|
||||
}
|
||||
|
||||
// regTokenPayload is the JSON stored in Redis for registration tokens.
|
||||
@ -64,6 +94,14 @@ type regTokenPayload struct {
|
||||
|
||||
const regTokenTTL = time.Hour
|
||||
|
||||
// requireAdminOrOwner returns nil iff the role is "owner" or "admin".
|
||||
func requireAdminOrOwner(role string) error {
|
||||
if role == "owner" || role == "admin" {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("forbidden: only team owners and admins can manage BYOC hosts")
|
||||
}
|
||||
|
||||
// Create creates a new host record and generates a one-time registration token.
|
||||
func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreateResult, error) {
|
||||
if p.Type != "regular" && p.Type != "byoc" {
|
||||
@ -75,8 +113,8 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can create regular hosts")
|
||||
}
|
||||
} else {
|
||||
// BYOC: admin or team owner.
|
||||
if p.TeamID == "" {
|
||||
// BYOC: platform admin, or team owner/admin.
|
||||
if !p.TeamID.Valid {
|
||||
return HostCreateResult{}, fmt.Errorf("invalid request: team_id is required for BYOC hosts")
|
||||
}
|
||||
if !p.IsRequestorAdmin {
|
||||
@ -90,40 +128,31 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
|
||||
if err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
|
||||
}
|
||||
if membership.Role != "owner" {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: only team owners can create BYOC hosts")
|
||||
if err := requireAdminOrOwner(membership.Role); err != nil {
|
||||
return HostCreateResult{}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate team exists for BYOC hosts.
|
||||
if p.TeamID != "" {
|
||||
if _, err := s.DB.GetTeam(ctx, p.TeamID); err != nil {
|
||||
// Validate team exists, is not deleted, and has BYOC enabled.
|
||||
if p.TeamID.Valid {
|
||||
team, err := s.DB.GetTeam(ctx, p.TeamID)
|
||||
if err != nil || team.DeletedAt.Valid {
|
||||
return HostCreateResult{}, fmt.Errorf("invalid request: team not found")
|
||||
}
|
||||
if !team.IsByoc {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: BYOC is not enabled for this team")
|
||||
}
|
||||
}
|
||||
|
||||
hostID := id.NewHostID()
|
||||
|
||||
var teamID pgtype.Text
|
||||
if p.TeamID != "" {
|
||||
teamID = pgtype.Text{String: p.TeamID, Valid: true}
|
||||
}
|
||||
var provider pgtype.Text
|
||||
if p.Provider != "" {
|
||||
provider = pgtype.Text{String: p.Provider, Valid: true}
|
||||
}
|
||||
var az pgtype.Text
|
||||
if p.AvailabilityZone != "" {
|
||||
az = pgtype.Text{String: p.AvailabilityZone, Valid: true}
|
||||
}
|
||||
|
||||
host, err := s.DB.InsertHost(ctx, db.InsertHostParams{
|
||||
ID: hostID,
|
||||
Type: p.Type,
|
||||
TeamID: teamID,
|
||||
Provider: provider,
|
||||
AvailabilityZone: az,
|
||||
TeamID: p.TeamID,
|
||||
Provider: p.Provider,
|
||||
AvailabilityZone: p.AvailabilityZone,
|
||||
CreatedBy: p.RequestingUserID,
|
||||
})
|
||||
if err != nil {
|
||||
@ -135,8 +164,8 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
|
||||
tokenID := id.NewHostTokenID()
|
||||
|
||||
payload, _ := json.Marshal(regTokenPayload{
|
||||
HostID: hostID,
|
||||
TokenID: tokenID,
|
||||
HostID: id.FormatHostID(hostID),
|
||||
TokenID: id.FormatHostTokenID(tokenID),
|
||||
})
|
||||
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
|
||||
@ -149,7 +178,7 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
|
||||
CreatedBy: p.RequestingUserID,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to insert host token audit record", "host_id", hostID, "error", err)
|
||||
slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
|
||||
}
|
||||
|
||||
return HostCreateResult{Host: host, RegistrationToken: token}, nil
|
||||
@ -158,7 +187,7 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
|
||||
// RegenerateToken issues a new registration token for a host still in "pending"
|
||||
// status. This allows retry when a previous registration attempt failed after
|
||||
// the original token was consumed.
|
||||
func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamID string, isAdmin bool) (HostCreateResult, error) {
|
||||
func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin bool) (HostCreateResult, error) {
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("host not found: %w", err)
|
||||
@ -167,12 +196,11 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
|
||||
return HostCreateResult{}, fmt.Errorf("invalid state: can only regenerate token for pending hosts (status: %s)", host.Status)
|
||||
}
|
||||
|
||||
// Same permission model as Create/Delete.
|
||||
if !isAdmin {
|
||||
if host.Type != "byoc" {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can manage regular hosts")
|
||||
}
|
||||
if !host.TeamID.Valid || host.TeamID.String != teamID {
|
||||
if !host.TeamID.Valid || host.TeamID != teamID {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: host does not belong to your team")
|
||||
}
|
||||
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
@ -185,8 +213,8 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
|
||||
if err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
|
||||
}
|
||||
if membership.Role != "owner" {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: only team owners can regenerate tokens")
|
||||
if err := requireAdminOrOwner(membership.Role); err != nil {
|
||||
return HostCreateResult{}, err
|
||||
}
|
||||
}
|
||||
|
||||
@ -194,8 +222,8 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
|
||||
tokenID := id.NewHostTokenID()
|
||||
|
||||
payload, _ := json.Marshal(regTokenPayload{
|
||||
HostID: hostID,
|
||||
TokenID: tokenID,
|
||||
HostID: id.FormatHostID(hostID),
|
||||
TokenID: id.FormatHostTokenID(tokenID),
|
||||
})
|
||||
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
|
||||
@ -208,14 +236,14 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
|
||||
CreatedBy: userID,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to insert host token audit record", "host_id", hostID, "error", err)
|
||||
slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
|
||||
}
|
||||
|
||||
return HostCreateResult{Host: host, RegistrationToken: token}, nil
|
||||
}
|
||||
|
||||
// Register validates a one-time registration token, updates the host with
|
||||
// machine specs, and returns a long-lived host JWT.
|
||||
// machine specs, and returns a short-lived host JWT plus a long-lived refresh token.
|
||||
func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostRegisterResult, error) {
|
||||
// Atomic consume: GetDel returns the value and deletes in one operation,
|
||||
// preventing concurrent requests from consuming the same token.
|
||||
@ -232,24 +260,44 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR
|
||||
return HostRegisterResult{}, fmt.Errorf("corrupted registration token")
|
||||
}
|
||||
|
||||
if _, err := s.DB.GetHost(ctx, payload.HostID); err != nil {
|
||||
hostID, err := id.ParseHostID(payload.HostID)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("corrupted registration token: %w", err)
|
||||
}
|
||||
tokenID, err := id.ParseHostTokenID(payload.TokenID)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("corrupted registration token: %w", err)
|
||||
}
|
||||
|
||||
if _, err := s.DB.GetHost(ctx, hostID); err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
|
||||
// Sign JWT before mutating DB — if signing fails, the host stays pending.
|
||||
hostJWT, err := auth.SignHostJWT(s.JWT, payload.HostID)
|
||||
hostJWT, err := auth.SignHostJWT(s.JWT, hostID)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("sign host token: %w", err)
|
||||
}
|
||||
|
||||
// Issue mTLS certificate if CA is configured.
|
||||
var hc auth.HostCert
|
||||
if s.CA != nil {
|
||||
hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(hostID), p.Address)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("issue host cert: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Atomically update only if still pending (defense-in-depth against races).
|
||||
rowsAffected, err := s.DB.RegisterHost(ctx, db.RegisterHostParams{
|
||||
ID: payload.HostID,
|
||||
Arch: pgtype.Text{String: p.Arch, Valid: p.Arch != ""},
|
||||
CpuCores: pgtype.Int4{Int32: p.CPUCores, Valid: p.CPUCores > 0},
|
||||
MemoryMb: pgtype.Int4{Int32: p.MemoryMB, Valid: p.MemoryMB > 0},
|
||||
DiskGb: pgtype.Int4{Int32: p.DiskGB, Valid: p.DiskGB > 0},
|
||||
Address: pgtype.Text{String: p.Address, Valid: p.Address != ""},
|
||||
ID: hostID,
|
||||
Arch: p.Arch,
|
||||
CpuCores: p.CPUCores,
|
||||
MemoryMb: p.MemoryMB,
|
||||
DiskGb: p.DiskGB,
|
||||
Address: p.Address,
|
||||
CertFingerprint: hc.Fingerprint,
|
||||
CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: s.CA != nil},
|
||||
})
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("register host: %w", err)
|
||||
@ -259,82 +307,293 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR
|
||||
}
|
||||
|
||||
// Mark audit trail.
|
||||
if err := s.DB.MarkHostTokenUsed(ctx, payload.TokenID); err != nil {
|
||||
if err := s.DB.MarkHostTokenUsed(ctx, tokenID); err != nil {
|
||||
slog.Warn("failed to mark host token used", "token_id", payload.TokenID, "error", err)
|
||||
}
|
||||
|
||||
// Issue a long-lived refresh token.
|
||||
refreshToken, err := s.issueRefreshToken(ctx, hostID)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("issue refresh token: %w", err)
|
||||
}
|
||||
|
||||
// Re-fetch the host to get the updated state.
|
||||
host, err := s.DB.GetHost(ctx, payload.HostID)
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err)
|
||||
}
|
||||
|
||||
return HostRegisterResult{Host: host, JWT: hostJWT}, nil
|
||||
result := HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken}
|
||||
if s.CA != nil {
|
||||
result.CertPEM = hc.CertPEM
|
||||
result.KeyPEM = hc.KeyPEM
|
||||
result.CACertPEM = s.CA.PEM
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Heartbeat updates the last heartbeat timestamp for a host.
|
||||
func (s *HostService) Heartbeat(ctx context.Context, hostID string) error {
|
||||
return s.DB.UpdateHostHeartbeat(ctx, hostID)
|
||||
// Refresh validates a refresh token, rotates it (revokes old, issues new),
|
||||
// and returns a fresh JWT plus the new refresh token.
|
||||
func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRefreshResult, error) {
|
||||
hash := hashToken(refreshToken)
|
||||
|
||||
row, err := s.DB.GetHostRefreshTokenByHash(ctx, hash)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return HostRefreshResult{}, fmt.Errorf("invalid or expired refresh token")
|
||||
}
|
||||
if err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("lookup refresh token: %w", err)
|
||||
}
|
||||
|
||||
host, err := s.DB.GetHost(ctx, row.HostID)
|
||||
if err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
|
||||
// Sign new JWT.
|
||||
hostJWT, err := auth.SignHostJWT(s.JWT, host.ID)
|
||||
if err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("sign host JWT: %w", err)
|
||||
}
|
||||
|
||||
// Renew mTLS certificate if CA is configured.
|
||||
var hc auth.HostCert
|
||||
if s.CA != nil {
|
||||
hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(host.ID), host.Address)
|
||||
if err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("renew host cert: %w", err)
|
||||
}
|
||||
if err := s.DB.UpdateHostCert(ctx, db.UpdateHostCertParams{
|
||||
ID: host.ID,
|
||||
CertFingerprint: hc.Fingerprint,
|
||||
CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: true},
|
||||
}); err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("update host cert: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Issue-then-revoke rotation: insert new token first so a crash between
|
||||
// the two DB calls leaves the host with two valid tokens rather than zero.
|
||||
newRefreshToken, err := s.issueRefreshToken(ctx, host.ID)
|
||||
if err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("issue new refresh token: %w", err)
|
||||
}
|
||||
|
||||
// Revoke old refresh token after the new one is safely persisted.
|
||||
if err := s.DB.RevokeHostRefreshToken(ctx, row.ID); err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("revoke old refresh token: %w", err)
|
||||
}
|
||||
|
||||
result := HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken}
|
||||
if s.CA != nil {
|
||||
result.CertPEM = hc.CertPEM
|
||||
result.KeyPEM = hc.KeyPEM
|
||||
result.CACertPEM = s.CA.PEM
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// issueRefreshToken creates a new refresh token record in the DB and returns
|
||||
// the opaque token string.
|
||||
func (s *HostService) issueRefreshToken(ctx context.Context, hostID pgtype.UUID) (string, error) {
|
||||
token := id.NewRefreshToken()
|
||||
hash := hashToken(token)
|
||||
now := time.Now()
|
||||
|
||||
if _, err := s.DB.InsertHostRefreshToken(ctx, db.InsertHostRefreshTokenParams{
|
||||
ID: id.NewRefreshTokenID(),
|
||||
HostID: hostID,
|
||||
TokenHash: hash,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: now.Add(auth.HostRefreshTokenExpiry), Valid: true},
|
||||
}); err != nil {
|
||||
return "", fmt.Errorf("insert refresh token: %w", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// hashToken returns the hex-encoded SHA-256 hash of the token.
|
||||
func hashToken(token string) string {
|
||||
h := sha256.Sum256([]byte(token))
|
||||
return fmt.Sprintf("%x", h)
|
||||
}
|
||||
|
||||
// Heartbeat updates the last heartbeat timestamp for a host and transitions
|
||||
// any 'unreachable' host back to 'online'. Returns a "host not found" error
|
||||
// (which becomes 404) if the host record no longer exists (e.g., was deleted).
|
||||
func (s *HostService) Heartbeat(ctx context.Context, hostID pgtype.UUID) error {
|
||||
n, err := s.DB.UpdateHostHeartbeatAndStatus(ctx, hostID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
return fmt.Errorf("host not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// List returns hosts visible to the caller.
|
||||
// Admins see all hosts; non-admins see only BYOC hosts belonging to their team.
|
||||
func (s *HostService) List(ctx context.Context, teamID string, isAdmin bool) ([]db.Host, error) {
|
||||
func (s *HostService) List(ctx context.Context, teamID pgtype.UUID, isAdmin bool) ([]db.Host, error) {
|
||||
if isAdmin {
|
||||
return s.DB.ListHosts(ctx)
|
||||
}
|
||||
return s.DB.ListHostsByTeam(ctx, pgtype.Text{String: teamID, Valid: true})
|
||||
return s.DB.ListHostsByTeam(ctx, teamID)
|
||||
}
|
||||
|
||||
// Get returns a single host, enforcing access control.
|
||||
func (s *HostService) Get(ctx context.Context, hostID, teamID string, isAdmin bool) (db.Host, error) {
|
||||
func (s *HostService) Get(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) {
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return db.Host{}, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
if !isAdmin {
|
||||
if !host.TeamID.Valid || host.TeamID.String != teamID {
|
||||
if !host.TeamID.Valid || host.TeamID != teamID {
|
||||
return db.Host{}, fmt.Errorf("host not found")
|
||||
}
|
||||
}
|
||||
return host, nil
|
||||
}
|
||||
|
||||
// Delete removes a host. Admins can delete any host. Team owners can delete
|
||||
// BYOC hosts belonging to their team.
|
||||
func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, isAdmin bool) error {
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
// DeletePreview returns what would be affected by deleting the host, without
|
||||
// making any changes. Use this to show the user a confirmation prompt.
|
||||
func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (HostDeletePreview, error) {
|
||||
host, err := s.checkDeletePermission(ctx, hostID, pgtype.UUID{}, teamID, isAdmin)
|
||||
if err != nil {
|
||||
return fmt.Errorf("host not found: %w", err)
|
||||
return HostDeletePreview{}, err
|
||||
}
|
||||
|
||||
if !isAdmin {
|
||||
if host.Type != "byoc" {
|
||||
return fmt.Errorf("forbidden: only admins can delete regular hosts")
|
||||
sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||
HostID: hostID,
|
||||
Column2: []string{"pending", "starting", "running", "missing"},
|
||||
})
|
||||
if err != nil {
|
||||
return HostDeletePreview{}, fmt.Errorf("list sandboxes: %w", err)
|
||||
}
|
||||
|
||||
ids := make([]string, len(sandboxes))
|
||||
for i, sb := range sandboxes {
|
||||
ids[i] = id.FormatSandboxID(sb.ID)
|
||||
}
|
||||
|
||||
return HostDeletePreview{Host: host, SandboxIDs: ids}, nil
|
||||
}
|
||||
|
||||
// Delete removes a host. Without force it returns an error listing active
|
||||
// sandboxes so the caller can present a confirmation. With force it gracefully
|
||||
// destroys all running sandboxes before deleting the host record.
|
||||
func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin, force bool) error {
|
||||
host, err := s.checkDeletePermission(ctx, hostID, userID, teamID, isAdmin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||
HostID: hostID,
|
||||
Column2: []string{"pending", "starting", "running", "missing"},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("list sandboxes: %w", err)
|
||||
}
|
||||
|
||||
if len(sandboxes) > 0 && !force {
|
||||
ids := make([]string, len(sandboxes))
|
||||
for i, sb := range sandboxes {
|
||||
ids[i] = id.FormatSandboxID(sb.ID)
|
||||
}
|
||||
if !host.TeamID.Valid || host.TeamID.String != teamID {
|
||||
return fmt.Errorf("forbidden: host does not belong to your team")
|
||||
return &HostHasSandboxesError{SandboxIDs: ids}
|
||||
}
|
||||
|
||||
hostIDStr := id.FormatHostID(hostID)
|
||||
|
||||
// Gracefully destroy running sandboxes and terminate the agent (best-effort).
|
||||
if host.Address != "" {
|
||||
agent, err := s.Pool.GetForHost(host)
|
||||
if err == nil {
|
||||
for _, sb := range sandboxes {
|
||||
if sb.Status == "running" || sb.Status == "starting" {
|
||||
_, rpcErr := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
||||
SandboxId: id.FormatSandboxID(sb.ID),
|
||||
}))
|
||||
if rpcErr != nil && connect.CodeOf(rpcErr) != connect.CodeNotFound {
|
||||
slog.Warn("delete host: failed to destroy sandbox on agent", "sandbox_id", id.FormatSandboxID(sb.ID), "error", rpcErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Tell the agent to shut itself down immediately.
|
||||
if _, rpcErr := agent.Terminate(ctx, connect.NewRequest(&pb.TerminateRequest{})); rpcErr != nil {
|
||||
slog.Warn("delete host: failed to send Terminate to agent", "host_id", hostIDStr, "error", rpcErr)
|
||||
}
|
||||
}
|
||||
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return fmt.Errorf("forbidden: not a member of the specified team")
|
||||
}
|
||||
|
||||
// Mark all affected sandboxes as stopped in DB.
|
||||
if len(sandboxes) > 0 {
|
||||
sbIDs := make([]pgtype.UUID, len(sandboxes))
|
||||
for i, sb := range sandboxes {
|
||||
sbIDs[i] = sb.ID
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("check team membership: %w", err)
|
||||
}
|
||||
if membership.Role != "owner" {
|
||||
return fmt.Errorf("forbidden: only team owners can delete BYOC hosts")
|
||||
if err := s.DB.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: sbIDs,
|
||||
Status: "stopped",
|
||||
}); err != nil {
|
||||
slog.Warn("delete host: failed to mark sandboxes stopped", "host_id", hostIDStr, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Revoke all refresh tokens for this host.
|
||||
if err := s.DB.RevokeHostRefreshTokensByHost(ctx, hostID); err != nil {
|
||||
slog.Warn("delete host: failed to revoke refresh tokens", "host_id", hostIDStr, "error", err)
|
||||
}
|
||||
|
||||
// Evict the client from the pool so no further RPCs are sent.
|
||||
if s.Pool != nil {
|
||||
s.Pool.Evict(id.FormatHostID(hostID))
|
||||
}
|
||||
|
||||
return s.DB.DeleteHost(ctx, hostID)
|
||||
}
|
||||
|
||||
// checkDeletePermission verifies the caller has permission to delete the given
|
||||
// host and returns the host record on success.
|
||||
func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) {
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return db.Host{}, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
|
||||
if isAdmin {
|
||||
return host, nil
|
||||
}
|
||||
|
||||
if host.Type != "byoc" {
|
||||
return db.Host{}, fmt.Errorf("forbidden: only admins can delete regular hosts")
|
||||
}
|
||||
if !host.TeamID.Valid || host.TeamID != teamID {
|
||||
return db.Host{}, fmt.Errorf("forbidden: host does not belong to your team")
|
||||
}
|
||||
|
||||
if userID.Valid {
|
||||
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return db.Host{}, fmt.Errorf("forbidden: not a member of the specified team")
|
||||
}
|
||||
if err != nil {
|
||||
return db.Host{}, fmt.Errorf("check team membership: %w", err)
|
||||
}
|
||||
if err := requireAdminOrOwner(membership.Role); err != nil {
|
||||
return db.Host{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return host, nil
|
||||
}
|
||||
|
||||
// AddTag adds a tag to a host.
|
||||
func (s *HostService) AddTag(ctx context.Context, hostID, teamID string, isAdmin bool, tag string) error {
|
||||
func (s *HostService) AddTag(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool, tag string) error {
|
||||
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -342,7 +601,7 @@ func (s *HostService) AddTag(ctx context.Context, hostID, teamID string, isAdmin
|
||||
}
|
||||
|
||||
// RemoveTag removes a tag from a host.
|
||||
func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID string, isAdmin bool, tag string) error {
|
||||
func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool, tag string) error {
|
||||
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -350,9 +609,20 @@ func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID string, isAd
|
||||
}
|
||||
|
||||
// ListTags returns all tags for a host.
|
||||
func (s *HostService) ListTags(ctx context.Context, hostID, teamID string, isAdmin bool) ([]string, error) {
|
||||
func (s *HostService) ListTags(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) ([]string, error) {
|
||||
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.DB.GetHostTags(ctx, hostID)
|
||||
}
|
||||
|
||||
// HostHasSandboxesError is returned by Delete when the host has active sandboxes
|
||||
// and force was not set. The caller should present the list to the user and
|
||||
// re-call Delete with force=true if they confirm.
|
||||
type HostHasSandboxesError struct {
|
||||
SandboxIDs []string
|
||||
}
|
||||
|
||||
func (e *HostHasSandboxesError) Error() string {
|
||||
return fmt.Sprintf("host has %d active sandbox(es): %v", len(e.SandboxIDs), e.SandboxIDs)
|
||||
}
|
||||
|
||||
@ -11,29 +11,60 @@ import (
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/scheduler"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/validate"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
// SandboxService provides sandbox lifecycle operations shared between the
|
||||
// REST API and the dashboard.
|
||||
type SandboxService struct {
|
||||
DB *db.Queries
|
||||
Agent hostagentv1connect.HostAgentServiceClient
|
||||
DB *db.Queries
|
||||
Pool *lifecycle.HostClientPool
|
||||
Scheduler scheduler.HostScheduler
|
||||
}
|
||||
|
||||
// SandboxCreateParams holds the parameters for creating a sandbox.
|
||||
type SandboxCreateParams struct {
|
||||
TeamID string
|
||||
TeamID pgtype.UUID
|
||||
Template string
|
||||
VCPUs int32
|
||||
MemoryMB int32
|
||||
TimeoutSec int32
|
||||
DiskSizeMB int32
|
||||
}
|
||||
|
||||
// Create creates a new sandbox: inserts a pending DB record, calls the host agent,
|
||||
// and updates the record to running. Returns the sandbox DB row.
|
||||
// agentForSandbox looks up the host for the given sandbox and returns a client.
|
||||
func (s *SandboxService) agentForSandbox(ctx context.Context, sandboxID pgtype.UUID) (hostagentClient, db.Sandbox, error) {
|
||||
sb, err := s.DB.GetSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
return nil, db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
|
||||
}
|
||||
host, err := s.DB.GetHost(ctx, sb.HostID)
|
||||
if err != nil {
|
||||
return nil, db.Sandbox{}, fmt.Errorf("host not found for sandbox: %w", err)
|
||||
}
|
||||
agent, err := s.Pool.GetForHost(host)
|
||||
if err != nil {
|
||||
return nil, db.Sandbox{}, fmt.Errorf("get agent client: %w", err)
|
||||
}
|
||||
return agent, sb, nil
|
||||
}
|
||||
|
||||
// hostagentClient is a local alias to avoid the full package path in signatures.
|
||||
type hostagentClient = interface {
|
||||
CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error)
|
||||
DestroySandbox(ctx context.Context, req *connect.Request[pb.DestroySandboxRequest]) (*connect.Response[pb.DestroySandboxResponse], error)
|
||||
PauseSandbox(ctx context.Context, req *connect.Request[pb.PauseSandboxRequest]) (*connect.Response[pb.PauseSandboxResponse], error)
|
||||
ResumeSandbox(ctx context.Context, req *connect.Request[pb.ResumeSandboxRequest]) (*connect.Response[pb.ResumeSandboxResponse], error)
|
||||
PingSandbox(ctx context.Context, req *connect.Request[pb.PingSandboxRequest]) (*connect.Response[pb.PingSandboxResponse], error)
|
||||
GetSandboxMetrics(ctx context.Context, req *connect.Request[pb.GetSandboxMetricsRequest]) (*connect.Response[pb.GetSandboxMetricsResponse], error)
|
||||
FlushSandboxMetrics(ctx context.Context, req *connect.Request[pb.FlushSandboxMetricsRequest]) (*connect.Response[pb.FlushSandboxMetricsResponse], error)
|
||||
}
|
||||
|
||||
// Create creates a new sandbox: picks a host via the scheduler, inserts a pending
|
||||
// DB record, calls the host agent, and updates the record to running.
|
||||
func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.Sandbox, error) {
|
||||
if p.Template == "" {
|
||||
p.Template = "minimal"
|
||||
@ -47,44 +78,82 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.
|
||||
if p.MemoryMB <= 0 {
|
||||
p.MemoryMB = 512
|
||||
}
|
||||
if p.DiskSizeMB <= 0 {
|
||||
p.DiskSizeMB = 5120 // 5 GB default
|
||||
}
|
||||
|
||||
// If the template is a snapshot, use its baked-in vcpus/memory.
|
||||
if tmpl, err := s.DB.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: p.Template, TeamID: p.TeamID}); err == nil && tmpl.Type == "snapshot" {
|
||||
if tmpl.Vcpus.Valid {
|
||||
p.VCPUs = tmpl.Vcpus.Int32
|
||||
// Resolve template name → (teamID, templateID).
|
||||
templateTeamID := id.PlatformTeamID
|
||||
templateID := id.MinimalTemplateID
|
||||
if p.Template != "minimal" {
|
||||
tmpl, err := s.DB.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: p.Template, TeamID: p.TeamID})
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("template %q not found: %w", p.Template, err)
|
||||
}
|
||||
if tmpl.MemoryMb.Valid {
|
||||
p.MemoryMB = tmpl.MemoryMb.Int32
|
||||
templateTeamID = tmpl.TeamID
|
||||
templateID = tmpl.ID
|
||||
// If the template is a snapshot, use its baked-in vcpus/memory.
|
||||
if tmpl.Type == "snapshot" {
|
||||
p.VCPUs = tmpl.Vcpus
|
||||
p.MemoryMB = tmpl.MemoryMb
|
||||
}
|
||||
}
|
||||
|
||||
if !p.TeamID.Valid {
|
||||
return db.Sandbox{}, fmt.Errorf("invalid request: team_id is required")
|
||||
}
|
||||
|
||||
// Determine whether this team uses BYOC hosts or platform hosts.
|
||||
team, err := s.DB.GetTeam(ctx, p.TeamID)
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("team not found: %w", err)
|
||||
}
|
||||
|
||||
// Pick a host for this sandbox.
|
||||
host, err := s.Scheduler.SelectHost(ctx, p.TeamID, team.IsByoc)
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("select host: %w", err)
|
||||
}
|
||||
|
||||
agent, err := s.Pool.GetForHost(host)
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("get agent client: %w", err)
|
||||
}
|
||||
|
||||
sandboxID := id.NewSandboxID()
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
|
||||
if _, err := s.DB.InsertSandbox(ctx, db.InsertSandboxParams{
|
||||
ID: sandboxID,
|
||||
TeamID: p.TeamID,
|
||||
HostID: "default",
|
||||
Template: p.Template,
|
||||
Status: "pending",
|
||||
Vcpus: p.VCPUs,
|
||||
MemoryMb: p.MemoryMB,
|
||||
TimeoutSec: p.TimeoutSec,
|
||||
ID: sandboxID,
|
||||
TeamID: p.TeamID,
|
||||
HostID: host.ID,
|
||||
Template: p.Template,
|
||||
Status: "pending",
|
||||
Vcpus: p.VCPUs,
|
||||
MemoryMb: p.MemoryMB,
|
||||
TimeoutSec: p.TimeoutSec,
|
||||
DiskSizeMb: p.DiskSizeMB,
|
||||
TemplateID: templateID,
|
||||
TemplateTeamID: templateTeamID,
|
||||
}); err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("insert sandbox: %w", err)
|
||||
}
|
||||
|
||||
resp, err := s.Agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{
|
||||
SandboxId: sandboxID,
|
||||
resp, err := agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Template: p.Template,
|
||||
TeamId: id.UUIDString(templateTeamID),
|
||||
TemplateId: id.UUIDString(templateID),
|
||||
Vcpus: p.VCPUs,
|
||||
MemoryMb: p.MemoryMB,
|
||||
TimeoutSec: p.TimeoutSec,
|
||||
DiskSizeMb: p.DiskSizeMB,
|
||||
}))
|
||||
if err != nil {
|
||||
if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "error",
|
||||
}); dbErr != nil {
|
||||
slog.Warn("failed to update sandbox status to error", "id", sandboxID, "error", dbErr)
|
||||
slog.Warn("failed to update sandbox status to error", "id", sandboxIDStr, "error", dbErr)
|
||||
}
|
||||
return db.Sandbox{}, fmt.Errorf("agent create: %w", err)
|
||||
}
|
||||
@ -107,17 +176,17 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.
|
||||
}
|
||||
|
||||
// List returns active sandboxes (excludes stopped/error) belonging to the given team.
|
||||
func (s *SandboxService) List(ctx context.Context, teamID string) ([]db.Sandbox, error) {
|
||||
func (s *SandboxService) List(ctx context.Context, teamID pgtype.UUID) ([]db.Sandbox, error) {
|
||||
return s.DB.ListSandboxesByTeam(ctx, teamID)
|
||||
}
|
||||
|
||||
// Get returns a single sandbox by ID, scoped to the given team.
|
||||
func (s *SandboxService) Get(ctx context.Context, sandboxID, teamID string) (db.Sandbox, error) {
|
||||
func (s *SandboxService) Get(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
|
||||
return s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
}
|
||||
|
||||
// Pause snapshots and freezes a running sandbox to disk.
|
||||
func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID string) (db.Sandbox, error) {
|
||||
func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
|
||||
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
|
||||
@ -126,23 +195,45 @@ func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID string) (d
|
||||
return db.Sandbox{}, fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
|
||||
}
|
||||
|
||||
if _, err := s.Agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{
|
||||
SandboxId: sandboxID,
|
||||
agent, _, err := s.agentForSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
return db.Sandbox{}, err
|
||||
}
|
||||
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
|
||||
// Pre-mark as "paused" in DB before the RPC so the reconciler does not
|
||||
// mark the sandbox "stopped" while the host agent processes the pause.
|
||||
if _, err := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "paused",
|
||||
}); err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("pre-mark paused: %w", err)
|
||||
}
|
||||
|
||||
// Flush all metrics tiers before pausing so data survives in DB.
|
||||
s.flushAndPersistMetrics(ctx, agent, sandboxID, true)
|
||||
|
||||
if _, err := agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
})); err != nil {
|
||||
// Revert status on failure.
|
||||
if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "running",
|
||||
}); dbErr != nil {
|
||||
slog.Warn("failed to revert sandbox status after pause error", "sandbox_id", sandboxIDStr, "error", dbErr)
|
||||
}
|
||||
return db.Sandbox{}, fmt.Errorf("agent pause: %w", err)
|
||||
}
|
||||
|
||||
sb, err = s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "paused",
|
||||
})
|
||||
sb, err = s.DB.GetSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("update status: %w", err)
|
||||
return db.Sandbox{}, fmt.Errorf("get sandbox after pause: %w", err)
|
||||
}
|
||||
return sb, nil
|
||||
}
|
||||
|
||||
// Resume restores a paused sandbox from snapshot.
|
||||
func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) (db.Sandbox, error) {
|
||||
func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
|
||||
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
|
||||
@ -151,8 +242,15 @@ func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) (
|
||||
return db.Sandbox{}, fmt.Errorf("sandbox is not paused (status: %s)", sb.Status)
|
||||
}
|
||||
|
||||
resp, err := s.Agent.ResumeSandbox(ctx, connect.NewRequest(&pb.ResumeSandboxRequest{
|
||||
SandboxId: sandboxID,
|
||||
agent, _, err := s.agentForSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
return db.Sandbox{}, err
|
||||
}
|
||||
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
|
||||
resp, err := agent.ResumeSandbox(ctx, connect.NewRequest(&pb.ResumeSandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
TimeoutSec: sb.TimeoutSec,
|
||||
}))
|
||||
if err != nil {
|
||||
@ -176,18 +274,41 @@ func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) (
|
||||
}
|
||||
|
||||
// Destroy stops a sandbox and marks it as stopped.
|
||||
func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string) error {
|
||||
if _, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID}); err != nil {
|
||||
func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID pgtype.UUID) error {
|
||||
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
if err != nil {
|
||||
return fmt.Errorf("sandbox not found: %w", err)
|
||||
}
|
||||
|
||||
agent, _, err := s.agentForSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
|
||||
// If running, flush 24h tier metrics for analytics before destroying.
|
||||
if sb.Status == "running" {
|
||||
s.flushAndPersistMetrics(ctx, agent, sandboxID, false)
|
||||
}
|
||||
|
||||
// Destroy on host agent. A not-found response is fine — sandbox is already gone.
|
||||
if _, err := s.Agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
||||
SandboxId: sandboxID,
|
||||
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
|
||||
return fmt.Errorf("agent destroy: %w", err)
|
||||
}
|
||||
|
||||
// For a paused sandbox, only keep 24h tier; remove the finer-grained tiers.
|
||||
if sb.Status == "paused" {
|
||||
_ = s.DB.DeleteSandboxMetricPointsByTier(ctx, db.DeleteSandboxMetricPointsByTierParams{
|
||||
SandboxID: sandboxID, Tier: "10m",
|
||||
})
|
||||
_ = s.DB.DeleteSandboxMetricPointsByTier(ctx, db.DeleteSandboxMetricPointsByTierParams{
|
||||
SandboxID: sandboxID, Tier: "2h",
|
||||
})
|
||||
}
|
||||
|
||||
if _, err := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "stopped",
|
||||
}); err != nil {
|
||||
@ -196,8 +317,45 @@ func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string)
|
||||
return nil
|
||||
}
|
||||
|
||||
// flushAndPersistMetrics calls FlushSandboxMetrics on the agent and stores
|
||||
// the returned data to DB. If allTiers is true, all three tiers are saved;
|
||||
// otherwise only the 24h tier (for post-destroy analytics).
|
||||
func (s *SandboxService) flushAndPersistMetrics(ctx context.Context, agent hostagentClient, sandboxID pgtype.UUID, allTiers bool) {
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
resp, err := agent.FlushSandboxMetrics(ctx, connect.NewRequest(&pb.FlushSandboxMetricsRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
}))
|
||||
if err != nil {
|
||||
slog.Warn("flush metrics failed (best-effort)", "sandbox_id", sandboxIDStr, "error", err)
|
||||
return
|
||||
}
|
||||
msg := resp.Msg
|
||||
|
||||
if allTiers {
|
||||
s.persistMetricPoints(ctx, sandboxID, "10m", msg.Points_10M)
|
||||
s.persistMetricPoints(ctx, sandboxID, "2h", msg.Points_2H)
|
||||
}
|
||||
s.persistMetricPoints(ctx, sandboxID, "24h", msg.Points_24H)
|
||||
}
|
||||
|
||||
func (s *SandboxService) persistMetricPoints(ctx context.Context, sandboxID pgtype.UUID, tier string, points []*pb.MetricPoint) {
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
for _, p := range points {
|
||||
if err := s.DB.InsertSandboxMetricPoint(ctx, db.InsertSandboxMetricPointParams{
|
||||
SandboxID: sandboxID,
|
||||
Tier: tier,
|
||||
Ts: p.TimestampUnix,
|
||||
CpuPct: p.CpuPct,
|
||||
MemBytes: p.MemBytes,
|
||||
DiskBytes: p.DiskBytes,
|
||||
}); err != nil {
|
||||
slog.Warn("persist metric point failed", "sandbox_id", sandboxIDStr, "tier", tier, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ping resets the inactivity timer for a running sandbox.
|
||||
func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) error {
|
||||
func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID pgtype.UUID) error {
|
||||
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
if err != nil {
|
||||
return fmt.Errorf("sandbox not found: %w", err)
|
||||
@ -206,8 +364,15 @@ func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) err
|
||||
return fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
|
||||
}
|
||||
|
||||
if _, err := s.Agent.PingSandbox(ctx, connect.NewRequest(&pb.PingSandboxRequest{
|
||||
SandboxId: sandboxID,
|
||||
agent, _, err := s.agentForSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
|
||||
if _, err := agent.PingSandbox(ctx, connect.NewRequest(&pb.PingSandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
})); err != nil {
|
||||
return fmt.Errorf("agent ping: %w", err)
|
||||
}
|
||||
@ -219,7 +384,7 @@ func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) err
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("ping: failed to update last_active_at", "sandbox_id", sandboxID, "error", err)
|
||||
slog.Warn("ping: failed to update last_active_at", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
160
internal/service/stats.go
Normal file
160
internal/service/stats.go
Normal file
@ -0,0 +1,160 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
)
|
||||
|
||||
// TimeRange identifies a chart time window.
|
||||
type TimeRange string
|
||||
|
||||
const (
|
||||
Range5m TimeRange = "5m"
|
||||
Range1h TimeRange = "1h"
|
||||
Range6h TimeRange = "6h"
|
||||
Range24h TimeRange = "24h"
|
||||
Range30d TimeRange = "30d"
|
||||
)
|
||||
|
||||
type rangeConfig struct {
|
||||
bucketSec int // bucket width in seconds for time-series aggregation
|
||||
intervalLiteral string // PostgreSQL interval literal for the lookback window
|
||||
}
|
||||
|
||||
var rangeConfigs = map[TimeRange]rangeConfig{
|
||||
Range5m: {bucketSec: 3, intervalLiteral: "5 minutes"},
|
||||
Range1h: {bucketSec: 30, intervalLiteral: "1 hour"},
|
||||
Range6h: {bucketSec: 180, intervalLiteral: "6 hours"},
|
||||
Range24h: {bucketSec: 720, intervalLiteral: "24 hours"},
|
||||
Range30d: {bucketSec: 21600, intervalLiteral: "30 days"},
|
||||
}
|
||||
|
||||
// ValidRange returns true if r is a known TimeRange value.
|
||||
func ValidRange(r TimeRange) bool {
|
||||
_, ok := rangeConfigs[r]
|
||||
return ok
|
||||
}
|
||||
|
||||
// StatPoint is one bucketed data point in the time-series.
|
||||
type StatPoint struct {
|
||||
Bucket time.Time
|
||||
RunningCount int32
|
||||
VCPUsReserved int32
|
||||
MemoryMBReserved int32
|
||||
}
|
||||
|
||||
// CurrentStats holds the live values for a team, read directly from sandboxes.
|
||||
type CurrentStats struct {
|
||||
RunningCount int32
|
||||
VCPUsReserved int32
|
||||
MemoryMBReserved int32
|
||||
}
|
||||
|
||||
// PeakStats holds the 30-day maximum values for a team.
|
||||
type PeakStats struct {
|
||||
RunningCount int32
|
||||
VCPUs int32
|
||||
MemoryMB int32
|
||||
}
|
||||
|
||||
// StatsService computes sandbox metrics for the dashboard.
|
||||
type StatsService struct {
|
||||
DB *db.Queries
|
||||
Pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
// GetStats returns current stats, 30-day peaks, and a time-series for the
|
||||
// given team and time range. If no snapshots exist yet, zeros are returned.
|
||||
func (s *StatsService) GetStats(ctx context.Context, teamID pgtype.UUID, r TimeRange) (CurrentStats, PeakStats, []StatPoint, error) {
|
||||
cfg, ok := rangeConfigs[r]
|
||||
if !ok {
|
||||
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("unknown range: %s", r)
|
||||
}
|
||||
|
||||
// Current live values — read directly from sandboxes so we always reflect
|
||||
// the true state even when no capsules are running.
|
||||
cur, err := s.DB.GetLiveMetrics(ctx, teamID)
|
||||
if err != nil {
|
||||
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("get live metrics: %w", err)
|
||||
}
|
||||
current := CurrentStats{
|
||||
RunningCount: cur.RunningCount,
|
||||
VCPUsReserved: cur.VcpusReserved,
|
||||
MemoryMBReserved: cur.MemoryMbReserved,
|
||||
}
|
||||
|
||||
// 30-day peaks.
|
||||
var peaks PeakStats
|
||||
pk, err := s.DB.GetPeakMetrics(ctx, teamID)
|
||||
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
|
||||
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("get peak metrics: %w", err)
|
||||
}
|
||||
if err == nil {
|
||||
peaks = PeakStats{
|
||||
RunningCount: pk.PeakRunningCount,
|
||||
VCPUs: pk.PeakVcpus,
|
||||
MemoryMB: pk.PeakMemoryMb,
|
||||
}
|
||||
}
|
||||
|
||||
// Time-series — dynamic bucket width, executed via pgx directly.
|
||||
series, err := s.queryTimeSeries(ctx, teamID, cfg)
|
||||
if err != nil {
|
||||
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("get time series: %w", err)
|
||||
}
|
||||
|
||||
return current, peaks, series, nil
|
||||
}
|
||||
|
||||
// timeSeriesSQL uses an epoch-floor trick to bucket rows by an arbitrary
|
||||
// integer number of seconds without requiring TimescaleDB.
|
||||
//
|
||||
// MAX is used instead of AVG so that short-lived running states are not
|
||||
// averaged down to zero within a bucket. For capacity metrics the peak
|
||||
// value in each bucket is what matters — AVG with ::INTEGER rounding
|
||||
// caused running_count, vcpus, and memory to become inconsistent with
|
||||
// each other (e.g. running=0 but vcpus=1).
|
||||
//
|
||||
// $1 = bucket width in seconds (integer)
|
||||
// $2 = team_id
|
||||
// $3 = lookback interval literal (e.g. '1 hour')
|
||||
const timeSeriesSQL = `
|
||||
SELECT
|
||||
to_timestamp(floor(extract(epoch FROM sampled_at) / $1) * $1) AS bucket,
|
||||
MAX(running_count) AS running_count,
|
||||
MAX(vcpus_reserved) AS vcpus_reserved,
|
||||
MAX(memory_mb_reserved) AS memory_mb_reserved
|
||||
FROM sandbox_metrics_snapshots
|
||||
WHERE team_id = $2
|
||||
AND sampled_at >= NOW() - $3::INTERVAL
|
||||
GROUP BY bucket
|
||||
ORDER BY bucket ASC
|
||||
`
|
||||
|
||||
func (s *StatsService) queryTimeSeries(ctx context.Context, teamID pgtype.UUID, cfg rangeConfig) ([]StatPoint, error) {
|
||||
rows, err := s.Pool.Query(ctx, timeSeriesSQL, cfg.bucketSec, teamID, cfg.intervalLiteral)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var points []StatPoint
|
||||
for rows.Next() {
|
||||
var p StatPoint
|
||||
var bucket time.Time
|
||||
if err := rows.Scan(&bucket, &p.RunningCount, &p.VCPUsReserved, &p.MemoryMBReserved); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.Bucket = bucket
|
||||
points = append(points, p)
|
||||
}
|
||||
return points, rows.Err()
|
||||
}
|
||||
443
internal/service/team.go
Normal file
443
internal/service/team.go
Normal file
@ -0,0 +1,443 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
var teamNameRE = regexp.MustCompile(`^[A-Za-z0-9 _\-@']{1,128}$`)
|
||||
|
||||
// TeamService provides team management operations.
|
||||
type TeamService struct {
|
||||
DB *db.Queries
|
||||
Pool *pgxpool.Pool
|
||||
HostPool *lifecycle.HostClientPool
|
||||
}
|
||||
|
||||
// TeamWithRole pairs a team with the calling user's role in it.
|
||||
type TeamWithRole struct {
|
||||
db.Team
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
// MemberInfo is a team member with resolved user details.
|
||||
type MemberInfo struct {
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
JoinedAt time.Time `json:"joined_at"`
|
||||
}
|
||||
|
||||
// callerRole fetches the calling user's role in the given team from DB.
|
||||
// Returns an error wrapping "forbidden" if the caller is not a member.
|
||||
func (s *TeamService) callerRole(ctx context.Context, teamID, callerUserID pgtype.UUID) (string, error) {
|
||||
m, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: callerUserID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return "", fmt.Errorf("forbidden: not a member of this team")
|
||||
}
|
||||
return "", fmt.Errorf("get membership: %w", err)
|
||||
}
|
||||
return m.Role, nil
|
||||
}
|
||||
|
||||
// requireAdmin returns an error if the caller is not an admin or owner.
|
||||
func requireAdmin(role string) error {
|
||||
if role != "owner" && role != "admin" {
|
||||
return fmt.Errorf("forbidden: admin or owner role required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTeam returns the team by ID. Returns an error if the team is deleted or not found.
|
||||
func (s *TeamService) GetTeam(ctx context.Context, teamID pgtype.UUID) (db.Team, error) {
|
||||
team, err := s.DB.GetTeam(ctx, teamID)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return db.Team{}, fmt.Errorf("team not found")
|
||||
}
|
||||
return db.Team{}, fmt.Errorf("get team: %w", err)
|
||||
}
|
||||
if team.DeletedAt.Valid {
|
||||
return db.Team{}, fmt.Errorf("team not found")
|
||||
}
|
||||
return team, nil
|
||||
}
|
||||
|
||||
// ListTeamsForUser returns all active teams the user belongs to, with their role in each.
|
||||
func (s *TeamService) ListTeamsForUser(ctx context.Context, userID pgtype.UUID) ([]TeamWithRole, error) {
|
||||
rows, err := s.DB.GetTeamsForUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list teams: %w", err)
|
||||
}
|
||||
result := make([]TeamWithRole, len(rows))
|
||||
for i, r := range rows {
|
||||
result[i] = TeamWithRole{
|
||||
Team: db.Team{ID: r.ID, Name: r.Name, CreatedAt: r.CreatedAt, IsByoc: r.IsByoc, Slug: r.Slug, DeletedAt: r.DeletedAt},
|
||||
Role: r.Role,
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// CreateTeam creates a new team owned by the given user.
|
||||
func (s *TeamService) CreateTeam(ctx context.Context, ownerUserID pgtype.UUID, name string) (TeamWithRole, error) {
|
||||
if !teamNameRE.MatchString(name) {
|
||||
return TeamWithRole{}, fmt.Errorf("invalid team name: must be 1-128 characters, A-Z a-z 0-9 space _")
|
||||
}
|
||||
|
||||
tx, err := s.Pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return TeamWithRole{}, fmt.Errorf("begin tx: %w", err)
|
||||
}
|
||||
defer tx.Rollback(ctx) //nolint:errcheck
|
||||
|
||||
qtx := s.DB.WithTx(tx)
|
||||
|
||||
teamID := id.NewTeamID()
|
||||
team, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
|
||||
ID: teamID,
|
||||
Name: name,
|
||||
Slug: id.NewTeamSlug(),
|
||||
})
|
||||
if err != nil {
|
||||
return TeamWithRole{}, fmt.Errorf("insert team: %w", err)
|
||||
}
|
||||
|
||||
if err := qtx.InsertTeamMember(ctx, db.InsertTeamMemberParams{
|
||||
UserID: ownerUserID,
|
||||
TeamID: teamID,
|
||||
IsDefault: false,
|
||||
Role: "owner",
|
||||
}); err != nil {
|
||||
return TeamWithRole{}, fmt.Errorf("insert owner: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return TeamWithRole{}, fmt.Errorf("commit: %w", err)
|
||||
}
|
||||
|
||||
return TeamWithRole{Team: team, Role: "owner"}, nil
|
||||
}
|
||||
|
||||
// RenameTeam updates the team name. Caller must be admin or owner (verified from DB).
|
||||
func (s *TeamService) RenameTeam(ctx context.Context, teamID, callerUserID pgtype.UUID, newName string) error {
|
||||
if !teamNameRE.MatchString(newName) {
|
||||
return fmt.Errorf("invalid team name: must be 1-128 characters, A-Z a-z 0-9 space _")
|
||||
}
|
||||
|
||||
role, err := s.callerRole(ctx, teamID, callerUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := requireAdmin(role); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.DB.UpdateTeamName(ctx, db.UpdateTeamNameParams{ID: teamID, Name: newName}); err != nil {
|
||||
return fmt.Errorf("update name: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteTeam soft-deletes the team and destroys all running/paused/starting sandboxes.
|
||||
// Caller must be owner (verified from DB). All DB records (sandboxes, keys, templates)
|
||||
// are preserved; only the team's deleted_at is set and active VMs are stopped.
|
||||
func (s *TeamService) DeleteTeam(ctx context.Context, teamID, callerUserID pgtype.UUID) error {
|
||||
role, err := s.callerRole(ctx, teamID, callerUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if role != "owner" {
|
||||
return fmt.Errorf("forbidden: only the owner can delete a team")
|
||||
}
|
||||
|
||||
// Collect active sandboxes and stop them.
|
||||
sandboxes, err := s.DB.ListActiveSandboxesByTeam(ctx, teamID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list active sandboxes: %w", err)
|
||||
}
|
||||
|
||||
var stopIDs []pgtype.UUID
|
||||
for _, sb := range sandboxes {
|
||||
host, hostErr := s.DB.GetHost(ctx, sb.HostID)
|
||||
if hostErr == nil {
|
||||
agent, agentErr := s.HostPool.GetForHost(host)
|
||||
if agentErr == nil {
|
||||
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
||||
SandboxId: id.FormatSandboxID(sb.ID),
|
||||
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
|
||||
slog.Warn("team delete: failed to destroy sandbox", "sandbox_id", id.FormatSandboxID(sb.ID), "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
stopIDs = append(stopIDs, sb.ID)
|
||||
}
|
||||
|
||||
if len(stopIDs) > 0 {
|
||||
if err := s.DB.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: stopIDs,
|
||||
Status: "stopped",
|
||||
}); err != nil {
|
||||
// Do not proceed to soft-delete if sandbox statuses couldn't be updated,
|
||||
// as that would leave orphaned "running" records for a deleted team.
|
||||
return fmt.Errorf("update sandbox statuses: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up team-owned templates from all hosts in the background.
|
||||
go s.cleanupTeamTemplates(context.Background(), teamID)
|
||||
|
||||
if err := s.DB.SoftDeleteTeam(ctx, teamID); err != nil {
|
||||
return fmt.Errorf("soft delete team: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupTeamTemplates deletes all template files for a team from all online hosts,
|
||||
// then removes the DB records. Called asynchronously during team deletion.
|
||||
func (s *TeamService) cleanupTeamTemplates(ctx context.Context, teamID pgtype.UUID) {
|
||||
templates, err := s.DB.ListTemplatesByTeamOnly(ctx, teamID)
|
||||
if err != nil {
|
||||
slog.Warn("team delete: failed to list templates for cleanup", "team_id", id.FormatTeamID(teamID), "error", err)
|
||||
return
|
||||
}
|
||||
if len(templates) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
hosts, err := s.DB.ListActiveHosts(ctx)
|
||||
if err != nil {
|
||||
slog.Warn("team delete: failed to list hosts for template cleanup", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, tmpl := range templates {
|
||||
for _, host := range hosts {
|
||||
if host.Status != "online" {
|
||||
continue
|
||||
}
|
||||
agent, err := s.HostPool.GetForHost(host)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
|
||||
TeamId: id.UUIDString(tmpl.TeamID),
|
||||
TemplateId: id.UUIDString(tmpl.ID),
|
||||
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
|
||||
slog.Warn("team delete: failed to delete template on host",
|
||||
"host_id", id.FormatHostID(host.ID),
|
||||
"template", tmpl.Name,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove DB records.
|
||||
if err := s.DB.DeleteTemplatesByTeam(ctx, teamID); err != nil {
|
||||
slog.Warn("team delete: failed to delete template records", "team_id", id.FormatTeamID(teamID), "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetMembers returns all members of the team with their emails and roles.
|
||||
func (s *TeamService) GetMembers(ctx context.Context, teamID pgtype.UUID) ([]MemberInfo, error) {
|
||||
rows, err := s.DB.GetTeamMembers(ctx, teamID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get members: %w", err)
|
||||
}
|
||||
members := make([]MemberInfo, len(rows))
|
||||
for i, r := range rows {
|
||||
var joinedAt time.Time
|
||||
if r.JoinedAt.Valid {
|
||||
joinedAt = r.JoinedAt.Time
|
||||
}
|
||||
members[i] = MemberInfo{
|
||||
UserID: id.FormatUserID(r.ID),
|
||||
Name: r.Name,
|
||||
Email: r.Email,
|
||||
Role: r.Role,
|
||||
JoinedAt: joinedAt,
|
||||
}
|
||||
}
|
||||
return members, nil
|
||||
}
|
||||
|
||||
// AddMember adds an existing user (looked up by email) to the team as a member.
|
||||
// Caller must be admin or owner (verified from DB).
|
||||
func (s *TeamService) AddMember(ctx context.Context, teamID, callerUserID pgtype.UUID, email string) (MemberInfo, error) {
|
||||
role, err := s.callerRole(ctx, teamID, callerUserID)
|
||||
if err != nil {
|
||||
return MemberInfo{}, err
|
||||
}
|
||||
if err := requireAdmin(role); err != nil {
|
||||
return MemberInfo{}, err
|
||||
}
|
||||
|
||||
target, err := s.DB.GetUserByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return MemberInfo{}, fmt.Errorf("user not found: no account with that email")
|
||||
}
|
||||
return MemberInfo{}, fmt.Errorf("look up user: %w", err)
|
||||
}
|
||||
|
||||
// Check if already a member.
|
||||
_, memberCheckErr := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: target.ID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if memberCheckErr == nil {
|
||||
return MemberInfo{}, fmt.Errorf("invalid: user is already a member of this team")
|
||||
} else if memberCheckErr != pgx.ErrNoRows {
|
||||
return MemberInfo{}, fmt.Errorf("check membership: %w", memberCheckErr)
|
||||
}
|
||||
|
||||
if err := s.DB.InsertTeamMember(ctx, db.InsertTeamMemberParams{
|
||||
UserID: target.ID,
|
||||
TeamID: teamID,
|
||||
IsDefault: false,
|
||||
Role: "member",
|
||||
}); err != nil {
|
||||
return MemberInfo{}, fmt.Errorf("insert member: %w", err)
|
||||
}
|
||||
|
||||
return MemberInfo{UserID: id.FormatUserID(target.ID), Name: target.Name, Email: target.Email, Role: "member"}, nil
|
||||
}
|
||||
|
||||
// RemoveMember removes a user from the team.
|
||||
// Caller must be admin or owner (verified from DB). Owner cannot be removed.
|
||||
func (s *TeamService) RemoveMember(ctx context.Context, teamID, callerUserID, targetUserID pgtype.UUID) error {
|
||||
callerRole, err := s.callerRole(ctx, teamID, callerUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := requireAdmin(callerRole); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
targetMembership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: targetUserID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return fmt.Errorf("not found: user is not a member of this team")
|
||||
}
|
||||
return fmt.Errorf("get target membership: %w", err)
|
||||
}
|
||||
|
||||
if targetMembership.Role == "owner" {
|
||||
return fmt.Errorf("forbidden: the owner cannot be removed from the team")
|
||||
}
|
||||
|
||||
if err := s.DB.DeleteTeamMember(ctx, db.DeleteTeamMemberParams{
|
||||
TeamID: teamID,
|
||||
UserID: targetUserID,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("delete member: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateMemberRole changes a member's role to admin or member.
|
||||
// Caller must be admin or owner (verified from DB). Owner's role cannot be changed.
|
||||
// Valid target roles: "admin", "member".
|
||||
func (s *TeamService) UpdateMemberRole(ctx context.Context, teamID, callerUserID, targetUserID pgtype.UUID, newRole string) error {
|
||||
if newRole != "admin" && newRole != "member" {
|
||||
return fmt.Errorf("invalid: role must be admin or member")
|
||||
}
|
||||
|
||||
callerRole, err := s.callerRole(ctx, teamID, callerUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := requireAdmin(callerRole); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
targetMembership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: targetUserID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return fmt.Errorf("not found: user is not a member of this team")
|
||||
}
|
||||
return fmt.Errorf("get target membership: %w", err)
|
||||
}
|
||||
|
||||
if targetMembership.Role == "owner" {
|
||||
return fmt.Errorf("forbidden: the owner's role cannot be changed")
|
||||
}
|
||||
|
||||
if err := s.DB.UpdateMemberRole(ctx, db.UpdateMemberRoleParams{
|
||||
TeamID: teamID,
|
||||
UserID: targetUserID,
|
||||
Role: newRole,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("update role: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LeaveTeam removes the calling user from the team.
|
||||
// The owner cannot leave; they must delete the team instead.
|
||||
func (s *TeamService) LeaveTeam(ctx context.Context, teamID, callerUserID pgtype.UUID) error {
|
||||
role, err := s.callerRole(ctx, teamID, callerUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if role == "owner" {
|
||||
return fmt.Errorf("forbidden: the owner cannot leave the team; delete the team instead")
|
||||
}
|
||||
|
||||
if err := s.DB.DeleteTeamMember(ctx, db.DeleteTeamMemberParams{
|
||||
TeamID: teamID,
|
||||
UserID: callerUserID,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("leave team: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetBYOC enables the BYOC feature flag for a team. Once enabled, BYOC cannot
|
||||
// be disabled — it is a one-way transition.
|
||||
// Admin-only — the caller must verify admin status before invoking this.
|
||||
func (s *TeamService) SetBYOC(ctx context.Context, teamID pgtype.UUID, enabled bool) error {
|
||||
team, err := s.DB.GetTeam(ctx, teamID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("team not found: %w", err)
|
||||
}
|
||||
if team.DeletedAt.Valid {
|
||||
return fmt.Errorf("team not found")
|
||||
}
|
||||
if !enabled {
|
||||
return fmt.Errorf("invalid request: BYOC cannot be disabled once enabled")
|
||||
}
|
||||
if team.IsByoc {
|
||||
// Already enabled — idempotent, no-op.
|
||||
return nil
|
||||
}
|
||||
if err := s.DB.SetTeamBYOC(ctx, db.SetTeamBYOCParams{ID: teamID, IsByoc: true}); err != nil {
|
||||
return fmt.Errorf("set byoc: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -3,6 +3,8 @@ package service
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
)
|
||||
|
||||
@ -14,7 +16,7 @@ type TemplateService struct {
|
||||
|
||||
// List returns all templates belonging to the given team. If typeFilter is
|
||||
// non-empty, only templates of that type ("base" or "snapshot") are returned.
|
||||
func (s *TemplateService) List(ctx context.Context, teamID, typeFilter string) ([]db.Template, error) {
|
||||
func (s *TemplateService) List(ctx context.Context, teamID pgtype.UUID, typeFilter string) ([]db.Template, error) {
|
||||
if typeFilter != "" {
|
||||
return s.DB.ListTemplatesByTeamAndType(ctx, db.ListTemplatesByTeamAndTypeParams{
|
||||
TeamID: teamID,
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
package snapshot
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
@ -172,6 +173,99 @@ func ProcessMemfileWithParent(memfilePath, diffPath, headerPath string, parentHe
|
||||
return header, nil
|
||||
}
|
||||
|
||||
// MergeDiffs consolidates multiple generation diff files into a single diff
|
||||
// file and resets the generation counter to 0. This is a pure file-level
|
||||
// operation — no Firecracker involvement.
|
||||
//
|
||||
// It reads each non-nil block from the appropriate diff file (as mapped by
|
||||
// the header), writes them all sequentially into a single new diff file,
|
||||
// and produces a fresh header pointing only at that file.
|
||||
//
|
||||
// diffFiles maps build ID (string) → open file path for each generation's diff.
|
||||
func MergeDiffs(header *Header, diffFiles map[string]string, mergedDiffPath, headerPath string) (*Header, error) {
|
||||
blockSize := int64(header.Metadata.BlockSize)
|
||||
mergedBuildID := uuid.New()
|
||||
|
||||
// Open all source diff files.
|
||||
sources := make(map[string]*os.File, len(diffFiles))
|
||||
for id, path := range diffFiles {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
// Close already opened files.
|
||||
for _, sf := range sources {
|
||||
sf.Close()
|
||||
}
|
||||
return nil, fmt.Errorf("open diff file for build %s: %w", id, err)
|
||||
}
|
||||
sources[id] = f
|
||||
}
|
||||
defer func() {
|
||||
for _, f := range sources {
|
||||
f.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
dst, err := os.Create(mergedDiffPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create merged diff file: %w", err)
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
totalBlocks := TotalBlocks(int64(header.Metadata.Size), blockSize)
|
||||
dirty := make([]bool, totalBlocks)
|
||||
empty := make([]bool, totalBlocks)
|
||||
buf := make([]byte, blockSize)
|
||||
|
||||
for i := int64(0); i < totalBlocks; i++ {
|
||||
offset := i * blockSize
|
||||
mappedOffset, _, buildID, err := header.GetShiftedMapping(context.Background(), offset)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("lookup block %d: %w", i, err)
|
||||
}
|
||||
|
||||
if *buildID == uuid.Nil {
|
||||
empty[i] = true
|
||||
continue
|
||||
}
|
||||
|
||||
src, ok := sources[buildID.String()]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no diff file for build %s (block %d)", buildID, i)
|
||||
}
|
||||
|
||||
if _, err := src.ReadAt(buf, mappedOffset); err != nil {
|
||||
return nil, fmt.Errorf("read block %d from build %s: %w", i, buildID, err)
|
||||
}
|
||||
|
||||
dirty[i] = true
|
||||
if _, err := dst.Write(buf); err != nil {
|
||||
return nil, fmt.Errorf("write merged block %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Build fresh header with generation 0.
|
||||
dirtyMappings := CreateMapping(mergedBuildID, dirty, blockSize)
|
||||
emptyMappings := CreateMapping(uuid.Nil, empty, blockSize)
|
||||
merged := MergeMappings(dirtyMappings, emptyMappings)
|
||||
normalized := NormalizeMappings(merged)
|
||||
|
||||
metadata := NewMetadata(mergedBuildID, uint64(blockSize), header.Metadata.Size)
|
||||
newHeader, err := NewHeader(metadata, normalized)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create merged header: %w", err)
|
||||
}
|
||||
|
||||
headerData, err := Serialize(metadata, normalized)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("serialize merged header: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(headerPath, headerData, 0644); err != nil {
|
||||
return nil, fmt.Errorf("write merged header: %w", err)
|
||||
}
|
||||
|
||||
return newHeader, nil
|
||||
}
|
||||
|
||||
// isZeroBlock checks if a block is entirely zero bytes.
|
||||
func isZeroBlock(block []byte) bool {
|
||||
// Fast path: compare 8 bytes at a time.
|
||||
|
||||
@ -11,7 +11,7 @@ func TestSafeName(t *testing.T) {
|
||||
{"simple", "minimal", false},
|
||||
{"with-dash", "template-abc123", false},
|
||||
{"with-dot", "my-snapshot.v2", false},
|
||||
{"sandbox-id", "sb-12345678", false},
|
||||
{"sandbox-id", "cl-12345678", false},
|
||||
{"single-char", "a", false},
|
||||
{"numbers", "123", false},
|
||||
{"max-length", "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz01", false},
|
||||
|
||||
@ -4,9 +4,13 @@ import "fmt"
|
||||
|
||||
// VMConfig holds the configuration for creating a Firecracker microVM.
|
||||
type VMConfig struct {
|
||||
// SandboxID is the unique identifier for this sandbox (e.g., "sb-a1b2c3d4").
|
||||
// SandboxID is the unique identifier for this sandbox (e.g., "cl-a1b2c3d4").
|
||||
SandboxID string
|
||||
|
||||
// TemplateID is the template UUID string used to populate MMDS metadata
|
||||
// so that envd can read WRENN_TEMPLATE_ID from inside the guest.
|
||||
TemplateID string
|
||||
|
||||
// KernelPath is the path to the uncompressed Linux kernel (vmlinux).
|
||||
KernelPath string
|
||||
|
||||
@ -91,7 +95,7 @@ func (c *VMConfig) kernelArgs() string {
|
||||
)
|
||||
|
||||
return fmt.Sprintf(
|
||||
"console=ttyS0 reboot=k panic=1 pci=off quiet loglevel=1 init=%s %s",
|
||||
"console=ttyS0 reboot=k panic=1 pci=off quiet loglevel=1 clocksource=kvm-clock init=%s %s",
|
||||
c.InitPath, ipArg,
|
||||
)
|
||||
}
|
||||
|
||||
@ -101,6 +101,31 @@ func (c *fcClient) setMachineConfig(ctx context.Context, vcpus, memMB int) error
|
||||
})
|
||||
}
|
||||
|
||||
// setMMDSConfig enables MMDS V2 token-based access on the given network interface.
|
||||
// Must be called before startVM.
|
||||
func (c *fcClient) setMMDSConfig(ctx context.Context, ifaceID string) error {
|
||||
return c.do(ctx, http.MethodPut, "/mmds/config", map[string]any{
|
||||
"version": "V2",
|
||||
"network_interfaces": []string{ifaceID},
|
||||
})
|
||||
}
|
||||
|
||||
// mmdsMetadata is the metadata payload written to the Firecracker MMDS store.
|
||||
// envd reads this via PollForMMDSOpts to populate WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID.
|
||||
type mmdsMetadata struct {
|
||||
SandboxID string `json:"instanceID"`
|
||||
TemplateID string `json:"envID"`
|
||||
}
|
||||
|
||||
// setMMDS writes sandbox metadata to the Firecracker MMDS store.
|
||||
// Can be called after the VM has started.
|
||||
func (c *fcClient) setMMDS(ctx context.Context, sandboxID, templateID string) error {
|
||||
return c.do(ctx, http.MethodPut, "/mmds", mmdsMetadata{
|
||||
SandboxID: sandboxID,
|
||||
TemplateID: templateID,
|
||||
})
|
||||
}
|
||||
|
||||
// startVM issues the InstanceStart action.
|
||||
func (c *fcClient) startVM(ctx context.Context) error {
|
||||
return c.do(ctx, http.MethodPut, "/actions", map[string]string{
|
||||
|
||||
@ -71,6 +71,13 @@ func (m *Manager) Create(ctx context.Context, cfg VMConfig) (*VM, error) {
|
||||
return nil, fmt.Errorf("start VM: %w", err)
|
||||
}
|
||||
|
||||
// Step 5: Push sandbox metadata into MMDS so envd can read
|
||||
// WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID from inside the guest.
|
||||
if err := client.setMMDS(ctx, cfg.SandboxID, cfg.TemplateID); err != nil {
|
||||
_ = proc.stop()
|
||||
return nil, fmt.Errorf("set MMDS metadata: %w", err)
|
||||
}
|
||||
|
||||
vm := &VM{
|
||||
Config: cfg,
|
||||
process: proc,
|
||||
@ -108,6 +115,12 @@ func configureVM(ctx context.Context, client *fcClient, cfg *VMConfig) error {
|
||||
return fmt.Errorf("set machine config: %w", err)
|
||||
}
|
||||
|
||||
// MMDS config — enable V2 token access on eth0 so that envd can read
|
||||
// WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID from inside the guest.
|
||||
if err := client.setMMDSConfig(ctx, "eth0"); err != nil {
|
||||
return fmt.Errorf("set MMDS config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -238,6 +251,12 @@ func (m *Manager) CreateFromSnapshot(ctx context.Context, cfg VMConfig, snapPath
|
||||
return nil, fmt.Errorf("resume VM: %w", err)
|
||||
}
|
||||
|
||||
// Step 5: Push sandbox metadata into MMDS.
|
||||
if err := client.setMMDS(ctx, cfg.SandboxID, cfg.TemplateID); err != nil {
|
||||
_ = proc.stop()
|
||||
return nil, fmt.Errorf("set MMDS metadata: %w", err)
|
||||
}
|
||||
|
||||
vm := &VM{
|
||||
Config: cfg,
|
||||
process: proc,
|
||||
@ -250,6 +269,12 @@ func (m *Manager) CreateFromSnapshot(ctx context.Context, cfg VMConfig, snapPath
|
||||
return vm, nil
|
||||
}
|
||||
|
||||
// PID returns the process ID of the unshare wrapper process.
|
||||
// The actual Firecracker process is a direct child of this PID.
|
||||
func (v *VM) PID() int {
|
||||
return v.process.cmd.Process.Pid
|
||||
}
|
||||
|
||||
// Get returns a running VM by sandbox ID.
|
||||
func (m *Manager) Get(sandboxID string) (*VM, bool) {
|
||||
vm, ok := m.vms[sandboxID]
|
||||
|
||||
Reference in New Issue
Block a user