forked from wrenn/wrenn
v0.1.0 (#17)
This commit is contained in:
@ -6,8 +6,8 @@ import (
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
|
||||
@ -17,10 +17,9 @@ import (
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
)
|
||||
|
||||
// Sentinel errors returned by proxyTarget, used to map to HTTP status codes
|
||||
@ -44,7 +43,7 @@ func (e errProxySandboxNotRunning) Error() string {
|
||||
return fmt.Sprintf("sandbox is not running (status: %s)", e.status)
|
||||
}
|
||||
|
||||
// proxyCacheEntry caches the resolved agent URL for a (sandbox, team) pair.
|
||||
// proxyCacheEntry caches the resolved agent URL for a sandbox.
|
||||
// The *httputil.ReverseProxy is built per-request (cheap) so the Director closure
|
||||
// can capture the correct port without the cache key needing to include it.
|
||||
type proxyCacheEntry struct {
|
||||
@ -52,23 +51,13 @@ type proxyCacheEntry struct {
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// proxyCacheKey is a fixed-size key from two UUIDs, avoids string allocation.
|
||||
type proxyCacheKey [32]byte
|
||||
|
||||
func makeProxyCacheKey(sandboxID, teamID pgtype.UUID) proxyCacheKey {
|
||||
var k proxyCacheKey
|
||||
copy(k[:16], sandboxID.Bytes[:])
|
||||
copy(k[16:], teamID.Bytes[:])
|
||||
return k
|
||||
}
|
||||
|
||||
// SandboxProxyWrapper wraps an existing HTTP handler and intercepts requests
|
||||
// whose Host header matches the {port}-{sandbox_id}.{domain} pattern. Matching
|
||||
// requests are reverse-proxied through the host agent that owns the sandbox.
|
||||
// All other requests are passed through to the inner handler.
|
||||
//
|
||||
// Authentication is via X-API-Key header only (no JWT). The API key's team
|
||||
// must own the sandbox.
|
||||
// No authentication is required — sandbox URLs are unguessable and access is
|
||||
// scoped to the sandbox ID embedded in the hostname.
|
||||
type SandboxProxyWrapper struct {
|
||||
inner http.Handler
|
||||
db *db.Queries
|
||||
@ -76,7 +65,7 @@ type SandboxProxyWrapper struct {
|
||||
transport http.RoundTripper
|
||||
|
||||
cacheMu sync.Mutex
|
||||
cache map[proxyCacheKey]proxyCacheEntry
|
||||
cache map[pgtype.UUID]proxyCacheEntry
|
||||
}
|
||||
|
||||
// NewSandboxProxyWrapper creates a new proxy wrapper.
|
||||
@ -86,19 +75,15 @@ func NewSandboxProxyWrapper(inner http.Handler, queries *db.Queries, pool *lifec
|
||||
db: queries,
|
||||
pool: pool,
|
||||
transport: pool.Transport(),
|
||||
cache: make(map[proxyCacheKey]proxyCacheEntry),
|
||||
cache: make(map[pgtype.UUID]proxyCacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
// proxyTarget looks up the cached agent URL for (sandboxID, teamID).
|
||||
// proxyTarget looks up the cached agent URL for sandboxID.
|
||||
// On a miss it queries the DB, resolves the address, and populates the cache.
|
||||
// The *httputil.ReverseProxy is built by the caller so the Director closure
|
||||
// captures the correct port without the cache key needing to include it.
|
||||
func (h *SandboxProxyWrapper) proxyTarget(ctx context.Context, sandboxID, teamID pgtype.UUID) (*url.URL, error) {
|
||||
cacheKey := makeProxyCacheKey(sandboxID, teamID)
|
||||
|
||||
func (h *SandboxProxyWrapper) proxyTarget(ctx context.Context, sandboxID pgtype.UUID) (*url.URL, error) {
|
||||
h.cacheMu.Lock()
|
||||
entry, ok := h.cache[cacheKey]
|
||||
entry, ok := h.cache[sandboxID]
|
||||
h.cacheMu.Unlock()
|
||||
|
||||
if ok && time.Now().Before(entry.expiresAt) {
|
||||
@ -106,10 +91,7 @@ func (h *SandboxProxyWrapper) proxyTarget(ctx context.Context, sandboxID, teamID
|
||||
}
|
||||
|
||||
// Cache miss or expired — query DB.
|
||||
target, err := h.db.GetSandboxProxyTarget(ctx, db.GetSandboxProxyTargetParams{
|
||||
ID: sandboxID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
target, err := h.db.GetSandboxProxyTarget(ctx, sandboxID)
|
||||
if err != nil {
|
||||
return nil, errProxySandboxNotFound
|
||||
}
|
||||
@ -126,7 +108,7 @@ func (h *SandboxProxyWrapper) proxyTarget(ctx context.Context, sandboxID, teamID
|
||||
}
|
||||
|
||||
h.cacheMu.Lock()
|
||||
h.cache[cacheKey] = proxyCacheEntry{
|
||||
h.cache[sandboxID] = proxyCacheEntry{
|
||||
agentURL: agentURL,
|
||||
expiresAt: time.Now().Add(proxyCacheTTL),
|
||||
}
|
||||
@ -135,11 +117,11 @@ func (h *SandboxProxyWrapper) proxyTarget(ctx context.Context, sandboxID, teamID
|
||||
return agentURL, nil
|
||||
}
|
||||
|
||||
// evictProxyCache removes the cached entry for a (sandbox, team) pair.
|
||||
// evictProxyCache removes the cached entry for a sandbox.
|
||||
// Called on 502 so a stopped/moved sandbox is re-resolved on the next request.
|
||||
func (h *SandboxProxyWrapper) evictProxyCache(sandboxID, teamID pgtype.UUID) {
|
||||
func (h *SandboxProxyWrapper) evictProxyCache(sandboxID pgtype.UUID) {
|
||||
h.cacheMu.Lock()
|
||||
delete(h.cache, makeProxyCacheKey(sandboxID, teamID))
|
||||
delete(h.cache, sandboxID)
|
||||
h.cacheMu.Unlock()
|
||||
}
|
||||
|
||||
@ -166,20 +148,13 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
// Authenticate: require API key or JWT, extract team ID.
|
||||
teamID, err := h.authenticateRequest(r)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid sandbox ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
agentURL, err := h.proxyTarget(r.Context(), sandboxID, teamID)
|
||||
agentURL, err := h.proxyTarget(r.Context(), sandboxID)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, errProxySandboxNotFound):
|
||||
@ -206,25 +181,9 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request)
|
||||
"port", port,
|
||||
"error", err,
|
||||
)
|
||||
h.evictProxyCache(sandboxID, teamID)
|
||||
h.evictProxyCache(sandboxID)
|
||||
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
|
||||
},
|
||||
}
|
||||
proxy.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// authenticateRequest validates the request's API key and returns the team ID.
|
||||
// Only API key authentication is supported for sandbox proxy requests (not JWT).
|
||||
func (h *SandboxProxyWrapper) authenticateRequest(r *http.Request) (pgtype.UUID, error) {
|
||||
key := r.Header.Get("X-API-Key")
|
||||
if key == "" {
|
||||
return pgtype.UUID{}, fmt.Errorf("X-API-Key header required")
|
||||
}
|
||||
|
||||
hash := auth.HashAPIKey(key)
|
||||
row, err := h.db.GetAPIKeyByHash(r.Context(), hash)
|
||||
if err != nil {
|
||||
return pgtype.UUID{}, fmt.Errorf("invalid API key")
|
||||
}
|
||||
return row.TeamID, nil
|
||||
}
|
||||
|
||||
248
internal/api/handlers_admin_capsules.go
Normal file
248
internal/api/handlers_admin_capsules.go
Normal file
@ -0,0 +1,248 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/validate"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
type adminCapsuleHandler struct {
|
||||
svc *service.SandboxService
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
audit *audit.AuditLogger
|
||||
}
|
||||
|
||||
func newAdminCapsuleHandler(svc *service.SandboxService, db *db.Queries, pool *lifecycle.HostClientPool, al *audit.AuditLogger) *adminCapsuleHandler {
|
||||
return &adminCapsuleHandler{svc: svc, db: db, pool: pool, audit: al}
|
||||
}
|
||||
|
||||
// Create handles POST /v1/admin/capsules.
|
||||
func (h *adminCapsuleHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
var req createSandboxRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sb, err := h.svc.Create(r.Context(), service.SandboxCreateParams{
|
||||
TeamID: id.PlatformTeamID,
|
||||
Template: req.Template,
|
||||
VCPUs: req.VCPUs,
|
||||
MemoryMB: req.MemoryMB,
|
||||
TimeoutSec: req.TimeoutSec,
|
||||
})
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSandboxCreate(r.Context(), ac, sb.ID, sb.Template)
|
||||
writeJSON(w, http.StatusCreated, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// List handles GET /v1/admin/capsules.
|
||||
func (h *adminCapsuleHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxes, err := h.svc.List(r.Context(), id.PlatformTeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list sandboxes")
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]sandboxResponse, len(sandboxes))
|
||||
for i, sb := range sandboxes {
|
||||
resp[i] = sandboxToResponse(sb)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Get handles GET /v1/admin/capsules/{id}.
|
||||
func (h *adminCapsuleHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.svc.Get(r.Context(), sandboxID, id.PlatformTeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// Destroy handles DELETE /v1/admin/capsules/{id}.
|
||||
func (h *adminCapsuleHandler) Destroy(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.Destroy(r.Context(), sandboxID, id.PlatformTeamID); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSandboxDestroy(r.Context(), ac, sandboxID)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
type adminSnapshotRequest struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Snapshot handles POST /v1/admin/capsules/{id}/snapshot.
|
||||
// Pauses the capsule, takes a snapshot as a platform template, then destroys the capsule.
|
||||
func (h *adminCapsuleHandler) Snapshot(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req adminSnapshotRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
req.Name = id.NewSnapshotName()
|
||||
}
|
||||
if err := validate.SafeName(req.Name); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid snapshot name: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
// Verify sandbox exists and belongs to platform team BEFORE any
|
||||
// destructive operations (template overwrite).
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: id.PlatformTeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" && sb.Status != "paused" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox must be running or paused")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if name already exists as a platform template.
|
||||
if existing, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
|
||||
// Delete old snapshot files from all hosts before removing the DB record.
|
||||
if err := deleteSnapshotBroadcast(ctx, h.db, h.pool, existing.TeamID, existing.ID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files")
|
||||
return
|
||||
}
|
||||
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: id.PlatformTeamID}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to remove existing template record")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
// Pre-mark sandbox as "paused" to prevent the reconciler from racing.
|
||||
if sb.Status == "running" {
|
||||
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "paused",
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to update sandbox status")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Use a detached context so the snapshot completes even if the client disconnects.
|
||||
snapCtx, snapCancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer snapCancel()
|
||||
|
||||
newTemplateID := id.NewTemplateID()
|
||||
|
||||
resp, err := agent.CreateSnapshot(snapCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Name: req.Name,
|
||||
TeamId: formatUUIDForRPC(id.PlatformTeamID),
|
||||
TemplateId: formatUUIDForRPC(newTemplateID),
|
||||
}))
|
||||
if err != nil {
|
||||
// Snapshot failed — revert status.
|
||||
if sb.Status == "running" {
|
||||
if _, dbErr := h.db.UpdateSandboxStatus(snapCtx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "running",
|
||||
}); dbErr != nil {
|
||||
slog.Error("failed to revert sandbox status after snapshot error", "sandbox_id", sandboxIDStr, "error", dbErr)
|
||||
}
|
||||
}
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{
|
||||
ID: newTemplateID,
|
||||
Name: req.Name,
|
||||
Type: "snapshot",
|
||||
Vcpus: sb.Vcpus,
|
||||
MemoryMb: sb.MemoryMb,
|
||||
SizeBytes: resp.Msg.SizeBytes,
|
||||
TeamID: id.PlatformTeamID,
|
||||
DefaultUser: "root",
|
||||
DefaultEnv: []byte("{}"),
|
||||
Metadata: sb.Metadata,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("failed to insert template record", "name", req.Name, "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "snapshot created but failed to record in database")
|
||||
return
|
||||
}
|
||||
|
||||
// Destroy the ephemeral capsule after successful snapshot.
|
||||
if err := h.svc.Destroy(snapCtx, sandboxID, id.PlatformTeamID); err != nil {
|
||||
slog.Error("failed to destroy capsule after snapshot", "sandbox_id", sandboxIDStr, "error", err)
|
||||
// Don't fail the response — the snapshot was created successfully.
|
||||
}
|
||||
|
||||
h.audit.LogSnapshotCreate(snapCtx, ac, req.Name)
|
||||
|
||||
if ctx.Err() != nil {
|
||||
slog.Info("snapshot created but client disconnected before response", "name", req.Name)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, templateToResponse(tmpl))
|
||||
}
|
||||
@ -6,11 +6,11 @@ import (
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/service"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
)
|
||||
|
||||
type apiKeyHandler struct {
|
||||
@ -63,7 +63,7 @@ func apiKeyWithCreatorToResponse(k db.ListAPIKeysByTeamWithCreatorRow) apiKeyRes
|
||||
Name: k.Name,
|
||||
KeyPrefix: k.KeyPrefix,
|
||||
CreatedBy: id.FormatUserID(k.CreatedBy),
|
||||
CreatorEmail: k.CreatorEmail,
|
||||
CreatorEmail: k.CreatorEmail.String,
|
||||
}
|
||||
if k.CreatedAt.Valid {
|
||||
resp.CreatedAt = k.CreatedAt.Time.Format(time.RFC3339)
|
||||
|
||||
@ -8,9 +8,9 @@ import (
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/service"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
)
|
||||
|
||||
type auditHandler struct {
|
||||
|
||||
@ -2,19 +2,32 @@ package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/email"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
const (
|
||||
activationKeyPrefix = "wrenn:activation:"
|
||||
activationTTL = 30 * time.Minute
|
||||
signupCooldown = 30 * time.Minute
|
||||
)
|
||||
|
||||
// loginTeam returns the team and role to stamp into a login JWT.
|
||||
@ -52,18 +65,89 @@ func loginTeam(ctx context.Context, q *db.Queries, userID pgtype.UUID) (db.Team,
|
||||
}, first.Role, nil
|
||||
}
|
||||
|
||||
// ensureDefaultTeam creates a default team for a user if they have none.
|
||||
// This happens on first login after activation or for edge cases where a user
|
||||
// has no teams. Returns the team, role, and whether the user was set as admin.
|
||||
func ensureDefaultTeam(ctx context.Context, qtx *db.Queries, pool *pgxpool.Pool, userID pgtype.UUID, userName string) (db.Team, string, bool, error) {
|
||||
// Try existing teams first.
|
||||
team, role, err := loginTeam(ctx, qtx, userID)
|
||||
if err == nil {
|
||||
return team, role, false, nil
|
||||
}
|
||||
if !errors.Is(err, pgx.ErrNoRows) {
|
||||
return db.Team{}, "", false, err
|
||||
}
|
||||
|
||||
// No teams — create default team in a transaction.
|
||||
tx, err := pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return db.Team{}, "", false, fmt.Errorf("begin tx: %w", err)
|
||||
}
|
||||
defer tx.Rollback(ctx) //nolint:errcheck
|
||||
|
||||
txq := qtx.WithTx(tx)
|
||||
|
||||
// First active user to have a team becomes admin.
|
||||
activeCount, err := txq.CountActiveUsers(ctx)
|
||||
if err != nil {
|
||||
return db.Team{}, "", false, fmt.Errorf("count active users: %w", err)
|
||||
}
|
||||
isFirstUser := activeCount == 1 // only this user is active
|
||||
|
||||
teamID := id.NewTeamID()
|
||||
teamRow, err := txq.InsertTeam(ctx, db.InsertTeamParams{
|
||||
ID: teamID,
|
||||
Name: userName + "'s Team",
|
||||
Slug: id.NewTeamSlug(),
|
||||
})
|
||||
if err != nil {
|
||||
return db.Team{}, "", false, fmt.Errorf("insert team: %w", err)
|
||||
}
|
||||
|
||||
if err := txq.InsertTeamMember(ctx, db.InsertTeamMemberParams{
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
IsDefault: true,
|
||||
Role: "owner",
|
||||
}); err != nil {
|
||||
return db.Team{}, "", false, fmt.Errorf("insert team member: %w", err)
|
||||
}
|
||||
|
||||
if isFirstUser {
|
||||
if err := txq.SetUserAdmin(ctx, db.SetUserAdminParams{ID: userID, IsAdmin: true}); err != nil {
|
||||
return db.Team{}, "", false, fmt.Errorf("set admin: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return db.Team{}, "", false, fmt.Errorf("commit: %w", err)
|
||||
}
|
||||
|
||||
return db.Team{
|
||||
ID: teamRow.ID,
|
||||
Name: teamRow.Name,
|
||||
Slug: teamRow.Slug,
|
||||
IsByoc: teamRow.IsByoc,
|
||||
CreatedAt: teamRow.CreatedAt,
|
||||
DeletedAt: teamRow.DeletedAt,
|
||||
}, "owner", isFirstUser, nil
|
||||
}
|
||||
|
||||
type switchTeamRequest struct {
|
||||
TeamID string `json:"team_id"`
|
||||
}
|
||||
|
||||
type authHandler struct {
|
||||
db *db.Queries
|
||||
pool *pgxpool.Pool
|
||||
jwtSecret []byte
|
||||
db *db.Queries
|
||||
pool *pgxpool.Pool
|
||||
jwtSecret []byte
|
||||
mailer email.Mailer
|
||||
rdb *redis.Client
|
||||
redirectURL string
|
||||
}
|
||||
|
||||
func newAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte) *authHandler {
|
||||
return &authHandler{db: db, pool: pool, jwtSecret: jwtSecret}
|
||||
func newAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte, mailer email.Mailer, rdb *redis.Client, redirectURL string) *authHandler {
|
||||
return &authHandler{db: db, pool: pool, jwtSecret: jwtSecret, mailer: mailer, rdb: rdb, redirectURL: strings.TrimRight(redirectURL, "/")}
|
||||
}
|
||||
|
||||
type signupRequest struct {
|
||||
@ -77,6 +161,10 @@ type loginRequest struct {
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type activateRequest struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type authResponse struct {
|
||||
Token string `json:"token"`
|
||||
UserID string `json:"user_id"`
|
||||
@ -85,6 +173,10 @@ type authResponse struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type signupResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Signup handles POST /v1/auth/signup.
|
||||
func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
var req signupRequest
|
||||
@ -110,24 +202,41 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
// Check for existing user with this email.
|
||||
existing, err := h.db.GetUserByEmail(ctx, req.Email)
|
||||
if err == nil {
|
||||
// User exists — decide what to do based on status.
|
||||
switch existing.Status {
|
||||
case "inactive":
|
||||
// Unactivated user — allow re-signup after cooldown.
|
||||
if time.Since(existing.CreatedAt.Time) < signupCooldown {
|
||||
writeError(w, http.StatusConflict, "signup_cooldown",
|
||||
"an activation email was recently sent to this address — please check your inbox or try again later")
|
||||
return
|
||||
}
|
||||
// Cooldown passed — delete the old row and proceed with fresh signup.
|
||||
if err := h.db.HardDeleteUser(ctx, existing.ID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to clean up previous signup")
|
||||
return
|
||||
}
|
||||
default:
|
||||
// active, disabled, deleted — email is taken.
|
||||
writeError(w, http.StatusConflict, "email_taken", "an account with this email already exists")
|
||||
return
|
||||
}
|
||||
} else if !errors.Is(err, pgx.ErrNoRows) {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up user")
|
||||
return
|
||||
}
|
||||
|
||||
passwordHash, err := auth.HashPassword(req.Password)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to hash password")
|
||||
return
|
||||
}
|
||||
|
||||
// Use a transaction to atomically create user + team + membership.
|
||||
tx, err := h.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to begin transaction")
|
||||
return
|
||||
}
|
||||
defer tx.Rollback(ctx) //nolint:errcheck
|
||||
|
||||
qtx := h.db.WithTx(tx)
|
||||
|
||||
userID := id.NewUserID()
|
||||
_, err = qtx.InsertUser(ctx, db.InsertUserParams{
|
||||
_, err = h.db.InsertUserInactive(ctx, db.InsertUserInactiveParams{
|
||||
ID: userID,
|
||||
Email: req.Email,
|
||||
PasswordHash: pgtype.Text{String: passwordHash, Valid: true},
|
||||
@ -143,44 +252,111 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Create default team.
|
||||
teamID := id.NewTeamID()
|
||||
if _, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
|
||||
ID: teamID,
|
||||
Name: req.Name + "'s Team",
|
||||
Slug: id.NewTeamSlug(),
|
||||
// Generate activation token and store in Redis.
|
||||
rawToken := generateActivationToken()
|
||||
tokenHash := hashActivationToken(rawToken)
|
||||
redisKey := activationKeyPrefix + tokenHash
|
||||
|
||||
if err := h.rdb.Set(ctx, redisKey, id.FormatUserID(userID), activationTTL).Err(); err != nil {
|
||||
slog.Error("signup: failed to store activation token in redis", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to create activation token")
|
||||
return
|
||||
}
|
||||
|
||||
activateURL := h.redirectURL + "/activate?token=" + rawToken
|
||||
go func() {
|
||||
sendCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := h.mailer.Send(sendCtx, req.Email, "Activate your Wrenn account", email.EmailData{
|
||||
RecipientName: req.Name,
|
||||
Message: "Welcome to Wrenn! Click the button below to activate your account. This link expires in 30 minutes.",
|
||||
Button: &email.Button{Text: "Activate Account", URL: activateURL},
|
||||
Closing: "If you didn't create this account, you can safely ignore this email.",
|
||||
}); err != nil {
|
||||
slog.Warn("signup: failed to send activation email", "email", req.Email, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
writeJSON(w, http.StatusCreated, signupResponse{
|
||||
Message: "Account created. Please check your email to activate your account.",
|
||||
})
|
||||
}
|
||||
|
||||
// Activate handles POST /v1/auth/activate.
|
||||
func (h *authHandler) Activate(w http.ResponseWriter, r *http.Request) {
|
||||
var req activateRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Token == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "token is required")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
tokenHash := hashActivationToken(req.Token)
|
||||
redisKey := activationKeyPrefix + tokenHash
|
||||
|
||||
userIDStr, err := h.rdb.GetDel(ctx, redisKey).Result()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
writeError(w, http.StatusBadRequest, "invalid_token", "activation link is invalid or has expired")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to verify token")
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := id.ParseUserID(userIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "invalid stored user ID")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.db.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
|
||||
return
|
||||
}
|
||||
|
||||
if user.Status != "inactive" {
|
||||
writeError(w, http.StatusBadRequest, "already_activated", "this account has already been activated")
|
||||
return
|
||||
}
|
||||
|
||||
// Activate the user.
|
||||
if err := h.db.SetUserStatus(ctx, db.SetUserStatusParams{
|
||||
ID: userID,
|
||||
Status: "active",
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to create team")
|
||||
slog.Error("activate: failed to set user status", "user_id", id.FormatUserID(userID), "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to activate user")
|
||||
return
|
||||
}
|
||||
|
||||
if err := qtx.InsertTeamMember(ctx, db.InsertTeamMemberParams{
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
IsDefault: true,
|
||||
Role: "owner",
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to add user to team")
|
||||
// Create default team and log them in.
|
||||
team, role, isFirstUser, err := ensureDefaultTeam(ctx, h.db, h.pool, userID, user.Name)
|
||||
if err != nil {
|
||||
slog.Error("activate: failed to create default team", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to set up account")
|
||||
return
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to commit signup")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, req.Email, req.Name, "owner", false)
|
||||
isAdmin := user.IsAdmin || isFirstUser
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, team.ID, user.Email, user.Name, role, isAdmin)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, authResponse{
|
||||
writeJSON(w, http.StatusOK, authResponse{
|
||||
Token: token,
|
||||
UserID: id.FormatUserID(userID),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Email: req.Email,
|
||||
Name: req.Name,
|
||||
TeamID: id.FormatTeamID(team.ID),
|
||||
Email: user.Email,
|
||||
Name: user.Name,
|
||||
})
|
||||
}
|
||||
|
||||
@ -222,17 +398,36 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
team, role, err := loginTeam(ctx, h.db, user.ID)
|
||||
switch user.Status {
|
||||
case "active":
|
||||
// OK — proceed.
|
||||
case "inactive":
|
||||
slog.Warn("login failed: account not activated", "email", req.Email, "ip", r.RemoteAddr)
|
||||
writeError(w, http.StatusForbidden, "account_not_activated", "please check your email and activate your account before signing in")
|
||||
return
|
||||
case "disabled":
|
||||
slog.Warn("login failed: account disabled", "email", req.Email, "ip", r.RemoteAddr)
|
||||
writeError(w, http.StatusForbidden, "account_disabled", "your account has been deactivated — contact your administrator to regain access")
|
||||
return
|
||||
case "deleted":
|
||||
slog.Warn("login failed: account deleted", "email", req.Email, "ip", r.RemoteAddr)
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
|
||||
return
|
||||
default:
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure user has a default team (creates one on first login after activation).
|
||||
team, role, isFirstUser, err := ensureDefaultTeam(ctx, h.db, h.pool, user.ID, user.Name)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
writeError(w, http.StatusForbidden, "no_team", "user is not a member of any team")
|
||||
return
|
||||
}
|
||||
slog.Error("login: failed to ensure default team", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up team")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
|
||||
isAdmin := user.IsAdmin || isFirstUser
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
return
|
||||
@ -322,3 +517,18 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) {
|
||||
Name: user.Name,
|
||||
})
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func generateActivationToken() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand failed: %v", err))
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func hashActivationToken(raw string) string {
|
||||
h := sha256.Sum256([]byte(raw))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
@ -3,19 +3,21 @@ package api
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/layout"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/service"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/validate"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/validate"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
@ -54,6 +56,8 @@ type buildResponse struct {
|
||||
Error *string `json:"error,omitempty"`
|
||||
SandboxID *string `json:"sandbox_id,omitempty"`
|
||||
HostID *string `json:"host_id,omitempty"`
|
||||
DefaultUser string `json:"default_user"`
|
||||
DefaultEnv json.RawMessage `json:"default_env"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
StartedAt *string `json:"started_at,omitempty"`
|
||||
CompletedAt *string `json:"completed_at,omitempty"`
|
||||
@ -71,6 +75,8 @@ func buildToResponse(b db.TemplateBuild) buildResponse {
|
||||
CurrentStep: b.CurrentStep,
|
||||
TotalSteps: b.TotalSteps,
|
||||
Logs: b.Logs,
|
||||
DefaultUser: b.DefaultUser,
|
||||
DefaultEnv: b.DefaultEnv,
|
||||
}
|
||||
if b.Healthcheck != "" {
|
||||
resp.Healthcheck = &b.Healthcheck
|
||||
@ -101,11 +107,54 @@ func buildToResponse(b db.TemplateBuild) buildResponse {
|
||||
}
|
||||
|
||||
// Create handles POST /v1/admin/builds.
|
||||
// Accepts either JSON body or multipart/form-data with a "config" JSON part
|
||||
// and an optional "archive" file part (tar/tar.gz/zip for COPY commands).
|
||||
func (h *buildHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
var req createBuildRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
var archive []byte
|
||||
var archiveName string
|
||||
|
||||
ct := r.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(ct, "multipart/") {
|
||||
// 100 MB max for multipart (archive + JSON config).
|
||||
if err := r.ParseMultipartForm(100 << 20); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "failed to parse multipart form")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse JSON config from "config" field.
|
||||
configStr := r.FormValue("config")
|
||||
if configStr == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "multipart form requires a 'config' JSON field")
|
||||
return
|
||||
}
|
||||
if err := json.Unmarshal([]byte(configStr), &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid config JSON in multipart form")
|
||||
return
|
||||
}
|
||||
|
||||
// Read optional archive file (max 100 MB).
|
||||
file, header, err := r.FormFile("archive")
|
||||
if err == nil {
|
||||
defer file.Close()
|
||||
const maxArchiveSize = 100 << 20 // 100 MB
|
||||
lr := io.LimitReader(file, maxArchiveSize+1)
|
||||
archive, err = io.ReadAll(lr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "failed to read archive file")
|
||||
return
|
||||
}
|
||||
if int64(len(archive)) > maxArchiveSize {
|
||||
writeError(w, http.StatusRequestEntityTooLarge, "invalid_request", "archive exceeds 100 MB limit")
|
||||
return
|
||||
}
|
||||
archiveName = header.Filename
|
||||
}
|
||||
} else {
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
@ -129,6 +178,8 @@ func (h *buildHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
VCPUs: req.VCPUs,
|
||||
MemoryMB: req.MemoryMB,
|
||||
SkipPrePost: req.SkipPrePost,
|
||||
Archive: archive,
|
||||
ArchiveName: archiveName,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("failed to create build", "error", err)
|
||||
|
||||
@ -8,11 +8,11 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/channels"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/channels"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
type channelHandler struct {
|
||||
|
||||
@ -12,10 +12,10 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
@ -29,9 +29,13 @@ func newExecHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execHandler
|
||||
}
|
||||
|
||||
type execRequest struct {
|
||||
Cmd string `json:"cmd"`
|
||||
Args []string `json:"args"`
|
||||
TimeoutSec int32 `json:"timeout_sec"`
|
||||
Cmd string `json:"cmd"`
|
||||
Args []string `json:"args"`
|
||||
TimeoutSec int32 `json:"timeout_sec"`
|
||||
Background bool `json:"background"`
|
||||
Tag string `json:"tag"`
|
||||
Envs map[string]string `json:"envs"`
|
||||
Cwd string `json:"cwd"`
|
||||
}
|
||||
|
||||
type execResponse struct {
|
||||
@ -45,7 +49,14 @@ type execResponse struct {
|
||||
Encoding string `json:"encoding"`
|
||||
}
|
||||
|
||||
// Exec handles POST /v1/sandboxes/{id}/exec.
|
||||
type backgroundExecResponse struct {
|
||||
SandboxID string `json:"sandbox_id"`
|
||||
Cmd string `json:"cmd"`
|
||||
PID uint32 `json:"pid"`
|
||||
Tag string `json:"tag"`
|
||||
}
|
||||
|
||||
// Exec handles POST /v1/capsules/{id}/exec.
|
||||
func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
@ -78,14 +89,54 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
// Background mode: start process and return immediately.
|
||||
if req.Background {
|
||||
tag := req.Tag
|
||||
if tag == "" {
|
||||
tag = "proc-" + id.NewPtyTag()
|
||||
}
|
||||
|
||||
bgResp, err := agent.StartBackground(ctx, connect.NewRequest(&pb.StartBackgroundRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Tag: tag,
|
||||
Cmd: req.Cmd,
|
||||
Args: req.Args,
|
||||
Envs: req.Envs,
|
||||
Cwd: req.Cwd,
|
||||
}))
|
||||
if err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.UpdateLastActive(ctx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusAccepted, backgroundExecResponse{
|
||||
SandboxID: sandboxIDStr,
|
||||
Cmd: req.Cmd,
|
||||
PID: bgResp.Msg.Pid,
|
||||
Tag: bgResp.Msg.Tag,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Cmd: req.Cmd,
|
||||
|
||||
@ -12,20 +12,21 @@ import (
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
type execStreamHandler struct {
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
jwtSecret []byte
|
||||
}
|
||||
|
||||
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execStreamHandler {
|
||||
return &execStreamHandler{db: db, pool: pool}
|
||||
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *execStreamHandler {
|
||||
return &execStreamHandler{db: db, pool: pool, jwtSecret: jwtSecret}
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
@ -47,11 +48,10 @@ type wsOutMsg struct {
|
||||
ExitCode *int32 `json:"exit_code,omitempty"` // only for "exit"
|
||||
}
|
||||
|
||||
// ExecStream handles WS /v1/sandboxes/{id}/exec/stream.
|
||||
// ExecStream handles WS /v1/capsules/{id}/exec/stream.
|
||||
func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
@ -59,13 +59,31 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
// Authenticate: use context from middleware (API key) or WS first message (JWT).
|
||||
ac, hasAuth := auth.FromContext(ctx)
|
||||
|
||||
if !hasAuth {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
var wsAC auth.AuthContext
|
||||
var authErr error
|
||||
if isAdminWSRoute(ctx) {
|
||||
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
|
||||
} else {
|
||||
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
|
||||
}
|
||||
if authErr != nil {
|
||||
sendWSError(conn, "authentication failed")
|
||||
return
|
||||
}
|
||||
ac = wsAC
|
||||
|
||||
h.runExecStream(ctx, conn, ac, sandboxID, sandboxIDStr)
|
||||
return
|
||||
}
|
||||
|
||||
@ -76,6 +94,20 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
h.runExecStream(ctx, conn, ac, sandboxID, sandboxIDStr)
|
||||
}
|
||||
|
||||
func (h *execStreamHandler) runExecStream(ctx context.Context, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr string) {
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
sendWSError(conn, "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
sendWSError(conn, "sandbox is not running (status: "+sb.Status+")")
|
||||
return
|
||||
}
|
||||
|
||||
// Read the start message.
|
||||
var startMsg wsStartMsg
|
||||
if err := conn.ReadJSON(&startMsg); err != nil {
|
||||
|
||||
@ -9,10 +9,10 @@ import (
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
@ -25,7 +25,7 @@ func newFilesHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesHandl
|
||||
return &filesHandler{db: db, pool: pool}
|
||||
}
|
||||
|
||||
// Upload handles POST /v1/sandboxes/{id}/files/write.
|
||||
// Upload handles POST /v1/capsules/{id}/files/write.
|
||||
// Expects multipart/form-data with:
|
||||
// - "path" text field: absolute destination path inside the sandbox
|
||||
// - "file" file field: binary content to write
|
||||
@ -105,7 +105,7 @@ type readFileRequest struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// Download handles POST /v1/sandboxes/{id}/files/read.
|
||||
// Download handles POST /v1/capsules/{id}/files/read.
|
||||
// Accepts JSON body with path, returns raw file content with Content-Disposition.
|
||||
func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
|
||||
@ -10,10 +10,10 @@ import (
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
@ -26,7 +26,7 @@ func newFilesStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *file
|
||||
return &filesStreamHandler{db: db, pool: pool}
|
||||
}
|
||||
|
||||
// StreamUpload handles POST /v1/sandboxes/{id}/files/stream/write.
|
||||
// StreamUpload handles POST /v1/capsules/{id}/files/stream/write.
|
||||
// Expects multipart/form-data with "path" text field and "file" file field.
|
||||
// Streams file content directly from the request body to the host agent without buffering.
|
||||
func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) {
|
||||
@ -150,7 +150,7 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// StreamDownload handles POST /v1/sandboxes/{id}/files/stream/read.
|
||||
// StreamDownload handles POST /v1/capsules/{id}/files/stream/read.
|
||||
// Accepts JSON body with path, streams file content back without buffering.
|
||||
func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
|
||||
236
internal/api/handlers_fs.go
Normal file
236
internal/api/handlers_fs.go
Normal file
@ -0,0 +1,236 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
type fsHandler struct {
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
}
|
||||
|
||||
func newFSHandler(db *db.Queries, pool *lifecycle.HostClientPool) *fsHandler {
|
||||
return &fsHandler{db: db, pool: pool}
|
||||
}
|
||||
|
||||
type listDirRequest struct {
|
||||
Path string `json:"path"`
|
||||
Depth uint32 `json:"depth"`
|
||||
}
|
||||
|
||||
type fileEntryResponse struct {
|
||||
Name string `json:"name"`
|
||||
Path string `json:"path"`
|
||||
Type string `json:"type"`
|
||||
Size int64 `json:"size"`
|
||||
Mode uint32 `json:"mode"`
|
||||
Permissions string `json:"permissions"`
|
||||
Owner string `json:"owner"`
|
||||
Group string `json:"group"`
|
||||
ModifiedAt int64 `json:"modified_at"`
|
||||
SymlinkTarget *string `json:"symlink_target,omitempty"`
|
||||
}
|
||||
|
||||
type listDirResponse struct {
|
||||
Entries []fileEntryResponse `json:"entries"`
|
||||
}
|
||||
|
||||
type makeDirRequest struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
type makeDirResponse struct {
|
||||
Entry fileEntryResponse `json:"entry"`
|
||||
}
|
||||
|
||||
type removeRequest struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// ListDir handles POST /v1/capsules/{id}/files/list.
|
||||
func (h *fsHandler) ListDir(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
return
|
||||
}
|
||||
|
||||
var req listDirRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
if req.Path == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "path is required")
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := agent.ListDir(ctx, connect.NewRequest(&pb.ListDirRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Path: req.Path,
|
||||
Depth: req.Depth,
|
||||
}))
|
||||
if err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
entries := make([]fileEntryResponse, 0, len(resp.Msg.Entries))
|
||||
for _, e := range resp.Msg.Entries {
|
||||
entries = append(entries, fileEntryFromPB(e))
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, listDirResponse{Entries: entries})
|
||||
}
|
||||
|
||||
// MakeDir handles POST /v1/capsules/{id}/files/mkdir.
|
||||
func (h *fsHandler) MakeDir(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
return
|
||||
}
|
||||
|
||||
var req makeDirRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
if req.Path == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "path is required")
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := agent.MakeDir(ctx, connect.NewRequest(&pb.MakeDirRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Path: req.Path,
|
||||
}))
|
||||
if err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, makeDirResponse{Entry: fileEntryFromPB(resp.Msg.Entry)})
|
||||
}
|
||||
|
||||
// Remove handles POST /v1/capsules/{id}/files/remove.
|
||||
func (h *fsHandler) Remove(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
return
|
||||
}
|
||||
|
||||
var req removeRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
if req.Path == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "path is required")
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := agent.RemovePath(ctx, connect.NewRequest(&pb.RemovePathRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Path: req.Path,
|
||||
})); err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func fileEntryFromPB(e *pb.FileEntry) fileEntryResponse {
|
||||
if e == nil {
|
||||
return fileEntryResponse{}
|
||||
}
|
||||
resp := fileEntryResponse{
|
||||
Name: e.Name,
|
||||
Path: e.Path,
|
||||
Type: e.Type,
|
||||
Size: e.Size,
|
||||
Mode: e.Mode,
|
||||
Permissions: e.Permissions,
|
||||
Owner: e.Owner,
|
||||
Group: e.Group,
|
||||
ModifiedAt: e.ModifiedAt,
|
||||
}
|
||||
if e.SymlinkTarget != nil {
|
||||
resp.SymlinkTarget = e.SymlinkTarget
|
||||
}
|
||||
return resp
|
||||
}
|
||||
@ -10,11 +10,11 @@ import (
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/service"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
)
|
||||
|
||||
type hostHandler struct {
|
||||
|
||||
585
internal/api/handlers_me.go
Normal file
585
internal/api/handlers_me.go
Normal file
@ -0,0 +1,585 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/email"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/oauth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
)
|
||||
|
||||
const (
|
||||
passwordResetKeyPrefix = "wrenn:password_reset:"
|
||||
passwordResetTTL = 15 * time.Minute
|
||||
)
|
||||
|
||||
type meHandler struct {
|
||||
db *db.Queries
|
||||
pool *pgxpool.Pool
|
||||
rdb *redis.Client
|
||||
jwtSecret []byte
|
||||
mailer email.Mailer
|
||||
oauthRegistry *oauth.Registry
|
||||
redirectURL string
|
||||
teamSvc *service.TeamService
|
||||
}
|
||||
|
||||
func newMeHandler(
|
||||
db *db.Queries,
|
||||
pool *pgxpool.Pool,
|
||||
rdb *redis.Client,
|
||||
jwtSecret []byte,
|
||||
mailer email.Mailer,
|
||||
registry *oauth.Registry,
|
||||
redirectURL string,
|
||||
teamSvc *service.TeamService,
|
||||
) *meHandler {
|
||||
return &meHandler{
|
||||
db: db,
|
||||
pool: pool,
|
||||
rdb: rdb,
|
||||
jwtSecret: jwtSecret,
|
||||
mailer: mailer,
|
||||
oauthRegistry: registry,
|
||||
redirectURL: strings.TrimRight(redirectURL, "/"),
|
||||
teamSvc: teamSvc,
|
||||
}
|
||||
}
|
||||
|
||||
type meResponse struct {
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
HasPassword bool `json:"has_password"`
|
||||
Providers []string `json:"providers"`
|
||||
}
|
||||
|
||||
type updateNameRequest struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type changePasswordRequest struct {
|
||||
CurrentPassword string `json:"current_password"`
|
||||
NewPassword string `json:"new_password"`
|
||||
ConfirmPassword string `json:"confirm_password"`
|
||||
}
|
||||
|
||||
type requestPasswordResetRequest struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
type confirmPasswordResetRequest struct {
|
||||
Token string `json:"token"`
|
||||
NewPassword string `json:"new_password"`
|
||||
}
|
||||
|
||||
type deleteAccountRequest struct {
|
||||
Confirmation string `json:"confirmation"`
|
||||
}
|
||||
|
||||
// GetMe handles GET /v1/me.
|
||||
func (h *meHandler) GetMe(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
ctx := r.Context()
|
||||
|
||||
user, err := h.db.GetUserByID(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
|
||||
return
|
||||
}
|
||||
|
||||
providers, err := h.db.GetOAuthProvidersByUserID(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get providers")
|
||||
return
|
||||
}
|
||||
|
||||
providerNames := make([]string, 0, len(providers))
|
||||
for _, p := range providers {
|
||||
providerNames = append(providerNames, p.Provider)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, meResponse{
|
||||
Name: user.Name,
|
||||
Email: user.Email,
|
||||
HasPassword: user.PasswordHash.Valid,
|
||||
Providers: providerNames,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateName handles PATCH /v1/me — updates the user's name and re-issues a JWT.
|
||||
func (h *meHandler) UpdateName(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
ctx := r.Context()
|
||||
|
||||
var req updateNameRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
req.Name = strings.TrimSpace(req.Name)
|
||||
if req.Name == "" || len(req.Name) > 100 {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "name must be between 1 and 100 characters")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.UpdateUserName(ctx, db.UpdateUserNameParams{
|
||||
ID: ac.UserID,
|
||||
Name: req.Name,
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to update name")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.db.GetUserByID(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
|
||||
return
|
||||
}
|
||||
|
||||
team, role, err := loginTeam(ctx, h.db, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get team")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, ac.UserID, team.ID, user.Email, req.Name, role, user.IsAdmin)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, authResponse{
|
||||
Token: token,
|
||||
UserID: id.FormatUserID(ac.UserID),
|
||||
TeamID: id.FormatTeamID(team.ID),
|
||||
Email: user.Email,
|
||||
Name: req.Name,
|
||||
})
|
||||
}
|
||||
|
||||
// ChangePassword handles POST /v1/me/password.
|
||||
// For users with a password: requires current_password + new_password.
|
||||
// For OAuth-only users: requires new_password + confirm_password.
|
||||
func (h *meHandler) ChangePassword(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
ctx := r.Context()
|
||||
|
||||
var req changePasswordRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.db.GetUserByID(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
|
||||
return
|
||||
}
|
||||
|
||||
if user.PasswordHash.Valid {
|
||||
// Changing existing password — verify current.
|
||||
if req.CurrentPassword == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "current_password is required")
|
||||
return
|
||||
}
|
||||
if err := auth.CheckPassword(user.PasswordHash.String, req.CurrentPassword); err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "wrong_password", "current password is incorrect")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// OAuth user adding a password — confirm must match.
|
||||
if req.ConfirmPassword == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "confirm_password is required")
|
||||
return
|
||||
}
|
||||
if req.NewPassword != req.ConfirmPassword {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "passwords do not match")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if len(req.NewPassword) < 8 {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "password must be at least 8 characters")
|
||||
return
|
||||
}
|
||||
|
||||
hash, err := auth.HashPassword(req.NewPassword)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to hash password")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.UpdateUserPassword(ctx, db.UpdateUserPasswordParams{
|
||||
ID: ac.UserID,
|
||||
PasswordHash: pgtype.Text{String: hash, Valid: true},
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to update password")
|
||||
return
|
||||
}
|
||||
|
||||
isAdding := !user.PasswordHash.Valid
|
||||
go func() {
|
||||
sendCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
subject, message := "Your Wrenn password was changed", "Your account password was successfully updated. If you did not make this change, reset your password immediately."
|
||||
if isAdding {
|
||||
subject = "Password added to your Wrenn account"
|
||||
message = "A password has been added to your Wrenn account. You can now sign in with your email and password in addition to any connected OAuth providers."
|
||||
}
|
||||
if err := h.mailer.Send(sendCtx, user.Email, subject, email.EmailData{
|
||||
RecipientName: user.Name,
|
||||
Message: message,
|
||||
Closing: "If you didn't make this change, contact support immediately.",
|
||||
}); err != nil {
|
||||
slog.Warn("change password: failed to send notification", "email", user.Email, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// RequestPasswordReset handles POST /v1/me/password/reset (unauthenticated).
|
||||
// Always returns 200 to avoid leaking account existence.
|
||||
func (h *meHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Request) {
|
||||
var req requestPasswordResetRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
req.Email = strings.TrimSpace(strings.ToLower(req.Email))
|
||||
if req.Email == "" {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
user, err := h.db.GetUserByEmail(ctx, req.Email)
|
||||
if err != nil {
|
||||
// Don't leak whether the email exists.
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
if user.Status != "active" {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
rawToken := generateResetToken()
|
||||
tokenHash := hashResetToken(rawToken)
|
||||
redisKey := passwordResetKeyPrefix + tokenHash
|
||||
|
||||
if err := h.rdb.Set(ctx, redisKey, id.FormatUserID(user.ID), passwordResetTTL).Err(); err != nil {
|
||||
slog.Error("password reset: failed to store token in redis", "error", err)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
resetURL := h.redirectURL + "/reset-password?token=" + rawToken
|
||||
go func() {
|
||||
sendCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := h.mailer.Send(sendCtx, user.Email, "Reset your Wrenn password", email.EmailData{
|
||||
RecipientName: user.Name,
|
||||
Message: "We received a request to reset your password. Click the button below to set a new password. This link expires in 15 minutes.",
|
||||
Button: &email.Button{Text: "Reset Password", URL: resetURL},
|
||||
Closing: "If you didn't request a password reset, you can safely ignore this email.",
|
||||
}); err != nil {
|
||||
slog.Error("password reset: failed to send email", "email", user.Email, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// ConfirmPasswordReset handles POST /v1/me/password/reset/confirm (unauthenticated).
|
||||
func (h *meHandler) ConfirmPasswordReset(w http.ResponseWriter, r *http.Request) {
|
||||
var req confirmPasswordResetRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Token == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "token is required")
|
||||
return
|
||||
}
|
||||
if len(req.NewPassword) < 8 {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "password must be at least 8 characters")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
tokenHash := hashResetToken(req.Token)
|
||||
redisKey := passwordResetKeyPrefix + tokenHash
|
||||
|
||||
// GetDel atomically retrieves and removes the token in a single round-trip,
|
||||
// preventing concurrent requests from both consuming the same token.
|
||||
userIDStr, err := h.rdb.GetDel(ctx, redisKey).Result()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
writeError(w, http.StatusBadRequest, "invalid_token", "reset token is invalid or has expired")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to verify token")
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := id.ParseUserID(userIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "invalid stored user ID")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.db.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
|
||||
return
|
||||
}
|
||||
|
||||
hash, err := auth.HashPassword(req.NewPassword)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to hash password")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.UpdateUserPassword(ctx, db.UpdateUserPasswordParams{
|
||||
ID: userID,
|
||||
PasswordHash: pgtype.Text{String: hash, Valid: true},
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to update password")
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
sendCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := h.mailer.Send(sendCtx, user.Email, "Your Wrenn password was reset", email.EmailData{
|
||||
RecipientName: user.Name,
|
||||
Message: "Your password has been successfully reset. You can now sign in with your new password.",
|
||||
Closing: "If you didn't request this change, contact support immediately.",
|
||||
}); err != nil {
|
||||
slog.Warn("confirm password reset: failed to send notification", "email", user.Email, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// ConnectProvider handles GET /v1/me/providers/{provider}/connect.
|
||||
// Sets OAuth state + link cookies and returns the provider auth URL.
|
||||
func (h *meHandler) ConnectProvider(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
provider := chi.URLParam(r, "provider")
|
||||
|
||||
p, ok := h.oauthRegistry.Get(provider)
|
||||
if !ok {
|
||||
writeError(w, http.StatusNotFound, "provider_not_found", "unsupported OAuth provider")
|
||||
return
|
||||
}
|
||||
|
||||
state, err := generateState()
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate state")
|
||||
return
|
||||
}
|
||||
|
||||
mac := computeHMAC(h.jwtSecret, state)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: state + ":" + mac,
|
||||
Path: "/",
|
||||
MaxAge: 600,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: isSecure(r),
|
||||
})
|
||||
|
||||
userIDStr := id.FormatUserID(ac.UserID)
|
||||
linkMac := computeHMAC(h.jwtSecret, userIDStr)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_link_user_id",
|
||||
Value: userIDStr + ":" + linkMac,
|
||||
Path: "/",
|
||||
MaxAge: 600,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: isSecure(r),
|
||||
})
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]string{"auth_url": p.AuthCodeURL(state)})
|
||||
}
|
||||
|
||||
// DisconnectProvider handles DELETE /v1/me/providers/{provider}.
|
||||
func (h *meHandler) DisconnectProvider(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
provider := chi.URLParam(r, "provider")
|
||||
ctx := r.Context()
|
||||
|
||||
user, err := h.db.GetUserByID(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
|
||||
return
|
||||
}
|
||||
|
||||
providers, err := h.db.GetOAuthProvidersByUserID(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get providers")
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure the user will still have at least one login method after disconnecting.
|
||||
if !user.PasswordHash.Valid && len(providers) <= 1 {
|
||||
writeError(w, http.StatusBadRequest, "last_login_method", "cannot disconnect your only login method — add a password first")
|
||||
return
|
||||
}
|
||||
|
||||
// Check the provider is actually linked to this user.
|
||||
found := false
|
||||
for _, p := range providers {
|
||||
if p.Provider == provider {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
writeError(w, http.StatusNotFound, "not_found", "provider not connected")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.DeleteOAuthProvider(ctx, db.DeleteOAuthProviderParams{
|
||||
UserID: ac.UserID,
|
||||
Provider: provider,
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to disconnect provider")
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// DeleteAccount handles DELETE /v1/me — soft-deletes the user's account.
|
||||
func (h *meHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
ctx := r.Context()
|
||||
|
||||
var req deleteAccountRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.db.GetUserByID(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.EqualFold(strings.TrimSpace(req.Confirmation), user.Email) {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "confirmation does not match your email address")
|
||||
return
|
||||
}
|
||||
|
||||
teamsBlocking, err := h.db.CountUserOwnedTeamsWithOtherMembers(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to check team ownership")
|
||||
return
|
||||
}
|
||||
if teamsBlocking > 0 {
|
||||
writeError(w, http.StatusConflict, "owns_team_with_members",
|
||||
fmt.Sprintf("you own %d team(s) with other members — transfer ownership or remove members before deleting your account", teamsBlocking))
|
||||
return
|
||||
}
|
||||
|
||||
// Delete all teams the user solely owns (no other members).
|
||||
// Team deletion involves RPC calls (sandbox destruction) that cannot be
|
||||
// transactional, so we do those first as best-effort, then wrap the
|
||||
// DB-only cleanup in a transaction.
|
||||
soleTeams, err := h.db.ListSoleOwnedTeams(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list owned teams")
|
||||
return
|
||||
}
|
||||
for _, teamID := range soleTeams {
|
||||
if err := h.teamSvc.DeleteTeamInternal(ctx, teamID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error",
|
||||
fmt.Sprintf("failed to delete sole-owned team %s", id.FormatTeamID(teamID)))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tx, err := h.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to start transaction")
|
||||
return
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
qtx := h.db.WithTx(tx)
|
||||
|
||||
if err := qtx.DeleteAPIKeysByCreator(ctx, ac.UserID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete user's API keys")
|
||||
return
|
||||
}
|
||||
|
||||
if err := qtx.SoftDeleteUser(ctx, ac.UserID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete account")
|
||||
return
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to commit account deletion")
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("account soft-deleted", "user_id", id.FormatUserID(ac.UserID), "email", user.Email)
|
||||
|
||||
go func() {
|
||||
sendCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := h.mailer.Send(sendCtx, user.Email, "Your Wrenn account has been deleted", email.EmailData{
|
||||
RecipientName: user.Name,
|
||||
Message: "Your Wrenn account has been deactivated and is scheduled for permanent deletion in 15 days. If this was a mistake, contact support before then to recover your account.",
|
||||
Closing: "Thank you for using Wrenn.",
|
||||
}); err != nil {
|
||||
slog.Warn("delete account: failed to send notification", "email", user.Email, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func generateResetToken() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand failed: %v", err))
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func hashResetToken(raw string) string {
|
||||
h := sha256.Sum256([]byte(raw))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
@ -9,10 +9,10 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
@ -38,7 +38,7 @@ type metricsResponse struct {
|
||||
Points []metricPointResponse `json:"points"`
|
||||
}
|
||||
|
||||
// GetMetrics handles GET /v1/sandboxes/{id}/metrics?range=10m|2h|24h.
|
||||
// GetMetrics handles GET /v1/capsules/{id}/metrics?range=10m|2h|24h.
|
||||
func (h *sandboxMetricsHandler) GetMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
|
||||
@ -16,10 +16,10 @@ import (
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth/oauth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/oauth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
type oauthHandler struct {
|
||||
@ -137,6 +137,73 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
email := strings.TrimSpace(strings.ToLower(profile.Email))
|
||||
|
||||
// Check for a link operation initiated from the settings page.
|
||||
if linkCookie, err := r.Cookie("oauth_link_user_id"); err == nil && linkCookie.Value != "" {
|
||||
// Clear the link cookie immediately.
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_link_user_id",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: isSecure(r),
|
||||
})
|
||||
|
||||
settingsBase := h.redirectURL + "/dashboard/settings"
|
||||
|
||||
// Verify the HMAC to prevent cookie forgery.
|
||||
linkParts := strings.SplitN(linkCookie.Value, ":", 2)
|
||||
if len(linkParts) != 2 || !hmac.Equal([]byte(computeHMAC(h.jwtSecret, linkParts[0])), []byte(linkParts[1])) {
|
||||
slog.Warn("oauth link: invalid or tampered link cookie")
|
||||
http.Redirect(w, r, settingsBase+"?connect_error=invalid_state", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
userID, parseErr := id.ParseUserID(linkParts[0])
|
||||
if parseErr != nil {
|
||||
slog.Error("oauth link: invalid user ID in cookie", "error", parseErr)
|
||||
http.Redirect(w, r, settingsBase+"?connect_error=invalid_state", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure the GitHub account isn't already linked to a different user.
|
||||
existing, lookupErr := h.db.GetOAuthProvider(ctx, db.GetOAuthProviderParams{
|
||||
Provider: provider,
|
||||
ProviderID: profile.ProviderID,
|
||||
})
|
||||
if lookupErr == nil && existing.UserID != userID {
|
||||
slog.Warn("oauth link: provider already linked to another account", "provider", provider)
|
||||
http.Redirect(w, r, settingsBase+"?connect_error=already_linked", http.StatusFound)
|
||||
return
|
||||
}
|
||||
if lookupErr == nil && existing.UserID == userID {
|
||||
// Already linked to this user — treat as success.
|
||||
http.Redirect(w, r, settingsBase+"?connected="+provider, http.StatusFound)
|
||||
return
|
||||
}
|
||||
if !errors.Is(lookupErr, pgx.ErrNoRows) {
|
||||
slog.Error("oauth link: db lookup failed", "error", lookupErr)
|
||||
http.Redirect(w, r, settingsBase+"?connect_error=db_error", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
if insertErr := h.db.InsertOAuthProvider(ctx, db.InsertOAuthProviderParams{
|
||||
Provider: provider,
|
||||
ProviderID: profile.ProviderID,
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
}); insertErr != nil {
|
||||
slog.Error("oauth link: failed to insert provider", "error", insertErr)
|
||||
http.Redirect(w, r, settingsBase+"?connect_error=db_error", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("oauth link: provider linked", "provider", provider, "user_id", id.FormatUserID(userID))
|
||||
http.Redirect(w, r, settingsBase+"?connected="+provider, http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this OAuth identity already exists.
|
||||
existing, err := h.db.GetOAuthProvider(ctx, db.GetOAuthProviderParams{
|
||||
Provider: provider,
|
||||
@ -145,18 +212,29 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
if err == nil {
|
||||
// Existing OAuth user — log them in.
|
||||
user, err := h.db.GetUserByID(ctx, existing.UserID)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Warn("oauth login: user no longer exists", "user_id", existing.UserID)
|
||||
redirectWithError(w, r, redirectBase, "account_deactivated")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to get user", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
team, role, err := loginTeam(ctx, h.db, user.ID)
|
||||
if user.Status != "active" {
|
||||
slog.Warn("oauth login: account not active", "email", user.Email, "status", user.Status)
|
||||
redirectWithError(w, r, redirectBase, "account_deactivated")
|
||||
return
|
||||
}
|
||||
team, role, isFirstUser, err := ensureDefaultTeam(ctx, h.db, h.pool, user.ID, user.Name)
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to get team", "error", err)
|
||||
slog.Error("oauth login: failed to ensure team", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
|
||||
isAdmin := user.IsAdmin || isFirstUser
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
@ -172,13 +250,21 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// New OAuth identity — check for email collision.
|
||||
_, err = h.db.GetUserByEmail(ctx, email)
|
||||
existingUser, err := h.db.GetUserByEmail(ctx, email)
|
||||
if err == nil {
|
||||
// Email already taken by another account.
|
||||
redirectWithError(w, r, redirectBase, "email_taken")
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, pgx.ErrNoRows) {
|
||||
if existingUser.Status == "inactive" {
|
||||
// Unactivated email signup — delete and let OAuth take over.
|
||||
if delErr := h.db.HardDeleteUser(ctx, existingUser.ID); delErr != nil {
|
||||
slog.Error("oauth: failed to delete inactive user", "error", delErr)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Email already taken by an active/disabled/deleted account.
|
||||
redirectWithError(w, r, redirectBase, "email_taken")
|
||||
return
|
||||
}
|
||||
} else if !errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Error("oauth: email check failed", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
@ -195,6 +281,15 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
qtx := h.db.WithTx(tx)
|
||||
|
||||
// The first user to sign up becomes a platform admin.
|
||||
userCount, err := qtx.CountUsers(ctx)
|
||||
if err != nil {
|
||||
slog.Error("oauth: failed to count users", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
isFirstUser := userCount == 0
|
||||
|
||||
userID := id.NewUserID()
|
||||
_, err = qtx.InsertUserOAuth(ctx, db.InsertUserOAuthParams{
|
||||
ID: userID,
|
||||
@ -238,6 +333,14 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if isFirstUser {
|
||||
if err := qtx.SetUserAdmin(ctx, db.SetUserAdminParams{ID: userID, IsAdmin: true}); err != nil {
|
||||
slog.Error("oauth: failed to set admin status", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := qtx.InsertOAuthProvider(ctx, db.InsertOAuthProviderParams{
|
||||
Provider: provider,
|
||||
ProviderID: profile.ProviderID,
|
||||
@ -255,7 +358,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email, profile.Name, "owner", false)
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email, profile.Name, "owner", isFirstUser)
|
||||
if err != nil {
|
||||
slog.Error("oauth: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
@ -279,18 +382,29 @@ func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, prov
|
||||
return
|
||||
}
|
||||
user, err := h.db.GetUserByID(ctx, existing.UserID)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Warn("oauth: retry login: user no longer exists", "user_id", existing.UserID)
|
||||
redirectWithError(w, r, redirectBase, "account_deactivated")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to get user", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
team, role, err := loginTeam(ctx, h.db, user.ID)
|
||||
if user.Status != "active" {
|
||||
slog.Warn("oauth: retry login: account not active", "email", user.Email, "status", user.Status)
|
||||
redirectWithError(w, r, redirectBase, "account_deactivated")
|
||||
return
|
||||
}
|
||||
team, role, isFirstUser, err := ensureDefaultTeam(ctx, h.db, h.pool, user.ID, user.Name)
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to get team", "error", err)
|
||||
slog.Error("oauth: retry login: failed to ensure team", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
|
||||
isAdmin := user.IsAdmin || isFirstUser
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
|
||||
298
internal/api/handlers_process.go
Normal file
298
internal/api/handlers_process.go
Normal file
@ -0,0 +1,298 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
type processHandler struct {
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
jwtSecret []byte
|
||||
}
|
||||
|
||||
func newProcessHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *processHandler {
|
||||
return &processHandler{db: db, pool: pool, jwtSecret: jwtSecret}
|
||||
}
|
||||
|
||||
// processResponse is a single entry in the process list.
|
||||
type processResponse struct {
|
||||
PID uint32 `json:"pid"`
|
||||
Tag string `json:"tag,omitempty"`
|
||||
Cmd string `json:"cmd"`
|
||||
Args []string `json:"args,omitempty"`
|
||||
}
|
||||
|
||||
// processListResponse wraps the list of processes.
|
||||
type processListResponse struct {
|
||||
Processes []processResponse `json:"processes"`
|
||||
}
|
||||
|
||||
// ListProcesses handles GET /v1/capsules/{id}/processes.
|
||||
func (h *processHandler) ListProcesses(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := agent.ListProcesses(ctx, connect.NewRequest(&pb.ListProcessesRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
}))
|
||||
if err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
procs := make([]processResponse, 0, len(resp.Msg.Processes))
|
||||
for _, p := range resp.Msg.Processes {
|
||||
procs = append(procs, processResponse{
|
||||
PID: p.Pid,
|
||||
Tag: p.Tag,
|
||||
Cmd: p.Cmd,
|
||||
Args: p.Args,
|
||||
})
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, processListResponse{Processes: procs})
|
||||
}
|
||||
|
||||
// KillProcess handles DELETE /v1/capsules/{id}/processes/{selector}.
|
||||
// The selector can be a numeric PID or a string tag.
|
||||
func (h *processHandler) KillProcess(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
selectorStr := chi.URLParam(r, "selector")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
// Build the kill request with PID or tag selector.
|
||||
killReq := &pb.KillProcessRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Signal: "SIGKILL",
|
||||
}
|
||||
if sig := r.URL.Query().Get("signal"); sig == "SIGTERM" {
|
||||
killReq.Signal = "SIGTERM"
|
||||
}
|
||||
|
||||
if pid, err := strconv.ParseUint(selectorStr, 10, 32); err == nil {
|
||||
killReq.Selector = &pb.KillProcessRequest_Pid{Pid: uint32(pid)}
|
||||
} else {
|
||||
killReq.Selector = &pb.KillProcessRequest_Tag{Tag: selectorStr}
|
||||
}
|
||||
|
||||
if _, err := agent.KillProcess(ctx, connect.NewRequest(killReq)); err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// wsProcessOut is the JSON message sent to the WebSocket client.
|
||||
type wsProcessOut struct {
|
||||
Type string `json:"type"` // "start", "stdout", "stderr", "exit", "error"
|
||||
PID uint32 `json:"pid,omitempty"` // only for "start"
|
||||
Data string `json:"data,omitempty"` // only for "stdout", "stderr", "error"
|
||||
ExitCode *int32 `json:"exit_code,omitempty"` // only for "exit"
|
||||
}
|
||||
|
||||
// ConnectProcess handles WS /v1/capsules/{id}/processes/{selector}/stream.
|
||||
func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
selectorStr := chi.URLParam(r, "selector")
|
||||
ctx := r.Context()
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Authenticate: use context from middleware (API key) or WS first message (JWT).
|
||||
ac, hasAuth := auth.FromContext(ctx)
|
||||
|
||||
if !hasAuth {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("process stream websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
var wsAC auth.AuthContext
|
||||
var authErr error
|
||||
if isAdminWSRoute(ctx) {
|
||||
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
|
||||
} else {
|
||||
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
|
||||
}
|
||||
if authErr != nil {
|
||||
sendProcessWSError(conn, "authentication failed")
|
||||
return
|
||||
}
|
||||
ac = wsAC
|
||||
|
||||
h.runConnectProcess(ctx, conn, ac, sandboxID, sandboxIDStr, selectorStr)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("process stream websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
h.runConnectProcess(ctx, conn, ac, sandboxID, sandboxIDStr, selectorStr)
|
||||
}
|
||||
|
||||
func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr, selectorStr string) {
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
sendProcessWSError(conn, "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
sendProcessWSError(conn, "sandbox is not running (status: "+sb.Status+")")
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
sendProcessWSError(conn, "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
// Build the connect request with PID or tag selector.
|
||||
connectReq := &pb.ConnectProcessRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
}
|
||||
if pid, err := strconv.ParseUint(selectorStr, 10, 32); err == nil {
|
||||
connectReq.Selector = &pb.ConnectProcessRequest_Pid{Pid: uint32(pid)}
|
||||
} else {
|
||||
connectReq.Selector = &pb.ConnectProcessRequest_Tag{Tag: selectorStr}
|
||||
}
|
||||
|
||||
streamCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
stream, err := agent.ConnectProcess(streamCtx, connect.NewRequest(connectReq))
|
||||
if err != nil {
|
||||
sendProcessWSError(conn, "failed to connect to process: "+err.Error())
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
// Listen for client disconnect in a goroutine.
|
||||
go func() {
|
||||
for {
|
||||
_, _, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Forward stream events to WebSocket.
|
||||
for stream.Receive() {
|
||||
resp := stream.Msg()
|
||||
switch ev := resp.Event.(type) {
|
||||
case *pb.ConnectProcessResponse_Start:
|
||||
writeWSJSON(conn, wsProcessOut{Type: "start", PID: ev.Start.Pid})
|
||||
|
||||
case *pb.ConnectProcessResponse_Data:
|
||||
switch o := ev.Data.Output.(type) {
|
||||
case *pb.ExecStreamData_Stdout:
|
||||
writeWSJSON(conn, wsProcessOut{Type: "stdout", Data: string(o.Stdout)})
|
||||
case *pb.ExecStreamData_Stderr:
|
||||
writeWSJSON(conn, wsProcessOut{Type: "stderr", Data: string(o.Stderr)})
|
||||
}
|
||||
|
||||
case *pb.ConnectProcessResponse_End:
|
||||
exitCode := ev.End.ExitCode
|
||||
writeWSJSON(conn, wsProcessOut{Type: "exit", ExitCode: &exitCode})
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil {
|
||||
if streamCtx.Err() == nil {
|
||||
sendProcessWSError(conn, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Update last active using a fresh context.
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last active after process stream", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func sendProcessWSError(conn *websocket.Conn, msg string) {
|
||||
writeWSJSON(conn, wsProcessOut{Type: "error", Data: msg})
|
||||
}
|
||||
400
internal/api/handlers_pty.go
Normal file
400
internal/api/handlers_pty.go
Normal file
@ -0,0 +1,400 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
const (
|
||||
ptyKeepaliveInterval = 30 * time.Second
|
||||
ptyDefaultCmd = "/bin/bash"
|
||||
ptyDefaultCols = 80
|
||||
ptyDefaultRows = 24
|
||||
)
|
||||
|
||||
type ptyHandler struct {
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
jwtSecret []byte
|
||||
}
|
||||
|
||||
func newPtyHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *ptyHandler {
|
||||
return &ptyHandler{db: db, pool: pool, jwtSecret: jwtSecret}
|
||||
}
|
||||
|
||||
// --- WebSocket message types ---
|
||||
|
||||
// wsPtyIn is the inbound message from the client.
|
||||
type wsPtyIn struct {
|
||||
Type string `json:"type"` // "start", "connect", "input", "resize", "kill"
|
||||
Cmd string `json:"cmd,omitempty"` // for "start"
|
||||
Args []string `json:"args,omitempty"` // for "start"
|
||||
Cols uint32 `json:"cols,omitempty"` // for "start", "resize"
|
||||
Rows uint32 `json:"rows,omitempty"` // for "start", "resize"
|
||||
Envs map[string]string `json:"envs,omitempty"` // for "start"
|
||||
Cwd string `json:"cwd,omitempty"` // for "start"
|
||||
User string `json:"user,omitempty"` // for "start"
|
||||
Tag string `json:"tag,omitempty"` // for "connect"
|
||||
Data string `json:"data,omitempty"` // for "input" (base64)
|
||||
}
|
||||
|
||||
// wsPtyOut is the outbound message to the client.
|
||||
type wsPtyOut struct {
|
||||
Type string `json:"type"` // "started", "output", "exit", "error"
|
||||
Tag string `json:"tag,omitempty"` // for "started"
|
||||
PID uint32 `json:"pid,omitempty"` // for "started"
|
||||
Data string `json:"data,omitempty"` // for "output" (base64), "error"
|
||||
ExitCode *int32 `json:"exit_code,omitempty"` // for "exit"
|
||||
Fatal bool `json:"fatal,omitempty"` // for "error"
|
||||
}
|
||||
|
||||
// wsWriter wraps a websocket.Conn with a mutex for concurrent writes.
|
||||
type wsWriter struct {
|
||||
conn *websocket.Conn
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (w *wsWriter) writeJSON(v any) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if err := w.conn.WriteJSON(v); err != nil {
|
||||
slog.Debug("pty websocket write error", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// PtySession handles WS /v1/capsules/{id}/pty.
|
||||
func (h *ptyHandler) PtySession(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
// API key auth is handled by middleware (sets context).
|
||||
// For browser JWT auth, we authenticate after upgrade via first WS message.
|
||||
ac, hasAuth := auth.FromContext(ctx)
|
||||
|
||||
if !hasAuth {
|
||||
// No pre-upgrade auth — upgrade first, then authenticate via WS message.
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("pty websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ws := &wsWriter{conn: conn}
|
||||
|
||||
var wsAC auth.AuthContext
|
||||
if isAdminWSRoute(ctx) {
|
||||
wsAC, err = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
|
||||
} else {
|
||||
wsAC, err = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
|
||||
}
|
||||
if err != nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "authentication failed", Fatal: true})
|
||||
return
|
||||
}
|
||||
ac = wsAC
|
||||
|
||||
h.runPtySession(ctx, ws, conn, ac, sandboxID, sandboxIDStr)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("pty websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ws := &wsWriter{conn: conn}
|
||||
h.runPtySession(ctx, ws, conn, ac, sandboxID, sandboxIDStr)
|
||||
}
|
||||
|
||||
func (h *ptyHandler) runPtySession(ctx context.Context, ws *wsWriter, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr string) {
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "sandbox not found", Fatal: true})
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "sandbox is not running (status: " + sb.Status + ")", Fatal: true})
|
||||
return
|
||||
}
|
||||
|
||||
// Read the first message to determine start vs connect.
|
||||
var firstMsg wsPtyIn
|
||||
if err := conn.ReadJSON(&firstMsg); err != nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "failed to read first message: " + err.Error(), Fatal: true})
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "sandbox host is not reachable", Fatal: true})
|
||||
return
|
||||
}
|
||||
|
||||
streamCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
switch firstMsg.Type {
|
||||
case "start":
|
||||
h.handleStart(streamCtx, cancel, ws, agent, sandboxIDStr, firstMsg)
|
||||
case "connect":
|
||||
h.handleConnect(streamCtx, cancel, ws, agent, sandboxIDStr, firstMsg)
|
||||
default:
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "first message must be type 'start' or 'connect'", Fatal: true})
|
||||
}
|
||||
|
||||
// Update last active using a fresh context.
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last active after pty session", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ptyHandler) handleStart(
|
||||
ctx context.Context,
|
||||
cancel context.CancelFunc,
|
||||
ws *wsWriter,
|
||||
agent hostagentv1connect.HostAgentServiceClient,
|
||||
sandboxIDStr string,
|
||||
msg wsPtyIn,
|
||||
) {
|
||||
cmd := msg.Cmd
|
||||
if cmd == "" {
|
||||
cmd = ptyDefaultCmd
|
||||
}
|
||||
cols := msg.Cols
|
||||
if cols == 0 {
|
||||
cols = ptyDefaultCols
|
||||
}
|
||||
rows := msg.Rows
|
||||
if rows == 0 {
|
||||
rows = ptyDefaultRows
|
||||
}
|
||||
|
||||
tag := newPtyTag()
|
||||
|
||||
stream, err := agent.PtyAttach(ctx, connect.NewRequest(&pb.PtyAttachRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Tag: tag,
|
||||
Cmd: cmd,
|
||||
Args: msg.Args,
|
||||
Cols: cols,
|
||||
Rows: rows,
|
||||
Envs: msg.Envs,
|
||||
Cwd: msg.Cwd,
|
||||
User: msg.User,
|
||||
}))
|
||||
if err != nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "failed to start pty: " + err.Error(), Fatal: true})
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
// Wait for the started event and forward it.
|
||||
if !stream.Receive() {
|
||||
if err := stream.Err(); err != nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "pty stream failed: " + err.Error(), Fatal: true})
|
||||
}
|
||||
return
|
||||
}
|
||||
resp := stream.Msg()
|
||||
started, ok := resp.Event.(*pb.PtyAttachResponse_Started)
|
||||
if !ok {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "expected started event from host agent", Fatal: true})
|
||||
return
|
||||
}
|
||||
ws.writeJSON(wsPtyOut{Type: "started", Tag: started.Started.Tag, PID: started.Started.Pid})
|
||||
|
||||
runPtyLoop(ctx, cancel, ws, stream, agent, sandboxIDStr, tag)
|
||||
}
|
||||
|
||||
func (h *ptyHandler) handleConnect(
|
||||
ctx context.Context,
|
||||
cancel context.CancelFunc,
|
||||
ws *wsWriter,
|
||||
agent hostagentv1connect.HostAgentServiceClient,
|
||||
sandboxIDStr string,
|
||||
msg wsPtyIn,
|
||||
) {
|
||||
if msg.Tag == "" {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "connect requires a 'tag' field", Fatal: true})
|
||||
return
|
||||
}
|
||||
|
||||
stream, err := agent.PtyAttach(ctx, connect.NewRequest(&pb.PtyAttachRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Tag: msg.Tag,
|
||||
}))
|
||||
if err != nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "failed to connect to pty: " + err.Error(), Fatal: true})
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
runPtyLoop(ctx, cancel, ws, stream, agent, sandboxIDStr, msg.Tag)
|
||||
}
|
||||
|
||||
// runPtyLoop drives the bidirectional communication between the WebSocket
|
||||
// and the host agent PTY stream.
|
||||
func runPtyLoop(
|
||||
ctx context.Context,
|
||||
cancel context.CancelFunc,
|
||||
ws *wsWriter,
|
||||
stream *connect.ServerStreamForClient[pb.PtyAttachResponse],
|
||||
agent hostagentv1connect.HostAgentServiceClient,
|
||||
sandboxID string,
|
||||
tag string,
|
||||
) {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Output pump: read from Connect stream, write to WebSocket.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer cancel()
|
||||
|
||||
for stream.Receive() {
|
||||
resp := stream.Msg()
|
||||
switch ev := resp.Event.(type) {
|
||||
case *pb.PtyAttachResponse_Started:
|
||||
// Already handled before the loop for "start" mode.
|
||||
// For "connect" mode this won't appear.
|
||||
ws.writeJSON(wsPtyOut{Type: "started", Tag: ev.Started.Tag, PID: ev.Started.Pid})
|
||||
|
||||
case *pb.PtyAttachResponse_Output:
|
||||
ws.writeJSON(wsPtyOut{
|
||||
Type: "output",
|
||||
Data: base64.StdEncoding.EncodeToString(ev.Output.Data),
|
||||
})
|
||||
|
||||
case *pb.PtyAttachResponse_Exited:
|
||||
exitCode := ev.Exited.ExitCode
|
||||
ws.writeJSON(wsPtyOut{Type: "exit", ExitCode: &exitCode})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil && ctx.Err() == nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: err.Error()})
|
||||
}
|
||||
}()
|
||||
|
||||
// Input pump: read from WebSocket, dispatch to host agent.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
_, raw, err := ws.conn.ReadMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var msg wsPtyIn
|
||||
if json.Unmarshal(raw, &msg) != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Use a background context for unary RPCs so they complete
|
||||
// even if the stream context is being cancelled.
|
||||
rpcCtx, rpcCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
|
||||
switch msg.Type {
|
||||
case "input":
|
||||
data, err := base64.StdEncoding.DecodeString(msg.Data)
|
||||
if err != nil {
|
||||
rpcCancel()
|
||||
continue
|
||||
}
|
||||
if _, err := agent.PtySendInput(rpcCtx, connect.NewRequest(&pb.PtySendInputRequest{
|
||||
SandboxId: sandboxID,
|
||||
Tag: tag,
|
||||
Data: data,
|
||||
})); err != nil {
|
||||
slog.Debug("pty send input error", "error", err)
|
||||
}
|
||||
|
||||
case "resize":
|
||||
cols := msg.Cols
|
||||
rows := msg.Rows
|
||||
if cols > 0 && rows > 0 {
|
||||
if _, err := agent.PtyResize(rpcCtx, connect.NewRequest(&pb.PtyResizeRequest{
|
||||
SandboxId: sandboxID,
|
||||
Tag: tag,
|
||||
Cols: cols,
|
||||
Rows: rows,
|
||||
})); err != nil {
|
||||
slog.Debug("pty resize error", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
case "kill":
|
||||
if _, err := agent.PtyKill(rpcCtx, connect.NewRequest(&pb.PtyKillRequest{
|
||||
SandboxId: sandboxID,
|
||||
Tag: tag,
|
||||
})); err != nil {
|
||||
slog.Debug("pty kill error", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
rpcCancel()
|
||||
}
|
||||
}()
|
||||
|
||||
// Keepalive pump: send periodic pings to prevent idle WS closure.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ticker := time.NewTicker(ptyKeepaliveInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
ws.writeJSON(wsPtyOut{Type: "ping"})
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// newPtyTag returns a PTY session tag: "pty-" + 8 random hex chars.
|
||||
func newPtyTag() string {
|
||||
return "pty-" + id.NewPtyTag()
|
||||
}
|
||||
@ -7,11 +7,11 @@ import (
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/service"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
)
|
||||
|
||||
type sandboxHandler struct {
|
||||
@ -31,18 +31,19 @@ type createSandboxRequest struct {
|
||||
}
|
||||
|
||||
type sandboxResponse struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Template string `json:"template"`
|
||||
VCPUs int32 `json:"vcpus"`
|
||||
MemoryMB int32 `json:"memory_mb"`
|
||||
TimeoutSec int32 `json:"timeout_sec"`
|
||||
GuestIP string `json:"guest_ip,omitempty"`
|
||||
HostIP string `json:"host_ip,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
StartedAt *string `json:"started_at,omitempty"`
|
||||
LastActiveAt *string `json:"last_active_at,omitempty"`
|
||||
LastUpdated string `json:"last_updated"`
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Template string `json:"template"`
|
||||
VCPUs int32 `json:"vcpus"`
|
||||
MemoryMB int32 `json:"memory_mb"`
|
||||
TimeoutSec int32 `json:"timeout_sec"`
|
||||
GuestIP string `json:"guest_ip,omitempty"`
|
||||
HostIP string `json:"host_ip,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
StartedAt *string `json:"started_at,omitempty"`
|
||||
LastActiveAt *string `json:"last_active_at,omitempty"`
|
||||
LastUpdated string `json:"last_updated"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
func sandboxToResponse(sb db.Sandbox) sandboxResponse {
|
||||
@ -56,6 +57,12 @@ func sandboxToResponse(sb db.Sandbox) sandboxResponse {
|
||||
GuestIP: sb.GuestIp,
|
||||
HostIP: sb.HostIp,
|
||||
}
|
||||
if len(sb.Metadata) > 0 {
|
||||
var meta map[string]string
|
||||
if err := json.Unmarshal(sb.Metadata, &meta); err == nil && len(meta) > 0 {
|
||||
resp.Metadata = meta
|
||||
}
|
||||
}
|
||||
if sb.CreatedAt.Valid {
|
||||
resp.CreatedAt = sb.CreatedAt.Time.Format(time.RFC3339)
|
||||
}
|
||||
@ -73,7 +80,7 @@ func sandboxToResponse(sb db.Sandbox) sandboxResponse {
|
||||
return resp
|
||||
}
|
||||
|
||||
// Create handles POST /v1/sandboxes.
|
||||
// Create handles POST /v1/capsules.
|
||||
func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
var req createSandboxRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
@ -104,7 +111,7 @@ func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusCreated, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// List handles GET /v1/sandboxes.
|
||||
// List handles GET /v1/capsules.
|
||||
func (h *sandboxHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
sandboxes, err := h.svc.List(r.Context(), ac.TeamID)
|
||||
@ -121,7 +128,7 @@ func (h *sandboxHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Get handles GET /v1/sandboxes/{id}.
|
||||
// Get handles GET /v1/capsules/{id}.
|
||||
func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
@ -141,7 +148,7 @@ func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// Pause handles POST /v1/sandboxes/{id}/pause.
|
||||
// Pause handles POST /v1/capsules/{id}/pause.
|
||||
func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
@ -163,7 +170,7 @@ func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// Resume handles POST /v1/sandboxes/{id}/resume.
|
||||
// Resume handles POST /v1/capsules/{id}/resume.
|
||||
func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
@ -185,7 +192,7 @@ func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// Ping handles POST /v1/sandboxes/{id}/ping.
|
||||
// Ping handles POST /v1/capsules/{id}/ping.
|
||||
func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
@ -205,7 +212,7 @@ func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// Destroy handles DELETE /v1/sandboxes/{id}.
|
||||
// Destroy handles DELETE /v1/capsules/{id}.
|
||||
func (h *sandboxHandler) Destroy(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
@ -13,14 +13,14 @@ import (
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/layout"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/service"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/validate"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/validate"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
@ -38,8 +38,8 @@ func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, pool *life
|
||||
// deleteSnapshotBroadcast attempts to delete snapshot files on all online hosts.
|
||||
// Snapshots aren't currently host-tracked in the DB, so we broadcast to all hosts
|
||||
// and ignore NotFound errors.
|
||||
func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, teamID, templateID pgtype.UUID) error {
|
||||
hosts, err := h.db.ListActiveHosts(ctx)
|
||||
func deleteSnapshotBroadcast(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, teamID, templateID pgtype.UUID) error {
|
||||
hosts, err := queries.ListActiveHosts(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list hosts: %w", err)
|
||||
}
|
||||
@ -47,7 +47,7 @@ func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, teamID, t
|
||||
if host.Status != "online" {
|
||||
continue
|
||||
}
|
||||
agent, err := h.pool.GetForHost(host)
|
||||
agent, err := pool.GetForHost(host)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@ -69,13 +69,14 @@ type createSnapshotRequest struct {
|
||||
}
|
||||
|
||||
type snapshotResponse struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
VCPUs *int32 `json:"vcpus,omitempty"`
|
||||
MemoryMB *int32 `json:"memory_mb,omitempty"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Platform bool `json:"platform"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
VCPUs *int32 `json:"vcpus,omitempty"`
|
||||
MemoryMB *int32 `json:"memory_mb,omitempty"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Platform bool `json:"platform"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
func templateToResponse(t db.Template) snapshotResponse {
|
||||
@ -94,6 +95,12 @@ func templateToResponse(t db.Template) snapshotResponse {
|
||||
if t.CreatedAt.Valid {
|
||||
resp.CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
|
||||
}
|
||||
if len(t.Metadata) > 0 {
|
||||
var meta map[string]string
|
||||
if err := json.Unmarshal(t.Metadata, &meta); err == nil && len(meta) > 0 {
|
||||
resp.Metadata = meta
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
@ -126,7 +133,6 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
overwrite := r.URL.Query().Get("overwrite") == "true"
|
||||
|
||||
// Check for global name collision.
|
||||
if _, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
|
||||
@ -135,20 +141,10 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Check if name already exists for this team.
|
||||
if existing, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
|
||||
if !overwrite {
|
||||
writeError(w, http.StatusConflict, "already_exists", "snapshot name already exists; use ?overwrite=true to replace")
|
||||
return
|
||||
}
|
||||
// Delete old snapshot files from all hosts before removing the DB record.
|
||||
if err := h.deleteSnapshotBroadcast(ctx, existing.TeamID, existing.ID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files")
|
||||
return
|
||||
}
|
||||
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to remove existing template record")
|
||||
return
|
||||
}
|
||||
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
|
||||
writeError(w, http.StatusConflict, "template_name_taken",
|
||||
"snapshot name already exists; delete the existing snapshot first to reuse this name")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify sandbox exists, belongs to team, and is running or paused.
|
||||
@ -210,13 +206,16 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{
|
||||
ID: newTemplateID,
|
||||
Name: req.Name,
|
||||
Type: "snapshot",
|
||||
Vcpus: sb.Vcpus,
|
||||
MemoryMb: sb.MemoryMb,
|
||||
SizeBytes: resp.Msg.SizeBytes,
|
||||
TeamID: ac.TeamID,
|
||||
ID: newTemplateID,
|
||||
Name: req.Name,
|
||||
Type: "snapshot",
|
||||
Vcpus: sb.Vcpus,
|
||||
MemoryMb: sb.MemoryMb,
|
||||
SizeBytes: resp.Msg.SizeBytes,
|
||||
TeamID: ac.TeamID,
|
||||
DefaultUser: "root",
|
||||
DefaultEnv: []byte("{}"),
|
||||
Metadata: sb.Metadata,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("failed to insert template record", "name", req.Name, "error", err)
|
||||
@ -277,7 +276,7 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.deleteSnapshotBroadcast(ctx, tmpl.TeamID, tmpl.ID); err != nil {
|
||||
if err := deleteSnapshotBroadcast(ctx, h.db, h.pool, tmpl.TeamID, tmpl.ID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete snapshot files")
|
||||
return
|
||||
}
|
||||
|
||||
@ -5,8 +5,8 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/service"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
)
|
||||
|
||||
type statsHandler struct {
|
||||
@ -43,7 +43,7 @@ type statsResponse struct {
|
||||
Series statsSeriesResponse `json:"series"`
|
||||
}
|
||||
|
||||
// GetStats handles GET /v1/sandboxes/stats?range=5m|1h|6h|24h|30d
|
||||
// GetStats handles GET /v1/capsules/stats?range=5m|1h|6h|24h|30d
|
||||
func (h *statsHandler) GetStats(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
@ -9,20 +11,22 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/service"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/email"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
)
|
||||
|
||||
type teamHandler struct {
|
||||
svc *service.TeamService
|
||||
audit *audit.AuditLogger
|
||||
svc *service.TeamService
|
||||
audit *audit.AuditLogger
|
||||
mailer email.Mailer
|
||||
}
|
||||
|
||||
func newTeamHandler(svc *service.TeamService, al *audit.AuditLogger) *teamHandler {
|
||||
return &teamHandler{svc: svc, audit: al}
|
||||
func newTeamHandler(svc *service.TeamService, al *audit.AuditLogger, mailer email.Mailer) *teamHandler {
|
||||
return &teamHandler{svc: svc, audit: al, mailer: mailer}
|
||||
}
|
||||
|
||||
// teamResponse is the JSON shape for a team.
|
||||
@ -131,6 +135,15 @@ func (h *teamHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := h.mailer.Send(context.Background(), ac.Email, "Your team has been created", email.EmailData{
|
||||
RecipientName: ac.Name,
|
||||
Message: fmt.Sprintf("Your team \"%s\" has been created on Wrenn. You can now invite members and start creating sandboxes under this team.", req.Name),
|
||||
}); err != nil {
|
||||
slog.Warn("failed to send team created email", "email", ac.Email, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
writeJSON(w, http.StatusCreated, teamWithRoleResponse{
|
||||
teamResponse: teamToResponse(team.Team),
|
||||
Role: team.Role,
|
||||
@ -279,6 +292,21 @@ func (h *teamHandler) AddMember(w http.ResponseWriter, r *http.Request) {
|
||||
if parseErr == nil {
|
||||
h.audit.LogMemberAdd(r.Context(), ac, targetUserID, member.Email, member.Role)
|
||||
}
|
||||
|
||||
go func() {
|
||||
team, err := h.svc.GetTeam(context.Background(), teamID)
|
||||
teamName := "a team"
|
||||
if err == nil {
|
||||
teamName = team.Name
|
||||
}
|
||||
if err := h.mailer.Send(context.Background(), member.Email, "You've been added to a team on Wrenn", email.EmailData{
|
||||
RecipientName: member.Name,
|
||||
Message: fmt.Sprintf("%s has added you to the team \"%s\" on Wrenn.", ac.Name, teamName),
|
||||
}); err != nil {
|
||||
slog.Warn("failed to send team invitation email", "email", member.Email, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
writeJSON(w, http.StatusCreated, memberInfoToResponse(member))
|
||||
}
|
||||
|
||||
@ -388,3 +416,87 @@ func (h *teamHandler) SetBYOC(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// AdminListTeams handles GET /v1/admin/teams?page=1
|
||||
// Returns a paginated list of all teams with member counts, owner info, and active sandbox counts.
|
||||
func (h *teamHandler) AdminListTeams(w http.ResponseWriter, r *http.Request) {
|
||||
page := 1
|
||||
if p := r.URL.Query().Get("page"); p != "" {
|
||||
if _, err := fmt.Sscanf(p, "%d", &page); err != nil || page < 1 {
|
||||
page = 1
|
||||
}
|
||||
}
|
||||
const perPage = 100
|
||||
offset := int32((page - 1) * perPage)
|
||||
|
||||
teams, total, err := h.svc.AdminListTeams(r.Context(), perPage, offset)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
type adminTeamResponse struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
IsByoc bool `json:"is_byoc"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
DeletedAt *string `json:"deleted_at"`
|
||||
MemberCount int32 `json:"member_count"`
|
||||
OwnerName string `json:"owner_name"`
|
||||
OwnerEmail string `json:"owner_email"`
|
||||
ActiveSandboxCount int32 `json:"active_sandbox_count"`
|
||||
ChannelCount int32 `json:"channel_count"`
|
||||
}
|
||||
|
||||
resp := make([]adminTeamResponse, len(teams))
|
||||
for i, t := range teams {
|
||||
r := adminTeamResponse{
|
||||
ID: id.FormatTeamID(t.ID),
|
||||
Name: t.Name,
|
||||
Slug: t.Slug,
|
||||
IsByoc: t.IsByoc,
|
||||
CreatedAt: t.CreatedAt.Format(time.RFC3339),
|
||||
MemberCount: t.MemberCount,
|
||||
OwnerName: t.OwnerName,
|
||||
OwnerEmail: t.OwnerEmail,
|
||||
ActiveSandboxCount: t.ActiveSandboxCount,
|
||||
ChannelCount: t.ChannelCount,
|
||||
}
|
||||
if t.DeletedAt != nil {
|
||||
s := t.DeletedAt.Format(time.RFC3339)
|
||||
r.DeletedAt = &s
|
||||
}
|
||||
resp[i] = r
|
||||
}
|
||||
|
||||
totalPages := (total + perPage - 1) / perPage
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"teams": resp,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": perPage,
|
||||
"total_pages": totalPages,
|
||||
})
|
||||
}
|
||||
|
||||
// AdminDeleteTeam handles DELETE /v1/admin/teams/{id}
|
||||
// Soft-deletes a team and destroys all its active sandboxes.
|
||||
func (h *teamHandler) AdminDeleteTeam(w http.ResponseWriter, r *http.Request) {
|
||||
teamIDStr := chi.URLParam(r, "id")
|
||||
|
||||
teamID, err := id.ParseTeamID(teamIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.AdminDeleteTeam(r.Context(), teamID); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
@ -1,22 +1,27 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
)
|
||||
|
||||
type usersHandler struct {
|
||||
db *db.Queries
|
||||
db *db.Queries
|
||||
svc *service.UserService
|
||||
}
|
||||
|
||||
func newUsersHandler(db *db.Queries) *usersHandler {
|
||||
return &usersHandler{db: db}
|
||||
func newUsersHandler(db *db.Queries, svc *service.UserService) *usersHandler {
|
||||
return &usersHandler{db: db, svc: svc}
|
||||
}
|
||||
|
||||
// Search handles GET /v1/users/search?email=<prefix>
|
||||
@ -50,3 +55,96 @@ func (h *usersHandler) Search(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// AdminListUsers handles GET /v1/admin/users?page=1
|
||||
// Returns a paginated list of all users with team counts.
|
||||
func (h *usersHandler) AdminListUsers(w http.ResponseWriter, r *http.Request) {
|
||||
page := 1
|
||||
if p := r.URL.Query().Get("page"); p != "" {
|
||||
if _, err := fmt.Sscanf(p, "%d", &page); err != nil || page < 1 {
|
||||
page = 1
|
||||
}
|
||||
}
|
||||
const perPage = 100
|
||||
offset := int32((page - 1) * perPage)
|
||||
|
||||
users, total, err := h.svc.AdminListUsers(r.Context(), perPage, offset)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
type adminUserResponse struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
TeamsJoined int32 `json:"teams_joined"`
|
||||
TeamsOwned int32 `json:"teams_owned"`
|
||||
}
|
||||
|
||||
resp := make([]adminUserResponse, len(users))
|
||||
for i, u := range users {
|
||||
resp[i] = adminUserResponse{
|
||||
ID: id.FormatUserID(u.ID),
|
||||
Email: u.Email,
|
||||
Name: u.Name,
|
||||
IsAdmin: u.IsAdmin,
|
||||
Status: u.Status,
|
||||
CreatedAt: u.CreatedAt.Format(time.RFC3339),
|
||||
TeamsJoined: u.TeamsJoined,
|
||||
TeamsOwned: u.TeamsOwned,
|
||||
}
|
||||
}
|
||||
|
||||
totalPages := (total + perPage - 1) / perPage
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"users": resp,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": perPage,
|
||||
"total_pages": totalPages,
|
||||
})
|
||||
}
|
||||
|
||||
// SetUserActive handles PUT /v1/admin/users/{id}/active
|
||||
// Enables or disables a user account. Admins cannot deactivate themselves.
|
||||
func (h *usersHandler) SetUserActive(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
userIDStr := chi.URLParam(r, "id")
|
||||
|
||||
userID, err := id.ParseUserID(userIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if ac.UserID == userID && !req.Active {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "cannot deactivate your own account")
|
||||
return
|
||||
}
|
||||
|
||||
newStatus := "active"
|
||||
if !req.Active {
|
||||
newStatus = "disabled"
|
||||
}
|
||||
|
||||
if err := h.svc.SetUserStatus(r.Context(), userID, newStatus); err != nil {
|
||||
httpStatus, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, httpStatus, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
109
internal/api/helpers_ws.go
Normal file
109
internal/api/helpers_ws.go
Normal file
@ -0,0 +1,109 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// isWebSocketUpgrade returns true if the request is a WebSocket upgrade.
|
||||
func isWebSocketUpgrade(r *http.Request) bool {
|
||||
return strings.EqualFold(r.Header.Get("Upgrade"), "websocket")
|
||||
}
|
||||
|
||||
// ctxKeyAdminWS is a context key for flagging admin WS routes.
|
||||
type ctxKeyAdminWS struct{}
|
||||
|
||||
// setAdminWSFlag marks the context as an admin WebSocket route.
|
||||
func setAdminWSFlag(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, ctxKeyAdminWS{}, true)
|
||||
}
|
||||
|
||||
// isAdminWSRoute checks if the request context was marked as admin WS.
|
||||
func isAdminWSRoute(ctx context.Context) bool {
|
||||
v, _ := ctx.Value(ctxKeyAdminWS{}).(bool)
|
||||
return v
|
||||
}
|
||||
|
||||
// wsAuthMsg is the first message a browser WS client sends to authenticate.
|
||||
type wsAuthMsg struct {
|
||||
Type string `json:"type"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
// wsAuthenticate reads a JWT auth message from the WebSocket and returns the
|
||||
// authenticated context. The caller must send this as the first message after
|
||||
// connecting.
|
||||
func wsAuthenticate(ctx context.Context, conn *websocket.Conn, jwtSecret []byte, queries *db.Queries) (auth.AuthContext, error) {
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
var msg wsAuthMsg
|
||||
if err := conn.ReadJSON(&msg); err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("read auth message: %w", err)
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Time{}) // clear deadline
|
||||
|
||||
if msg.Type != "auth" || msg.Token == "" {
|
||||
return auth.AuthContext{}, fmt.Errorf("first message must be type 'auth' with a token")
|
||||
}
|
||||
|
||||
claims, err := auth.VerifyJWT(jwtSecret, msg.Token)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("invalid or expired token: %w", err)
|
||||
}
|
||||
|
||||
teamID, err := id.ParseTeamID(claims.TeamID)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("invalid team ID in token: %w", err)
|
||||
}
|
||||
|
||||
userID, err := id.ParseUserID(claims.Subject)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("invalid user ID in token: %w", err)
|
||||
}
|
||||
|
||||
user, err := queries.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("user not found")
|
||||
}
|
||||
if user.Status != "active" {
|
||||
return auth.AuthContext{}, fmt.Errorf("account deactivated")
|
||||
}
|
||||
|
||||
return auth.AuthContext{
|
||||
TeamID: teamID,
|
||||
UserID: userID,
|
||||
Email: claims.Email,
|
||||
Name: claims.Name,
|
||||
Role: claims.Role,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// wsAuthenticateAdmin performs WS-based auth and verifies admin status,
|
||||
// returning an AuthContext with the platform team ID.
|
||||
func wsAuthenticateAdmin(ctx context.Context, conn *websocket.Conn, jwtSecret []byte, queries *db.Queries) (auth.AuthContext, error) {
|
||||
ac, err := wsAuthenticate(ctx, conn, jwtSecret, queries)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, err
|
||||
}
|
||||
|
||||
user, err := queries.GetUserByID(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("user not found")
|
||||
}
|
||||
if !user.IsAdmin {
|
||||
return auth.AuthContext{}, fmt.Errorf("admin access required")
|
||||
}
|
||||
|
||||
ac.TeamID = id.PlatformTeamID
|
||||
return ac, nil
|
||||
}
|
||||
@ -8,10 +8,10 @@ import (
|
||||
"connectrpc.com/connect"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ import (
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
)
|
||||
|
||||
// MetricsSampler records per-team sandbox resource usage to
|
||||
|
||||
@ -14,7 +14,7 @@ import (
|
||||
"connectrpc.com/connect"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
type errorResponse struct {
|
||||
@ -50,8 +50,12 @@ func agentErrToHTTP(err error) (int, string, string) {
|
||||
return http.StatusNotFound, "not_found", err.Error()
|
||||
case connect.CodeInvalidArgument:
|
||||
return http.StatusBadRequest, "invalid_request", err.Error()
|
||||
case connect.CodeFailedPrecondition:
|
||||
case connect.CodeFailedPrecondition, connect.CodeAlreadyExists:
|
||||
return http.StatusConflict, "conflict", err.Error()
|
||||
case connect.CodePermissionDenied:
|
||||
return http.StatusForbidden, "forbidden", err.Error()
|
||||
case connect.CodeUnimplemented:
|
||||
return http.StatusNotImplemented, "agent_error", err.Error()
|
||||
default:
|
||||
return http.StatusBadGateway, "agent_error", err.Error()
|
||||
}
|
||||
@ -90,21 +94,25 @@ func serviceErrToHTTP(err error) (int, string, string) {
|
||||
}
|
||||
|
||||
// Map well-known service error patterns.
|
||||
// Return generic messages for most cases to avoid leaking internal details.
|
||||
switch {
|
||||
case strings.Contains(msg, "not found"):
|
||||
return http.StatusNotFound, "not_found", msg
|
||||
case strings.Contains(msg, "not running"), strings.Contains(msg, "not paused"):
|
||||
return http.StatusConflict, "invalid_state", msg
|
||||
return http.StatusNotFound, "not_found", "resource not found"
|
||||
case strings.Contains(msg, "not running"):
|
||||
return http.StatusConflict, "invalid_state", "resource is not running"
|
||||
case strings.Contains(msg, "not paused"):
|
||||
return http.StatusConflict, "invalid_state", "resource is not paused"
|
||||
case strings.Contains(msg, "conflict:"):
|
||||
return http.StatusConflict, "conflict", msg
|
||||
return http.StatusConflict, "conflict", strings.TrimPrefix(msg, "conflict: ")
|
||||
case strings.Contains(msg, "forbidden"):
|
||||
return http.StatusForbidden, "forbidden", msg
|
||||
return http.StatusForbidden, "forbidden", "forbidden"
|
||||
case strings.Contains(msg, "invalid or expired"):
|
||||
return http.StatusUnauthorized, "unauthorized", msg
|
||||
return http.StatusUnauthorized, "unauthorized", "invalid or expired credentials"
|
||||
case strings.Contains(msg, "invalid"):
|
||||
return http.StatusBadRequest, "invalid_request", msg
|
||||
return http.StatusBadRequest, "invalid_request", "invalid request"
|
||||
default:
|
||||
return http.StatusInternalServerError, "internal_error", msg
|
||||
slog.Error("unhandled service error", "error", err)
|
||||
return http.StatusInternalServerError, "internal_error", "an internal error occurred"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -3,19 +3,47 @@ package api
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// injectPlatformTeam overwrites the AuthContext's TeamID with the platform
|
||||
// sentinel UUID. This lets existing team-scoped handlers (exec, files, pty,
|
||||
// metrics) work unchanged under admin routes. Must run after requireAdmin.
|
||||
func injectPlatformTeam() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if _, ok := auth.FromContext(r.Context()); !ok {
|
||||
// No auth context yet (WS upgrade); handler will inject platform team after WS auth.
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
ac.TeamID = id.PlatformTeamID
|
||||
ctx := auth.WithAuthContext(r.Context(), ac)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// requireAdmin validates that the authenticated user is a platform admin.
|
||||
// Must run after requireJWT (depends on AuthContext being present).
|
||||
// Re-validates against the DB — the JWT is_admin claim is for UI only;
|
||||
// the DB is the source of truth for admin access.
|
||||
// WebSocket upgrade requests without auth context are passed through —
|
||||
// admin WS handlers verify admin status after upgrade via wsAuthenticateAdmin.
|
||||
func requireAdmin(queries *db.Queries) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ac, ok := auth.FromContext(r.Context())
|
||||
if !ok {
|
||||
if isWebSocketUpgrade(r) {
|
||||
ctx := r.Context()
|
||||
ctx = setAdminWSFlag(ctx)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "authentication required")
|
||||
return
|
||||
}
|
||||
|
||||
@ -5,9 +5,9 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// requireAPIKeyOrJWT accepts either X-API-Key header or Authorization: Bearer JWT.
|
||||
@ -38,9 +38,12 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
|
||||
return
|
||||
}
|
||||
|
||||
// Try JWT bearer token.
|
||||
// Try JWT bearer token from Authorization header.
|
||||
tokenStr := ""
|
||||
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
|
||||
tokenStr := strings.TrimPrefix(header, "Bearer ")
|
||||
tokenStr = strings.TrimPrefix(header, "Bearer ")
|
||||
}
|
||||
if tokenStr != "" {
|
||||
claims, err := auth.VerifyJWT(jwtSecret, tokenStr)
|
||||
if err != nil {
|
||||
slog.Warn("jwt auth failed", "error", err, "ip", r.RemoteAddr)
|
||||
@ -59,6 +62,18 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
|
||||
return
|
||||
}
|
||||
|
||||
// Verify user is still active in the database.
|
||||
user, err := queries.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
slog.Warn("jwt auth: failed to look up user", "user_id", claims.Subject, "error", err)
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "user not found")
|
||||
return
|
||||
}
|
||||
if user.Status != "active" {
|
||||
writeError(w, http.StatusForbidden, "account_deactivated", "your account has been deactivated — contact your administrator to regain access")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: teamID,
|
||||
UserID: userID,
|
||||
@ -70,7 +85,15 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
|
||||
return
|
||||
}
|
||||
|
||||
// WebSocket upgrade requests may not carry auth headers (browsers
|
||||
// cannot set custom headers on WS connections). Pass through —
|
||||
// the WS handler authenticates via the first message after upgrade.
|
||||
if isWebSocketUpgrade(r) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key or Authorization: Bearer <token> required")
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3,8 +3,8 @@ package api
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// requireHostToken validates the X-Host-Token header containing a host JWT,
|
||||
|
||||
@ -1,25 +1,37 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// requireJWT validates the Authorization: Bearer <token> header, verifies the JWT
|
||||
// signature and expiry, and stamps UserID + TeamID + Email into the request context.
|
||||
func requireJWT(secret []byte) func(http.Handler) http.Handler {
|
||||
// requireJWT validates a JWT from the Authorization: Bearer header.
|
||||
// It also verifies the user is still active in the database.
|
||||
// WebSocket upgrade requests without an Authorization header are passed through
|
||||
// — WS handlers authenticate via the first message after upgrade.
|
||||
func requireJWT(secret []byte, queries *db.Queries) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
header := r.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(header, "Bearer ") {
|
||||
var tokenStr string
|
||||
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
|
||||
tokenStr = strings.TrimPrefix(header, "Bearer ")
|
||||
}
|
||||
if tokenStr == "" {
|
||||
// WebSocket upgrade requests may not have an Authorization header
|
||||
// (browsers cannot set custom headers on WS connections). Let them
|
||||
// through — the handler authenticates via the first WS message.
|
||||
if isWebSocketUpgrade(r) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "Authorization: Bearer <token> required")
|
||||
return
|
||||
}
|
||||
|
||||
tokenStr := strings.TrimPrefix(header, "Bearer ")
|
||||
claims, err := auth.VerifyJWT(secret, tokenStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired token")
|
||||
@ -37,6 +49,18 @@ func requireJWT(secret []byte) func(http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
// Verify user is still active in the database.
|
||||
user, err := queries.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
slog.Warn("jwt auth: failed to look up user", "user_id", claims.Subject, "error", err)
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "user not found")
|
||||
return
|
||||
}
|
||||
if user.Status != "active" {
|
||||
writeError(w, http.StatusForbidden, "account_deactivated", "your account has been deactivated — contact your administrator to regain access")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: teamID,
|
||||
UserID: userID,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -9,14 +9,16 @@ import (
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth/oauth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/channels"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/scheduler"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/service"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/email"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/oauth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/channels"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/scheduler"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
)
|
||||
|
||||
//go:embed openapi.yaml
|
||||
@ -26,9 +28,12 @@ var openapiYAML []byte
|
||||
type Server struct {
|
||||
router chi.Router
|
||||
BuildSvc *service.BuildService
|
||||
version string
|
||||
}
|
||||
|
||||
// New constructs the chi router and registers all routes.
|
||||
// Extensions are called after core routes are registered, allowing enterprise
|
||||
// or third-party code to add routes and middleware.
|
||||
func New(
|
||||
queries *db.Queries,
|
||||
pool *lifecycle.HostClientPool,
|
||||
@ -41,6 +46,10 @@ func New(
|
||||
ca *auth.CA,
|
||||
al *audit.AuditLogger,
|
||||
channelSvc *channels.Service,
|
||||
mailer email.Mailer,
|
||||
extensions []cpextension.Extension,
|
||||
sctx cpextension.ServerContext,
|
||||
version string,
|
||||
) *Server {
|
||||
r := chi.NewRouter()
|
||||
r.Use(requestLogger())
|
||||
@ -51,27 +60,39 @@ func New(
|
||||
templateSvc := &service.TemplateService{DB: queries}
|
||||
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool, CA: ca}
|
||||
teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool}
|
||||
userSvc := &service.UserService{DB: queries, SandboxSvc: sandboxSvc}
|
||||
auditSvc := &service.AuditService{DB: queries}
|
||||
statsSvc := &service.StatsService{DB: queries, Pool: pgPool}
|
||||
buildSvc := &service.BuildService{DB: queries, Redis: rdb, Pool: pool, Scheduler: sched}
|
||||
|
||||
sandbox := newSandboxHandler(sandboxSvc, al)
|
||||
exec := newExecHandler(queries, pool)
|
||||
execStream := newExecStreamHandler(queries, pool)
|
||||
execStream := newExecStreamHandler(queries, pool, jwtSecret)
|
||||
files := newFilesHandler(queries, pool)
|
||||
filesStream := newFilesStreamHandler(queries, pool)
|
||||
fsH := newFSHandler(queries, pool)
|
||||
snapshots := newSnapshotHandler(templateSvc, queries, pool, al)
|
||||
authH := newAuthHandler(queries, pgPool, jwtSecret)
|
||||
authH := newAuthHandler(queries, pgPool, jwtSecret, mailer, rdb, oauthRedirectURL)
|
||||
oauthH := newOAuthHandler(queries, pgPool, jwtSecret, oauthRegistry, oauthRedirectURL)
|
||||
apiKeys := newAPIKeyHandler(apiKeySvc, al)
|
||||
hostH := newHostHandler(hostSvc, queries, al)
|
||||
teamH := newTeamHandler(teamSvc, al)
|
||||
usersH := newUsersHandler(queries)
|
||||
teamH := newTeamHandler(teamSvc, al, mailer)
|
||||
usersH := newUsersHandler(queries, userSvc)
|
||||
auditH := newAuditHandler(auditSvc)
|
||||
statsH := newStatsHandler(statsSvc)
|
||||
metricsH := newSandboxMetricsHandler(queries, pool)
|
||||
buildH := newBuildHandler(buildSvc, queries, pool)
|
||||
channelH := newChannelHandler(channelSvc, al)
|
||||
ptyH := newPtyHandler(queries, pool, jwtSecret)
|
||||
processH := newProcessHandler(queries, pool, jwtSecret)
|
||||
adminCapsules := newAdminCapsuleHandler(sandboxSvc, queries, pool, al)
|
||||
meH := newMeHandler(queries, pgPool, rdb, jwtSecret, mailer, oauthRegistry, oauthRedirectURL, teamSvc)
|
||||
|
||||
// Health check.
|
||||
r.Get("/health", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprintf(w, `{"status":"ok","version":%q}`, version)
|
||||
})
|
||||
|
||||
// OpenAPI spec and docs.
|
||||
r.Get("/openapi.yaml", serveOpenAPI)
|
||||
@ -80,15 +101,31 @@ func New(
|
||||
// Unauthenticated auth endpoints.
|
||||
r.Post("/v1/auth/signup", authH.Signup)
|
||||
r.Post("/v1/auth/login", authH.Login)
|
||||
r.Post("/v1/auth/activate", authH.Activate)
|
||||
r.Get("/auth/oauth/{provider}", oauthH.Redirect)
|
||||
r.Get("/auth/oauth/{provider}/callback", oauthH.Callback)
|
||||
|
||||
// Unauthenticated: password reset request and confirmation.
|
||||
r.Post("/v1/me/password/reset", meH.RequestPasswordReset)
|
||||
r.Post("/v1/me/password/reset/confirm", meH.ConfirmPasswordReset)
|
||||
|
||||
// JWT-authenticated: self-service account management.
|
||||
r.Route("/v1/me", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Get("/", meH.GetMe)
|
||||
r.Patch("/", meH.UpdateName)
|
||||
r.Post("/password", meH.ChangePassword)
|
||||
r.Get("/providers/{provider}/connect", meH.ConnectProvider)
|
||||
r.Delete("/providers/{provider}", meH.DisconnectProvider)
|
||||
r.Delete("/", meH.DeleteAccount)
|
||||
})
|
||||
|
||||
// JWT-authenticated: switch active team.
|
||||
r.With(requireJWT(jwtSecret)).Post("/v1/auth/switch-team", authH.SwitchTeam)
|
||||
r.With(requireJWT(jwtSecret, queries)).Post("/v1/auth/switch-team", authH.SwitchTeam)
|
||||
|
||||
// JWT-authenticated: API key management.
|
||||
r.Route("/v1/api-keys", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret))
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Post("/", apiKeys.Create)
|
||||
r.Get("/", apiKeys.List)
|
||||
r.Delete("/{id}", apiKeys.Delete)
|
||||
@ -96,7 +133,7 @@ func New(
|
||||
|
||||
// JWT-authenticated: team management.
|
||||
r.Route("/v1/teams", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret))
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Get("/", teamH.List)
|
||||
r.Post("/", teamH.Create)
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
@ -112,10 +149,12 @@ func New(
|
||||
})
|
||||
|
||||
// JWT-authenticated: user search (for add-member UI).
|
||||
r.With(requireJWT(jwtSecret)).Get("/v1/users/search", usersH.Search)
|
||||
r.With(requireJWT(jwtSecret, queries)).Get("/v1/users/search", usersH.Search)
|
||||
|
||||
// Sandbox lifecycle: accepts API key or JWT bearer token.
|
||||
r.Route("/v1/sandboxes", func(r chi.Router) {
|
||||
// Capsule lifecycle: accepts API key or JWT bearer token.
|
||||
// WebSocket upgrade requests without auth headers are passed through by
|
||||
// requireAPIKeyOrJWT — the WS handlers authenticate via first message.
|
||||
r.Route("/v1/capsules", func(r chi.Router) {
|
||||
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
|
||||
r.Post("/", sandbox.Create)
|
||||
r.Get("/", sandbox.List)
|
||||
@ -133,7 +172,14 @@ func New(
|
||||
r.Post("/files/read", files.Download)
|
||||
r.Post("/files/stream/write", filesStream.StreamUpload)
|
||||
r.Post("/files/stream/read", filesStream.StreamDownload)
|
||||
r.Post("/files/list", fsH.ListDir)
|
||||
r.Post("/files/mkdir", fsH.MakeDir)
|
||||
r.Post("/files/remove", fsH.Remove)
|
||||
r.Get("/metrics", metricsH.GetMetrics)
|
||||
r.Get("/pty", ptyH.PtySession)
|
||||
r.Get("/processes", processH.ListProcesses)
|
||||
r.Delete("/processes/{selector}", processH.KillProcess)
|
||||
r.Get("/processes/{selector}/stream", processH.ConnectProcess)
|
||||
})
|
||||
})
|
||||
|
||||
@ -158,7 +204,7 @@ func New(
|
||||
|
||||
// JWT-authenticated: host CRUD and tags.
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret))
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Post("/", hostH.Create)
|
||||
r.Get("/", hostH.List)
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
@ -175,7 +221,7 @@ func New(
|
||||
|
||||
// JWT-authenticated: notification channels.
|
||||
r.Route("/v1/channels", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret))
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Post("/", channelH.Create)
|
||||
r.Get("/", channelH.List)
|
||||
r.Post("/test", channelH.Test)
|
||||
@ -188,22 +234,51 @@ func New(
|
||||
})
|
||||
|
||||
// JWT-authenticated: audit log.
|
||||
r.With(requireJWT(jwtSecret)).Get("/v1/audit-logs", auditH.List)
|
||||
r.With(requireJWT(jwtSecret, queries)).Get("/v1/audit-logs", auditH.List)
|
||||
|
||||
// Platform admin routes — require JWT + DB-validated admin status.
|
||||
r.Route("/v1/admin", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret))
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Use(requireAdmin(queries))
|
||||
r.Get("/teams", teamH.AdminListTeams)
|
||||
r.Put("/teams/{id}/byoc", teamH.SetBYOC)
|
||||
r.Delete("/teams/{id}", teamH.AdminDeleteTeam)
|
||||
r.Get("/users", usersH.AdminListUsers)
|
||||
r.Put("/users/{id}/active", usersH.SetUserActive)
|
||||
r.Get("/templates", buildH.ListTemplates)
|
||||
r.Delete("/templates/{name}", buildH.DeleteTemplate)
|
||||
r.Post("/builds", buildH.Create)
|
||||
r.Get("/builds", buildH.List)
|
||||
r.Get("/builds/{id}", buildH.Get)
|
||||
r.Post("/builds/{id}/cancel", buildH.Cancel)
|
||||
r.Post("/capsules", adminCapsules.Create)
|
||||
r.Get("/capsules", adminCapsules.List)
|
||||
r.Route("/capsules/{id}", func(r chi.Router) {
|
||||
r.Use(injectPlatformTeam())
|
||||
r.Get("/", adminCapsules.Get)
|
||||
r.Delete("/", adminCapsules.Destroy)
|
||||
r.Post("/snapshot", adminCapsules.Snapshot)
|
||||
r.Post("/exec", exec.Exec)
|
||||
r.Get("/exec/stream", execStream.ExecStream)
|
||||
r.Post("/files/write", files.Upload)
|
||||
r.Post("/files/read", files.Download)
|
||||
r.Post("/files/list", fsH.ListDir)
|
||||
r.Post("/files/mkdir", fsH.MakeDir)
|
||||
r.Post("/files/remove", fsH.Remove)
|
||||
r.Get("/metrics", metricsH.GetMetrics)
|
||||
r.Get("/pty", ptyH.PtySession)
|
||||
r.Get("/processes", processH.ListProcesses)
|
||||
r.Delete("/processes/{selector}", processH.KillProcess)
|
||||
r.Get("/processes/{selector}/stream", processH.ConnectProcess)
|
||||
})
|
||||
})
|
||||
|
||||
return &Server{router: r, BuildSvc: buildSvc}
|
||||
// Let extensions register their routes after all core routes.
|
||||
for _, ext := range extensions {
|
||||
ext.RegisterRoutes(r, sctx)
|
||||
}
|
||||
|
||||
return &Server{router: r, BuildSvc: buildSvc, version: version}
|
||||
}
|
||||
|
||||
// Handler returns the HTTP handler.
|
||||
@ -211,6 +286,11 @@ func (s *Server) Handler() http.Handler {
|
||||
return s.router
|
||||
}
|
||||
|
||||
// Router returns the underlying chi.Router for direct access.
|
||||
func (s *Server) Router() chi.Router {
|
||||
return s.router
|
||||
}
|
||||
|
||||
func serveOpenAPI(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/yaml")
|
||||
_, _ = w.Write(openapiYAML)
|
||||
@ -223,7 +303,7 @@ func serveDocs(w http.ResponseWriter, r *http.Request) {
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Wrenn Sandbox API</title>
|
||||
<title>Wrenn API</title>
|
||||
<link rel="stylesheet" href="https://unpkg.com/swagger-ui-dist@5.18.2/swagger-ui.css" integrity="sha384-rcbEi6xgdPk0iWkAQzT2F3FeBJXdG+ydrawGlfHAFIZG7wU6aKbQaRewysYpmrlW" crossorigin="anonymous">
|
||||
<style>
|
||||
body { margin: 0; background: #fafafa; }
|
||||
|
||||
@ -1,569 +0,0 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/events"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
)
|
||||
|
||||
// AuditLogger writes audit log entries for user-initiated and system events.
|
||||
// All methods are fire-and-forget: failures are logged via slog and never
|
||||
// propagated to the caller.
|
||||
type AuditLogger struct {
|
||||
db *db.Queries
|
||||
pub events.EventPublisher // optional — nil disables event publishing
|
||||
}
|
||||
|
||||
// New constructs an AuditLogger without event publishing.
|
||||
func New(queries *db.Queries) *AuditLogger {
|
||||
return &AuditLogger{db: queries}
|
||||
}
|
||||
|
||||
// NewWithPublisher constructs an AuditLogger that also publishes channel events.
|
||||
func NewWithPublisher(queries *db.Queries, pub events.EventPublisher) *AuditLogger {
|
||||
return &AuditLogger{db: queries, pub: pub}
|
||||
}
|
||||
|
||||
// publish sends an event to the notification stream if a publisher is configured.
|
||||
func (l *AuditLogger) publish(ctx context.Context, e events.Event) {
|
||||
if l.pub != nil {
|
||||
l.pub.Publish(ctx, e)
|
||||
}
|
||||
}
|
||||
|
||||
// actorToEvent converts auth context fields to an events.Actor.
|
||||
func actorToEvent(ac auth.AuthContext) events.Actor {
|
||||
at, aid, aname := actorFields(ac)
|
||||
return events.Actor{Type: events.ActorKind(at), ID: aid, Name: aname}
|
||||
}
|
||||
|
||||
// systemActor returns an events.Actor for system-initiated events.
|
||||
func systemActor() events.Actor {
|
||||
return events.Actor{Type: events.ActorSystem}
|
||||
}
|
||||
|
||||
// actorFields extracts actor_type, actor_id, and actor_name from an AuthContext.
|
||||
// actor_id is stored as a prefixed string in the TEXT column.
|
||||
func actorFields(ac auth.AuthContext) (actorType, actorID, actorName string) {
|
||||
if ac.UserID.Valid {
|
||||
return "user", id.FormatUserID(ac.UserID), ac.Name
|
||||
}
|
||||
if ac.APIKeyID.Valid {
|
||||
return "api_key", id.FormatAPIKeyID(ac.APIKeyID), ac.APIKeyName
|
||||
}
|
||||
return "system", "", ""
|
||||
}
|
||||
|
||||
func (l *AuditLogger) write(ctx context.Context, p db.InsertAuditLogParams) {
|
||||
if err := l.db.InsertAuditLog(ctx, p); err != nil {
|
||||
slog.Warn("audit: failed to write log entry",
|
||||
"action", p.Action,
|
||||
"resource_type", p.ResourceType,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func marshalMeta(meta map[string]any) []byte {
|
||||
if len(meta) == 0 {
|
||||
return []byte("{}")
|
||||
}
|
||||
b, err := json.Marshal(meta)
|
||||
if err != nil {
|
||||
return []byte("{}")
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// optText returns a valid pgtype.Text if s is non-empty, otherwise an invalid (NULL) one.
|
||||
func optText(s string) pgtype.Text {
|
||||
if s == "" {
|
||||
return pgtype.Text{}
|
||||
}
|
||||
return pgtype.Text{String: s, Valid: true}
|
||||
}
|
||||
|
||||
// --- Sandbox events (scope: team) ---
|
||||
|
||||
func (l *AuditLogger) LogSandboxCreate(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID, template string) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "sandbox",
|
||||
ResourceID: optText(id.FormatSandboxID(sandboxID)),
|
||||
Action: "create",
|
||||
Scope: "team",
|
||||
Status: "success",
|
||||
Metadata: marshalMeta(map[string]any{"template": template}),
|
||||
})
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.CapsuleCreated,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(ac.TeamID),
|
||||
Actor: actorToEvent(ac),
|
||||
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogSandboxPause(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "sandbox",
|
||||
ResourceID: optText(id.FormatSandboxID(sandboxID)),
|
||||
Action: "pause",
|
||||
Scope: "team",
|
||||
Status: "success",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.CapsulePaused,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(ac.TeamID),
|
||||
Actor: actorToEvent(ac),
|
||||
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
|
||||
})
|
||||
}
|
||||
|
||||
// LogSandboxAutoPause records a system-initiated auto-pause (TTL or host reconciler).
|
||||
func (l *AuditLogger) LogSandboxAutoPause(ctx context.Context, teamID, sandboxID pgtype.UUID) {
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: teamID,
|
||||
ActorType: "system",
|
||||
ActorID: pgtype.Text{},
|
||||
ActorName: "",
|
||||
ResourceType: "sandbox",
|
||||
ResourceID: optText(id.FormatSandboxID(sandboxID)),
|
||||
Action: "pause",
|
||||
Scope: "team",
|
||||
Status: "info",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.CapsulePaused,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Actor: systemActor(),
|
||||
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogSandboxResume(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "sandbox",
|
||||
ResourceID: optText(id.FormatSandboxID(sandboxID)),
|
||||
Action: "resume",
|
||||
Scope: "team",
|
||||
Status: "success",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.CapsuleRunning,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(ac.TeamID),
|
||||
Actor: actorToEvent(ac),
|
||||
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogSandboxDestroy(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "sandbox",
|
||||
ResourceID: optText(id.FormatSandboxID(sandboxID)),
|
||||
Action: "destroy",
|
||||
Scope: "team",
|
||||
Status: "warning",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.CapsuleDestroyed,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(ac.TeamID),
|
||||
Actor: actorToEvent(ac),
|
||||
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
|
||||
})
|
||||
}
|
||||
|
||||
// --- Snapshot events (scope: team) ---
|
||||
|
||||
func (l *AuditLogger) LogSnapshotCreate(ctx context.Context, ac auth.AuthContext, name string) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "snapshot",
|
||||
ResourceID: optText(name),
|
||||
Action: "create",
|
||||
Scope: "team",
|
||||
Status: "success",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.SnapshotCreated,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(ac.TeamID),
|
||||
Actor: actorToEvent(ac),
|
||||
Resource: events.Resource{ID: name, Type: "snapshot"},
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogSnapshotDelete(ctx context.Context, ac auth.AuthContext, name string) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "snapshot",
|
||||
ResourceID: optText(name),
|
||||
Action: "delete",
|
||||
Scope: "team",
|
||||
Status: "warning",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.SnapshotDeleted,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(ac.TeamID),
|
||||
Actor: actorToEvent(ac),
|
||||
Resource: events.Resource{ID: name, Type: "snapshot"},
|
||||
})
|
||||
}
|
||||
|
||||
// --- Team events (scope: team) ---
|
||||
|
||||
func (l *AuditLogger) LogTeamRename(ctx context.Context, ac auth.AuthContext, teamID pgtype.UUID, oldName, newName string) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "team",
|
||||
ResourceID: optText(id.FormatTeamID(teamID)),
|
||||
Action: "rename",
|
||||
Scope: "team",
|
||||
Status: "info",
|
||||
Metadata: marshalMeta(map[string]any{"old_name": oldName, "new_name": newName}),
|
||||
})
|
||||
}
|
||||
|
||||
// --- Channel events (scope: team) ---
|
||||
|
||||
func (l *AuditLogger) LogChannelCreate(ctx context.Context, ac auth.AuthContext, channelID pgtype.UUID, name, provider string) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "channel",
|
||||
ResourceID: optText(id.FormatChannelID(channelID)),
|
||||
Action: "create",
|
||||
Scope: "team",
|
||||
Status: "success",
|
||||
Metadata: marshalMeta(map[string]any{"name": name, "provider": provider}),
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogChannelUpdate(ctx context.Context, ac auth.AuthContext, channelID pgtype.UUID) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "channel",
|
||||
ResourceID: optText(id.FormatChannelID(channelID)),
|
||||
Action: "update",
|
||||
Scope: "team",
|
||||
Status: "info",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogChannelRotateConfig(ctx context.Context, ac auth.AuthContext, channelID pgtype.UUID) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "channel",
|
||||
ResourceID: optText(id.FormatChannelID(channelID)),
|
||||
Action: "rotate_config",
|
||||
Scope: "team",
|
||||
Status: "info",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogChannelDelete(ctx context.Context, ac auth.AuthContext, channelID pgtype.UUID) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "channel",
|
||||
ResourceID: optText(id.FormatChannelID(channelID)),
|
||||
Action: "delete",
|
||||
Scope: "team",
|
||||
Status: "warning",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
}
|
||||
|
||||
// --- API key events (scope: team) ---
|
||||
|
||||
func (l *AuditLogger) LogAPIKeyCreate(ctx context.Context, ac auth.AuthContext, keyID pgtype.UUID, keyName string) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "api_key",
|
||||
ResourceID: optText(id.FormatAPIKeyID(keyID)),
|
||||
Action: "create",
|
||||
Scope: "team",
|
||||
Status: "success",
|
||||
Metadata: marshalMeta(map[string]any{"name": keyName}),
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogAPIKeyRevoke(ctx context.Context, ac auth.AuthContext, keyID pgtype.UUID) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "api_key",
|
||||
ResourceID: optText(id.FormatAPIKeyID(keyID)),
|
||||
Action: "revoke",
|
||||
Scope: "team",
|
||||
Status: "warning",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
}
|
||||
|
||||
// --- Member events (scope: admin) ---
|
||||
|
||||
func (l *AuditLogger) LogMemberAdd(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID, targetEmail, role string) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "member",
|
||||
ResourceID: optText(id.FormatUserID(targetUserID)),
|
||||
Action: "add",
|
||||
Scope: "admin",
|
||||
Status: "success",
|
||||
Metadata: marshalMeta(map[string]any{"email": targetEmail, "role": role}),
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogMemberRemove(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "member",
|
||||
ResourceID: optText(id.FormatUserID(targetUserID)),
|
||||
Action: "remove",
|
||||
Scope: "admin",
|
||||
Status: "warning",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogMemberLeave(ctx context.Context, ac auth.AuthContext) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
resourceID := ""
|
||||
if ac.UserID.Valid {
|
||||
resourceID = id.FormatUserID(ac.UserID)
|
||||
}
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "member",
|
||||
ResourceID: optText(resourceID),
|
||||
Action: "leave",
|
||||
Scope: "admin",
|
||||
Status: "info",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogMemberRoleUpdate(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID, newRole string) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: ac.TeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "member",
|
||||
ResourceID: optText(id.FormatUserID(targetUserID)),
|
||||
Action: "role_update",
|
||||
Scope: "admin",
|
||||
Status: "info",
|
||||
Metadata: marshalMeta(map[string]any{"new_role": newRole}),
|
||||
})
|
||||
}
|
||||
|
||||
// --- Host events (scope: admin) ---
|
||||
|
||||
func (l *AuditLogger) LogHostCreate(ctx context.Context, ac auth.AuthContext, hostID, teamID pgtype.UUID) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
// For shared hosts with no owning team, use the caller's team.
|
||||
logTeamID := teamID
|
||||
if !logTeamID.Valid {
|
||||
logTeamID = ac.TeamID
|
||||
}
|
||||
if !logTeamID.Valid {
|
||||
return
|
||||
}
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: logTeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "host",
|
||||
ResourceID: optText(id.FormatHostID(hostID)),
|
||||
Action: "create",
|
||||
Scope: "admin",
|
||||
Status: "success",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogHostDelete(ctx context.Context, ac auth.AuthContext, hostID, teamID pgtype.UUID) {
|
||||
actorType, actorID, actorName := actorFields(ac)
|
||||
logTeamID := teamID
|
||||
if !logTeamID.Valid {
|
||||
logTeamID = ac.TeamID
|
||||
}
|
||||
if !logTeamID.Valid {
|
||||
return
|
||||
}
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: logTeamID,
|
||||
ActorType: actorType,
|
||||
ActorID: optText(actorID),
|
||||
ActorName: actorName,
|
||||
ResourceType: "host",
|
||||
ResourceID: optText(id.FormatHostID(hostID)),
|
||||
Action: "delete",
|
||||
Scope: "admin",
|
||||
Status: "warning",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
}
|
||||
|
||||
// LogHostMarkedDown records a system-initiated host status transition to unreachable.
|
||||
// Scoped to "team" so BYOC team members can see when their hosts go down.
|
||||
func (l *AuditLogger) LogHostMarkedDown(ctx context.Context, teamID, hostID pgtype.UUID) {
|
||||
if !teamID.Valid {
|
||||
return
|
||||
}
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: teamID,
|
||||
ActorType: "system",
|
||||
ActorID: pgtype.Text{},
|
||||
ActorName: "",
|
||||
ResourceType: "host",
|
||||
ResourceID: optText(id.FormatHostID(hostID)),
|
||||
Action: "marked_down",
|
||||
Scope: "team",
|
||||
Status: "error",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.HostDown,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Actor: systemActor(),
|
||||
Resource: events.Resource{ID: id.FormatHostID(hostID), Type: "host"},
|
||||
})
|
||||
}
|
||||
|
||||
// LogHostMarkedUp records a system-initiated host status transition back to online.
|
||||
// Scoped to "team" so BYOC team members can see when their hosts recover.
|
||||
func (l *AuditLogger) LogHostMarkedUp(ctx context.Context, teamID, hostID pgtype.UUID) {
|
||||
if !teamID.Valid {
|
||||
return
|
||||
}
|
||||
l.write(ctx, db.InsertAuditLogParams{
|
||||
ID: id.NewAuditLogID(),
|
||||
TeamID: teamID,
|
||||
ActorType: "system",
|
||||
ActorID: pgtype.Text{},
|
||||
ActorName: "",
|
||||
ResourceType: "host",
|
||||
ResourceID: optText(id.FormatHostID(hostID)),
|
||||
Action: "marked_up",
|
||||
Scope: "team",
|
||||
Status: "success",
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.HostUp,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Actor: systemActor(),
|
||||
Resource: events.Resource{ID: id.FormatHostID(hostID), Type: "host"},
|
||||
})
|
||||
}
|
||||
@ -1,35 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// GenerateAPIKey returns a plaintext key in the form "wrn_" + 32 random hex chars
|
||||
// and its SHA-256 hash. The caller must show the plaintext to the user exactly once;
|
||||
// only the hash is stored.
|
||||
func GenerateAPIKey() (plaintext, hash string, err error) {
|
||||
b := make([]byte, 16) // 16 bytes → 32 hex chars
|
||||
if _, err = rand.Read(b); err != nil {
|
||||
return "", "", fmt.Errorf("generate api key: %w", err)
|
||||
}
|
||||
plaintext = "wrn_" + hex.EncodeToString(b)
|
||||
hash = HashAPIKey(plaintext)
|
||||
return plaintext, hash, nil
|
||||
}
|
||||
|
||||
// HashAPIKey returns the hex-encoded SHA-256 hash of a plaintext API key.
|
||||
func HashAPIKey(plaintext string) string {
|
||||
sum := sha256.Sum256([]byte(plaintext))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// APIKeyPrefix returns the first 8 characters of a plaintext API key (e.g. "wrn_ab12").
|
||||
func APIKeyPrefix(plaintext string) string {
|
||||
if len(plaintext) > 10 {
|
||||
return plaintext[:10]
|
||||
}
|
||||
return plaintext
|
||||
}
|
||||
@ -1,251 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CPCertRenewInterval is how often the control plane should renew its client
|
||||
// certificate. It is set to half the cert TTL so there is always a wide safety
|
||||
// margin before expiry.
|
||||
const CPCertRenewInterval = cpCertTTL / 2
|
||||
|
||||
const (
|
||||
hostCertTTL = 7 * 24 * time.Hour
|
||||
cpCertTTL = 24 * time.Hour
|
||||
)
|
||||
|
||||
// CA holds a parsed certificate authority ready to issue leaf certificates.
|
||||
type CA struct {
|
||||
Cert *x509.Certificate
|
||||
Key *ecdsa.PrivateKey
|
||||
PEM string // PEM-encoded certificate for embedding in register/refresh responses
|
||||
}
|
||||
|
||||
// ParseCA parses PEM-encoded CA certificate and private key strings.
|
||||
// The cert and key are expected to be ECDSA P-256.
|
||||
func ParseCA(certPEM, keyPEM string) (*CA, error) {
|
||||
certBlock, _ := pem.Decode([]byte(certPEM))
|
||||
if certBlock == nil {
|
||||
return nil, fmt.Errorf("failed to decode CA certificate PEM")
|
||||
}
|
||||
cert, err := x509.ParseCertificate(certBlock.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse CA certificate: %w", err)
|
||||
}
|
||||
|
||||
keyBlock, _ := pem.Decode([]byte(keyPEM))
|
||||
if keyBlock == nil {
|
||||
return nil, fmt.Errorf("failed to decode CA key PEM")
|
||||
}
|
||||
keyIface, err := x509.ParseECPrivateKey(keyBlock.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse CA private key: %w", err)
|
||||
}
|
||||
|
||||
return &CA{Cert: cert, Key: keyIface, PEM: certPEM}, nil
|
||||
}
|
||||
|
||||
// HostCert holds all material returned when issuing a leaf cert for a host agent.
|
||||
type HostCert struct {
|
||||
CertPEM string
|
||||
KeyPEM string
|
||||
Fingerprint string // hex-encoded SHA-256 of DER bytes, stored in hosts.cert_fingerprint
|
||||
ExpiresAt time.Time // stored in hosts.cert_expires_at
|
||||
TLSCert tls.Certificate
|
||||
}
|
||||
|
||||
// IssueHostCert generates an ECDSA P-256 key pair and issues a 7-day server
|
||||
// certificate for the host agent. hostID becomes the common name; the host's
|
||||
// IP address (parsed from hostAddr) is added as an IP SAN so Go's TLS
|
||||
// stack can verify the connection without disabling hostname checking.
|
||||
func IssueHostCert(ca *CA, hostID, hostAddr string) (HostCert, error) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return HostCert{}, fmt.Errorf("generate host key: %w", err)
|
||||
}
|
||||
|
||||
serial, err := randomSerial()
|
||||
if err != nil {
|
||||
return HostCert{}, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
expires := now.Add(hostCertTTL)
|
||||
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
Subject: pkix.Name{CommonName: hostID},
|
||||
NotBefore: now.Add(-time.Minute), // small clock-skew tolerance
|
||||
NotAfter: expires,
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
}
|
||||
|
||||
// Extract IP from "ip:port" address; fall back to DNS SAN if not parseable.
|
||||
host, _, err := net.SplitHostPort(hostAddr)
|
||||
if err != nil {
|
||||
host = hostAddr
|
||||
}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
tmpl.IPAddresses = []net.IP{ip}
|
||||
} else {
|
||||
tmpl.DNSNames = []string{host}
|
||||
}
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key)
|
||||
if err != nil {
|
||||
return HostCert{}, fmt.Errorf("create host certificate: %w", err)
|
||||
}
|
||||
|
||||
certPEM := string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}))
|
||||
keyDER, err := x509.MarshalECPrivateKey(key)
|
||||
if err != nil {
|
||||
return HostCert{}, fmt.Errorf("marshal host key: %w", err)
|
||||
}
|
||||
keyPEM := string(pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}))
|
||||
|
||||
tlsCert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
|
||||
if err != nil {
|
||||
return HostCert{}, fmt.Errorf("build TLS certificate: %w", err)
|
||||
}
|
||||
|
||||
fp := fmt.Sprintf("%x", sha256.Sum256(derBytes))
|
||||
|
||||
return HostCert{
|
||||
CertPEM: certPEM,
|
||||
KeyPEM: keyPEM,
|
||||
Fingerprint: fp,
|
||||
ExpiresAt: expires,
|
||||
TLSCert: tlsCert,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// IssueCPClientCert generates a short-lived (24h) ECDSA client certificate for
|
||||
// the control plane to present during mTLS handshakes with host agents.
|
||||
// Called once at CP startup; the result is embedded into the shared HTTP client.
|
||||
func IssueCPClientCert(ca *CA) (tls.Certificate, error) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("generate CP client key: %w", err)
|
||||
}
|
||||
|
||||
serial, err := randomSerial()
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
Subject: pkix.Name{CommonName: "wrenn-cp"},
|
||||
NotBefore: now.Add(-time.Minute),
|
||||
NotAfter: now.Add(cpCertTTL),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||
}
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("create CP client certificate: %w", err)
|
||||
}
|
||||
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
keyDER, err := x509.MarshalECPrivateKey(key)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("marshal CP client key: %w", err)
|
||||
}
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||
|
||||
return tls.X509KeyPair(certPEM, keyPEM)
|
||||
}
|
||||
|
||||
// AgentTLSConfigFromPEM returns a tls.Config for the host agent using the
|
||||
// PEM-encoded CA certificate. This is used on the agent side where only the
|
||||
// CA certificate (not the private key) is available.
|
||||
func AgentTLSConfigFromPEM(caCertPEM string, getCert func(*tls.ClientHelloInfo) (*tls.Certificate, error)) *tls.Config {
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM([]byte(caCertPEM)) {
|
||||
return nil
|
||||
}
|
||||
return &tls.Config{
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
ClientCAs: pool,
|
||||
GetCertificate: getCert,
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}
|
||||
}
|
||||
|
||||
// CPCertStore provides lock-free read/write access to the control plane's
|
||||
// current client TLS certificate. It is used with tls.Config.GetClientCertificate
|
||||
// to enable hot-swap without restarting the HTTP client.
|
||||
//
|
||||
// The zero value is not usable; use NewCPCertStore to create one.
|
||||
type CPCertStore struct {
|
||||
ptr atomic.Pointer[tls.Certificate]
|
||||
ca *CA
|
||||
}
|
||||
|
||||
// NewCPCertStore issues an initial CP client certificate from ca and returns a
|
||||
// store that can renew it in place. Returns an error if the initial issuance fails.
|
||||
func NewCPCertStore(ca *CA) (*CPCertStore, error) {
|
||||
s := &CPCertStore{ca: ca}
|
||||
if err := s.Refresh(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Refresh issues a fresh CP client certificate and atomically stores it.
|
||||
// If issuance fails the existing cert is unchanged.
|
||||
func (s *CPCertStore) Refresh() error {
|
||||
cert, err := IssueCPClientCert(s.ca)
|
||||
if err != nil {
|
||||
return fmt.Errorf("renew CP client certificate: %w", err)
|
||||
}
|
||||
s.ptr.Store(&cert)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClientCertificate satisfies tls.Config.GetClientCertificate. It is called
|
||||
// per-handshake and always returns the most recently stored certificate.
|
||||
func (s *CPCertStore) GetClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
cert := s.ptr.Load()
|
||||
if cert == nil {
|
||||
return nil, fmt.Errorf("no CP client certificate available")
|
||||
}
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// CPClientTLSConfig returns a tls.Config for the CP's outbound HTTP client.
|
||||
// It uses certStore.GetClientCertificate so the certificate can be renewed
|
||||
// without replacing the config or transport.
|
||||
func CPClientTLSConfig(ca *CA, certStore *CPCertStore) *tls.Config {
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(ca.Cert)
|
||||
return &tls.Config{
|
||||
RootCAs: pool,
|
||||
GetClientCertificate: certStore.GetClientCertificate,
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}
|
||||
}
|
||||
|
||||
// randomSerial returns a random 128-bit certificate serial number.
|
||||
func randomSerial() (*big.Int, error) {
|
||||
serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate serial number: %w", err)
|
||||
}
|
||||
return serial, nil
|
||||
}
|
||||
@ -1,72 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
type contextKey int
|
||||
|
||||
const authCtxKey contextKey = 0
|
||||
|
||||
// AuthContext is stamped into request context by auth middleware.
|
||||
type AuthContext struct {
|
||||
TeamID pgtype.UUID
|
||||
UserID pgtype.UUID // zero value (Valid=false) when authenticated via API key
|
||||
Email string // empty when authenticated via API key
|
||||
Name string // empty when authenticated via API key
|
||||
Role string // owner, admin, or member; empty when authenticated via API key
|
||||
IsAdmin bool // platform-level admin; always false when authenticated via API key
|
||||
APIKeyID pgtype.UUID // populated when authenticated via API key; zero value for JWT auth
|
||||
APIKeyName string // display name of the key, snapshotted at auth time; empty for JWT auth
|
||||
}
|
||||
|
||||
// WithAuthContext returns a new context with the given AuthContext.
|
||||
func WithAuthContext(ctx context.Context, a AuthContext) context.Context {
|
||||
return context.WithValue(ctx, authCtxKey, a)
|
||||
}
|
||||
|
||||
// FromContext retrieves the AuthContext. Returns zero value and false if absent.
|
||||
func FromContext(ctx context.Context) (AuthContext, bool) {
|
||||
a, ok := ctx.Value(authCtxKey).(AuthContext)
|
||||
return a, ok
|
||||
}
|
||||
|
||||
// MustFromContext retrieves the AuthContext. Panics if absent — only call
|
||||
// inside handlers behind auth middleware.
|
||||
func MustFromContext(ctx context.Context) AuthContext {
|
||||
a, ok := FromContext(ctx)
|
||||
if !ok {
|
||||
panic("auth: MustFromContext called on unauthenticated request")
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
const hostCtxKey contextKey = 1
|
||||
|
||||
// HostContext is stamped into request context by host token middleware.
|
||||
type HostContext struct {
|
||||
HostID pgtype.UUID
|
||||
}
|
||||
|
||||
// WithHostContext returns a new context with the given HostContext.
|
||||
func WithHostContext(ctx context.Context, h HostContext) context.Context {
|
||||
return context.WithValue(ctx, hostCtxKey, h)
|
||||
}
|
||||
|
||||
// HostFromContext retrieves the HostContext. Returns zero value and false if absent.
|
||||
func HostFromContext(ctx context.Context) (HostContext, bool) {
|
||||
h, ok := ctx.Value(hostCtxKey).(HostContext)
|
||||
return h, ok
|
||||
}
|
||||
|
||||
// MustHostFromContext retrieves the HostContext. Panics if absent — only call
|
||||
// inside handlers behind host token middleware.
|
||||
func MustHostFromContext(ctx context.Context) HostContext {
|
||||
h, ok := HostFromContext(ctx)
|
||||
if !ok {
|
||||
panic("auth: MustHostFromContext called on unauthenticated request")
|
||||
}
|
||||
return h
|
||||
}
|
||||
@ -1,113 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
)
|
||||
|
||||
const jwtExpiry = 6 * time.Hour
|
||||
const hostJWTExpiry = 7 * 24 * time.Hour // 7 days; host refreshes via refresh token
|
||||
const HostRefreshTokenExpiry = 60 * 24 * time.Hour // 60 days; exported for service layer
|
||||
|
||||
// Claims are the JWT payload for user tokens.
|
||||
type Claims struct {
|
||||
Type string `json:"typ,omitempty"` // empty for user tokens; used to reject host tokens
|
||||
TeamID string `json:"team_id"`
|
||||
Role string `json:"role"` // owner, admin, or member within TeamID
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
IsAdmin bool `json:"is_admin,omitempty"` // platform-level admin flag
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// SignJWT signs a new 6-hour JWT for the given user.
|
||||
func SignJWT(secret []byte, userID, teamID pgtype.UUID, email, name, role string, isAdmin bool) (string, error) {
|
||||
now := time.Now()
|
||||
claims := Claims{
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Role: role,
|
||||
Email: email,
|
||||
Name: name,
|
||||
IsAdmin: isAdmin,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: id.FormatUserID(userID),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(jwtExpiry)),
|
||||
},
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString(secret)
|
||||
}
|
||||
|
||||
// VerifyJWT parses and validates a user JWT, returning the claims on success.
|
||||
// Rejects host JWTs (which carry a "typ" claim) to prevent cross-token confusion.
|
||||
func VerifyJWT(secret []byte, tokenStr string) (Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return Claims{}, fmt.Errorf("invalid token: %w", err)
|
||||
}
|
||||
c, ok := token.Claims.(*Claims)
|
||||
if !ok || !token.Valid {
|
||||
return Claims{}, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
if c.Type == "host" {
|
||||
return Claims{}, fmt.Errorf("invalid token: host token cannot be used as user token")
|
||||
}
|
||||
return *c, nil
|
||||
}
|
||||
|
||||
// HostClaims are the JWT payload for host agent tokens.
|
||||
type HostClaims struct {
|
||||
Type string `json:"typ"` // always "host"
|
||||
HostID string `json:"host_id"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// SignHostJWT signs a long-lived (7-day) JWT for a registered host agent.
|
||||
func SignHostJWT(secret []byte, hostID pgtype.UUID) (string, error) {
|
||||
formatted := id.FormatHostID(hostID)
|
||||
now := time.Now()
|
||||
claims := HostClaims{
|
||||
Type: "host",
|
||||
HostID: formatted,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: formatted,
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(hostJWTExpiry)),
|
||||
},
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString(secret)
|
||||
}
|
||||
|
||||
// VerifyHostJWT parses and validates a host JWT, returning the claims on success.
|
||||
// It rejects user JWTs by checking the "typ" claim.
|
||||
func VerifyHostJWT(secret []byte, tokenStr string) (HostClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenStr, &HostClaims{}, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return HostClaims{}, fmt.Errorf("invalid token: %w", err)
|
||||
}
|
||||
c, ok := token.Claims.(*HostClaims)
|
||||
if !ok || !token.Valid {
|
||||
return HostClaims{}, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
if c.Type != "host" {
|
||||
return HostClaims{}, fmt.Errorf("invalid token type: expected host")
|
||||
}
|
||||
return *c, nil
|
||||
}
|
||||
@ -1,127 +0,0 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/endpoints"
|
||||
)
|
||||
|
||||
// GitHubProvider implements Provider for GitHub OAuth.
|
||||
type GitHubProvider struct {
|
||||
cfg *oauth2.Config
|
||||
}
|
||||
|
||||
// NewGitHubProvider creates a GitHub OAuth provider.
|
||||
func NewGitHubProvider(clientID, clientSecret, callbackURL string) *GitHubProvider {
|
||||
return &GitHubProvider{
|
||||
cfg: &oauth2.Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
Endpoint: endpoints.GitHub,
|
||||
Scopes: []string{"user:email"},
|
||||
RedirectURL: callbackURL,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GitHubProvider) Name() string { return "github" }
|
||||
|
||||
func (p *GitHubProvider) AuthCodeURL(state string) string {
|
||||
return p.cfg.AuthCodeURL(state, oauth2.AccessTypeOnline)
|
||||
}
|
||||
|
||||
func (p *GitHubProvider) Exchange(ctx context.Context, code string) (UserProfile, error) {
|
||||
token, err := p.cfg.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return UserProfile{}, fmt.Errorf("exchange code: %w", err)
|
||||
}
|
||||
|
||||
client := p.cfg.Client(ctx, token)
|
||||
|
||||
profile, err := fetchGitHubUser(client)
|
||||
if err != nil {
|
||||
return UserProfile{}, err
|
||||
}
|
||||
|
||||
// GitHub may not include email if the user's email is private.
|
||||
if profile.Email == "" {
|
||||
email, err := fetchGitHubPrimaryEmail(client)
|
||||
if err != nil {
|
||||
return UserProfile{}, err
|
||||
}
|
||||
profile.Email = email
|
||||
}
|
||||
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
type githubUser struct {
|
||||
ID int64 `json:"id"`
|
||||
Login string `json:"login"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func fetchGitHubUser(client *http.Client) (UserProfile, error) {
|
||||
resp, err := client.Get("https://api.github.com/user")
|
||||
if err != nil {
|
||||
return UserProfile{}, fmt.Errorf("fetch github user: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return UserProfile{}, fmt.Errorf("github /user returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var u githubUser
|
||||
if err := json.NewDecoder(resp.Body).Decode(&u); err != nil {
|
||||
return UserProfile{}, fmt.Errorf("decode github user: %w", err)
|
||||
}
|
||||
|
||||
name := u.Name
|
||||
if name == "" {
|
||||
name = u.Login
|
||||
}
|
||||
|
||||
return UserProfile{
|
||||
ProviderID: strconv.FormatInt(u.ID, 10),
|
||||
Email: u.Email,
|
||||
Name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type githubEmail struct {
|
||||
Email string `json:"email"`
|
||||
Primary bool `json:"primary"`
|
||||
Verified bool `json:"verified"`
|
||||
}
|
||||
|
||||
func fetchGitHubPrimaryEmail(client *http.Client) (string, error) {
|
||||
resp, err := client.Get("https://api.github.com/user/emails")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch github emails: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("github /user/emails returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var emails []githubEmail
|
||||
if err := json.NewDecoder(resp.Body).Decode(&emails); err != nil {
|
||||
return "", fmt.Errorf("decode github emails: %w", err)
|
||||
}
|
||||
|
||||
for _, e := range emails {
|
||||
if e.Primary && e.Verified {
|
||||
return e.Email, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("github account has no verified primary email")
|
||||
}
|
||||
@ -1,41 +0,0 @@
|
||||
package oauth
|
||||
|
||||
import "context"
|
||||
|
||||
// UserProfile is the normalized user info returned by an OAuth provider.
|
||||
type UserProfile struct {
|
||||
ProviderID string
|
||||
Email string
|
||||
Name string
|
||||
}
|
||||
|
||||
// Provider abstracts an OAuth 2.0 identity provider.
|
||||
type Provider interface {
|
||||
// Name returns the provider identifier (e.g. "github", "google").
|
||||
Name() string
|
||||
// AuthCodeURL returns the URL to redirect the user to for authorization.
|
||||
AuthCodeURL(state string) string
|
||||
// Exchange trades an authorization code for a user profile.
|
||||
Exchange(ctx context.Context, code string) (UserProfile, error)
|
||||
}
|
||||
|
||||
// Registry maps provider names to Provider implementations.
|
||||
type Registry struct {
|
||||
providers map[string]Provider
|
||||
}
|
||||
|
||||
// NewRegistry creates an empty provider registry.
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{providers: make(map[string]Provider)}
|
||||
}
|
||||
|
||||
// Register adds a provider to the registry.
|
||||
func (r *Registry) Register(p Provider) {
|
||||
r.providers[p.Name()] = p
|
||||
}
|
||||
|
||||
// Get looks up a provider by name.
|
||||
func (r *Registry) Get(name string) (Provider, bool) {
|
||||
p, ok := r.providers[name]
|
||||
return p, ok
|
||||
}
|
||||
@ -1,16 +0,0 @@
|
||||
package auth
|
||||
|
||||
import "golang.org/x/crypto/bcrypt"
|
||||
|
||||
const bcryptCost = 12
|
||||
|
||||
// HashPassword returns the bcrypt hash of a plaintext password.
|
||||
func HashPassword(plaintext string) (string, error) {
|
||||
b, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcryptCost)
|
||||
return string(b), err
|
||||
}
|
||||
|
||||
// CheckPassword returns nil if plaintext matches the stored hash.
|
||||
func CheckPassword(hash, plaintext string) error {
|
||||
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintext))
|
||||
}
|
||||
@ -1,63 +0,0 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// EncryptSecret encrypts plaintext using AES-256-GCM with a random nonce.
|
||||
// Returns base64(nonce || ciphertext).
|
||||
func EncryptSecret(key [32]byte, plaintext string) (string, error) {
|
||||
block, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("aes cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("gcm: %w", err)
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", fmt.Errorf("nonce: %w", err)
|
||||
}
|
||||
|
||||
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// DecryptSecret decrypts a value produced by EncryptSecret.
|
||||
func DecryptSecret(key [32]byte, encoded string) (string, error) {
|
||||
data, err := base64.StdEncoding.DecodeString(encoded)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("base64 decode: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("aes cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("gcm: %w", err)
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return "", fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt: %w", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
@ -1,36 +0,0 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/containrrr/shoutrrr"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/events"
|
||||
)
|
||||
|
||||
// Deliver sends a notification to a single provider with the given config.
|
||||
// For webhooks it uses HMAC-signed HTTP POST; for all others it uses shoutrrr.
|
||||
func Deliver(ctx context.Context, provider string, config map[string]string, e events.Event) error {
|
||||
payload, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal event: %w", err)
|
||||
}
|
||||
|
||||
if provider == "webhook" {
|
||||
wh := NewWebhookDelivery()
|
||||
return wh.Deliver(ctx, config["url"], config["secret"], payload)
|
||||
}
|
||||
|
||||
shoutrrrURL, err := ShoutrrrURL(provider, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build shoutrrr URL: %w", err)
|
||||
}
|
||||
|
||||
msg := FormatMessage(e)
|
||||
if err := shoutrrr.Send(shoutrrrURL, msg); err != nil {
|
||||
return fmt.Errorf("shoutrrr send: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -1,183 +0,0 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/events"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
)
|
||||
|
||||
const (
|
||||
groupName = "wrenn-channels-v1"
|
||||
consumerName = "cp-0"
|
||||
)
|
||||
|
||||
// Dispatcher consumes events from the Redis stream and delivers them
|
||||
// to matching notification channels.
|
||||
type Dispatcher struct {
|
||||
rdb *redis.Client
|
||||
db *db.Queries
|
||||
encKey [32]byte
|
||||
webhook *WebhookDelivery
|
||||
}
|
||||
|
||||
// NewDispatcher constructs an event dispatcher.
|
||||
func NewDispatcher(rdb *redis.Client, queries *db.Queries, encKey [32]byte) *Dispatcher {
|
||||
return &Dispatcher{
|
||||
rdb: rdb,
|
||||
db: queries,
|
||||
encKey: encKey,
|
||||
webhook: NewWebhookDelivery(),
|
||||
}
|
||||
}
|
||||
|
||||
// Start launches the consumer goroutine. Returns when ctx is cancelled.
|
||||
func (d *Dispatcher) Start(ctx context.Context) {
|
||||
go d.run(ctx)
|
||||
}
|
||||
|
||||
func (d *Dispatcher) run(ctx context.Context) {
|
||||
// Create consumer group idempotently. "$" means only new messages.
|
||||
err := d.rdb.XGroupCreateMkStream(ctx, streamKey, groupName, "$").Err()
|
||||
if err != nil && !isGroupExistsError(err) {
|
||||
slog.Error("channels: failed to create consumer group", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
streams, err := d.rdb.XReadGroup(ctx, &redis.XReadGroupArgs{
|
||||
Group: groupName,
|
||||
Consumer: consumerName,
|
||||
Streams: []string{streamKey, ">"},
|
||||
Count: 10,
|
||||
Block: 5 * time.Second,
|
||||
}).Result()
|
||||
|
||||
if err != nil {
|
||||
if err == redis.Nil || ctx.Err() != nil {
|
||||
continue
|
||||
}
|
||||
slog.Warn("channels: xreadgroup error", "error", err)
|
||||
time.Sleep(1 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, stream := range streams {
|
||||
for _, msg := range stream.Messages {
|
||||
d.handleMessage(ctx, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dispatcher) handleMessage(ctx context.Context, msg redis.XMessage) {
|
||||
defer func() {
|
||||
if err := d.rdb.XAck(ctx, streamKey, groupName, msg.ID).Err(); err != nil {
|
||||
slog.Warn("channels: xack failed", "id", msg.ID, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
payload, ok := msg.Values["payload"].(string)
|
||||
if !ok {
|
||||
slog.Warn("channels: message missing payload", "id", msg.ID)
|
||||
return
|
||||
}
|
||||
|
||||
var event events.Event
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
slog.Warn("channels: failed to unmarshal event", "id", msg.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
teamID, err := id.ParseTeamID(event.TeamID)
|
||||
if err != nil {
|
||||
slog.Warn("channels: invalid team ID in event", "team_id", event.TeamID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
channels, err := d.db.ListChannelsForEvent(ctx, db.ListChannelsForEventParams{
|
||||
TeamID: teamID,
|
||||
EventType: event.Event,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Warn("channels: failed to list channels for event", "event", event.Event, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, ch := range channels {
|
||||
d.dispatch(ctx, ch, event)
|
||||
}
|
||||
}
|
||||
|
||||
// retryDelays defines the wait durations before each retry attempt.
|
||||
var retryDelays = []time.Duration{10 * time.Second, 30 * time.Second}
|
||||
|
||||
func (d *Dispatcher) dispatch(ctx context.Context, ch db.Channel, e events.Event) {
|
||||
config, err := d.decryptConfig(ch.Config)
|
||||
if err != nil {
|
||||
slog.Warn("channels: failed to decrypt config",
|
||||
"channel_id", id.FormatChannelID(ch.ID), "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
chID := id.FormatChannelID(ch.ID)
|
||||
|
||||
if err := Deliver(ctx, ch.Provider, config, e); err != nil {
|
||||
slog.Warn("channels: delivery failed, scheduling retries",
|
||||
"channel_id", chID, "provider", ch.Provider, "error", err)
|
||||
go d.retryDeliver(ctx, ch.Provider, config, e, chID)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dispatcher) retryDeliver(ctx context.Context, provider string, config map[string]string, e events.Event, chID string) {
|
||||
for i, delay := range retryDelays {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(delay):
|
||||
}
|
||||
|
||||
if err := Deliver(ctx, provider, config, e); err != nil {
|
||||
slog.Warn("channels: retry delivery failed",
|
||||
"channel_id", chID, "provider", provider,
|
||||
"attempt", i+2, "error", err)
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
slog.Error("channels: delivery failed after all retries",
|
||||
"channel_id", chID, "provider", provider, "event", e.Event)
|
||||
}
|
||||
|
||||
func (d *Dispatcher) decryptConfig(configJSON []byte) (map[string]string, error) {
|
||||
var encrypted map[string]string
|
||||
if err := json.Unmarshal(configJSON, &encrypted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
decrypted := make(map[string]string, len(encrypted))
|
||||
for k, v := range encrypted {
|
||||
plaintext, err := DecryptSecret(d.encKey, v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decrypted[k] = plaintext
|
||||
}
|
||||
return decrypted, nil
|
||||
}
|
||||
|
||||
func isGroupExistsError(err error) bool {
|
||||
return err != nil && err.Error() == "BUSYGROUP Consumer Group name already exists"
|
||||
}
|
||||
@ -1,65 +0,0 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/events"
|
||||
)
|
||||
|
||||
// FormatMessage produces a human-readable notification string containing
|
||||
// the event summary, resource details, actor, and timestamp.
|
||||
func FormatMessage(e events.Event) string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString(formatSummary(e))
|
||||
fmt.Fprintf(&b, "\n\nEvent: %s", e.Event)
|
||||
fmt.Fprintf(&b, "\nResource: %s %s", e.Resource.Type, e.Resource.ID)
|
||||
fmt.Fprintf(&b, "\nActor: %s", formatActor(e.Actor))
|
||||
fmt.Fprintf(&b, "\nTeam: %s", e.TeamID)
|
||||
fmt.Fprintf(&b, "\nTime: %s", e.Timestamp)
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func formatSummary(e events.Event) string {
|
||||
switch e.Event {
|
||||
case events.CapsuleCreated:
|
||||
return fmt.Sprintf("Capsule %s created", e.Resource.ID)
|
||||
case events.CapsuleRunning:
|
||||
return fmt.Sprintf("Capsule %s is running", e.Resource.ID)
|
||||
case events.CapsulePaused:
|
||||
return fmt.Sprintf("Capsule %s paused", e.Resource.ID)
|
||||
case events.CapsuleDestroyed:
|
||||
return fmt.Sprintf("Capsule %s destroyed", e.Resource.ID)
|
||||
case events.SnapshotCreated:
|
||||
return fmt.Sprintf("Template snapshot %s created", e.Resource.ID)
|
||||
case events.SnapshotDeleted:
|
||||
return fmt.Sprintf("Template snapshot %s deleted", e.Resource.ID)
|
||||
case events.HostUp:
|
||||
return fmt.Sprintf("Host %s is up", e.Resource.ID)
|
||||
case events.HostDown:
|
||||
return fmt.Sprintf("Host %s is down", e.Resource.ID)
|
||||
default:
|
||||
return fmt.Sprintf("%s %s", e.Resource.Type, e.Resource.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func formatActor(a events.Actor) string {
|
||||
switch a.Type {
|
||||
case events.ActorSystem:
|
||||
return "system"
|
||||
case events.ActorUser:
|
||||
if a.Name != "" {
|
||||
return fmt.Sprintf("%s (%s)", a.Name, a.ID)
|
||||
}
|
||||
return a.ID
|
||||
case events.ActorAPIKey:
|
||||
if a.Name != "" {
|
||||
return fmt.Sprintf("api_key %s (%s)", a.Name, a.ID)
|
||||
}
|
||||
return fmt.Sprintf("api_key %s", a.ID)
|
||||
default:
|
||||
return string(a.Type)
|
||||
}
|
||||
}
|
||||
@ -1,44 +0,0 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/events"
|
||||
)
|
||||
|
||||
const streamKey = "wrenn:events"
|
||||
|
||||
// Publisher pushes events onto the Redis stream for the dispatcher to consume.
|
||||
type Publisher struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewPublisher constructs an event publisher.
|
||||
func NewPublisher(rdb *redis.Client) *Publisher {
|
||||
return &Publisher{rdb: rdb}
|
||||
}
|
||||
|
||||
// Publish serializes the event and appends it to the global stream.
|
||||
// Fire-and-forget: failures are logged, never propagated.
|
||||
func (p *Publisher) Publish(ctx context.Context, e events.Event) {
|
||||
payload, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
slog.Warn("channels: failed to marshal event", "event", e.Event, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := p.rdb.XAdd(ctx, &redis.XAddArgs{
|
||||
Stream: streamKey,
|
||||
MaxLen: 10000,
|
||||
Approx: true,
|
||||
Values: map[string]interface{}{
|
||||
"payload": string(payload),
|
||||
},
|
||||
}).Err(); err != nil {
|
||||
slog.Warn("channels: failed to publish event", "event", e.Event, "error", err)
|
||||
}
|
||||
}
|
||||
@ -1,298 +0,0 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/events"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/validate"
|
||||
)
|
||||
|
||||
// Valid providers.
|
||||
var validProviders = map[string]bool{
|
||||
"discord": true,
|
||||
"slack": true,
|
||||
"teams": true,
|
||||
"googlechat": true,
|
||||
"telegram": true,
|
||||
"matrix": true,
|
||||
"webhook": true,
|
||||
}
|
||||
|
||||
// Required config fields per provider.
|
||||
var requiredFields = map[string][]string{
|
||||
"discord": {"webhook_url"},
|
||||
"slack": {"webhook_url"},
|
||||
"teams": {"webhook_url"},
|
||||
"googlechat": {"webhook_url"},
|
||||
"telegram": {"bot_token", "chat_id"},
|
||||
"matrix": {"homeserver_url", "access_token", "room_id"},
|
||||
"webhook": {"url"},
|
||||
}
|
||||
|
||||
// validEvents maps event type strings to true for validation.
|
||||
var validEvents map[string]bool
|
||||
|
||||
func init() {
|
||||
validEvents = make(map[string]bool, len(events.AllEventTypes))
|
||||
for _, et := range events.AllEventTypes {
|
||||
validEvents[et] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Service handles channel CRUD operations.
|
||||
type Service struct {
|
||||
DB *db.Queries
|
||||
EncKey [32]byte
|
||||
}
|
||||
|
||||
// CreateParams holds the parameters for creating a channel.
|
||||
type CreateParams struct {
|
||||
TeamID pgtype.UUID
|
||||
Name string
|
||||
Provider string
|
||||
Config map[string]string
|
||||
Events []string
|
||||
}
|
||||
|
||||
// CreateResult holds the result of creating a channel.
|
||||
type CreateResult struct {
|
||||
Channel db.Channel
|
||||
PlaintextSecret string // non-empty only for webhook provider
|
||||
}
|
||||
|
||||
// Create creates a new notification channel.
|
||||
func (s *Service) Create(ctx context.Context, p CreateParams) (CreateResult, error) {
|
||||
clean, err := cleanName(p.Name)
|
||||
if err != nil {
|
||||
return CreateResult{}, err
|
||||
}
|
||||
p.Name = clean
|
||||
|
||||
if !validProviders[p.Provider] {
|
||||
return CreateResult{}, fmt.Errorf("invalid: unsupported provider %q", p.Provider)
|
||||
}
|
||||
|
||||
if len(p.Events) == 0 {
|
||||
return CreateResult{}, fmt.Errorf("invalid: at least one event type is required")
|
||||
}
|
||||
for _, et := range p.Events {
|
||||
if !validEvents[et] {
|
||||
return CreateResult{}, fmt.Errorf("invalid: unknown event type %q", et)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate required config fields.
|
||||
for _, field := range requiredFields[p.Provider] {
|
||||
if p.Config[field] == "" {
|
||||
return CreateResult{}, fmt.Errorf("invalid: %s is required for %s", field, p.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
// For webhooks, auto-generate secret if not provided.
|
||||
var plaintextSecret string
|
||||
if p.Provider == "webhook" {
|
||||
if p.Config["secret"] == "" {
|
||||
secret := generateSecret()
|
||||
p.Config["secret"] = secret
|
||||
plaintextSecret = secret
|
||||
} else {
|
||||
plaintextSecret = p.Config["secret"]
|
||||
}
|
||||
}
|
||||
|
||||
// Encrypt config fields.
|
||||
encrypted := make(map[string]string, len(p.Config))
|
||||
for k, v := range p.Config {
|
||||
enc, err := EncryptSecret(s.EncKey, v)
|
||||
if err != nil {
|
||||
return CreateResult{}, fmt.Errorf("encrypt config field %s: %w", k, err)
|
||||
}
|
||||
encrypted[k] = enc
|
||||
}
|
||||
|
||||
configJSON, err := json.Marshal(encrypted)
|
||||
if err != nil {
|
||||
return CreateResult{}, fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
|
||||
ch, err := s.DB.InsertChannel(ctx, db.InsertChannelParams{
|
||||
ID: id.NewChannelID(),
|
||||
TeamID: p.TeamID,
|
||||
Name: p.Name,
|
||||
Provider: p.Provider,
|
||||
Config: configJSON,
|
||||
EventTypes: p.Events,
|
||||
})
|
||||
if err != nil {
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
|
||||
return CreateResult{}, fmt.Errorf("conflict: channel name %q already exists", p.Name)
|
||||
}
|
||||
return CreateResult{}, fmt.Errorf("insert channel: %w", err)
|
||||
}
|
||||
|
||||
return CreateResult{Channel: ch, PlaintextSecret: plaintextSecret}, nil
|
||||
}
|
||||
|
||||
// List returns all channels belonging to the given team.
|
||||
func (s *Service) List(ctx context.Context, teamID pgtype.UUID) ([]db.Channel, error) {
|
||||
return s.DB.ListChannelsByTeam(ctx, teamID)
|
||||
}
|
||||
|
||||
// Get returns a single channel by ID, scoped to the given team.
|
||||
func (s *Service) Get(ctx context.Context, channelID, teamID pgtype.UUID) (db.Channel, error) {
|
||||
return s.DB.GetChannelByTeam(ctx, db.GetChannelByTeamParams{ID: channelID, TeamID: teamID})
|
||||
}
|
||||
|
||||
// Update updates a channel's name and event types.
|
||||
func (s *Service) Update(ctx context.Context, channelID, teamID pgtype.UUID, name string, eventTypes []string) (db.Channel, error) {
|
||||
clean, err := cleanName(name)
|
||||
if err != nil {
|
||||
return db.Channel{}, err
|
||||
}
|
||||
name = clean
|
||||
|
||||
if len(eventTypes) == 0 {
|
||||
return db.Channel{}, fmt.Errorf("invalid: at least one event type is required")
|
||||
}
|
||||
for _, et := range eventTypes {
|
||||
if !validEvents[et] {
|
||||
return db.Channel{}, fmt.Errorf("invalid: unknown event type %q", et)
|
||||
}
|
||||
}
|
||||
|
||||
ch, err := s.DB.UpdateChannel(ctx, db.UpdateChannelParams{
|
||||
ID: channelID,
|
||||
TeamID: teamID,
|
||||
Name: name,
|
||||
EventTypes: eventTypes,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return db.Channel{}, fmt.Errorf("channel not found")
|
||||
}
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
|
||||
return db.Channel{}, fmt.Errorf("conflict: channel name %q already exists", name)
|
||||
}
|
||||
return db.Channel{}, fmt.Errorf("update channel: %w", err)
|
||||
}
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// RotateConfig replaces a channel's config with new provider secrets.
|
||||
func (s *Service) RotateConfig(ctx context.Context, channelID, teamID pgtype.UUID, config map[string]string) (db.Channel, error) {
|
||||
// Look up the existing channel to get its provider for validation.
|
||||
ch, err := s.DB.GetChannelByTeam(ctx, db.GetChannelByTeamParams{ID: channelID, TeamID: teamID})
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return db.Channel{}, fmt.Errorf("channel not found")
|
||||
}
|
||||
return db.Channel{}, fmt.Errorf("get channel: %w", err)
|
||||
}
|
||||
|
||||
// Validate required config fields for this provider.
|
||||
for _, field := range requiredFields[ch.Provider] {
|
||||
if config[field] == "" {
|
||||
return db.Channel{}, fmt.Errorf("invalid: %s is required for %s", field, ch.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
// For webhooks, auto-generate secret if not provided.
|
||||
if ch.Provider == "webhook" && config["secret"] == "" {
|
||||
config["secret"] = generateSecret()
|
||||
}
|
||||
|
||||
// Encrypt all config fields.
|
||||
encrypted := make(map[string]string, len(config))
|
||||
for k, v := range config {
|
||||
enc, err := EncryptSecret(s.EncKey, v)
|
||||
if err != nil {
|
||||
return db.Channel{}, fmt.Errorf("encrypt config field %s: %w", k, err)
|
||||
}
|
||||
encrypted[k] = enc
|
||||
}
|
||||
|
||||
configJSON, err := json.Marshal(encrypted)
|
||||
if err != nil {
|
||||
return db.Channel{}, fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
|
||||
updated, err := s.DB.UpdateChannelConfig(ctx, db.UpdateChannelConfigParams{
|
||||
ID: channelID,
|
||||
TeamID: teamID,
|
||||
Config: configJSON,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return db.Channel{}, fmt.Errorf("channel not found")
|
||||
}
|
||||
return db.Channel{}, fmt.Errorf("update channel config: %w", err)
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// Test validates config and sends a test notification without persisting anything.
|
||||
func (s *Service) Test(ctx context.Context, provider string, config map[string]string) error {
|
||||
if !validProviders[provider] {
|
||||
return fmt.Errorf("invalid: unsupported provider %q", provider)
|
||||
}
|
||||
|
||||
for _, field := range requiredFields[provider] {
|
||||
if config[field] == "" {
|
||||
return fmt.Errorf("invalid: %s is required for %s", field, provider)
|
||||
}
|
||||
}
|
||||
|
||||
// For webhooks, auto-generate a temporary secret if not provided.
|
||||
if provider == "webhook" && config["secret"] == "" {
|
||||
config["secret"] = generateSecret()
|
||||
}
|
||||
|
||||
testEvent := events.Event{
|
||||
Event: "channel.test",
|
||||
Timestamp: events.Now(),
|
||||
TeamID: "test",
|
||||
Actor: events.Actor{Type: events.ActorSystem},
|
||||
Resource: events.Resource{ID: "test", Type: "channel"},
|
||||
}
|
||||
|
||||
return Deliver(ctx, provider, config, testEvent)
|
||||
}
|
||||
|
||||
// Delete removes a channel by ID, scoped to the given team.
|
||||
func (s *Service) Delete(ctx context.Context, channelID, teamID pgtype.UUID) error {
|
||||
return s.DB.DeleteChannelByTeam(ctx, db.DeleteChannelByTeamParams{ID: channelID, TeamID: teamID})
|
||||
}
|
||||
|
||||
// cleanName normalises a channel name: trim whitespace, lowercase, replace
|
||||
// spaces with hyphens, then validate against SafeName rules.
|
||||
func cleanName(name string) (string, error) {
|
||||
name = strings.TrimSpace(name)
|
||||
name = strings.ToLower(name)
|
||||
name = strings.ReplaceAll(name, " ", "-")
|
||||
if err := validate.SafeName(name); err != nil {
|
||||
return "", fmt.Errorf("invalid: %w", err)
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func generateSecret() string {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand failed: %v", err))
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
@ -1,119 +0,0 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ShoutrrrURL builds a shoutrrr-compatible URL from structured provider config.
|
||||
func ShoutrrrURL(provider string, config map[string]string) (string, error) {
|
||||
switch provider {
|
||||
case "discord":
|
||||
return discordURL(config)
|
||||
case "slack":
|
||||
return slackURL(config)
|
||||
case "teams":
|
||||
return teamsURL(config)
|
||||
case "googlechat":
|
||||
return googlechatURL(config)
|
||||
case "telegram":
|
||||
return telegramURL(config)
|
||||
case "matrix":
|
||||
return matrixURL(config)
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported shoutrrr provider: %s", provider)
|
||||
}
|
||||
}
|
||||
|
||||
// discordURL converts https://discord.com/api/webhooks/{id}/{token} → discord://{token}@{id}
|
||||
func discordURL(config map[string]string) (string, error) {
|
||||
u, err := url.Parse(config["webhook_url"])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid discord webhook URL: %w", err)
|
||||
}
|
||||
// Path: /api/webhooks/{id}/{token}
|
||||
parts := strings.Split(strings.TrimPrefix(u.Path, "/"), "/")
|
||||
if len(parts) < 4 || parts[0] != "api" || parts[1] != "webhooks" {
|
||||
return "", fmt.Errorf("unexpected discord webhook URL format")
|
||||
}
|
||||
webhookID, token := parts[2], parts[3]
|
||||
return fmt.Sprintf("discord://%s@%s?splitLines=No", token, webhookID), nil
|
||||
}
|
||||
|
||||
// slackURL converts https://hooks.slack.com/services/T.../B.../XXX → slack://T.../B.../XXX
|
||||
func slackURL(config map[string]string) (string, error) {
|
||||
u, err := url.Parse(config["webhook_url"])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid slack webhook URL: %w", err)
|
||||
}
|
||||
// Path: /services/TXXXXX/BXXXXX/XXXXXXXX
|
||||
parts := strings.Split(strings.TrimPrefix(u.Path, "/"), "/")
|
||||
if len(parts) < 4 || parts[0] != "services" {
|
||||
return "", fmt.Errorf("unexpected slack webhook URL format")
|
||||
}
|
||||
return fmt.Sprintf("slack://hook:%s-%s-%s@webhook", parts[1], parts[2], parts[3]), nil
|
||||
}
|
||||
|
||||
// teamsWebhookRe extracts the 4 components from a Teams webhook URL.
|
||||
// Format: https://<host>/<path>/{group}@{tenant}/IncomingWebhook/{altID}/{groupOwner}
|
||||
var teamsWebhookRe = regexp.MustCompile(`([0-9a-f-]{36})@([0-9a-f-]{36})/[^/]+/([0-9a-f]{32})/([0-9a-f-]{36})`)
|
||||
|
||||
// teamsURL converts a Teams webhook URL → teams://Group@Tenant/AltID/GroupOwner
|
||||
func teamsURL(config map[string]string) (string, error) {
|
||||
webhookURL := config["webhook_url"]
|
||||
if webhookURL == "" {
|
||||
return "", fmt.Errorf("teams webhook_url is required")
|
||||
}
|
||||
groups := teamsWebhookRe.FindStringSubmatch(webhookURL)
|
||||
if len(groups) != 5 {
|
||||
return "", fmt.Errorf("unexpected teams webhook URL format")
|
||||
}
|
||||
group, tenant, altID, groupOwner := groups[1], groups[2], groups[3], groups[4]
|
||||
return fmt.Sprintf("teams://%s@%s/%s/%s", group, tenant, altID, groupOwner), nil
|
||||
}
|
||||
|
||||
// googlechatURL converts a Google Chat webhook URL to shoutrrr format.
|
||||
// Input: https://chat.googleapis.com/v1/spaces/SPACE/messages?key=KEY&token=TOKEN
|
||||
// Output: googlechat://chat.googleapis.com/v1/spaces/SPACE/messages?key=KEY&token=TOKEN
|
||||
func googlechatURL(config map[string]string) (string, error) {
|
||||
webhookURL := config["webhook_url"]
|
||||
if webhookURL == "" {
|
||||
return "", fmt.Errorf("googlechat webhook_url is required")
|
||||
}
|
||||
u, err := url.Parse(webhookURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid googlechat webhook URL: %w", err)
|
||||
}
|
||||
if u.Host != "chat.googleapis.com" {
|
||||
return "", fmt.Errorf("unexpected googlechat webhook URL host: %s", u.Host)
|
||||
}
|
||||
// Rebuild as googlechat:// scheme with same host, path, and query.
|
||||
u.Scheme = "googlechat"
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
// telegramURL builds telegram://token@telegram/?chats=chatID
|
||||
func telegramURL(config map[string]string) (string, error) {
|
||||
token := config["bot_token"]
|
||||
chatID := config["chat_id"]
|
||||
if token == "" || chatID == "" {
|
||||
return "", fmt.Errorf("telegram bot_token and chat_id are required")
|
||||
}
|
||||
return fmt.Sprintf("telegram://%s@telegram/?chats=%s", token, chatID), nil
|
||||
}
|
||||
|
||||
// matrixURL builds matrix://user:token@homeserver/room
|
||||
func matrixURL(config map[string]string) (string, error) {
|
||||
homeserver := config["homeserver_url"]
|
||||
token := config["access_token"]
|
||||
roomID := config["room_id"]
|
||||
if homeserver == "" || token == "" || roomID == "" {
|
||||
return "", fmt.Errorf("matrix homeserver_url, access_token, and room_id are required")
|
||||
}
|
||||
// Strip protocol from homeserver URL.
|
||||
host := strings.TrimPrefix(strings.TrimPrefix(homeserver, "https://"), "http://")
|
||||
// Room ID often starts with ! — URL-encode it.
|
||||
return fmt.Sprintf("matrix://:%s@%s/%s", url.PathEscape(token), host, url.PathEscape(roomID)), nil
|
||||
}
|
||||
@ -1,62 +0,0 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// WebhookDelivery delivers events to webhook URLs with HMAC signing.
|
||||
type WebhookDelivery struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewWebhookDelivery constructs a webhook delivery client.
|
||||
func NewWebhookDelivery() *WebhookDelivery {
|
||||
return &WebhookDelivery{
|
||||
client: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
CheckRedirect: func(*http.Request, []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Deliver signs and POSTs the event payload to the configured URL.
|
||||
func (d *WebhookDelivery) Deliver(ctx context.Context, targetURL, secret string, payload []byte) error {
|
||||
timestamp := time.Now().UTC().Format(time.RFC3339)
|
||||
deliveryID := uuid.New().String()
|
||||
|
||||
// Compute HMAC-SHA256: sign over "timestamp.body".
|
||||
mac := hmac.New(sha256.New, []byte(secret))
|
||||
mac.Write([]byte(timestamp + "." + string(payload)))
|
||||
signature := "sha256=" + hex.EncodeToString(mac.Sum(nil))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, strings.NewReader(string(payload)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-WRENN-SIGNATURE", signature)
|
||||
req.Header.Set("X-Wrenn-Delivery", deliveryID)
|
||||
req.Header.Set("X-Wrenn-Timestamp", timestamp)
|
||||
|
||||
resp, err := d.client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("http post: %w", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return fmt.Errorf("webhook returned %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -1,70 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"os"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
||||
// Config holds the control plane configuration.
|
||||
type Config struct {
|
||||
DatabaseURL string
|
||||
RedisURL string
|
||||
ListenAddr string
|
||||
JWTSecret string
|
||||
|
||||
// mTLS — CP→Agent channel. Both must be set to enable mTLS; omitting either
|
||||
// disables cert issuance and leaves agent connections on plain HTTP (dev mode).
|
||||
CACert string // WRENN_CA_CERT — PEM-encoded internal CA certificate
|
||||
CAKey string // WRENN_CA_KEY — PEM-encoded internal CA private key
|
||||
|
||||
OAuthGitHubClientID string
|
||||
OAuthGitHubClientSecret string
|
||||
OAuthRedirectURL string
|
||||
CPPublicURL string
|
||||
|
||||
// Channels — encryption for channel secrets (AES-256-GCM).
|
||||
EncryptionKeyHex string // WRENN_ENCRYPTION_KEY raw hex string (for validation)
|
||||
EncryptionKey [32]byte // parsed 32-byte key
|
||||
}
|
||||
|
||||
// Load reads configuration from a .env file (if present) and environment variables.
|
||||
// Real environment variables take precedence over .env values.
|
||||
func Load() Config {
|
||||
// Best-effort load — missing .env file is fine.
|
||||
_ = godotenv.Load()
|
||||
|
||||
cfg := Config{
|
||||
DatabaseURL: envOrDefault("DATABASE_URL", "postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable"),
|
||||
RedisURL: envOrDefault("REDIS_URL", "redis://localhost:6379/0"),
|
||||
ListenAddr: envOrDefault("WRENN_CP_LISTEN_ADDR", ":8080"),
|
||||
JWTSecret: os.Getenv("JWT_SECRET"),
|
||||
|
||||
CACert: os.Getenv("WRENN_CA_CERT"),
|
||||
CAKey: os.Getenv("WRENN_CA_KEY"),
|
||||
|
||||
OAuthGitHubClientID: os.Getenv("OAUTH_GITHUB_CLIENT_ID"),
|
||||
OAuthGitHubClientSecret: os.Getenv("OAUTH_GITHUB_CLIENT_SECRET"),
|
||||
OAuthRedirectURL: envOrDefault("OAUTH_REDIRECT_URL", "https://app.wrenn.dev"),
|
||||
CPPublicURL: os.Getenv("CP_PUBLIC_URL"),
|
||||
|
||||
EncryptionKeyHex: os.Getenv("WRENN_ENCRYPTION_KEY"),
|
||||
}
|
||||
|
||||
if cfg.EncryptionKeyHex != "" {
|
||||
b, err := hex.DecodeString(cfg.EncryptionKeyHex)
|
||||
if err == nil && len(b) == 32 {
|
||||
copy(cfg.EncryptionKey[:], b)
|
||||
}
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
func envOrDefault(key, def string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return def
|
||||
}
|
||||
@ -1,177 +0,0 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: api_keys.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const deleteAPIKey = `-- name: DeleteAPIKey :exec
|
||||
DELETE FROM team_api_keys WHERE id = $1 AND team_id = $2
|
||||
`
|
||||
|
||||
type DeleteAPIKeyParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) DeleteAPIKey(ctx context.Context, arg DeleteAPIKeyParams) error {
|
||||
_, err := q.db.Exec(ctx, deleteAPIKey, arg.ID, arg.TeamID)
|
||||
return err
|
||||
}
|
||||
|
||||
const getAPIKeyByHash = `-- name: GetAPIKeyByHash :one
|
||||
SELECT id, team_id, name, key_hash, key_prefix, created_by, created_at, last_used FROM team_api_keys WHERE key_hash = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetAPIKeyByHash(ctx context.Context, keyHash string) (TeamApiKey, error) {
|
||||
row := q.db.QueryRow(ctx, getAPIKeyByHash, keyHash)
|
||||
var i TeamApiKey
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.Name,
|
||||
&i.KeyHash,
|
||||
&i.KeyPrefix,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.LastUsed,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertAPIKey = `-- name: InsertAPIKey :one
|
||||
INSERT INTO team_api_keys (id, team_id, name, key_hash, key_prefix, created_by)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id, team_id, name, key_hash, key_prefix, created_by, created_at, last_used
|
||||
`
|
||||
|
||||
type InsertAPIKeyParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
KeyHash string `json:"key_hash"`
|
||||
KeyPrefix string `json:"key_prefix"`
|
||||
CreatedBy pgtype.UUID `json:"created_by"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (TeamApiKey, error) {
|
||||
row := q.db.QueryRow(ctx, insertAPIKey,
|
||||
arg.ID,
|
||||
arg.TeamID,
|
||||
arg.Name,
|
||||
arg.KeyHash,
|
||||
arg.KeyPrefix,
|
||||
arg.CreatedBy,
|
||||
)
|
||||
var i TeamApiKey
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.Name,
|
||||
&i.KeyHash,
|
||||
&i.KeyPrefix,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.LastUsed,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listAPIKeysByTeam = `-- name: ListAPIKeysByTeam :many
|
||||
SELECT id, team_id, name, key_hash, key_prefix, created_by, created_at, last_used FROM team_api_keys WHERE team_id = $1 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListAPIKeysByTeam(ctx context.Context, teamID pgtype.UUID) ([]TeamApiKey, error) {
|
||||
rows, err := q.db.Query(ctx, listAPIKeysByTeam, teamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []TeamApiKey
|
||||
for rows.Next() {
|
||||
var i TeamApiKey
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.Name,
|
||||
&i.KeyHash,
|
||||
&i.KeyPrefix,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.LastUsed,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listAPIKeysByTeamWithCreator = `-- name: ListAPIKeysByTeamWithCreator :many
|
||||
SELECT k.id, k.team_id, k.name, k.key_hash, k.key_prefix, k.created_by, k.created_at, k.last_used,
|
||||
u.email AS creator_email
|
||||
FROM team_api_keys k
|
||||
JOIN users u ON u.id = k.created_by
|
||||
WHERE k.team_id = $1
|
||||
ORDER BY k.created_at DESC
|
||||
`
|
||||
|
||||
type ListAPIKeysByTeamWithCreatorRow struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
KeyHash string `json:"key_hash"`
|
||||
KeyPrefix string `json:"key_prefix"`
|
||||
CreatedBy pgtype.UUID `json:"created_by"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
LastUsed pgtype.Timestamptz `json:"last_used"`
|
||||
CreatorEmail string `json:"creator_email"`
|
||||
}
|
||||
|
||||
func (q *Queries) ListAPIKeysByTeamWithCreator(ctx context.Context, teamID pgtype.UUID) ([]ListAPIKeysByTeamWithCreatorRow, error) {
|
||||
rows, err := q.db.Query(ctx, listAPIKeysByTeamWithCreator, teamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ListAPIKeysByTeamWithCreatorRow
|
||||
for rows.Next() {
|
||||
var i ListAPIKeysByTeamWithCreatorRow
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.Name,
|
||||
&i.KeyHash,
|
||||
&i.KeyPrefix,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.LastUsed,
|
||||
&i.CreatorEmail,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateAPIKeyLastUsed = `-- name: UpdateAPIKeyLastUsed :exec
|
||||
UPDATE team_api_keys SET last_used = NOW() WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) UpdateAPIKeyLastUsed(ctx context.Context, id pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, updateAPIKeyLastUsed, id)
|
||||
return err
|
||||
}
|
||||
@ -1,111 +0,0 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: audit.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const insertAuditLog = `-- name: InsertAuditLog :exec
|
||||
INSERT INTO audit_logs (id, team_id, actor_type, actor_id, actor_name, resource_type, resource_id, action, scope, status, metadata)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
`
|
||||
|
||||
type InsertAuditLogParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
ActorType string `json:"actor_type"`
|
||||
ActorID pgtype.Text `json:"actor_id"`
|
||||
ActorName string `json:"actor_name"`
|
||||
ResourceType string `json:"resource_type"`
|
||||
ResourceID pgtype.Text `json:"resource_id"`
|
||||
Action string `json:"action"`
|
||||
Scope string `json:"scope"`
|
||||
Status string `json:"status"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) error {
|
||||
_, err := q.db.Exec(ctx, insertAuditLog,
|
||||
arg.ID,
|
||||
arg.TeamID,
|
||||
arg.ActorType,
|
||||
arg.ActorID,
|
||||
arg.ActorName,
|
||||
arg.ResourceType,
|
||||
arg.ResourceID,
|
||||
arg.Action,
|
||||
arg.Scope,
|
||||
arg.Status,
|
||||
arg.Metadata,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
const listAuditLogs = `-- name: ListAuditLogs :many
|
||||
SELECT id, team_id, actor_type, actor_id, actor_name, resource_type, resource_id, action, scope, status, metadata, created_at FROM audit_logs
|
||||
WHERE team_id = $1
|
||||
AND scope = ANY($2::text[])
|
||||
AND (cardinality($3::text[]) = 0 OR resource_type = ANY($3::text[]))
|
||||
AND (cardinality($4::text[]) = 0 OR action = ANY($4::text[]))
|
||||
AND ($5::timestamptz IS NULL OR created_at < $5
|
||||
OR (created_at = $5 AND id < $6))
|
||||
ORDER BY created_at DESC, id DESC
|
||||
LIMIT $7
|
||||
`
|
||||
|
||||
type ListAuditLogsParams struct {
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Column2 []string `json:"column_2"`
|
||||
Column3 []string `json:"column_3"`
|
||||
Column4 []string `json:"column_4"`
|
||||
Column5 pgtype.Timestamptz `json:"column_5"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Limit int32 `json:"limit"`
|
||||
}
|
||||
|
||||
func (q *Queries) ListAuditLogs(ctx context.Context, arg ListAuditLogsParams) ([]AuditLog, error) {
|
||||
rows, err := q.db.Query(ctx, listAuditLogs,
|
||||
arg.TeamID,
|
||||
arg.Column2,
|
||||
arg.Column3,
|
||||
arg.Column4,
|
||||
arg.Column5,
|
||||
arg.ID,
|
||||
arg.Limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []AuditLog
|
||||
for rows.Next() {
|
||||
var i AuditLog
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.ActorType,
|
||||
&i.ActorID,
|
||||
&i.ActorName,
|
||||
&i.ResourceType,
|
||||
&i.ResourceID,
|
||||
&i.Action,
|
||||
&i.Scope,
|
||||
&i.Status,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
@ -1,225 +0,0 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: channels.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const deleteChannelByTeam = `-- name: DeleteChannelByTeam :exec
|
||||
DELETE FROM channels WHERE id = $1 AND team_id = $2
|
||||
`
|
||||
|
||||
type DeleteChannelByTeamParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) DeleteChannelByTeam(ctx context.Context, arg DeleteChannelByTeamParams) error {
|
||||
_, err := q.db.Exec(ctx, deleteChannelByTeam, arg.ID, arg.TeamID)
|
||||
return err
|
||||
}
|
||||
|
||||
const getChannelByTeam = `-- name: GetChannelByTeam :one
|
||||
SELECT id, team_id, name, provider, config, event_types, created_at, updated_at FROM channels WHERE id = $1 AND team_id = $2
|
||||
`
|
||||
|
||||
type GetChannelByTeamParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetChannelByTeam(ctx context.Context, arg GetChannelByTeamParams) (Channel, error) {
|
||||
row := q.db.QueryRow(ctx, getChannelByTeam, arg.ID, arg.TeamID)
|
||||
var i Channel
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.Config,
|
||||
&i.EventTypes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertChannel = `-- name: InsertChannel :one
|
||||
INSERT INTO channels (id, team_id, name, provider, config, event_types)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id, team_id, name, provider, config, event_types, created_at, updated_at
|
||||
`
|
||||
|
||||
type InsertChannelParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Config []byte `json:"config"`
|
||||
EventTypes []string `json:"event_types"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertChannel(ctx context.Context, arg InsertChannelParams) (Channel, error) {
|
||||
row := q.db.QueryRow(ctx, insertChannel,
|
||||
arg.ID,
|
||||
arg.TeamID,
|
||||
arg.Name,
|
||||
arg.Provider,
|
||||
arg.Config,
|
||||
arg.EventTypes,
|
||||
)
|
||||
var i Channel
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.Config,
|
||||
&i.EventTypes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listChannelsByTeam = `-- name: ListChannelsByTeam :many
|
||||
SELECT id, team_id, name, provider, config, event_types, created_at, updated_at FROM channels WHERE team_id = $1 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListChannelsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Channel, error) {
|
||||
rows, err := q.db.Query(ctx, listChannelsByTeam, teamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Channel
|
||||
for rows.Next() {
|
||||
var i Channel
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.Config,
|
||||
&i.EventTypes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listChannelsForEvent = `-- name: ListChannelsForEvent :many
|
||||
SELECT id, team_id, name, provider, config, event_types, created_at, updated_at FROM channels
|
||||
WHERE team_id = $1
|
||||
AND $2::text = ANY(event_types)
|
||||
ORDER BY created_at
|
||||
`
|
||||
|
||||
type ListChannelsForEventParams struct {
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
EventType string `json:"event_type"`
|
||||
}
|
||||
|
||||
func (q *Queries) ListChannelsForEvent(ctx context.Context, arg ListChannelsForEventParams) ([]Channel, error) {
|
||||
rows, err := q.db.Query(ctx, listChannelsForEvent, arg.TeamID, arg.EventType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Channel
|
||||
for rows.Next() {
|
||||
var i Channel
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.Config,
|
||||
&i.EventTypes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateChannel = `-- name: UpdateChannel :one
|
||||
UPDATE channels SET name = $3, event_types = $4, updated_at = NOW()
|
||||
WHERE id = $1 AND team_id = $2
|
||||
RETURNING id, team_id, name, provider, config, event_types, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateChannelParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
EventTypes []string `json:"event_types"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateChannel(ctx context.Context, arg UpdateChannelParams) (Channel, error) {
|
||||
row := q.db.QueryRow(ctx, updateChannel,
|
||||
arg.ID,
|
||||
arg.TeamID,
|
||||
arg.Name,
|
||||
arg.EventTypes,
|
||||
)
|
||||
var i Channel
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.Config,
|
||||
&i.EventTypes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateChannelConfig = `-- name: UpdateChannelConfig :one
|
||||
UPDATE channels SET config = $3, updated_at = NOW()
|
||||
WHERE id = $1 AND team_id = $2
|
||||
RETURNING id, team_id, name, provider, config, event_types, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateChannelConfigParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Config []byte `json:"config"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateChannelConfig(ctx context.Context, arg UpdateChannelConfigParams) (Channel, error) {
|
||||
row := q.db.QueryRow(ctx, updateChannelConfig, arg.ID, arg.TeamID, arg.Config)
|
||||
var i Channel
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.Config,
|
||||
&i.EventTypes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@ -1,32 +0,0 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
type DBTX interface {
|
||||
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
|
||||
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
|
||||
QueryRow(context.Context, string, ...interface{}) pgx.Row
|
||||
}
|
||||
|
||||
func New(db DBTX) *Queries {
|
||||
return &Queries{db: db}
|
||||
}
|
||||
|
||||
type Queries struct {
|
||||
db DBTX
|
||||
}
|
||||
|
||||
func (q *Queries) WithTx(tx pgx.Tx) *Queries {
|
||||
return &Queries{
|
||||
db: tx,
|
||||
}
|
||||
}
|
||||
@ -1,92 +0,0 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: host_refresh_tokens.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const deleteExpiredHostRefreshTokens = `-- name: DeleteExpiredHostRefreshTokens :exec
|
||||
DELETE FROM host_refresh_tokens
|
||||
WHERE expires_at < NOW() OR revoked_at IS NOT NULL
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteExpiredHostRefreshTokens(ctx context.Context) error {
|
||||
_, err := q.db.Exec(ctx, deleteExpiredHostRefreshTokens)
|
||||
return err
|
||||
}
|
||||
|
||||
const getHostRefreshTokenByHash = `-- name: GetHostRefreshTokenByHash :one
|
||||
SELECT id, host_id, token_hash, expires_at, created_at, revoked_at FROM host_refresh_tokens
|
||||
WHERE token_hash = $1 AND revoked_at IS NULL AND expires_at > NOW()
|
||||
`
|
||||
|
||||
func (q *Queries) GetHostRefreshTokenByHash(ctx context.Context, tokenHash string) (HostRefreshToken, error) {
|
||||
row := q.db.QueryRow(ctx, getHostRefreshTokenByHash, tokenHash)
|
||||
var i HostRefreshToken
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.HostID,
|
||||
&i.TokenHash,
|
||||
&i.ExpiresAt,
|
||||
&i.CreatedAt,
|
||||
&i.RevokedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertHostRefreshToken = `-- name: InsertHostRefreshToken :one
|
||||
INSERT INTO host_refresh_tokens (id, host_id, token_hash, expires_at)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING id, host_id, token_hash, expires_at, created_at, revoked_at
|
||||
`
|
||||
|
||||
type InsertHostRefreshTokenParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
HostID pgtype.UUID `json:"host_id"`
|
||||
TokenHash string `json:"token_hash"`
|
||||
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertHostRefreshToken(ctx context.Context, arg InsertHostRefreshTokenParams) (HostRefreshToken, error) {
|
||||
row := q.db.QueryRow(ctx, insertHostRefreshToken,
|
||||
arg.ID,
|
||||
arg.HostID,
|
||||
arg.TokenHash,
|
||||
arg.ExpiresAt,
|
||||
)
|
||||
var i HostRefreshToken
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.HostID,
|
||||
&i.TokenHash,
|
||||
&i.ExpiresAt,
|
||||
&i.CreatedAt,
|
||||
&i.RevokedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const revokeHostRefreshToken = `-- name: RevokeHostRefreshToken :exec
|
||||
UPDATE host_refresh_tokens SET revoked_at = NOW() WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) RevokeHostRefreshToken(ctx context.Context, id pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, revokeHostRefreshToken, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const revokeHostRefreshTokensByHost = `-- name: RevokeHostRefreshTokensByHost :exec
|
||||
UPDATE host_refresh_tokens SET revoked_at = NOW()
|
||||
WHERE host_id = $1 AND revoked_at IS NULL
|
||||
`
|
||||
|
||||
func (q *Queries) RevokeHostRefreshTokensByHost(ctx context.Context, hostID pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, revokeHostRefreshTokensByHost, hostID)
|
||||
return err
|
||||
}
|
||||
@ -1,632 +0,0 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: hosts.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const addHostTag = `-- name: AddHostTag :exec
|
||||
INSERT INTO host_tags (host_id, tag) VALUES ($1, $2) ON CONFLICT DO NOTHING
|
||||
`
|
||||
|
||||
type AddHostTagParams struct {
|
||||
HostID pgtype.UUID `json:"host_id"`
|
||||
Tag string `json:"tag"`
|
||||
}
|
||||
|
||||
func (q *Queries) AddHostTag(ctx context.Context, arg AddHostTagParams) error {
|
||||
_, err := q.db.Exec(ctx, addHostTag, arg.HostID, arg.Tag)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteHost = `-- name: DeleteHost :exec
|
||||
DELETE FROM hosts WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteHost(ctx context.Context, id pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, deleteHost, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const getHost = `-- name: GetHost :one
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetHost(ctx context.Context, id pgtype.UUID) (Host, error) {
|
||||
row := q.db.QueryRow(ctx, getHost, id)
|
||||
var i Host
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Type,
|
||||
&i.TeamID,
|
||||
&i.Provider,
|
||||
&i.AvailabilityZone,
|
||||
&i.Arch,
|
||||
&i.CpuCores,
|
||||
&i.MemoryMb,
|
||||
&i.DiskGb,
|
||||
&i.Address,
|
||||
&i.Status,
|
||||
&i.LastHeartbeatAt,
|
||||
&i.Metadata,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.CertExpiresAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getHostByTeam = `-- name: GetHostByTeam :one
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE id = $1 AND team_id = $2
|
||||
`
|
||||
|
||||
type GetHostByTeamParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetHostByTeam(ctx context.Context, arg GetHostByTeamParams) (Host, error) {
|
||||
row := q.db.QueryRow(ctx, getHostByTeam, arg.ID, arg.TeamID)
|
||||
var i Host
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Type,
|
||||
&i.TeamID,
|
||||
&i.Provider,
|
||||
&i.AvailabilityZone,
|
||||
&i.Arch,
|
||||
&i.CpuCores,
|
||||
&i.MemoryMb,
|
||||
&i.DiskGb,
|
||||
&i.Address,
|
||||
&i.Status,
|
||||
&i.LastHeartbeatAt,
|
||||
&i.Metadata,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.CertExpiresAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getHostTags = `-- name: GetHostTags :many
|
||||
SELECT tag FROM host_tags WHERE host_id = $1 ORDER BY tag
|
||||
`
|
||||
|
||||
func (q *Queries) GetHostTags(ctx context.Context, hostID pgtype.UUID) ([]string, error) {
|
||||
rows, err := q.db.Query(ctx, getHostTags, hostID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []string
|
||||
for rows.Next() {
|
||||
var tag string
|
||||
if err := rows.Scan(&tag); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, tag)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getHostTokensByHost = `-- name: GetHostTokensByHost :many
|
||||
SELECT id, host_id, created_by, created_at, expires_at, used_at FROM host_tokens WHERE host_id = $1 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID pgtype.UUID) ([]HostToken, error) {
|
||||
rows, err := q.db.Query(ctx, getHostTokensByHost, hostID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []HostToken
|
||||
for rows.Next() {
|
||||
var i HostToken
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.HostID,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.ExpiresAt,
|
||||
&i.UsedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const insertHost = `-- name: InsertHost :one
|
||||
INSERT INTO hosts (id, type, team_id, provider, availability_zone, created_by)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at
|
||||
`
|
||||
|
||||
type InsertHostParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Type string `json:"type"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Provider string `json:"provider"`
|
||||
AvailabilityZone string `json:"availability_zone"`
|
||||
CreatedBy pgtype.UUID `json:"created_by"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertHost(ctx context.Context, arg InsertHostParams) (Host, error) {
|
||||
row := q.db.QueryRow(ctx, insertHost,
|
||||
arg.ID,
|
||||
arg.Type,
|
||||
arg.TeamID,
|
||||
arg.Provider,
|
||||
arg.AvailabilityZone,
|
||||
arg.CreatedBy,
|
||||
)
|
||||
var i Host
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Type,
|
||||
&i.TeamID,
|
||||
&i.Provider,
|
||||
&i.AvailabilityZone,
|
||||
&i.Arch,
|
||||
&i.CpuCores,
|
||||
&i.MemoryMb,
|
||||
&i.DiskGb,
|
||||
&i.Address,
|
||||
&i.Status,
|
||||
&i.LastHeartbeatAt,
|
||||
&i.Metadata,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.CertExpiresAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertHostToken = `-- name: InsertHostToken :one
|
||||
INSERT INTO host_tokens (id, host_id, created_by, expires_at)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING id, host_id, created_by, created_at, expires_at, used_at
|
||||
`
|
||||
|
||||
type InsertHostTokenParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
HostID pgtype.UUID `json:"host_id"`
|
||||
CreatedBy pgtype.UUID `json:"created_by"`
|
||||
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertHostToken(ctx context.Context, arg InsertHostTokenParams) (HostToken, error) {
|
||||
row := q.db.QueryRow(ctx, insertHostToken,
|
||||
arg.ID,
|
||||
arg.HostID,
|
||||
arg.CreatedBy,
|
||||
arg.ExpiresAt,
|
||||
)
|
||||
var i HostToken
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.HostID,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.ExpiresAt,
|
||||
&i.UsedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listActiveHosts = `-- name: ListActiveHosts :many
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at
|
||||
`
|
||||
|
||||
// Returns all hosts that have completed registration (not pending/offline).
|
||||
func (q *Queries) ListActiveHosts(ctx context.Context) ([]Host, error) {
|
||||
rows, err := q.db.Query(ctx, listActiveHosts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Host
|
||||
for rows.Next() {
|
||||
var i Host
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Type,
|
||||
&i.TeamID,
|
||||
&i.Provider,
|
||||
&i.AvailabilityZone,
|
||||
&i.Arch,
|
||||
&i.CpuCores,
|
||||
&i.MemoryMb,
|
||||
&i.DiskGb,
|
||||
&i.Address,
|
||||
&i.Status,
|
||||
&i.LastHeartbeatAt,
|
||||
&i.Metadata,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.CertExpiresAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listHosts = `-- name: ListHosts :many
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
|
||||
rows, err := q.db.Query(ctx, listHosts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Host
|
||||
for rows.Next() {
|
||||
var i Host
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Type,
|
||||
&i.TeamID,
|
||||
&i.Provider,
|
||||
&i.AvailabilityZone,
|
||||
&i.Arch,
|
||||
&i.CpuCores,
|
||||
&i.MemoryMb,
|
||||
&i.DiskGb,
|
||||
&i.Address,
|
||||
&i.Status,
|
||||
&i.LastHeartbeatAt,
|
||||
&i.Metadata,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.CertExpiresAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listHostsByStatus = `-- name: ListHostsByStatus :many
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status = $1 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, error) {
|
||||
rows, err := q.db.Query(ctx, listHostsByStatus, status)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Host
|
||||
for rows.Next() {
|
||||
var i Host
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Type,
|
||||
&i.TeamID,
|
||||
&i.Provider,
|
||||
&i.AvailabilityZone,
|
||||
&i.Arch,
|
||||
&i.CpuCores,
|
||||
&i.MemoryMb,
|
||||
&i.DiskGb,
|
||||
&i.Address,
|
||||
&i.Status,
|
||||
&i.LastHeartbeatAt,
|
||||
&i.Metadata,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.CertExpiresAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listHostsByTag = `-- name: ListHostsByTag :many
|
||||
SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.cert_expires_at FROM hosts h
|
||||
JOIN host_tags ht ON ht.host_id = h.id
|
||||
WHERE ht.tag = $1
|
||||
ORDER BY h.created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error) {
|
||||
rows, err := q.db.Query(ctx, listHostsByTag, tag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Host
|
||||
for rows.Next() {
|
||||
var i Host
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Type,
|
||||
&i.TeamID,
|
||||
&i.Provider,
|
||||
&i.AvailabilityZone,
|
||||
&i.Arch,
|
||||
&i.CpuCores,
|
||||
&i.MemoryMb,
|
||||
&i.DiskGb,
|
||||
&i.Address,
|
||||
&i.Status,
|
||||
&i.LastHeartbeatAt,
|
||||
&i.Metadata,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.CertExpiresAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listHostsByTeam = `-- name: ListHostsByTeam :many
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Host, error) {
|
||||
rows, err := q.db.Query(ctx, listHostsByTeam, teamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Host
|
||||
for rows.Next() {
|
||||
var i Host
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Type,
|
||||
&i.TeamID,
|
||||
&i.Provider,
|
||||
&i.AvailabilityZone,
|
||||
&i.Arch,
|
||||
&i.CpuCores,
|
||||
&i.MemoryMb,
|
||||
&i.DiskGb,
|
||||
&i.Address,
|
||||
&i.Status,
|
||||
&i.LastHeartbeatAt,
|
||||
&i.Metadata,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.CertExpiresAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listHostsByType = `-- name: ListHostsByType :many
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE type = $1 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, error) {
|
||||
rows, err := q.db.Query(ctx, listHostsByType, type_)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Host
|
||||
for rows.Next() {
|
||||
var i Host
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Type,
|
||||
&i.TeamID,
|
||||
&i.Provider,
|
||||
&i.AvailabilityZone,
|
||||
&i.Arch,
|
||||
&i.CpuCores,
|
||||
&i.MemoryMb,
|
||||
&i.DiskGb,
|
||||
&i.Address,
|
||||
&i.Status,
|
||||
&i.LastHeartbeatAt,
|
||||
&i.Metadata,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.CertExpiresAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const markHostTokenUsed = `-- name: MarkHostTokenUsed :exec
|
||||
UPDATE host_tokens SET used_at = NOW() WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) MarkHostTokenUsed(ctx context.Context, id pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, markHostTokenUsed, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const markHostUnreachable = `-- name: MarkHostUnreachable :exec
|
||||
UPDATE hosts SET status = 'unreachable', updated_at = NOW() WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) MarkHostUnreachable(ctx context.Context, id pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, markHostUnreachable, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const registerHost = `-- name: RegisterHost :execrows
|
||||
UPDATE hosts
|
||||
SET arch = $2,
|
||||
cpu_cores = $3,
|
||||
memory_mb = $4,
|
||||
disk_gb = $5,
|
||||
address = $6,
|
||||
cert_fingerprint = $7,
|
||||
cert_expires_at = $8,
|
||||
status = 'online',
|
||||
last_heartbeat_at = NOW(),
|
||||
updated_at = NOW()
|
||||
WHERE id = $1 AND status = 'pending'
|
||||
`
|
||||
|
||||
type RegisterHostParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Arch string `json:"arch"`
|
||||
CpuCores int32 `json:"cpu_cores"`
|
||||
MemoryMb int32 `json:"memory_mb"`
|
||||
DiskGb int32 `json:"disk_gb"`
|
||||
Address string `json:"address"`
|
||||
CertFingerprint string `json:"cert_fingerprint"`
|
||||
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int64, error) {
|
||||
result, err := q.db.Exec(ctx, registerHost,
|
||||
arg.ID,
|
||||
arg.Arch,
|
||||
arg.CpuCores,
|
||||
arg.MemoryMb,
|
||||
arg.DiskGb,
|
||||
arg.Address,
|
||||
arg.CertFingerprint,
|
||||
arg.CertExpiresAt,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected(), nil
|
||||
}
|
||||
|
||||
const removeHostTag = `-- name: RemoveHostTag :exec
|
||||
DELETE FROM host_tags WHERE host_id = $1 AND tag = $2
|
||||
`
|
||||
|
||||
type RemoveHostTagParams struct {
|
||||
HostID pgtype.UUID `json:"host_id"`
|
||||
Tag string `json:"tag"`
|
||||
}
|
||||
|
||||
func (q *Queries) RemoveHostTag(ctx context.Context, arg RemoveHostTagParams) error {
|
||||
_, err := q.db.Exec(ctx, removeHostTag, arg.HostID, arg.Tag)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateHostCert = `-- name: UpdateHostCert :exec
|
||||
UPDATE hosts
|
||||
SET cert_fingerprint = $2,
|
||||
cert_expires_at = $3,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
type UpdateHostCertParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
CertFingerprint string `json:"cert_fingerprint"`
|
||||
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateHostCert(ctx context.Context, arg UpdateHostCertParams) error {
|
||||
_, err := q.db.Exec(ctx, updateHostCert, arg.ID, arg.CertFingerprint, arg.CertExpiresAt)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateHostHeartbeat = `-- name: UpdateHostHeartbeat :exec
|
||||
UPDATE hosts SET last_heartbeat_at = NOW(), updated_at = NOW() WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) UpdateHostHeartbeat(ctx context.Context, id pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, updateHostHeartbeat, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateHostHeartbeatAndStatus = `-- name: UpdateHostHeartbeatAndStatus :execrows
|
||||
UPDATE hosts
|
||||
SET last_heartbeat_at = NOW(),
|
||||
status = CASE WHEN status = 'unreachable' THEN 'online' ELSE status END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
// Updates last_heartbeat_at and transitions unreachable hosts back to online.
|
||||
// Returns 0 if no host was found (deleted), which the caller treats as 404.
|
||||
func (q *Queries) UpdateHostHeartbeatAndStatus(ctx context.Context, id pgtype.UUID) (int64, error) {
|
||||
result, err := q.db.Exec(ctx, updateHostHeartbeatAndStatus, id)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected(), nil
|
||||
}
|
||||
|
||||
const updateHostStatus = `-- name: UpdateHostStatus :exec
|
||||
UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1
|
||||
`
|
||||
|
||||
type UpdateHostStatusParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateHostStatus(ctx context.Context, arg UpdateHostStatusParams) error {
|
||||
_, err := q.db.Exec(ctx, updateHostStatus, arg.ID, arg.Status)
|
||||
return err
|
||||
}
|
||||
@ -1,250 +0,0 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: metrics.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const deleteSandboxMetricPoints = `-- name: DeleteSandboxMetricPoints :exec
|
||||
DELETE FROM sandbox_metric_points
|
||||
WHERE sandbox_id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteSandboxMetricPoints(ctx context.Context, sandboxID pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, deleteSandboxMetricPoints, sandboxID)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteSandboxMetricPointsByTier = `-- name: DeleteSandboxMetricPointsByTier :exec
|
||||
DELETE FROM sandbox_metric_points
|
||||
WHERE sandbox_id = $1 AND tier = $2
|
||||
`
|
||||
|
||||
type DeleteSandboxMetricPointsByTierParams struct {
|
||||
SandboxID pgtype.UUID `json:"sandbox_id"`
|
||||
Tier string `json:"tier"`
|
||||
}
|
||||
|
||||
func (q *Queries) DeleteSandboxMetricPointsByTier(ctx context.Context, arg DeleteSandboxMetricPointsByTierParams) error {
|
||||
_, err := q.db.Exec(ctx, deleteSandboxMetricPointsByTier, arg.SandboxID, arg.Tier)
|
||||
return err
|
||||
}
|
||||
|
||||
const getLiveMetrics = `-- name: GetLiveMetrics :one
|
||||
SELECT
|
||||
(COUNT(*) FILTER (WHERE status IN ('running', 'starting')))::INTEGER AS running_count,
|
||||
(COALESCE(SUM(vcpus) FILTER (WHERE status IN ('running', 'starting')), 0))::INTEGER AS vcpus_reserved,
|
||||
(COALESCE(SUM(memory_mb) FILTER (WHERE status IN ('running', 'starting')), 0)
|
||||
+ COALESCE(SUM(CEIL(memory_mb::NUMERIC / 2)) FILTER (WHERE status = 'paused'), 0))::INTEGER AS memory_mb_reserved
|
||||
FROM sandboxes
|
||||
WHERE team_id = $1
|
||||
`
|
||||
|
||||
type GetLiveMetricsRow struct {
|
||||
RunningCount int32 `json:"running_count"`
|
||||
VcpusReserved int32 `json:"vcpus_reserved"`
|
||||
MemoryMbReserved int32 `json:"memory_mb_reserved"`
|
||||
}
|
||||
|
||||
// Reads directly from sandboxes for accurate real-time current values.
|
||||
// CPU reserved = running + starting only (paused VMs release CPU).
|
||||
// RAM reserved = running + starting + sum(ceil(each_paused/2)) (per-VM ceiling).
|
||||
func (q *Queries) GetLiveMetrics(ctx context.Context, teamID pgtype.UUID) (GetLiveMetricsRow, error) {
|
||||
row := q.db.QueryRow(ctx, getLiveMetrics, teamID)
|
||||
var i GetLiveMetricsRow
|
||||
err := row.Scan(&i.RunningCount, &i.VcpusReserved, &i.MemoryMbReserved)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getPeakMetrics = `-- name: GetPeakMetrics :one
|
||||
SELECT
|
||||
COALESCE(MAX(running_count), 0)::INTEGER AS peak_running_count,
|
||||
COALESCE(MAX(vcpus_reserved), 0)::INTEGER AS peak_vcpus,
|
||||
COALESCE(MAX(memory_mb_reserved), 0)::INTEGER AS peak_memory_mb
|
||||
FROM sandbox_metrics_snapshots
|
||||
WHERE team_id = $1
|
||||
AND sampled_at > NOW() - INTERVAL '30 days'
|
||||
`
|
||||
|
||||
type GetPeakMetricsRow struct {
|
||||
PeakRunningCount int32 `json:"peak_running_count"`
|
||||
PeakVcpus int32 `json:"peak_vcpus"`
|
||||
PeakMemoryMb int32 `json:"peak_memory_mb"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetPeakMetrics(ctx context.Context, teamID pgtype.UUID) (GetPeakMetricsRow, error) {
|
||||
row := q.db.QueryRow(ctx, getPeakMetrics, teamID)
|
||||
var i GetPeakMetricsRow
|
||||
err := row.Scan(&i.PeakRunningCount, &i.PeakVcpus, &i.PeakMemoryMb)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getSandboxMetricPoints = `-- name: GetSandboxMetricPoints :many
|
||||
SELECT ts, cpu_pct, mem_bytes, disk_bytes
|
||||
FROM sandbox_metric_points
|
||||
WHERE sandbox_id = $1 AND tier = $2 AND ts >= $3
|
||||
ORDER BY ts ASC
|
||||
`
|
||||
|
||||
type GetSandboxMetricPointsParams struct {
|
||||
SandboxID pgtype.UUID `json:"sandbox_id"`
|
||||
Tier string `json:"tier"`
|
||||
Ts int64 `json:"ts"`
|
||||
}
|
||||
|
||||
type GetSandboxMetricPointsRow struct {
|
||||
Ts int64 `json:"ts"`
|
||||
CpuPct float64 `json:"cpu_pct"`
|
||||
MemBytes int64 `json:"mem_bytes"`
|
||||
DiskBytes int64 `json:"disk_bytes"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetSandboxMetricPoints(ctx context.Context, arg GetSandboxMetricPointsParams) ([]GetSandboxMetricPointsRow, error) {
|
||||
rows, err := q.db.Query(ctx, getSandboxMetricPoints, arg.SandboxID, arg.Tier, arg.Ts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []GetSandboxMetricPointsRow
|
||||
for rows.Next() {
|
||||
var i GetSandboxMetricPointsRow
|
||||
if err := rows.Scan(
|
||||
&i.Ts,
|
||||
&i.CpuPct,
|
||||
&i.MemBytes,
|
||||
&i.DiskBytes,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const insertMetricsSnapshot = `-- name: InsertMetricsSnapshot :exec
|
||||
INSERT INTO sandbox_metrics_snapshots (team_id, running_count, vcpus_reserved, memory_mb_reserved)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
`
|
||||
|
||||
type InsertMetricsSnapshotParams struct {
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
RunningCount int32 `json:"running_count"`
|
||||
VcpusReserved int32 `json:"vcpus_reserved"`
|
||||
MemoryMbReserved int32 `json:"memory_mb_reserved"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertMetricsSnapshot(ctx context.Context, arg InsertMetricsSnapshotParams) error {
|
||||
_, err := q.db.Exec(ctx, insertMetricsSnapshot,
|
||||
arg.TeamID,
|
||||
arg.RunningCount,
|
||||
arg.VcpusReserved,
|
||||
arg.MemoryMbReserved,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
const insertSandboxMetricPoint = `-- name: InsertSandboxMetricPoint :exec
|
||||
INSERT INTO sandbox_metric_points (sandbox_id, tier, ts, cpu_pct, mem_bytes, disk_bytes)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (sandbox_id, tier, ts) DO NOTHING
|
||||
`
|
||||
|
||||
type InsertSandboxMetricPointParams struct {
|
||||
SandboxID pgtype.UUID `json:"sandbox_id"`
|
||||
Tier string `json:"tier"`
|
||||
Ts int64 `json:"ts"`
|
||||
CpuPct float64 `json:"cpu_pct"`
|
||||
MemBytes int64 `json:"mem_bytes"`
|
||||
DiskBytes int64 `json:"disk_bytes"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertSandboxMetricPoint(ctx context.Context, arg InsertSandboxMetricPointParams) error {
|
||||
_, err := q.db.Exec(ctx, insertSandboxMetricPoint,
|
||||
arg.SandboxID,
|
||||
arg.Tier,
|
||||
arg.Ts,
|
||||
arg.CpuPct,
|
||||
arg.MemBytes,
|
||||
arg.DiskBytes,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
const pruneOldMetrics = `-- name: PruneOldMetrics :exec
|
||||
DELETE FROM sandbox_metrics_snapshots
|
||||
WHERE sampled_at < NOW() - INTERVAL '60 days'
|
||||
`
|
||||
|
||||
func (q *Queries) PruneOldMetrics(ctx context.Context) error {
|
||||
_, err := q.db.Exec(ctx, pruneOldMetrics)
|
||||
return err
|
||||
}
|
||||
|
||||
const pruneSandboxMetricPoints = `-- name: PruneSandboxMetricPoints :exec
|
||||
DELETE FROM sandbox_metric_points
|
||||
WHERE ts < EXTRACT(EPOCH FROM NOW() - INTERVAL '30 days')::BIGINT
|
||||
`
|
||||
|
||||
// Remove metric points older than 30 days for destroyed sandboxes.
|
||||
func (q *Queries) PruneSandboxMetricPoints(ctx context.Context) error {
|
||||
_, err := q.db.Exec(ctx, pruneSandboxMetricPoints)
|
||||
return err
|
||||
}
|
||||
|
||||
const sampleSandboxMetrics = `-- name: SampleSandboxMetrics :many
|
||||
SELECT
|
||||
team_id,
|
||||
(COUNT(*) FILTER (WHERE status IN ('running', 'starting')))::INTEGER AS running_count,
|
||||
(COALESCE(SUM(vcpus) FILTER (WHERE status IN ('running', 'starting')), 0))::INTEGER AS vcpus_reserved,
|
||||
(COALESCE(SUM(memory_mb) FILTER (WHERE status IN ('running', 'starting')), 0)
|
||||
+ COALESCE(SUM(CEIL(memory_mb::NUMERIC / 2)) FILTER (WHERE status = 'paused'), 0))::INTEGER AS memory_mb_reserved
|
||||
FROM sandboxes
|
||||
GROUP BY team_id
|
||||
`
|
||||
|
||||
type SampleSandboxMetricsRow struct {
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
RunningCount int32 `json:"running_count"`
|
||||
VcpusReserved int32 `json:"vcpus_reserved"`
|
||||
MemoryMbReserved int32 `json:"memory_mb_reserved"`
|
||||
}
|
||||
|
||||
// Aggregates per-team resource usage from the live sandboxes table.
|
||||
// Groups by all teams that have any sandbox row (including stopped) so that
|
||||
// zero-value snapshots are recorded when all capsules are stopped, keeping the
|
||||
// time-series charts continuous rather than trailing off into empty space.
|
||||
// CPU reserved = running + starting only (paused VMs release CPU).
|
||||
// RAM reserved = running + starting + sum(ceil(each_paused/2)) (per-VM ceiling).
|
||||
func (q *Queries) SampleSandboxMetrics(ctx context.Context) ([]SampleSandboxMetricsRow, error) {
|
||||
rows, err := q.db.Query(ctx, sampleSandboxMetrics)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []SampleSandboxMetricsRow
|
||||
for rows.Next() {
|
||||
var i SampleSandboxMetricsRow
|
||||
if err := rows.Scan(
|
||||
&i.TeamID,
|
||||
&i.RunningCount,
|
||||
&i.VcpusReserved,
|
||||
&i.MemoryMbReserved,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
@ -1,204 +0,0 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
type AdminPermission struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
Permission string `json:"permission"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
}
|
||||
|
||||
type AuditLog struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
ActorType string `json:"actor_type"`
|
||||
ActorID pgtype.Text `json:"actor_id"`
|
||||
ActorName string `json:"actor_name"`
|
||||
ResourceType string `json:"resource_type"`
|
||||
ResourceID pgtype.Text `json:"resource_id"`
|
||||
Action string `json:"action"`
|
||||
Scope string `json:"scope"`
|
||||
Status string `json:"status"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
}
|
||||
|
||||
type Channel struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Config []byte `json:"config"`
|
||||
EventTypes []string `json:"event_types"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
}
|
||||
|
||||
type Host struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Type string `json:"type"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Provider string `json:"provider"`
|
||||
AvailabilityZone string `json:"availability_zone"`
|
||||
Arch string `json:"arch"`
|
||||
CpuCores int32 `json:"cpu_cores"`
|
||||
MemoryMb int32 `json:"memory_mb"`
|
||||
DiskGb int32 `json:"disk_gb"`
|
||||
Address string `json:"address"`
|
||||
Status string `json:"status"`
|
||||
LastHeartbeatAt pgtype.Timestamptz `json:"last_heartbeat_at"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
CreatedBy pgtype.UUID `json:"created_by"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
CertFingerprint string `json:"cert_fingerprint"`
|
||||
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
|
||||
}
|
||||
|
||||
type HostRefreshToken struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
HostID pgtype.UUID `json:"host_id"`
|
||||
TokenHash string `json:"token_hash"`
|
||||
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
RevokedAt pgtype.Timestamptz `json:"revoked_at"`
|
||||
}
|
||||
|
||||
type HostTag struct {
|
||||
HostID pgtype.UUID `json:"host_id"`
|
||||
Tag string `json:"tag"`
|
||||
}
|
||||
|
||||
type HostToken struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
HostID pgtype.UUID `json:"host_id"`
|
||||
CreatedBy pgtype.UUID `json:"created_by"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
|
||||
UsedAt pgtype.Timestamptz `json:"used_at"`
|
||||
}
|
||||
|
||||
type OauthProvider struct {
|
||||
Provider string `json:"provider"`
|
||||
ProviderID string `json:"provider_id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
}
|
||||
|
||||
type Sandbox struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
HostID pgtype.UUID `json:"host_id"`
|
||||
Template string `json:"template"`
|
||||
Status string `json:"status"`
|
||||
Vcpus int32 `json:"vcpus"`
|
||||
MemoryMb int32 `json:"memory_mb"`
|
||||
TimeoutSec int32 `json:"timeout_sec"`
|
||||
DiskSizeMb int32 `json:"disk_size_mb"`
|
||||
GuestIp string `json:"guest_ip"`
|
||||
HostIp string `json:"host_ip"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
StartedAt pgtype.Timestamptz `json:"started_at"`
|
||||
LastActiveAt pgtype.Timestamptz `json:"last_active_at"`
|
||||
LastUpdated pgtype.Timestamptz `json:"last_updated"`
|
||||
TemplateID pgtype.UUID `json:"template_id"`
|
||||
TemplateTeamID pgtype.UUID `json:"template_team_id"`
|
||||
}
|
||||
|
||||
type SandboxMetricPoint struct {
|
||||
SandboxID pgtype.UUID `json:"sandbox_id"`
|
||||
Tier string `json:"tier"`
|
||||
Ts int64 `json:"ts"`
|
||||
CpuPct float64 `json:"cpu_pct"`
|
||||
MemBytes int64 `json:"mem_bytes"`
|
||||
DiskBytes int64 `json:"disk_bytes"`
|
||||
}
|
||||
|
||||
type SandboxMetricsSnapshot struct {
|
||||
ID int64 `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
SampledAt pgtype.Timestamptz `json:"sampled_at"`
|
||||
RunningCount int32 `json:"running_count"`
|
||||
VcpusReserved int32 `json:"vcpus_reserved"`
|
||||
MemoryMbReserved int32 `json:"memory_mb_reserved"`
|
||||
}
|
||||
|
||||
type Team struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
IsByoc bool `json:"is_byoc"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
DeletedAt pgtype.Timestamptz `json:"deleted_at"`
|
||||
}
|
||||
|
||||
type TeamApiKey struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
KeyHash string `json:"key_hash"`
|
||||
KeyPrefix string `json:"key_prefix"`
|
||||
CreatedBy pgtype.UUID `json:"created_by"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
LastUsed pgtype.Timestamptz `json:"last_used"`
|
||||
}
|
||||
|
||||
type Template struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Vcpus int32 `json:"vcpus"`
|
||||
MemoryMb int32 `json:"memory_mb"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
}
|
||||
|
||||
type TemplateBuild struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
BaseTemplate string `json:"base_template"`
|
||||
Recipe []byte `json:"recipe"`
|
||||
Healthcheck string `json:"healthcheck"`
|
||||
Vcpus int32 `json:"vcpus"`
|
||||
MemoryMb int32 `json:"memory_mb"`
|
||||
Status string `json:"status"`
|
||||
CurrentStep int32 `json:"current_step"`
|
||||
TotalSteps int32 `json:"total_steps"`
|
||||
Logs []byte `json:"logs"`
|
||||
Error string `json:"error"`
|
||||
SandboxID pgtype.UUID `json:"sandbox_id"`
|
||||
HostID pgtype.UUID `json:"host_id"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
StartedAt pgtype.Timestamptz `json:"started_at"`
|
||||
CompletedAt pgtype.Timestamptz `json:"completed_at"`
|
||||
TemplateID pgtype.UUID `json:"template_id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
SkipPrePost bool `json:"skip_pre_post"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
PasswordHash pgtype.Text `json:"password_hash"`
|
||||
Name string `json:"name"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
}
|
||||
|
||||
type UsersTeam struct {
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
Role string `json:"role"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
}
|
||||
@ -1,57 +0,0 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: oauth.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const getOAuthProvider = `-- name: GetOAuthProvider :one
|
||||
SELECT provider, provider_id, user_id, email, created_at FROM oauth_providers
|
||||
WHERE provider = $1 AND provider_id = $2
|
||||
`
|
||||
|
||||
type GetOAuthProviderParams struct {
|
||||
Provider string `json:"provider"`
|
||||
ProviderID string `json:"provider_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetOAuthProvider(ctx context.Context, arg GetOAuthProviderParams) (OauthProvider, error) {
|
||||
row := q.db.QueryRow(ctx, getOAuthProvider, arg.Provider, arg.ProviderID)
|
||||
var i OauthProvider
|
||||
err := row.Scan(
|
||||
&i.Provider,
|
||||
&i.ProviderID,
|
||||
&i.UserID,
|
||||
&i.Email,
|
||||
&i.CreatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertOAuthProvider = `-- name: InsertOAuthProvider :exec
|
||||
INSERT INTO oauth_providers (provider, provider_id, user_id, email)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
`
|
||||
|
||||
type InsertOAuthProviderParams struct {
|
||||
Provider string `json:"provider"`
|
||||
ProviderID string `json:"provider_id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertOAuthProvider(ctx context.Context, arg InsertOAuthProviderParams) error {
|
||||
_, err := q.db.Exec(ctx, insertOAuthProvider,
|
||||
arg.Provider,
|
||||
arg.ProviderID,
|
||||
arg.UserID,
|
||||
arg.Email,
|
||||
)
|
||||
return err
|
||||
}
|
||||
@ -1,487 +0,0 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: sandboxes.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const bulkRestoreRunning = `-- name: BulkRestoreRunning :exec
|
||||
UPDATE sandboxes
|
||||
SET status = 'running',
|
||||
last_updated = NOW()
|
||||
WHERE id = ANY($1::uuid[]) AND status = 'missing'
|
||||
`
|
||||
|
||||
// Called by the reconciler when a host comes back online and its sandboxes are
|
||||
// confirmed alive. Restores only sandboxes that are in 'missing' state.
|
||||
func (q *Queries) BulkRestoreRunning(ctx context.Context, dollar_1 []pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, bulkRestoreRunning, dollar_1)
|
||||
return err
|
||||
}
|
||||
|
||||
const bulkUpdateStatusByIDs = `-- name: BulkUpdateStatusByIDs :exec
|
||||
UPDATE sandboxes
|
||||
SET status = $2,
|
||||
last_updated = NOW()
|
||||
WHERE id = ANY($1::uuid[])
|
||||
`
|
||||
|
||||
type BulkUpdateStatusByIDsParams struct {
|
||||
Column1 []pgtype.UUID `json:"column_1"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func (q *Queries) BulkUpdateStatusByIDs(ctx context.Context, arg BulkUpdateStatusByIDsParams) error {
|
||||
_, err := q.db.Exec(ctx, bulkUpdateStatusByIDs, arg.Column1, arg.Status)
|
||||
return err
|
||||
}
|
||||
|
||||
const getSandbox = `-- name: GetSandbox :one
|
||||
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetSandbox(ctx context.Context, id pgtype.UUID) (Sandbox, error) {
|
||||
row := q.db.QueryRow(ctx, getSandbox, id)
|
||||
var i Sandbox
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.DiskSizeMb,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TemplateID,
|
||||
&i.TemplateTeamID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getSandboxByTeam = `-- name: GetSandboxByTeam :one
|
||||
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes WHERE id = $1 AND team_id = $2
|
||||
`
|
||||
|
||||
type GetSandboxByTeamParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetSandboxByTeam(ctx context.Context, arg GetSandboxByTeamParams) (Sandbox, error) {
|
||||
row := q.db.QueryRow(ctx, getSandboxByTeam, arg.ID, arg.TeamID)
|
||||
var i Sandbox
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.DiskSizeMb,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TemplateID,
|
||||
&i.TemplateTeamID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getSandboxProxyTarget = `-- name: GetSandboxProxyTarget :one
|
||||
SELECT s.status, h.address AS host_address
|
||||
FROM sandboxes s
|
||||
JOIN hosts h ON h.id = s.host_id
|
||||
WHERE s.id = $1 AND s.team_id = $2
|
||||
`
|
||||
|
||||
type GetSandboxProxyTargetParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
}
|
||||
|
||||
type GetSandboxProxyTargetRow struct {
|
||||
Status string `json:"status"`
|
||||
HostAddress string `json:"host_address"`
|
||||
}
|
||||
|
||||
// Returns the sandbox status and its host's address in one query.
|
||||
// Used by SandboxProxyWrapper to avoid two round-trips.
|
||||
func (q *Queries) GetSandboxProxyTarget(ctx context.Context, arg GetSandboxProxyTargetParams) (GetSandboxProxyTargetRow, error) {
|
||||
row := q.db.QueryRow(ctx, getSandboxProxyTarget, arg.ID, arg.TeamID)
|
||||
var i GetSandboxProxyTargetRow
|
||||
err := row.Scan(&i.Status, &i.HostAddress)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertSandbox = `-- name: InsertSandbox :one
|
||||
INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, template_id, template_team_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id
|
||||
`
|
||||
|
||||
type InsertSandboxParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
HostID pgtype.UUID `json:"host_id"`
|
||||
Template string `json:"template"`
|
||||
Status string `json:"status"`
|
||||
Vcpus int32 `json:"vcpus"`
|
||||
MemoryMb int32 `json:"memory_mb"`
|
||||
TimeoutSec int32 `json:"timeout_sec"`
|
||||
DiskSizeMb int32 `json:"disk_size_mb"`
|
||||
TemplateID pgtype.UUID `json:"template_id"`
|
||||
TemplateTeamID pgtype.UUID `json:"template_team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (Sandbox, error) {
|
||||
row := q.db.QueryRow(ctx, insertSandbox,
|
||||
arg.ID,
|
||||
arg.TeamID,
|
||||
arg.HostID,
|
||||
arg.Template,
|
||||
arg.Status,
|
||||
arg.Vcpus,
|
||||
arg.MemoryMb,
|
||||
arg.TimeoutSec,
|
||||
arg.DiskSizeMb,
|
||||
arg.TemplateID,
|
||||
arg.TemplateTeamID,
|
||||
)
|
||||
var i Sandbox
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.DiskSizeMb,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TemplateID,
|
||||
&i.TemplateTeamID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listActiveSandboxesByTeam = `-- name: ListActiveSandboxesByTeam :many
|
||||
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes
|
||||
WHERE team_id = $1 AND status IN ('running', 'paused', 'starting')
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Sandbox, error) {
|
||||
rows, err := q.db.Query(ctx, listActiveSandboxesByTeam, teamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Sandbox
|
||||
for rows.Next() {
|
||||
var i Sandbox
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.DiskSizeMb,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TemplateID,
|
||||
&i.TemplateTeamID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listSandboxes = `-- name: ListSandboxes :many
|
||||
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) {
|
||||
rows, err := q.db.Query(ctx, listSandboxes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Sandbox
|
||||
for rows.Next() {
|
||||
var i Sandbox
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.DiskSizeMb,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TemplateID,
|
||||
&i.TemplateTeamID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listSandboxesByHostAndStatus = `-- name: ListSandboxesByHostAndStatus :many
|
||||
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes
|
||||
WHERE host_id = $1 AND status = ANY($2::text[])
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
type ListSandboxesByHostAndStatusParams struct {
|
||||
HostID pgtype.UUID `json:"host_id"`
|
||||
Column2 []string `json:"column_2"`
|
||||
}
|
||||
|
||||
func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSandboxesByHostAndStatusParams) ([]Sandbox, error) {
|
||||
rows, err := q.db.Query(ctx, listSandboxesByHostAndStatus, arg.HostID, arg.Column2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Sandbox
|
||||
for rows.Next() {
|
||||
var i Sandbox
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.DiskSizeMb,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TemplateID,
|
||||
&i.TemplateTeamID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listSandboxesByTeam = `-- name: ListSandboxesByTeam :many
|
||||
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes
|
||||
WHERE team_id = $1 AND status NOT IN ('stopped', 'error')
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Sandbox, error) {
|
||||
rows, err := q.db.Query(ctx, listSandboxesByTeam, teamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Sandbox
|
||||
for rows.Next() {
|
||||
var i Sandbox
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.DiskSizeMb,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TemplateID,
|
||||
&i.TemplateTeamID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const markSandboxesMissingByHost = `-- name: MarkSandboxesMissingByHost :exec
|
||||
UPDATE sandboxes
|
||||
SET status = 'missing',
|
||||
last_updated = NOW()
|
||||
WHERE host_id = $1 AND status IN ('running', 'starting', 'pending')
|
||||
`
|
||||
|
||||
// Called when the host monitor marks a host unreachable.
|
||||
// Marks running/starting/pending sandboxes on that host as 'missing' so users see
|
||||
// the sandbox is not currently reachable, without permanently losing the record.
|
||||
func (q *Queries) MarkSandboxesMissingByHost(ctx context.Context, hostID pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, markSandboxesMissingByHost, hostID)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateLastActive = `-- name: UpdateLastActive :exec
|
||||
UPDATE sandboxes
|
||||
SET last_active_at = $2,
|
||||
last_updated = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
type UpdateLastActiveParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
LastActiveAt pgtype.Timestamptz `json:"last_active_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateLastActive(ctx context.Context, arg UpdateLastActiveParams) error {
|
||||
_, err := q.db.Exec(ctx, updateLastActive, arg.ID, arg.LastActiveAt)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateSandboxRunning = `-- name: UpdateSandboxRunning :one
|
||||
UPDATE sandboxes
|
||||
SET status = 'running',
|
||||
host_ip = $2,
|
||||
guest_ip = $3,
|
||||
started_at = $4,
|
||||
last_active_at = $4,
|
||||
last_updated = NOW()
|
||||
WHERE id = $1
|
||||
RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id
|
||||
`
|
||||
|
||||
type UpdateSandboxRunningParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
HostIp string `json:"host_ip"`
|
||||
GuestIp string `json:"guest_ip"`
|
||||
StartedAt pgtype.Timestamptz `json:"started_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateSandboxRunning(ctx context.Context, arg UpdateSandboxRunningParams) (Sandbox, error) {
|
||||
row := q.db.QueryRow(ctx, updateSandboxRunning,
|
||||
arg.ID,
|
||||
arg.HostIp,
|
||||
arg.GuestIp,
|
||||
arg.StartedAt,
|
||||
)
|
||||
var i Sandbox
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.DiskSizeMb,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TemplateID,
|
||||
&i.TemplateTeamID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateSandboxStatus = `-- name: UpdateSandboxStatus :one
|
||||
UPDATE sandboxes
|
||||
SET status = $2,
|
||||
last_updated = NOW()
|
||||
WHERE id = $1
|
||||
RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id
|
||||
`
|
||||
|
||||
type UpdateSandboxStatusParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateSandboxStatus(ctx context.Context, arg UpdateSandboxStatusParams) (Sandbox, error) {
|
||||
row := q.db.QueryRow(ctx, updateSandboxStatus, arg.ID, arg.Status)
|
||||
var i Sandbox
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.DiskSizeMb,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TemplateID,
|
||||
&i.TemplateTeamID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@ -1,324 +0,0 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: teams.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const deleteTeamMember = `-- name: DeleteTeamMember :exec
|
||||
DELETE FROM users_teams WHERE team_id = $1 AND user_id = $2
|
||||
`
|
||||
|
||||
type DeleteTeamMemberParams struct {
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) DeleteTeamMember(ctx context.Context, arg DeleteTeamMemberParams) error {
|
||||
_, err := q.db.Exec(ctx, deleteTeamMember, arg.TeamID, arg.UserID)
|
||||
return err
|
||||
}
|
||||
|
||||
const getBYOCTeams = `-- name: GetBYOCTeams :many
|
||||
SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE is_byoc = TRUE AND deleted_at IS NULL ORDER BY created_at
|
||||
`
|
||||
|
||||
func (q *Queries) GetBYOCTeams(ctx context.Context) ([]Team, error) {
|
||||
rows, err := q.db.Query(ctx, getBYOCTeams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Team
|
||||
for rows.Next() {
|
||||
var i Team
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.Slug,
|
||||
&i.IsByoc,
|
||||
&i.CreatedAt,
|
||||
&i.DeletedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getDefaultTeamForUser = `-- name: GetDefaultTeamForUser :one
|
||||
SELECT t.id, t.name, t.slug, t.is_byoc, t.created_at, t.deleted_at FROM teams t
|
||||
JOIN users_teams ut ON ut.team_id = t.id
|
||||
WHERE ut.user_id = $1 AND ut.is_default = TRUE AND t.deleted_at IS NULL
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
func (q *Queries) GetDefaultTeamForUser(ctx context.Context, userID pgtype.UUID) (Team, error) {
|
||||
row := q.db.QueryRow(ctx, getDefaultTeamForUser, userID)
|
||||
var i Team
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.Slug,
|
||||
&i.IsByoc,
|
||||
&i.CreatedAt,
|
||||
&i.DeletedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getTeam = `-- name: GetTeam :one
|
||||
SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetTeam(ctx context.Context, id pgtype.UUID) (Team, error) {
|
||||
row := q.db.QueryRow(ctx, getTeam, id)
|
||||
var i Team
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.Slug,
|
||||
&i.IsByoc,
|
||||
&i.CreatedAt,
|
||||
&i.DeletedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getTeamBySlug = `-- name: GetTeamBySlug :one
|
||||
SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE slug = $1 AND deleted_at IS NULL
|
||||
`
|
||||
|
||||
func (q *Queries) GetTeamBySlug(ctx context.Context, slug string) (Team, error) {
|
||||
row := q.db.QueryRow(ctx, getTeamBySlug, slug)
|
||||
var i Team
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.Slug,
|
||||
&i.IsByoc,
|
||||
&i.CreatedAt,
|
||||
&i.DeletedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getTeamMembers = `-- name: GetTeamMembers :many
|
||||
SELECT u.id, u.name, u.email, ut.role, ut.created_at AS joined_at
|
||||
FROM users_teams ut
|
||||
JOIN users u ON u.id = ut.user_id
|
||||
WHERE ut.team_id = $1
|
||||
ORDER BY ut.created_at
|
||||
`
|
||||
|
||||
type GetTeamMembersRow struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
JoinedAt pgtype.Timestamptz `json:"joined_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetTeamMembers(ctx context.Context, teamID pgtype.UUID) ([]GetTeamMembersRow, error) {
|
||||
rows, err := q.db.Query(ctx, getTeamMembers, teamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []GetTeamMembersRow
|
||||
for rows.Next() {
|
||||
var i GetTeamMembersRow
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.Email,
|
||||
&i.Role,
|
||||
&i.JoinedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getTeamMembership = `-- name: GetTeamMembership :one
|
||||
SELECT user_id, team_id, is_default, role, created_at FROM users_teams WHERE user_id = $1 AND team_id = $2
|
||||
`
|
||||
|
||||
type GetTeamMembershipParams struct {
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetTeamMembership(ctx context.Context, arg GetTeamMembershipParams) (UsersTeam, error) {
|
||||
row := q.db.QueryRow(ctx, getTeamMembership, arg.UserID, arg.TeamID)
|
||||
var i UsersTeam
|
||||
err := row.Scan(
|
||||
&i.UserID,
|
||||
&i.TeamID,
|
||||
&i.IsDefault,
|
||||
&i.Role,
|
||||
&i.CreatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getTeamsForUser = `-- name: GetTeamsForUser :many
|
||||
SELECT t.id, t.name, t.slug, t.is_byoc, t.created_at, t.deleted_at, ut.role
|
||||
FROM teams t
|
||||
JOIN users_teams ut ON ut.team_id = t.id
|
||||
WHERE ut.user_id = $1 AND t.deleted_at IS NULL
|
||||
ORDER BY ut.created_at
|
||||
`
|
||||
|
||||
type GetTeamsForUserRow struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
IsByoc bool `json:"is_byoc"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
DeletedAt pgtype.Timestamptz `json:"deleted_at"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetTeamsForUser(ctx context.Context, userID pgtype.UUID) ([]GetTeamsForUserRow, error) {
|
||||
rows, err := q.db.Query(ctx, getTeamsForUser, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []GetTeamsForUserRow
|
||||
for rows.Next() {
|
||||
var i GetTeamsForUserRow
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.Slug,
|
||||
&i.IsByoc,
|
||||
&i.CreatedAt,
|
||||
&i.DeletedAt,
|
||||
&i.Role,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const insertTeam = `-- name: InsertTeam :one
|
||||
INSERT INTO teams (id, name, slug)
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING id, name, slug, is_byoc, created_at, deleted_at
|
||||
`
|
||||
|
||||
type InsertTeamParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertTeam(ctx context.Context, arg InsertTeamParams) (Team, error) {
|
||||
row := q.db.QueryRow(ctx, insertTeam, arg.ID, arg.Name, arg.Slug)
|
||||
var i Team
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.Slug,
|
||||
&i.IsByoc,
|
||||
&i.CreatedAt,
|
||||
&i.DeletedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertTeamMember = `-- name: InsertTeamMember :exec
|
||||
INSERT INTO users_teams (user_id, team_id, is_default, role)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
`
|
||||
|
||||
type InsertTeamMemberParams struct {
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertTeamMember(ctx context.Context, arg InsertTeamMemberParams) error {
|
||||
_, err := q.db.Exec(ctx, insertTeamMember,
|
||||
arg.UserID,
|
||||
arg.TeamID,
|
||||
arg.IsDefault,
|
||||
arg.Role,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
const setTeamBYOC = `-- name: SetTeamBYOC :exec
|
||||
UPDATE teams SET is_byoc = $2 WHERE id = $1
|
||||
`
|
||||
|
||||
type SetTeamBYOCParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
IsByoc bool `json:"is_byoc"`
|
||||
}
|
||||
|
||||
func (q *Queries) SetTeamBYOC(ctx context.Context, arg SetTeamBYOCParams) error {
|
||||
_, err := q.db.Exec(ctx, setTeamBYOC, arg.ID, arg.IsByoc)
|
||||
return err
|
||||
}
|
||||
|
||||
const softDeleteTeam = `-- name: SoftDeleteTeam :exec
|
||||
UPDATE teams SET deleted_at = NOW() WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) SoftDeleteTeam(ctx context.Context, id pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, softDeleteTeam, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateMemberRole = `-- name: UpdateMemberRole :exec
|
||||
UPDATE users_teams SET role = $3 WHERE team_id = $1 AND user_id = $2
|
||||
`
|
||||
|
||||
type UpdateMemberRoleParams struct {
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateMemberRole(ctx context.Context, arg UpdateMemberRoleParams) error {
|
||||
_, err := q.db.Exec(ctx, updateMemberRole, arg.TeamID, arg.UserID, arg.Role)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateTeamName = `-- name: UpdateTeamName :exec
|
||||
UPDATE teams SET name = $2 WHERE id = $1 AND deleted_at IS NULL
|
||||
`
|
||||
|
||||
type UpdateTeamNameParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateTeamName(ctx context.Context, arg UpdateTeamNameParams) error {
|
||||
_, err := q.db.Exec(ctx, updateTeamName, arg.ID, arg.Name)
|
||||
return err
|
||||
}
|
||||
@ -1,241 +0,0 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: template_builds.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const getTemplateBuild = `-- name: GetTemplateBuild :one
|
||||
SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post FROM template_builds WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetTemplateBuild(ctx context.Context, id pgtype.UUID) (TemplateBuild, error) {
|
||||
row := q.db.QueryRow(ctx, getTemplateBuild, id)
|
||||
var i TemplateBuild
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.BaseTemplate,
|
||||
&i.Recipe,
|
||||
&i.Healthcheck,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.Status,
|
||||
&i.CurrentStep,
|
||||
&i.TotalSteps,
|
||||
&i.Logs,
|
||||
&i.Error,
|
||||
&i.SandboxID,
|
||||
&i.HostID,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.CompletedAt,
|
||||
&i.TemplateID,
|
||||
&i.TeamID,
|
||||
&i.SkipPrePost,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertTemplateBuild = `-- name: InsertTemplateBuild :one
|
||||
INSERT INTO template_builds (id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, total_steps, template_id, team_id, skip_pre_post)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8, $9, $10, $11)
|
||||
RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post
|
||||
`
|
||||
|
||||
type InsertTemplateBuildParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
BaseTemplate string `json:"base_template"`
|
||||
Recipe []byte `json:"recipe"`
|
||||
Healthcheck string `json:"healthcheck"`
|
||||
Vcpus int32 `json:"vcpus"`
|
||||
MemoryMb int32 `json:"memory_mb"`
|
||||
TotalSteps int32 `json:"total_steps"`
|
||||
TemplateID pgtype.UUID `json:"template_id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
SkipPrePost bool `json:"skip_pre_post"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertTemplateBuild(ctx context.Context, arg InsertTemplateBuildParams) (TemplateBuild, error) {
|
||||
row := q.db.QueryRow(ctx, insertTemplateBuild,
|
||||
arg.ID,
|
||||
arg.Name,
|
||||
arg.BaseTemplate,
|
||||
arg.Recipe,
|
||||
arg.Healthcheck,
|
||||
arg.Vcpus,
|
||||
arg.MemoryMb,
|
||||
arg.TotalSteps,
|
||||
arg.TemplateID,
|
||||
arg.TeamID,
|
||||
arg.SkipPrePost,
|
||||
)
|
||||
var i TemplateBuild
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.BaseTemplate,
|
||||
&i.Recipe,
|
||||
&i.Healthcheck,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.Status,
|
||||
&i.CurrentStep,
|
||||
&i.TotalSteps,
|
||||
&i.Logs,
|
||||
&i.Error,
|
||||
&i.SandboxID,
|
||||
&i.HostID,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.CompletedAt,
|
||||
&i.TemplateID,
|
||||
&i.TeamID,
|
||||
&i.SkipPrePost,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listTemplateBuilds = `-- name: ListTemplateBuilds :many
|
||||
SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post FROM template_builds ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListTemplateBuilds(ctx context.Context) ([]TemplateBuild, error) {
|
||||
rows, err := q.db.Query(ctx, listTemplateBuilds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []TemplateBuild
|
||||
for rows.Next() {
|
||||
var i TemplateBuild
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.BaseTemplate,
|
||||
&i.Recipe,
|
||||
&i.Healthcheck,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.Status,
|
||||
&i.CurrentStep,
|
||||
&i.TotalSteps,
|
||||
&i.Logs,
|
||||
&i.Error,
|
||||
&i.SandboxID,
|
||||
&i.HostID,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.CompletedAt,
|
||||
&i.TemplateID,
|
||||
&i.TeamID,
|
||||
&i.SkipPrePost,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateBuildError = `-- name: UpdateBuildError :exec
|
||||
UPDATE template_builds
|
||||
SET error = $2, status = 'failed', completed_at = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
type UpdateBuildErrorParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateBuildError(ctx context.Context, arg UpdateBuildErrorParams) error {
|
||||
_, err := q.db.Exec(ctx, updateBuildError, arg.ID, arg.Error)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateBuildProgress = `-- name: UpdateBuildProgress :exec
|
||||
UPDATE template_builds
|
||||
SET current_step = $2, logs = $3
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
type UpdateBuildProgressParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
CurrentStep int32 `json:"current_step"`
|
||||
Logs []byte `json:"logs"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateBuildProgress(ctx context.Context, arg UpdateBuildProgressParams) error {
|
||||
_, err := q.db.Exec(ctx, updateBuildProgress, arg.ID, arg.CurrentStep, arg.Logs)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateBuildSandbox = `-- name: UpdateBuildSandbox :exec
|
||||
UPDATE template_builds
|
||||
SET sandbox_id = $2, host_id = $3
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
type UpdateBuildSandboxParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
SandboxID pgtype.UUID `json:"sandbox_id"`
|
||||
HostID pgtype.UUID `json:"host_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateBuildSandbox(ctx context.Context, arg UpdateBuildSandboxParams) error {
|
||||
_, err := q.db.Exec(ctx, updateBuildSandbox, arg.ID, arg.SandboxID, arg.HostID)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateBuildStatus = `-- name: UpdateBuildStatus :one
|
||||
UPDATE template_builds
|
||||
SET status = $2,
|
||||
started_at = CASE WHEN $2 = 'running' AND started_at IS NULL THEN NOW() ELSE started_at END,
|
||||
completed_at = CASE WHEN $2 IN ('success', 'failed', 'cancelled') THEN NOW() ELSE completed_at END
|
||||
WHERE id = $1
|
||||
RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post
|
||||
`
|
||||
|
||||
type UpdateBuildStatusParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateBuildStatus(ctx context.Context, arg UpdateBuildStatusParams) (TemplateBuild, error) {
|
||||
row := q.db.QueryRow(ctx, updateBuildStatus, arg.ID, arg.Status)
|
||||
var i TemplateBuild
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.BaseTemplate,
|
||||
&i.Recipe,
|
||||
&i.Healthcheck,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.Status,
|
||||
&i.CurrentStep,
|
||||
&i.TotalSteps,
|
||||
&i.Logs,
|
||||
&i.Error,
|
||||
&i.SandboxID,
|
||||
&i.HostID,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.CompletedAt,
|
||||
&i.TemplateID,
|
||||
&i.TeamID,
|
||||
&i.SkipPrePost,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@ -1,351 +0,0 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: templates.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const deleteTemplate = `-- name: DeleteTemplate :exec
|
||||
DELETE FROM templates WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteTemplate(ctx context.Context, id pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, deleteTemplate, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteTemplateByTeam = `-- name: DeleteTemplateByTeam :exec
|
||||
DELETE FROM templates WHERE name = $1 AND team_id = $2
|
||||
`
|
||||
|
||||
type DeleteTemplateByTeamParams struct {
|
||||
Name string `json:"name"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) DeleteTemplateByTeam(ctx context.Context, arg DeleteTemplateByTeamParams) error {
|
||||
_, err := q.db.Exec(ctx, deleteTemplateByTeam, arg.Name, arg.TeamID)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteTemplatesByTeam = `-- name: DeleteTemplatesByTeam :exec
|
||||
DELETE FROM templates WHERE team_id = $1
|
||||
`
|
||||
|
||||
// Bulk delete all templates owned by a team (for team soft-delete cleanup).
|
||||
func (q *Queries) DeleteTemplatesByTeam(ctx context.Context, teamID pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, deleteTemplatesByTeam, teamID)
|
||||
return err
|
||||
}
|
||||
|
||||
const getPlatformTemplateByName = `-- name: GetPlatformTemplateByName :one
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE team_id = '00000000-0000-0000-0000-000000000000' AND name = $1
|
||||
`
|
||||
|
||||
// Check if a global (platform) template exists with the given name.
|
||||
func (q *Queries) GetPlatformTemplateByName(ctx context.Context, name string) (Template, error) {
|
||||
row := q.db.QueryRow(ctx, getPlatformTemplateByName, name)
|
||||
var i Template
|
||||
err := row.Scan(
|
||||
&i.Name,
|
||||
&i.Type,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
&i.ID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getTemplate = `-- name: GetTemplate :one
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetTemplate(ctx context.Context, id pgtype.UUID) (Template, error) {
|
||||
row := q.db.QueryRow(ctx, getTemplate, id)
|
||||
var i Template
|
||||
err := row.Scan(
|
||||
&i.Name,
|
||||
&i.Type,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
&i.ID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getTemplateByName = `-- name: GetTemplateByName :one
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE team_id = $1 AND name = $2
|
||||
`
|
||||
|
||||
type GetTemplateByNameParams struct {
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Look up a template by team_id and name (exact team match, no global fallback).
|
||||
func (q *Queries) GetTemplateByName(ctx context.Context, arg GetTemplateByNameParams) (Template, error) {
|
||||
row := q.db.QueryRow(ctx, getTemplateByName, arg.TeamID, arg.Name)
|
||||
var i Template
|
||||
err := row.Scan(
|
||||
&i.Name,
|
||||
&i.Type,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
&i.ID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getTemplateByTeam = `-- name: GetTemplateByTeam :one
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE name = $1 AND (team_id = $2 OR team_id = '00000000-0000-0000-0000-000000000000')
|
||||
`
|
||||
|
||||
type GetTemplateByTeamParams struct {
|
||||
Name string `json:"name"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
}
|
||||
|
||||
// Platform templates (team_id = 00000000-...) are visible to all teams.
|
||||
func (q *Queries) GetTemplateByTeam(ctx context.Context, arg GetTemplateByTeamParams) (Template, error) {
|
||||
row := q.db.QueryRow(ctx, getTemplateByTeam, arg.Name, arg.TeamID)
|
||||
var i Template
|
||||
err := row.Scan(
|
||||
&i.Name,
|
||||
&i.Type,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
&i.ID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertTemplate = `-- name: InsertTemplate :one
|
||||
INSERT INTO templates (id, name, type, vcpus, memory_mb, size_bytes, team_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
RETURNING name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id
|
||||
`
|
||||
|
||||
type InsertTemplateParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Vcpus int32 `json:"vcpus"`
|
||||
MemoryMb int32 `json:"memory_mb"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertTemplate(ctx context.Context, arg InsertTemplateParams) (Template, error) {
|
||||
row := q.db.QueryRow(ctx, insertTemplate,
|
||||
arg.ID,
|
||||
arg.Name,
|
||||
arg.Type,
|
||||
arg.Vcpus,
|
||||
arg.MemoryMb,
|
||||
arg.SizeBytes,
|
||||
arg.TeamID,
|
||||
)
|
||||
var i Template
|
||||
err := row.Scan(
|
||||
&i.Name,
|
||||
&i.Type,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
&i.ID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listTemplates = `-- name: ListTemplates :many
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) {
|
||||
rows, err := q.db.Query(ctx, listTemplates)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Template
|
||||
for rows.Next() {
|
||||
var i Template
|
||||
if err := rows.Scan(
|
||||
&i.Name,
|
||||
&i.Type,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
&i.ID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listTemplatesByTeam = `-- name: ListTemplatesByTeam :many
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
// Platform templates are visible to all teams.
|
||||
func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Template, error) {
|
||||
rows, err := q.db.Query(ctx, listTemplatesByTeam, teamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Template
|
||||
for rows.Next() {
|
||||
var i Template
|
||||
if err := rows.Scan(
|
||||
&i.Name,
|
||||
&i.Type,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
&i.ID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listTemplatesByTeamAndType = `-- name: ListTemplatesByTeamAndType :many
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') AND type = $2 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
type ListTemplatesByTeamAndTypeParams struct {
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// Platform templates are visible to all teams.
|
||||
func (q *Queries) ListTemplatesByTeamAndType(ctx context.Context, arg ListTemplatesByTeamAndTypeParams) ([]Template, error) {
|
||||
rows, err := q.db.Query(ctx, listTemplatesByTeamAndType, arg.TeamID, arg.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Template
|
||||
for rows.Next() {
|
||||
var i Template
|
||||
if err := rows.Scan(
|
||||
&i.Name,
|
||||
&i.Type,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
&i.ID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listTemplatesByTeamOnly = `-- name: ListTemplatesByTeamOnly :many
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE team_id = $1 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
// List templates owned by a specific team (NOT including platform templates).
|
||||
func (q *Queries) ListTemplatesByTeamOnly(ctx context.Context, teamID pgtype.UUID) ([]Template, error) {
|
||||
rows, err := q.db.Query(ctx, listTemplatesByTeamOnly, teamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Template
|
||||
for rows.Next() {
|
||||
var i Template
|
||||
if err := rows.Scan(
|
||||
&i.Name,
|
||||
&i.Type,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
&i.ID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listTemplatesByType = `-- name: ListTemplatesByType :many
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE type = $1 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListTemplatesByType(ctx context.Context, type_ string) ([]Template, error) {
|
||||
rows, err := q.db.Query(ctx, listTemplatesByType, type_)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Template
|
||||
for rows.Next() {
|
||||
var i Template
|
||||
if err := rows.Scan(
|
||||
&i.Name,
|
||||
&i.Type,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
&i.ID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
@ -1,276 +0,0 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: users.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const deleteAdminPermission = `-- name: DeleteAdminPermission :exec
|
||||
DELETE FROM admin_permissions WHERE user_id = $1 AND permission = $2
|
||||
`
|
||||
|
||||
type DeleteAdminPermissionParams struct {
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
Permission string `json:"permission"`
|
||||
}
|
||||
|
||||
func (q *Queries) DeleteAdminPermission(ctx context.Context, arg DeleteAdminPermissionParams) error {
|
||||
_, err := q.db.Exec(ctx, deleteAdminPermission, arg.UserID, arg.Permission)
|
||||
return err
|
||||
}
|
||||
|
||||
const getAdminPermissions = `-- name: GetAdminPermissions :many
|
||||
SELECT id, user_id, permission, created_at FROM admin_permissions WHERE user_id = $1 ORDER BY permission
|
||||
`
|
||||
|
||||
func (q *Queries) GetAdminPermissions(ctx context.Context, userID pgtype.UUID) ([]AdminPermission, error) {
|
||||
rows, err := q.db.Query(ctx, getAdminPermissions, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []AdminPermission
|
||||
for rows.Next() {
|
||||
var i AdminPermission
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.Permission,
|
||||
&i.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getAdminUsers = `-- name: GetAdminUsers :many
|
||||
SELECT id, email, password_hash, name, is_admin, created_at, updated_at FROM users WHERE is_admin = TRUE ORDER BY created_at
|
||||
`
|
||||
|
||||
func (q *Queries) GetAdminUsers(ctx context.Context) ([]User, error) {
|
||||
rows, err := q.db.Query(ctx, getAdminUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []User
|
||||
for rows.Next() {
|
||||
var i User
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Email,
|
||||
&i.PasswordHash,
|
||||
&i.Name,
|
||||
&i.IsAdmin,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getUserByEmail = `-- name: GetUserByEmail :one
|
||||
SELECT id, email, password_hash, name, is_admin, created_at, updated_at FROM users WHERE email = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error) {
|
||||
row := q.db.QueryRow(ctx, getUserByEmail, email)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Email,
|
||||
&i.PasswordHash,
|
||||
&i.Name,
|
||||
&i.IsAdmin,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserByID = `-- name: GetUserByID :one
|
||||
SELECT id, email, password_hash, name, is_admin, created_at, updated_at FROM users WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetUserByID(ctx context.Context, id pgtype.UUID) (User, error) {
|
||||
row := q.db.QueryRow(ctx, getUserByID, id)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Email,
|
||||
&i.PasswordHash,
|
||||
&i.Name,
|
||||
&i.IsAdmin,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const hasAdminPermission = `-- name: HasAdminPermission :one
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM admin_permissions WHERE user_id = $1 AND permission = $2
|
||||
) AS has_permission
|
||||
`
|
||||
|
||||
type HasAdminPermissionParams struct {
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
Permission string `json:"permission"`
|
||||
}
|
||||
|
||||
func (q *Queries) HasAdminPermission(ctx context.Context, arg HasAdminPermissionParams) (bool, error) {
|
||||
row := q.db.QueryRow(ctx, hasAdminPermission, arg.UserID, arg.Permission)
|
||||
var has_permission bool
|
||||
err := row.Scan(&has_permission)
|
||||
return has_permission, err
|
||||
}
|
||||
|
||||
const insertAdminPermission = `-- name: InsertAdminPermission :exec
|
||||
INSERT INTO admin_permissions (id, user_id, permission)
|
||||
VALUES ($1, $2, $3)
|
||||
`
|
||||
|
||||
type InsertAdminPermissionParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
Permission string `json:"permission"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertAdminPermission(ctx context.Context, arg InsertAdminPermissionParams) error {
|
||||
_, err := q.db.Exec(ctx, insertAdminPermission, arg.ID, arg.UserID, arg.Permission)
|
||||
return err
|
||||
}
|
||||
|
||||
const insertUser = `-- name: InsertUser :one
|
||||
INSERT INTO users (id, email, password_hash, name)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING id, email, password_hash, name, is_admin, created_at, updated_at
|
||||
`
|
||||
|
||||
type InsertUserParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
PasswordHash pgtype.Text `json:"password_hash"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) {
|
||||
row := q.db.QueryRow(ctx, insertUser,
|
||||
arg.ID,
|
||||
arg.Email,
|
||||
arg.PasswordHash,
|
||||
arg.Name,
|
||||
)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Email,
|
||||
&i.PasswordHash,
|
||||
&i.Name,
|
||||
&i.IsAdmin,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertUserOAuth = `-- name: InsertUserOAuth :one
|
||||
INSERT INTO users (id, email, name)
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING id, email, password_hash, name, is_admin, created_at, updated_at
|
||||
`
|
||||
|
||||
type InsertUserOAuthParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertUserOAuth(ctx context.Context, arg InsertUserOAuthParams) (User, error) {
|
||||
row := q.db.QueryRow(ctx, insertUserOAuth, arg.ID, arg.Email, arg.Name)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Email,
|
||||
&i.PasswordHash,
|
||||
&i.Name,
|
||||
&i.IsAdmin,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const searchUsersByEmailPrefix = `-- name: SearchUsersByEmailPrefix :many
|
||||
SELECT id, email FROM users WHERE email LIKE $1 || '%' ORDER BY email LIMIT 10
|
||||
`
|
||||
|
||||
type SearchUsersByEmailPrefixRow struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func (q *Queries) SearchUsersByEmailPrefix(ctx context.Context, dollar_1 pgtype.Text) ([]SearchUsersByEmailPrefixRow, error) {
|
||||
rows, err := q.db.Query(ctx, searchUsersByEmailPrefix, dollar_1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []SearchUsersByEmailPrefixRow
|
||||
for rows.Next() {
|
||||
var i SearchUsersByEmailPrefixRow
|
||||
if err := rows.Scan(&i.ID, &i.Email); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const setUserAdmin = `-- name: SetUserAdmin :exec
|
||||
UPDATE users SET is_admin = $2, updated_at = NOW() WHERE id = $1
|
||||
`
|
||||
|
||||
type SetUserAdminParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
}
|
||||
|
||||
func (q *Queries) SetUserAdmin(ctx context.Context, arg SetUserAdminParams) error {
|
||||
_, err := q.db.Exec(ctx, setUserAdmin, arg.ID, arg.IsAdmin)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateUserName = `-- name: UpdateUserName :exec
|
||||
UPDATE users SET name = $2, updated_at = NOW() WHERE id = $1
|
||||
`
|
||||
|
||||
type UpdateUserNameParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateUserName(ctx context.Context, arg UpdateUserNameParams) error {
|
||||
_, err := q.db.Exec(ctx, updateUserName, arg.ID, arg.Name)
|
||||
return err
|
||||
}
|
||||
233
internal/email/email.go
Normal file
233
internal/email/email.go
Normal file
@ -0,0 +1,233 @@
|
||||
// Package email provides transactional email sending via SMTP.
|
||||
//
|
||||
// Emails are rendered from embedded Go templates (html/template + text/template)
|
||||
// and sent as multipart/alternative MIME messages. When SMTP is not configured
|
||||
// (Host is empty), a no-op mailer is returned that logs and discards.
|
||||
package email
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"mime/quotedprintable"
|
||||
"net"
|
||||
"net/smtp"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Config holds SMTP connection credentials. All fields except Host are
|
||||
// optional — omitting Host disables email entirely (no-op mailer).
|
||||
type Config struct {
|
||||
Host string // SMTP server hostname
|
||||
Port int // SMTP server port (default 587)
|
||||
Username string // SMTP auth username
|
||||
Password string // SMTP auth password
|
||||
FromEmail string // envelope sender address
|
||||
}
|
||||
|
||||
// Mailer sends transactional emails.
|
||||
type Mailer interface {
|
||||
Send(ctx context.Context, to string, subject string, data EmailData) error
|
||||
}
|
||||
|
||||
// EmailData is the generic payload for all transactional emails.
|
||||
// Templates conditionally render each field based on presence.
|
||||
type EmailData struct {
|
||||
RecipientName string // optional — used after "Hello"
|
||||
Message string // main body (plain text; HTML template wraps it)
|
||||
Button *Button // optional CTA button
|
||||
Closing string // optional closing/footer message
|
||||
}
|
||||
|
||||
// Button represents a call-to-action link rendered as a button in HTML
|
||||
// and as a plain URL in the text variant.
|
||||
type Button struct {
|
||||
Text string // button label
|
||||
URL string // target URL
|
||||
}
|
||||
|
||||
// New constructs a Mailer. If cfg.Host is empty, returns a no-op mailer
|
||||
// that logs at debug level and discards. Panics if templates fail to parse
|
||||
// (indicates a build-time bug in embedded templates).
|
||||
func New(cfg Config) Mailer {
|
||||
if cfg.Host == "" {
|
||||
slog.Info("email: SMTP not configured, using no-op mailer")
|
||||
return &noopMailer{}
|
||||
}
|
||||
if cfg.Port == 0 {
|
||||
cfg.Port = 587
|
||||
}
|
||||
tmpl := mustLoadTemplates()
|
||||
slog.Info("email: SMTP configured", "host", cfg.Host, "port", cfg.Port, "from", cfg.FromEmail)
|
||||
return &mailer{cfg: cfg, tmpl: tmpl}
|
||||
}
|
||||
|
||||
// mailer is the live SMTP implementation.
|
||||
type mailer struct {
|
||||
cfg Config
|
||||
tmpl *templates
|
||||
}
|
||||
|
||||
func (m *mailer) Send(ctx context.Context, to string, subject string, data EmailData) error {
|
||||
if data.Button != nil {
|
||||
u, err := url.Parse(data.Button.URL)
|
||||
if err != nil || (u.Scheme != "https" && u.Scheme != "http") {
|
||||
return fmt.Errorf("invalid button URL scheme: %s", data.Button.URL)
|
||||
}
|
||||
}
|
||||
|
||||
htmlBody, err := m.tmpl.renderHTML(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("render html: %w", err)
|
||||
}
|
||||
textBody, err := m.tmpl.renderText(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("render text: %w", err)
|
||||
}
|
||||
|
||||
msg, err := buildMIME(m.cfg.FromEmail, to, subject, htmlBody, textBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build mime: %w", err)
|
||||
}
|
||||
|
||||
if err := m.send(to, msg); err != nil {
|
||||
return fmt.Errorf("send email to %s: %w", to, err)
|
||||
}
|
||||
|
||||
slog.Info("email: sent", "to", to, "subject", subject)
|
||||
return nil
|
||||
}
|
||||
|
||||
// send dials the SMTP server and delivers the message.
|
||||
// Port 465 uses implicit TLS; all other ports use STARTTLS.
|
||||
func (m *mailer) send(to string, msg []byte) error {
|
||||
addr := net.JoinHostPort(m.cfg.Host, strconv.Itoa(m.cfg.Port))
|
||||
auth := smtp.PlainAuth("", m.cfg.Username, m.cfg.Password, m.cfg.Host)
|
||||
|
||||
if m.cfg.Port == 465 {
|
||||
return m.sendImplicitTLS(addr, auth, to, msg)
|
||||
}
|
||||
// STARTTLS (port 587 or other).
|
||||
return smtp.SendMail(addr, auth, m.cfg.FromEmail, []string{to}, msg)
|
||||
}
|
||||
|
||||
// sendImplicitTLS handles port 465 (SMTPS) where the entire connection is TLS.
|
||||
func (m *mailer) sendImplicitTLS(addr string, auth smtp.Auth, to string, msg []byte) error {
|
||||
conn, err := tls.Dial("tcp", addr, &tls.Config{ServerName: m.cfg.Host})
|
||||
if err != nil {
|
||||
return fmt.Errorf("tls dial: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
c, err := smtp.NewClient(conn, m.cfg.Host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("smtp client: %w", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
if err := c.Auth(auth); err != nil {
|
||||
return fmt.Errorf("smtp auth: %w", err)
|
||||
}
|
||||
if err := c.Mail(m.cfg.FromEmail); err != nil {
|
||||
return fmt.Errorf("smtp mail: %w", err)
|
||||
}
|
||||
if err := c.Rcpt(to); err != nil {
|
||||
return fmt.Errorf("smtp rcpt: %w", err)
|
||||
}
|
||||
|
||||
w, err := c.Data()
|
||||
if err != nil {
|
||||
return fmt.Errorf("smtp data: %w", err)
|
||||
}
|
||||
if _, err := w.Write(msg); err != nil {
|
||||
return fmt.Errorf("smtp write: %w", err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
return fmt.Errorf("smtp close data: %w", err)
|
||||
}
|
||||
return c.Quit()
|
||||
}
|
||||
|
||||
// buildMIME assembles a multipart/alternative message with text and HTML parts.
|
||||
// Both parts are quoted-printable encoded per RFC 2045.
|
||||
func buildMIME(from, to, subject, htmlBody, textBody string) ([]byte, error) {
|
||||
var headerBuf bytes.Buffer
|
||||
var bodyBuf bytes.Buffer
|
||||
|
||||
// Sanitize header values to prevent header injection.
|
||||
from = sanitizeHeader(from)
|
||||
to = sanitizeHeader(to)
|
||||
|
||||
// Encode "From" with display name.
|
||||
encodedFrom := mime.QEncoding.Encode("utf-8", "Wrenn") + " <" + from + ">"
|
||||
|
||||
// Build multipart body first to get the boundary.
|
||||
mw := multipart.NewWriter(&bodyBuf)
|
||||
|
||||
// Text part (first = lowest preference per RFC 2046).
|
||||
textPart, err := mw.CreatePart(textproto.MIMEHeader{
|
||||
"Content-Type": {"text/plain; charset=utf-8"},
|
||||
"Content-Transfer-Encoding": {"quoted-printable"},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qpw := quotedprintable.NewWriter(textPart)
|
||||
if _, err := qpw.Write([]byte(textBody)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := qpw.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// HTML part (second = highest preference).
|
||||
htmlPart, err := mw.CreatePart(textproto.MIMEHeader{
|
||||
"Content-Type": {"text/html; charset=utf-8"},
|
||||
"Content-Transfer-Encoding": {"quoted-printable"},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qpw = quotedprintable.NewWriter(htmlPart)
|
||||
if _, err := qpw.Write([]byte(htmlBody)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := qpw.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := mw.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Write headers.
|
||||
fmt.Fprintf(&headerBuf, "From: %s\r\n", encodedFrom)
|
||||
fmt.Fprintf(&headerBuf, "To: %s\r\n", to)
|
||||
fmt.Fprintf(&headerBuf, "Subject: %s\r\n", mime.QEncoding.Encode("utf-8", subject))
|
||||
fmt.Fprintf(&headerBuf, "MIME-Version: 1.0\r\n")
|
||||
fmt.Fprintf(&headerBuf, "Content-Type: multipart/alternative; boundary=\"%s\"\r\n", mw.Boundary())
|
||||
fmt.Fprintf(&headerBuf, "\r\n")
|
||||
|
||||
headerBuf.Write(bodyBuf.Bytes())
|
||||
return headerBuf.Bytes(), nil
|
||||
}
|
||||
|
||||
// sanitizeHeader strips CR and LF characters to prevent SMTP header injection.
|
||||
func sanitizeHeader(s string) string {
|
||||
return strings.NewReplacer("\r", "", "\n", "").Replace(s)
|
||||
}
|
||||
|
||||
// noopMailer discards emails when SMTP is not configured.
|
||||
type noopMailer struct{}
|
||||
|
||||
func (n *noopMailer) Send(_ context.Context, to string, subject string, _ EmailData) error {
|
||||
slog.Debug("email: no-op send", "to", to, "subject", subject)
|
||||
return nil
|
||||
}
|
||||
191
internal/email/email_test.go
Normal file
191
internal/email/email_test.go
Normal file
@ -0,0 +1,191 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNoopMailerDoesNotError(t *testing.T) {
|
||||
m := &noopMailer{}
|
||||
err := m.Send(context.Background(), "test@example.com", "Test Subject", EmailData{
|
||||
RecipientName: "Alice",
|
||||
Message: "Hello world",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("noopMailer.Send() returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewReturnsNoopWhenHostEmpty(t *testing.T) {
|
||||
m := New(Config{})
|
||||
if _, ok := m.(*noopMailer); !ok {
|
||||
t.Fatalf("expected noopMailer, got %T", m)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewReturnsMailerWhenHostSet(t *testing.T) {
|
||||
m := New(Config{Host: "smtp.example.com"})
|
||||
if _, ok := m.(*mailer); !ok {
|
||||
t.Fatalf("expected *mailer, got %T", m)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateRenderHTML(t *testing.T) {
|
||||
tmpl := mustLoadTemplates()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data EmailData
|
||||
want []string // substrings that must appear in output
|
||||
}{
|
||||
{
|
||||
name: "with all fields",
|
||||
data: EmailData{
|
||||
RecipientName: "Alice",
|
||||
Message: "Welcome to Wrenn!",
|
||||
Button: &Button{Text: "Get Started", URL: "https://wrenn.dev"},
|
||||
Closing: "See you soon.",
|
||||
},
|
||||
want: []string{"Alice", "Welcome to Wrenn!", "Get Started", "https://wrenn.dev", "See you soon."},
|
||||
},
|
||||
{
|
||||
name: "message only",
|
||||
data: EmailData{
|
||||
Message: "Your password has been changed.",
|
||||
},
|
||||
want: []string{"Your password has been changed."},
|
||||
},
|
||||
{
|
||||
name: "with button no closing",
|
||||
data: EmailData{
|
||||
RecipientName: "Bob",
|
||||
Message: "Reset your password.",
|
||||
Button: &Button{Text: "Reset Password", URL: "https://wrenn.dev/reset?token=abc"},
|
||||
},
|
||||
want: []string{"Bob", "Reset your password.", "Reset Password", "https://wrenn.dev/reset?token=abc"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
html, err := tmpl.renderHTML(tt.data)
|
||||
if err != nil {
|
||||
t.Fatalf("renderHTML() error: %v", err)
|
||||
}
|
||||
for _, s := range tt.want {
|
||||
if !strings.Contains(html, s) {
|
||||
t.Errorf("renderHTML() missing substring %q", s)
|
||||
}
|
||||
}
|
||||
// Verify basic HTML structure.
|
||||
if !strings.Contains(html, "<!DOCTYPE html>") {
|
||||
t.Error("renderHTML() missing DOCTYPE")
|
||||
}
|
||||
if !strings.Contains(html, "wrenn.dev") {
|
||||
t.Error("renderHTML() missing wrenn.dev reference")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateRenderText(t *testing.T) {
|
||||
tmpl := mustLoadTemplates()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data EmailData
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "with all fields",
|
||||
data: EmailData{
|
||||
RecipientName: "Alice",
|
||||
Message: "Welcome to Wrenn!",
|
||||
Button: &Button{Text: "Get Started", URL: "https://wrenn.dev"},
|
||||
Closing: "See you soon.",
|
||||
},
|
||||
want: []string{"Hello Alice", "Welcome to Wrenn!", "Get Started: https://wrenn.dev", "See you soon."},
|
||||
},
|
||||
{
|
||||
name: "message only",
|
||||
data: EmailData{
|
||||
Message: "Done.",
|
||||
},
|
||||
want: []string{"Hello,", "Done."},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
text, err := tmpl.renderText(tt.data)
|
||||
if err != nil {
|
||||
t.Fatalf("renderText() error: %v", err)
|
||||
}
|
||||
for _, s := range tt.want {
|
||||
if !strings.Contains(text, s) {
|
||||
t.Errorf("renderText() missing substring %q\nGot:\n%s", s, text)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMIME(t *testing.T) {
|
||||
msg, err := buildMIME("noreply@wrenn.dev", "user@example.com", "Test Subject", "<h1>HTML</h1>", "Plain text")
|
||||
if err != nil {
|
||||
t.Fatalf("buildMIME() error: %v", err)
|
||||
}
|
||||
|
||||
s := string(msg)
|
||||
if !strings.Contains(s, "From:") {
|
||||
t.Error("missing From header")
|
||||
}
|
||||
if !strings.Contains(s, "To: user@example.com") {
|
||||
t.Error("missing To header")
|
||||
}
|
||||
if !strings.Contains(s, "Wrenn") {
|
||||
t.Error("missing Wrenn sender name")
|
||||
}
|
||||
if !strings.Contains(s, "multipart/alternative") {
|
||||
t.Error("missing multipart/alternative content type")
|
||||
}
|
||||
if !strings.Contains(s, "text/plain") {
|
||||
t.Error("missing text/plain part")
|
||||
}
|
||||
if !strings.Contains(s, "text/html") {
|
||||
t.Error("missing text/html part")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMIMENonASCII(t *testing.T) {
|
||||
msg, err := buildMIME("noreply@wrenn.dev", "user@example.com", "Test", "<p>\u00c5ngstr\u00f6m</p>", "Hello \u00c5ngstr\u00f6m")
|
||||
if err != nil {
|
||||
t.Fatalf("buildMIME() error: %v", err)
|
||||
}
|
||||
|
||||
s := string(msg)
|
||||
// Non-ASCII characters should be QP-encoded, not appear as raw bytes.
|
||||
// \u00c5 (U+00C5, 0xC3 0x85 in UTF-8) should be encoded as =C3=85.
|
||||
if !strings.Contains(s, "=C3=85") {
|
||||
t.Error("non-ASCII character not quoted-printable encoded")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"normal@example.com", "normal@example.com"},
|
||||
{"injected\r\nBcc: evil@example.com", "injectedBcc: evil@example.com"},
|
||||
{"has\nnewline", "hasnewline"},
|
||||
{"has\rcarriage", "hascarriage"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := sanitizeHeader(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("sanitizeHeader(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
52
internal/email/templates.go
Normal file
52
internal/email/templates.go
Normal file
@ -0,0 +1,52 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"embed"
|
||||
"fmt"
|
||||
"html/template"
|
||||
text_template "text/template"
|
||||
)
|
||||
|
||||
//go:embed templates/*.html templates/*.txt
|
||||
var templateFS embed.FS
|
||||
|
||||
// templates holds the parsed HTML and plain-text template sets.
|
||||
type templates struct {
|
||||
html *template.Template
|
||||
text *text_template.Template
|
||||
}
|
||||
|
||||
// mustLoadTemplates parses all embedded templates. Panics on error
|
||||
// because malformed templates are a build-time bug.
|
||||
func mustLoadTemplates() *templates {
|
||||
html, err := template.ParseFS(templateFS, "templates/*.html")
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("email: failed to parse HTML templates: %v", err))
|
||||
}
|
||||
|
||||
text, err := text_template.ParseFS(templateFS, "templates/*.txt")
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("email: failed to parse text templates: %v", err))
|
||||
}
|
||||
|
||||
return &templates{html: html, text: text}
|
||||
}
|
||||
|
||||
// renderHTML executes the HTML base template with the given data.
|
||||
func (t *templates) renderHTML(data EmailData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
if err := t.html.ExecuteTemplate(&buf, "base.html", data); err != nil {
|
||||
return "", fmt.Errorf("execute html template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// renderText executes the plain-text base template with the given data.
|
||||
func (t *templates) renderText(data EmailData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
if err := t.text.ExecuteTemplate(&buf, "base.txt", data); err != nil {
|
||||
return "", fmt.Errorf("execute text template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
119
internal/email/templates/base.html
Normal file
119
internal/email/templates/base.html
Normal file
@ -0,0 +1,119 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=edge">
|
||||
<title>Wrenn</title>
|
||||
<!--[if mso]>
|
||||
<noscript>
|
||||
<xml>
|
||||
<o:OfficeDocumentSettings>
|
||||
<o:PixelsPerInch>96</o:PixelsPerInch>
|
||||
</o:OfficeDocumentSettings>
|
||||
</xml>
|
||||
</noscript>
|
||||
<![endif]-->
|
||||
<style type="text/css">
|
||||
body, table, td, a { -webkit-text-size-adjust: 100%; -ms-text-size-adjust: 100%; }
|
||||
table, td { mso-table-lspace: 0pt; mso-table-rspace: 0pt; }
|
||||
img { -ms-interpolation-mode: bicubic; border: 0; height: auto; line-height: 100%; outline: none; text-decoration: none; }
|
||||
body { margin: 0; padding: 0; width: 100% !important; }
|
||||
</style>
|
||||
</head>
|
||||
<body style="margin: 0; padding: 0; background-color: #f4f3f1; font-family: 'Manrope', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif; -webkit-font-smoothing: antialiased;">
|
||||
|
||||
<!-- Outer wrapper -->
|
||||
<table role="presentation" cellpadding="0" cellspacing="0" width="100%" style="background-color: #f4f3f1;">
|
||||
<tr>
|
||||
<td align="center" style="padding: 40px 16px;">
|
||||
|
||||
<!-- Logo + Wordmark -->
|
||||
<table role="presentation" cellpadding="0" cellspacing="0" width="560" style="max-width: 560px;">
|
||||
<tr>
|
||||
<td align="left" style="padding-bottom: 32px;">
|
||||
<a href="https://wrenn.dev" style="text-decoration: none;">
|
||||
<table role="presentation" cellpadding="0" cellspacing="0">
|
||||
<tr>
|
||||
<td style="vertical-align: middle;">
|
||||
<img src="https://wrenn.dev/logo.png" alt="Wrenn" width="36" height="36" style="display: block; border-radius: 6px;">
|
||||
</td>
|
||||
<td style="vertical-align: middle; padding-left: 10px;">
|
||||
<span style="font-family: 'Alice', Georgia, 'Times New Roman', serif; font-size: 22px; color: #1a1917; letter-spacing: 0.01em;">Wrenn</span>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<!-- Card -->
|
||||
<table role="presentation" cellpadding="0" cellspacing="0" width="560" style="max-width: 560px; background-color: #ffffff; border: 1px solid #e5e4e0; border-radius: 8px;">
|
||||
<tr>
|
||||
<td style="padding: 44px 48px;">
|
||||
|
||||
<!-- Greeting -->
|
||||
<p style="margin: 0 0 8px 0; font-size: 15px; line-height: 1.6; color: #3a3835;">
|
||||
Hello{{if .RecipientName}} {{.RecipientName}}{{end}},
|
||||
</p>
|
||||
|
||||
<!-- Message -->
|
||||
<p style="margin: 0 0 36px 0; font-size: 15px; line-height: 1.7; color: #3a3835;">
|
||||
{{.Message}}
|
||||
</p>
|
||||
|
||||
<!-- Button -->
|
||||
{{if .Button}}
|
||||
<table role="presentation" cellpadding="0" cellspacing="0" style="margin: 0 0 36px 0;">
|
||||
<tr>
|
||||
<td align="center" style="background-color: #5e8c58; border-radius: 5px;">
|
||||
<!--[if mso]>
|
||||
<v:roundrect xmlns:v="urn:schemas-microsoft-com:vml" xmlns:w="urn:schemas-microsoft-com:office:word" href="{{.Button.URL}}" style="height:44px;v-text-anchor:middle;width:200px;" arcsize="12%" strokecolor="#5e8c58" fillcolor="#5e8c58">
|
||||
<w:anchorlock/>
|
||||
<center style="color:#ffffff;font-family:'Manrope',-apple-system,sans-serif;font-size:14px;font-weight:600;">{{.Button.Text}}</center>
|
||||
</v:roundrect>
|
||||
<![endif]-->
|
||||
<!--[if !mso]><!-->
|
||||
<a href="{{.Button.URL}}" target="_blank" style="display: inline-block; padding: 12px 32px; font-size: 14px; font-weight: 600; color: #ffffff; text-decoration: none; border-radius: 5px; background-color: #5e8c58;">
|
||||
{{.Button.Text}}
|
||||
</a>
|
||||
<!--<![endif]-->
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<p style="margin: 0 0 12px 0; font-size: 12px; line-height: 1.5; color: #b5b0a8;">
|
||||
If the button doesn't work, copy and paste this URL into your browser:<br>
|
||||
<a href="{{.Button.URL}}" style="color: #5e8c58; word-break: break-all;">{{.Button.URL}}</a>
|
||||
</p>
|
||||
{{end}}
|
||||
|
||||
<!-- Closing -->
|
||||
{{if .Closing}}
|
||||
<p style="margin: {{if .Button}}20px{{else}}0{{end}} 0 0 0; font-size: 15px; line-height: 1.7; color: #3a3835;">
|
||||
{{.Closing}}
|
||||
</p>
|
||||
{{end}}
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<!-- Footer -->
|
||||
<table role="presentation" cellpadding="0" cellspacing="0" width="560" style="max-width: 560px;">
|
||||
<tr>
|
||||
<td style="padding: 32px 0 16px 0; text-align: center;">
|
||||
<p style="margin: 0; font-size: 12px; line-height: 1.5; color: #9b9790;">
|
||||
This is a transactional email from <a href="https://wrenn.dev" style="color: #5e8c58; text-decoration: none;">Wrenn</a>.
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
13
internal/email/templates/base.txt
Normal file
13
internal/email/templates/base.txt
Normal file
@ -0,0 +1,13 @@
|
||||
Hello{{if .RecipientName}} {{.RecipientName}}{{end}},
|
||||
|
||||
{{.Message}}
|
||||
{{if .Button}}
|
||||
|
||||
{{.Button.Text}}: {{.Button.URL}}
|
||||
{{end}}{{if .Closing}}
|
||||
|
||||
{{.Closing}}
|
||||
{{end}}
|
||||
|
||||
---
|
||||
This is a transactional email from Wrenn (https://wrenn.dev).
|
||||
@ -3,6 +3,7 @@ package envdclient
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
@ -268,6 +269,82 @@ func (c *Client) ReadFile(ctx context.Context, path string) ([]byte, error) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// PrepareSnapshot calls envd's POST /snapshot/prepare endpoint, which quiesces
|
||||
// continuous goroutines (port scanner, forwarder) and forces a GC cycle before
|
||||
// Firecracker takes a VM snapshot. This ensures the Go runtime's page allocator
|
||||
// is in a consistent state when vCPUs are frozen.
|
||||
//
|
||||
// Best-effort: the caller should log a warning on error but not abort the pause.
|
||||
func (c *Client) PrepareSnapshot(ctx context.Context) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base+"/snapshot/prepare", nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("prepare snapshot: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("prepare snapshot: status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PostInit calls envd's POST /init endpoint, which triggers a re-read of
|
||||
// Firecracker MMDS metadata. This updates WRENN_SANDBOX_ID, WRENN_TEMPLATE_ID
|
||||
// env vars and the corresponding files under /run/wrenn/ inside the guest.
|
||||
// Must be called after snapshot restore so envd picks up the new sandbox's metadata.
|
||||
func (c *Client) PostInit(ctx context.Context) error {
|
||||
return c.PostInitWithDefaults(ctx, "", nil)
|
||||
}
|
||||
|
||||
// PostInitWithDefaults calls envd's POST /init endpoint with optional default
|
||||
// user and environment variables. These are applied to envd's defaults so all
|
||||
// subsequent process executions use them.
|
||||
func (c *Client) PostInitWithDefaults(ctx context.Context, defaultUser string, envVars map[string]string) error {
|
||||
var body io.Reader
|
||||
if defaultUser != "" || len(envVars) > 0 {
|
||||
payload := make(map[string]any)
|
||||
if defaultUser != "" {
|
||||
payload["defaultUser"] = defaultUser
|
||||
}
|
||||
if len(envVars) > 0 {
|
||||
payload["envVars"] = envVars
|
||||
}
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal init body: %w", err)
|
||||
}
|
||||
body = bytes.NewReader(data)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base+"/init", body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("post init: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("post init: status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListDir lists directory contents inside the sandbox.
|
||||
func (c *Client) ListDir(ctx context.Context, path string, depth uint32) (*envdpb.ListDirResponse, error) {
|
||||
req := connect.NewRequest(&envdpb.ListDirRequest{
|
||||
@ -282,3 +359,30 @@ func (c *Client) ListDir(ctx context.Context, path string, depth uint32) (*envdp
|
||||
|
||||
return resp.Msg, nil
|
||||
}
|
||||
|
||||
// MakeDir creates a directory inside the sandbox.
|
||||
func (c *Client) MakeDir(ctx context.Context, path string) (*envdpb.MakeDirResponse, error) {
|
||||
req := connect.NewRequest(&envdpb.MakeDirRequest{
|
||||
Path: path,
|
||||
})
|
||||
|
||||
resp, err := c.filesystem.MakeDir(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("make dir: %w", err)
|
||||
}
|
||||
|
||||
return resp.Msg, nil
|
||||
}
|
||||
|
||||
// Remove removes a file or directory inside the sandbox.
|
||||
func (c *Client) Remove(ctx context.Context, path string) error {
|
||||
req := connect.NewRequest(&envdpb.RemoveRequest{
|
||||
Path: path,
|
||||
})
|
||||
|
||||
if _, err := c.filesystem.Remove(ctx, req); err != nil {
|
||||
return fmt.Errorf("remove: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -2,7 +2,9 @@ package envdclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
@ -31,6 +33,38 @@ func (c *Client) WaitUntilReady(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
// FetchVersion queries envd's health endpoint and returns the reported version.
|
||||
func (c *Client) FetchVersion(ctx context.Context) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.healthURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("build health request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch envd version: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
|
||||
return "", fmt.Errorf("health check returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil || len(body) == 0 {
|
||||
return "", nil // envd may not support version reporting yet
|
||||
}
|
||||
|
||||
var data struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &data); err != nil {
|
||||
return "", nil // non-JSON response, old envd
|
||||
}
|
||||
|
||||
return data.Version, nil
|
||||
}
|
||||
|
||||
// healthCheck sends a single GET /health request to envd.
|
||||
func (c *Client) healthCheck(ctx context.Context) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.healthURL, nil)
|
||||
|
||||
187
internal/envdclient/process.go
Normal file
187
internal/envdclient/process.go
Normal file
@ -0,0 +1,187 @@
|
||||
package envdclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
envdpb "git.omukk.dev/wrenn/wrenn/proto/envd/gen"
|
||||
)
|
||||
|
||||
// ProcessInfo holds metadata about a running process inside the sandbox.
|
||||
type ProcessInfo struct {
|
||||
PID uint32
|
||||
Tag string
|
||||
Cmd string
|
||||
Args []string
|
||||
}
|
||||
|
||||
// StartBackground starts a process that runs independently of the RPC stream.
|
||||
// It opens a Start stream, reads the first StartEvent to obtain the PID,
|
||||
// then closes the stream. The process continues running inside the VM because
|
||||
// envd binds it to context.Background().
|
||||
func (c *Client) StartBackground(ctx context.Context, tag, cmd string, args []string, envs map[string]string, cwd string) (uint32, error) {
|
||||
stdin := false
|
||||
cfg := &envdpb.ProcessConfig{
|
||||
Cmd: cmd,
|
||||
Args: args,
|
||||
Envs: envs,
|
||||
}
|
||||
if cwd != "" {
|
||||
cfg.Cwd = &cwd
|
||||
}
|
||||
|
||||
req := connect.NewRequest(&envdpb.StartRequest{
|
||||
Process: cfg,
|
||||
Tag: &tag,
|
||||
Stdin: &stdin,
|
||||
})
|
||||
|
||||
stream, err := c.process.Start(ctx, req)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("start background process: %w", err)
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
// Read events until we get the StartEvent with the PID.
|
||||
for stream.Receive() {
|
||||
msg := stream.Msg()
|
||||
if msg.Event == nil {
|
||||
continue
|
||||
}
|
||||
if start, ok := msg.Event.GetEvent().(*envdpb.ProcessEvent_Start); ok {
|
||||
return start.Start.GetPid(), nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil && err != io.EOF {
|
||||
return 0, fmt.Errorf("start background process stream: %w", err)
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("start background process: no start event received")
|
||||
}
|
||||
|
||||
// ConnectProcess re-attaches to a running process by PID or tag and returns
|
||||
// a channel of streaming events. The channel is closed when the process ends
|
||||
// or the context is cancelled.
|
||||
func (c *Client) ConnectProcess(ctx context.Context, pid uint32, tag string) (<-chan ExecStreamEvent, error) {
|
||||
var selector *envdpb.ProcessSelector
|
||||
if tag != "" {
|
||||
selector = &envdpb.ProcessSelector{
|
||||
Selector: &envdpb.ProcessSelector_Tag{Tag: tag},
|
||||
}
|
||||
} else {
|
||||
selector = &envdpb.ProcessSelector{
|
||||
Selector: &envdpb.ProcessSelector_Pid{Pid: pid},
|
||||
}
|
||||
}
|
||||
|
||||
stream, err := c.process.Connect(ctx, connect.NewRequest(&envdpb.ConnectRequest{
|
||||
Process: selector,
|
||||
}))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect process: %w", err)
|
||||
}
|
||||
|
||||
ch := make(chan ExecStreamEvent, 16)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
defer stream.Close()
|
||||
|
||||
for stream.Receive() {
|
||||
msg := stream.Msg()
|
||||
if msg.Event == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var ev ExecStreamEvent
|
||||
switch e := msg.Event.GetEvent().(type) {
|
||||
case *envdpb.ProcessEvent_Start:
|
||||
ev = ExecStreamEvent{Type: "start", PID: e.Start.GetPid()}
|
||||
|
||||
case *envdpb.ProcessEvent_Data:
|
||||
switch o := e.Data.GetOutput().(type) {
|
||||
case *envdpb.ProcessEvent_DataEvent_Stdout:
|
||||
ev = ExecStreamEvent{Type: "stdout", Data: o.Stdout}
|
||||
case *envdpb.ProcessEvent_DataEvent_Stderr:
|
||||
ev = ExecStreamEvent{Type: "stderr", Data: o.Stderr}
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
case *envdpb.ProcessEvent_End:
|
||||
ev = ExecStreamEvent{Type: "end", ExitCode: e.End.GetExitCode()}
|
||||
if e.End.Error != nil {
|
||||
ev.Error = e.End.GetError()
|
||||
}
|
||||
|
||||
case *envdpb.ProcessEvent_Keepalive:
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case ch <- ev:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil && err != io.EOF {
|
||||
slog.Debug("connect process stream error", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// ListProcesses returns all running processes inside the sandbox.
|
||||
func (c *Client) ListProcesses(ctx context.Context) ([]ProcessInfo, error) {
|
||||
resp, err := c.process.List(ctx, connect.NewRequest(&envdpb.ListRequest{}))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list processes: %w", err)
|
||||
}
|
||||
|
||||
procs := make([]ProcessInfo, 0, len(resp.Msg.Processes))
|
||||
for _, p := range resp.Msg.Processes {
|
||||
info := ProcessInfo{
|
||||
PID: p.Pid,
|
||||
}
|
||||
if p.Tag != nil {
|
||||
info.Tag = *p.Tag
|
||||
}
|
||||
if p.Config != nil {
|
||||
info.Cmd = p.Config.Cmd
|
||||
info.Args = p.Config.Args
|
||||
}
|
||||
procs = append(procs, info)
|
||||
}
|
||||
|
||||
return procs, nil
|
||||
}
|
||||
|
||||
// KillProcess sends a signal to a process identified by PID or tag.
|
||||
func (c *Client) KillProcess(ctx context.Context, pid uint32, tag string, signal envdpb.Signal) error {
|
||||
var selector *envdpb.ProcessSelector
|
||||
if tag != "" {
|
||||
selector = &envdpb.ProcessSelector{
|
||||
Selector: &envdpb.ProcessSelector_Tag{Tag: tag},
|
||||
}
|
||||
} else {
|
||||
selector = &envdpb.ProcessSelector{
|
||||
Selector: &envdpb.ProcessSelector_Pid{Pid: pid},
|
||||
}
|
||||
}
|
||||
|
||||
_, err := c.process.SendSignal(ctx, connect.NewRequest(&envdpb.SendSignalRequest{
|
||||
Process: selector,
|
||||
Signal: signal,
|
||||
}))
|
||||
if err != nil {
|
||||
return fmt.Errorf("kill process: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
220
internal/envdclient/pty.go
Normal file
220
internal/envdclient/pty.go
Normal file
@ -0,0 +1,220 @@
|
||||
package envdclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
envdpb "git.omukk.dev/wrenn/wrenn/proto/envd/gen"
|
||||
)
|
||||
|
||||
// PtyEvent represents a single event from a PTY output stream.
|
||||
type PtyEvent struct {
|
||||
Type string // "started", "output", "end"
|
||||
PID uint32
|
||||
Data []byte
|
||||
ExitCode int32
|
||||
Error string
|
||||
}
|
||||
|
||||
// PtyStart starts a new PTY process in the guest and returns a channel of events.
|
||||
// The tag is the stable identifier used to reconnect via PtyConnect.
|
||||
// The channel is closed when the process ends or ctx is cancelled.
|
||||
// NOTE: The user parameter from PtyAttachRequest is not yet supported by envd's
|
||||
// ProcessConfig proto. When envd adds user support, thread it through here.
|
||||
func (c *Client) PtyStart(ctx context.Context, tag, cmd string, args []string, cols, rows uint32, envs map[string]string, cwd string) (<-chan PtyEvent, error) {
|
||||
stdin := true
|
||||
cfg := &envdpb.ProcessConfig{
|
||||
Cmd: cmd,
|
||||
Args: args,
|
||||
Envs: envs,
|
||||
}
|
||||
if cwd != "" {
|
||||
cfg.Cwd = &cwd
|
||||
}
|
||||
|
||||
req := connect.NewRequest(&envdpb.StartRequest{
|
||||
Process: cfg,
|
||||
Pty: &envdpb.PTY{
|
||||
Size: &envdpb.PTY_Size{
|
||||
Cols: cols,
|
||||
Rows: rows,
|
||||
},
|
||||
},
|
||||
Tag: &tag,
|
||||
Stdin: &stdin,
|
||||
})
|
||||
|
||||
stream, err := c.process.Start(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pty start: %w", err)
|
||||
}
|
||||
|
||||
return drainPtyStream(ctx, &startStream{s: stream}, true), nil
|
||||
}
|
||||
|
||||
// PtyConnect re-attaches to an existing PTY process by tag.
|
||||
// Returns a channel of output events starting from the current point.
|
||||
func (c *Client) PtyConnect(ctx context.Context, tag string) (<-chan PtyEvent, error) {
|
||||
req := connect.NewRequest(&envdpb.ConnectRequest{
|
||||
Process: &envdpb.ProcessSelector{
|
||||
Selector: &envdpb.ProcessSelector_Tag{Tag: tag},
|
||||
},
|
||||
})
|
||||
|
||||
stream, err := c.process.Connect(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pty connect: %w", err)
|
||||
}
|
||||
|
||||
return drainPtyStream(ctx, &connectStream{s: stream}, false), nil
|
||||
}
|
||||
|
||||
// PtySendInput sends raw bytes to the PTY process identified by tag.
|
||||
func (c *Client) PtySendInput(ctx context.Context, tag string, data []byte) error {
|
||||
req := connect.NewRequest(&envdpb.SendInputRequest{
|
||||
Process: &envdpb.ProcessSelector{
|
||||
Selector: &envdpb.ProcessSelector_Tag{Tag: tag},
|
||||
},
|
||||
Input: &envdpb.ProcessInput{
|
||||
Input: &envdpb.ProcessInput_Pty{Pty: data},
|
||||
},
|
||||
})
|
||||
|
||||
if _, err := c.process.SendInput(ctx, req); err != nil {
|
||||
return fmt.Errorf("pty send input: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PtyResize updates the terminal dimensions for the PTY process identified by tag.
|
||||
func (c *Client) PtyResize(ctx context.Context, tag string, cols, rows uint32) error {
|
||||
req := connect.NewRequest(&envdpb.UpdateRequest{
|
||||
Process: &envdpb.ProcessSelector{
|
||||
Selector: &envdpb.ProcessSelector_Tag{Tag: tag},
|
||||
},
|
||||
Pty: &envdpb.PTY{
|
||||
Size: &envdpb.PTY_Size{
|
||||
Cols: cols,
|
||||
Rows: rows,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if _, err := c.process.Update(ctx, req); err != nil {
|
||||
return fmt.Errorf("pty resize: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PtyKill sends SIGKILL to the PTY process identified by tag.
|
||||
func (c *Client) PtyKill(ctx context.Context, tag string) error {
|
||||
req := connect.NewRequest(&envdpb.SendSignalRequest{
|
||||
Process: &envdpb.ProcessSelector{
|
||||
Selector: &envdpb.ProcessSelector_Tag{Tag: tag},
|
||||
},
|
||||
Signal: envdpb.Signal_SIGNAL_SIGKILL,
|
||||
})
|
||||
|
||||
if _, err := c.process.SendSignal(ctx, req); err != nil {
|
||||
return fmt.Errorf("pty kill: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// eventStream is an interface covering both StartResponse and ConnectResponse streams.
|
||||
type eventStream interface {
|
||||
Receive() bool
|
||||
Err() error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type startStream struct {
|
||||
s *connect.ServerStreamForClient[envdpb.StartResponse]
|
||||
}
|
||||
|
||||
func (s *startStream) Receive() bool { return s.s.Receive() }
|
||||
func (s *startStream) Err() error { return s.s.Err() }
|
||||
func (s *startStream) Close() error { return s.s.Close() }
|
||||
func (s *startStream) Event() *envdpb.ProcessEvent {
|
||||
return s.s.Msg().GetEvent()
|
||||
}
|
||||
|
||||
type connectStream struct {
|
||||
s *connect.ServerStreamForClient[envdpb.ConnectResponse]
|
||||
}
|
||||
|
||||
func (s *connectStream) Receive() bool { return s.s.Receive() }
|
||||
func (s *connectStream) Err() error { return s.s.Err() }
|
||||
func (s *connectStream) Close() error { return s.s.Close() }
|
||||
func (s *connectStream) Event() *envdpb.ProcessEvent {
|
||||
return s.s.Msg().GetEvent()
|
||||
}
|
||||
|
||||
type eventProvider interface {
|
||||
eventStream
|
||||
Event() *envdpb.ProcessEvent
|
||||
}
|
||||
|
||||
// drainPtyStream reads events from either a Start or Connect stream and maps
|
||||
// them into PtyEvent values on a channel.
|
||||
func drainPtyStream(ctx context.Context, stream eventProvider, expectStart bool) <-chan PtyEvent {
|
||||
ch := make(chan PtyEvent, 16)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
defer stream.Close()
|
||||
|
||||
for stream.Receive() {
|
||||
event := stream.Event()
|
||||
if event == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var ev PtyEvent
|
||||
switch e := event.GetEvent().(type) {
|
||||
case *envdpb.ProcessEvent_Start:
|
||||
if expectStart {
|
||||
ev = PtyEvent{Type: "started", PID: e.Start.GetPid()}
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
|
||||
case *envdpb.ProcessEvent_Data:
|
||||
switch o := e.Data.GetOutput().(type) {
|
||||
case *envdpb.ProcessEvent_DataEvent_Pty:
|
||||
ev = PtyEvent{Type: "output", Data: o.Pty}
|
||||
case *envdpb.ProcessEvent_DataEvent_Stdout:
|
||||
ev = PtyEvent{Type: "output", Data: o.Stdout}
|
||||
case *envdpb.ProcessEvent_DataEvent_Stderr:
|
||||
ev = PtyEvent{Type: "output", Data: o.Stderr}
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
case *envdpb.ProcessEvent_End:
|
||||
ev = PtyEvent{Type: "end", ExitCode: e.End.GetExitCode()}
|
||||
if e.End.Error != nil {
|
||||
ev.Error = e.End.GetError()
|
||||
}
|
||||
|
||||
case *envdpb.ProcessEvent_Keepalive:
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case ch <- ev:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil && err != io.EOF {
|
||||
slog.Debug("pty stream error", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch
|
||||
}
|
||||
@ -1,73 +0,0 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// EventPublisher pushes events onto the notification stream.
|
||||
// Satisfied by *channels.Publisher.
|
||||
type EventPublisher interface {
|
||||
Publish(ctx context.Context, e Event)
|
||||
}
|
||||
|
||||
// ActorKind identifies what initiated an event.
|
||||
type ActorKind string
|
||||
|
||||
const (
|
||||
ActorUser ActorKind = "user"
|
||||
ActorAPIKey ActorKind = "api_key"
|
||||
ActorSystem ActorKind = "system"
|
||||
)
|
||||
|
||||
// Actor describes who triggered an event.
|
||||
type Actor struct {
|
||||
Type ActorKind `json:"type"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
// Resource identifies the object the event relates to.
|
||||
type Resource struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// Event is the canonical notification payload published to the Redis stream
|
||||
// and delivered to channel subscribers.
|
||||
type Event struct {
|
||||
Event string `json:"event"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
TeamID string `json:"team_id"`
|
||||
Actor Actor `json:"actor"`
|
||||
Resource Resource `json:"resource"`
|
||||
}
|
||||
|
||||
// Event type constants.
|
||||
const (
|
||||
CapsuleCreated = "capsule.created"
|
||||
CapsuleRunning = "capsule.running"
|
||||
CapsulePaused = "capsule.paused"
|
||||
CapsuleDestroyed = "capsule.destroyed"
|
||||
SnapshotCreated = "template.snapshot.created"
|
||||
SnapshotDeleted = "template.snapshot.deleted"
|
||||
HostUp = "host.up"
|
||||
HostDown = "host.down"
|
||||
)
|
||||
|
||||
// AllEventTypes is the complete set of valid event type strings.
|
||||
var AllEventTypes = []string{
|
||||
CapsuleCreated,
|
||||
CapsuleRunning,
|
||||
CapsulePaused,
|
||||
CapsuleDestroyed,
|
||||
SnapshotCreated,
|
||||
SnapshotDeleted,
|
||||
HostUp,
|
||||
HostDown,
|
||||
}
|
||||
|
||||
// Now returns the current time formatted for event timestamps.
|
||||
func Now() string {
|
||||
return time.Now().UTC().Format(time.RFC3339)
|
||||
}
|
||||
@ -15,6 +15,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
envdpb "git.omukk.dev/wrenn/wrenn/proto/envd/gen"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect"
|
||||
|
||||
@ -68,10 +69,18 @@ func (s *Server) CreateSandbox(
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create sandbox: %w", err))
|
||||
}
|
||||
|
||||
// Apply template defaults (user, env vars) if provided.
|
||||
if msg.DefaultUser != "" || len(msg.DefaultEnv) > 0 {
|
||||
if err := s.mgr.SetDefaults(ctx, sb.ID, msg.DefaultUser, msg.DefaultEnv); err != nil {
|
||||
slog.Warn("failed to set sandbox defaults", "sandbox", sb.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.CreateSandboxResponse{
|
||||
SandboxId: sb.ID,
|
||||
Status: string(sb.Status),
|
||||
HostIp: sb.HostIP.String(),
|
||||
Metadata: sb.Metadata,
|
||||
}), nil
|
||||
}
|
||||
|
||||
@ -99,14 +108,24 @@ func (s *Server) ResumeSandbox(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.ResumeSandboxRequest],
|
||||
) (*connect.Response[pb.ResumeSandboxResponse], error) {
|
||||
sb, err := s.mgr.Resume(ctx, req.Msg.SandboxId, int(req.Msg.TimeoutSec))
|
||||
msg := req.Msg
|
||||
sb, err := s.mgr.Resume(ctx, msg.SandboxId, int(msg.TimeoutSec), msg.KernelVersion)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, err)
|
||||
}
|
||||
|
||||
// Apply template defaults (user, env vars) if provided.
|
||||
if msg.DefaultUser != "" || len(msg.DefaultEnv) > 0 {
|
||||
if err := s.mgr.SetDefaults(ctx, sb.ID, msg.DefaultUser, msg.DefaultEnv); err != nil {
|
||||
slog.Warn("failed to set sandbox defaults on resume", "sandbox", sb.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.ResumeSandboxResponse{
|
||||
SandboxId: sb.ID,
|
||||
Status: string(sb.Status),
|
||||
HostIp: sb.HostIP.String(),
|
||||
Metadata: sb.Metadata,
|
||||
}), nil
|
||||
}
|
||||
|
||||
@ -252,6 +271,69 @@ func (s *Server) ReadFile(
|
||||
return connect.NewResponse(&pb.ReadFileResponse{Content: content}), nil
|
||||
}
|
||||
|
||||
func (s *Server) ListDir(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.ListDirRequest],
|
||||
) (*connect.Response[pb.ListDirResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
client, err := s.mgr.GetClient(msg.SandboxId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
|
||||
resp, err := client.ListDir(ctx, msg.Path, msg.Depth)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("list dir: %w", err))
|
||||
}
|
||||
|
||||
entries := make([]*pb.FileEntry, 0, len(resp.Entries))
|
||||
for _, e := range resp.Entries {
|
||||
entries = append(entries, entryInfoToPB(e))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.ListDirResponse{Entries: entries}), nil
|
||||
}
|
||||
|
||||
func (s *Server) MakeDir(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.MakeDirRequest],
|
||||
) (*connect.Response[pb.MakeDirResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
client, err := s.mgr.GetClient(msg.SandboxId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
|
||||
resp, err := client.MakeDir(ctx, msg.Path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("make dir: %w", err)
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.MakeDirResponse{
|
||||
Entry: entryInfoToPB(resp.Entry),
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s *Server) RemovePath(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.RemovePathRequest],
|
||||
) (*connect.Response[pb.RemovePathResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
client, err := s.mgr.GetClient(msg.SandboxId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
|
||||
if err := client.Remove(ctx, msg.Path); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("remove: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.RemovePathResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Server) ExecStream(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.ExecStreamRequest],
|
||||
@ -436,6 +518,16 @@ func (s *Server) ReadFileStream(
|
||||
// Stream file content in 64KB chunks.
|
||||
buf := make([]byte, 64*1024)
|
||||
for {
|
||||
// Bail out early if the client disconnected or the context was cancelled.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return connect.NewError(connect.CodeDeadlineExceeded, ctx.Err())
|
||||
}
|
||||
return connect.NewError(connect.CodeCanceled, ctx.Err())
|
||||
default:
|
||||
}
|
||||
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
chunk := make([]byte, n)
|
||||
@ -474,6 +566,7 @@ func (s *Server) ListSandboxes(
|
||||
CreatedAtUnix: sb.CreatedAt.Unix(),
|
||||
LastActiveAtUnix: sb.LastActiveAt.Unix(),
|
||||
TimeoutSec: int32(sb.TimeoutSec),
|
||||
Metadata: sb.Metadata,
|
||||
}
|
||||
}
|
||||
|
||||
@ -545,3 +638,269 @@ func metricPointsToPB(pts []sandbox.MetricPoint) []*pb.MetricPoint {
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *Server) PtyAttach(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.PtyAttachRequest],
|
||||
stream *connect.ServerStream[pb.PtyAttachResponse],
|
||||
) error {
|
||||
msg := req.Msg
|
||||
|
||||
events, err := s.mgr.PtyAttach(ctx, msg.SandboxId, msg.Tag, msg.Cmd, msg.Args, msg.Cols, msg.Rows, msg.Envs, msg.Cwd)
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("pty attach: %w", err))
|
||||
}
|
||||
|
||||
for ev := range events {
|
||||
var resp pb.PtyAttachResponse
|
||||
switch ev.Type {
|
||||
case "started":
|
||||
resp.Event = &pb.PtyAttachResponse_Started{
|
||||
Started: &pb.PtyStarted{Pid: ev.PID, Tag: msg.Tag},
|
||||
}
|
||||
case "output":
|
||||
resp.Event = &pb.PtyAttachResponse_Output{
|
||||
Output: &pb.PtyOutput{Data: ev.Data},
|
||||
}
|
||||
case "end":
|
||||
resp.Event = &pb.PtyAttachResponse_Exited{
|
||||
Exited: &pb.PtyExited{ExitCode: ev.ExitCode, Error: ev.Error},
|
||||
}
|
||||
default:
|
||||
continue
|
||||
}
|
||||
if err := stream.Send(&resp); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) PtySendInput(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.PtySendInputRequest],
|
||||
) (*connect.Response[pb.PtySendInputResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
if err := s.mgr.PtySendInput(ctx, msg.SandboxId, msg.Tag, msg.Data); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("pty send input: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.PtySendInputResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Server) PtyResize(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.PtyResizeRequest],
|
||||
) (*connect.Response[pb.PtyResizeResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
if err := s.mgr.PtyResize(ctx, msg.SandboxId, msg.Tag, msg.Cols, msg.Rows); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("pty resize: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.PtyResizeResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Server) PtyKill(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.PtyKillRequest],
|
||||
) (*connect.Response[pb.PtyKillResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
if err := s.mgr.PtyKill(ctx, msg.SandboxId, msg.Tag); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("pty kill: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.PtyKillResponse{}), nil
|
||||
}
|
||||
|
||||
// entryInfoToPB maps an envd EntryInfo to a hostagent FileEntry.
|
||||
func entryInfoToPB(e *envdpb.EntryInfo) *pb.FileEntry {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var fileType string
|
||||
switch e.Type {
|
||||
case envdpb.FileType_FILE_TYPE_FILE:
|
||||
fileType = "file"
|
||||
case envdpb.FileType_FILE_TYPE_DIRECTORY:
|
||||
fileType = "directory"
|
||||
case envdpb.FileType_FILE_TYPE_SYMLINK:
|
||||
fileType = "symlink"
|
||||
default:
|
||||
fileType = "unknown"
|
||||
}
|
||||
|
||||
entry := &pb.FileEntry{
|
||||
Name: e.Name,
|
||||
Path: e.Path,
|
||||
Type: fileType,
|
||||
Size: e.Size,
|
||||
Mode: e.Mode,
|
||||
Permissions: e.Permissions,
|
||||
Owner: e.Owner,
|
||||
Group: e.Group,
|
||||
}
|
||||
|
||||
if e.ModifiedTime != nil {
|
||||
entry.ModifiedAt = e.ModifiedTime.GetSeconds()
|
||||
}
|
||||
|
||||
if e.SymlinkTarget != nil {
|
||||
entry.SymlinkTarget = e.SymlinkTarget
|
||||
}
|
||||
|
||||
return entry
|
||||
}
|
||||
|
||||
// ── Background Processes ────────────────────────────────────────────
|
||||
|
||||
func (s *Server) StartBackground(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.StartBackgroundRequest],
|
||||
) (*connect.Response[pb.StartBackgroundResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
pid, err := s.mgr.StartBackground(ctx, msg.SandboxId, msg.Tag, msg.Cmd, msg.Args, msg.Envs, msg.Cwd)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("start background: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.StartBackgroundResponse{
|
||||
Pid: pid,
|
||||
Tag: msg.Tag,
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s *Server) ListProcesses(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.ListProcessesRequest],
|
||||
) (*connect.Response[pb.ListProcessesResponse], error) {
|
||||
procs, err := s.mgr.ListProcesses(ctx, req.Msg.SandboxId)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("list processes: %w", err))
|
||||
}
|
||||
|
||||
entries := make([]*pb.ProcessEntry, 0, len(procs))
|
||||
for _, p := range procs {
|
||||
entries = append(entries, &pb.ProcessEntry{
|
||||
Pid: p.PID,
|
||||
Tag: p.Tag,
|
||||
Cmd: p.Cmd,
|
||||
Args: p.Args,
|
||||
})
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.ListProcessesResponse{
|
||||
Processes: entries,
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s *Server) KillProcess(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.KillProcessRequest],
|
||||
) (*connect.Response[pb.KillProcessResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
// Resolve PID/tag selector.
|
||||
var pid uint32
|
||||
var tag string
|
||||
switch sel := msg.Selector.(type) {
|
||||
case *pb.KillProcessRequest_Pid:
|
||||
pid = sel.Pid
|
||||
case *pb.KillProcessRequest_Tag:
|
||||
tag = sel.Tag
|
||||
default:
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("pid or tag is required"))
|
||||
}
|
||||
|
||||
// Map signal string to envd enum.
|
||||
var signal envdpb.Signal
|
||||
switch msg.Signal {
|
||||
case "", "SIGKILL":
|
||||
signal = envdpb.Signal_SIGNAL_SIGKILL
|
||||
case "SIGTERM":
|
||||
signal = envdpb.Signal_SIGNAL_SIGTERM
|
||||
default:
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("unsupported signal: %s (use SIGKILL or SIGTERM)", msg.Signal))
|
||||
}
|
||||
|
||||
if err := s.mgr.KillProcess(ctx, msg.SandboxId, pid, tag, signal); err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("kill process: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.KillProcessResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Server) ConnectProcess(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.ConnectProcessRequest],
|
||||
stream *connect.ServerStream[pb.ConnectProcessResponse],
|
||||
) error {
|
||||
msg := req.Msg
|
||||
|
||||
var pid uint32
|
||||
var tag string
|
||||
switch sel := msg.Selector.(type) {
|
||||
case *pb.ConnectProcessRequest_Pid:
|
||||
pid = sel.Pid
|
||||
case *pb.ConnectProcessRequest_Tag:
|
||||
tag = sel.Tag
|
||||
default:
|
||||
return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("pid or tag is required"))
|
||||
}
|
||||
|
||||
events, err := s.mgr.ConnectProcess(ctx, msg.SandboxId, pid, tag)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("connect process: %w", err))
|
||||
}
|
||||
|
||||
for ev := range events {
|
||||
var resp pb.ConnectProcessResponse
|
||||
switch ev.Type {
|
||||
case "start":
|
||||
resp.Event = &pb.ConnectProcessResponse_Start{
|
||||
Start: &pb.ExecStreamStart{Pid: ev.PID},
|
||||
}
|
||||
case "stdout":
|
||||
resp.Event = &pb.ConnectProcessResponse_Data{
|
||||
Data: &pb.ExecStreamData{
|
||||
Output: &pb.ExecStreamData_Stdout{Stdout: ev.Data},
|
||||
},
|
||||
}
|
||||
case "stderr":
|
||||
resp.Event = &pb.ConnectProcessResponse_Data{
|
||||
Data: &pb.ExecStreamData{
|
||||
Output: &pb.ExecStreamData_Stderr{Stderr: ev.Data},
|
||||
},
|
||||
}
|
||||
case "end":
|
||||
resp.Event = &pb.ConnectProcessResponse_End{
|
||||
End: &pb.ExecStreamEnd{
|
||||
ExitCode: ev.ExitCode,
|
||||
Error: ev.Error,
|
||||
},
|
||||
}
|
||||
}
|
||||
if err := stream.Send(&resp); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -1,186 +0,0 @@
|
||||
package id
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const (
|
||||
base36Alphabet = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
base36IDLen = 25 // ceil(128 * log2 / log36) = 25 chars for a full UUID
|
||||
)
|
||||
|
||||
var base36Base = big.NewInt(36)
|
||||
|
||||
// --- Generation ---
|
||||
|
||||
// newUUID returns a new random (v4) UUID wrapped in pgtype.UUID for direct DB use.
|
||||
func newUUID() pgtype.UUID {
|
||||
return pgtype.UUID{Bytes: uuid.New(), Valid: true}
|
||||
}
|
||||
|
||||
func NewSandboxID() pgtype.UUID { return newUUID() }
|
||||
func NewUserID() pgtype.UUID { return newUUID() }
|
||||
func NewTeamID() pgtype.UUID { return newUUID() }
|
||||
func NewAPIKeyID() pgtype.UUID { return newUUID() }
|
||||
func NewHostID() pgtype.UUID { return newUUID() }
|
||||
func NewHostTokenID() pgtype.UUID { return newUUID() }
|
||||
func NewRefreshTokenID() pgtype.UUID { return newUUID() }
|
||||
func NewAuditLogID() pgtype.UUID { return newUUID() }
|
||||
func NewBuildID() pgtype.UUID { return newUUID() }
|
||||
func NewAdminPermissionID() pgtype.UUID { return newUUID() }
|
||||
func NewChannelID() pgtype.UUID { return newUUID() }
|
||||
|
||||
func NewTemplateID() pgtype.UUID { return newUUID() }
|
||||
|
||||
// NewSnapshotName generates a snapshot name: "template-" + 8 hex chars.
|
||||
func NewSnapshotName() string {
|
||||
return "template-" + hex8()
|
||||
}
|
||||
|
||||
// NewTeamSlug generates a unique team slug in the format "xxxxxx-yyyyyy".
|
||||
func NewTeamSlug() string {
|
||||
b := make([]byte, 6)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand failed: %v", err))
|
||||
}
|
||||
return hex.EncodeToString(b[:3]) + "-" + hex.EncodeToString(b[3:])
|
||||
}
|
||||
|
||||
// NewRegistrationToken generates a 64-char hex token (32 bytes of entropy).
|
||||
func NewRegistrationToken() string {
|
||||
return hexToken(32)
|
||||
}
|
||||
|
||||
// NewRefreshToken generates a 64-char hex token (32 bytes of entropy).
|
||||
func NewRefreshToken() string {
|
||||
return hexToken(32)
|
||||
}
|
||||
|
||||
// --- Formatting (pgtype.UUID → prefixed string for API/RPC output) ---
|
||||
|
||||
const (
|
||||
PrefixSandbox = "cl-"
|
||||
PrefixUser = "usr-"
|
||||
PrefixTeam = "team-"
|
||||
PrefixAPIKey = "key-"
|
||||
PrefixHost = "host-"
|
||||
PrefixHostToken = "htok-"
|
||||
PrefixRefreshToken = "hrt-"
|
||||
PrefixAuditLog = "log-"
|
||||
PrefixBuild = "bld-"
|
||||
PrefixAdminPermission = "perm-"
|
||||
PrefixChannel = "ch-"
|
||||
)
|
||||
|
||||
// UUIDToBase36 encodes 16 UUID bytes as a 25-char base36 string (0-9a-z).
|
||||
func UUIDToBase36(b [16]byte) string {
|
||||
n := new(big.Int).SetBytes(b[:])
|
||||
buf := make([]byte, base36IDLen)
|
||||
mod := new(big.Int)
|
||||
for i := base36IDLen - 1; i >= 0; i-- {
|
||||
n.DivMod(n, base36Base, mod)
|
||||
buf[i] = base36Alphabet[mod.Int64()]
|
||||
}
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
// base36ToUUID decodes a 25-char base36 string back to 16 UUID bytes.
|
||||
func base36ToUUID(s string) ([16]byte, error) {
|
||||
if len(s) != base36IDLen {
|
||||
return [16]byte{}, fmt.Errorf("expected %d-char base36 ID, got %d", base36IDLen, len(s))
|
||||
}
|
||||
n := new(big.Int)
|
||||
for _, c := range s {
|
||||
idx := strings.IndexRune(base36Alphabet, c)
|
||||
if idx < 0 {
|
||||
return [16]byte{}, fmt.Errorf("invalid base36 character: %c", c)
|
||||
}
|
||||
n.Mul(n, base36Base)
|
||||
n.Add(n, big.NewInt(int64(idx)))
|
||||
}
|
||||
b := n.Bytes()
|
||||
var out [16]byte
|
||||
// big.Int.Bytes() strips leading zeros; right-align into 16-byte array.
|
||||
copy(out[16-len(b):], b)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func formatUUID(prefix string, id pgtype.UUID) string {
|
||||
return prefix + UUIDToBase36(id.Bytes)
|
||||
}
|
||||
|
||||
func FormatSandboxID(id pgtype.UUID) string { return formatUUID(PrefixSandbox, id) }
|
||||
func FormatUserID(id pgtype.UUID) string { return formatUUID(PrefixUser, id) }
|
||||
func FormatTeamID(id pgtype.UUID) string { return formatUUID(PrefixTeam, id) }
|
||||
func FormatAPIKeyID(id pgtype.UUID) string { return formatUUID(PrefixAPIKey, id) }
|
||||
func FormatHostID(id pgtype.UUID) string { return formatUUID(PrefixHost, id) }
|
||||
func FormatHostTokenID(id pgtype.UUID) string { return formatUUID(PrefixHostToken, id) }
|
||||
func FormatRefreshTokenID(id pgtype.UUID) string { return formatUUID(PrefixRefreshToken, id) }
|
||||
func FormatAuditLogID(id pgtype.UUID) string { return formatUUID(PrefixAuditLog, id) }
|
||||
func FormatBuildID(id pgtype.UUID) string { return formatUUID(PrefixBuild, id) }
|
||||
func FormatChannelID(id pgtype.UUID) string { return formatUUID(PrefixChannel, id) }
|
||||
|
||||
// --- Parsing (prefixed string from API/RPC input → pgtype.UUID) ---
|
||||
|
||||
func parseUUID(prefix, s string) (pgtype.UUID, error) {
|
||||
if !strings.HasPrefix(s, prefix) {
|
||||
return pgtype.UUID{}, fmt.Errorf("invalid ID: expected %q prefix, got %q", prefix, s)
|
||||
}
|
||||
b, err := base36ToUUID(strings.TrimPrefix(s, prefix))
|
||||
if err != nil {
|
||||
return pgtype.UUID{}, fmt.Errorf("invalid ID %q: %w", s, err)
|
||||
}
|
||||
return pgtype.UUID{Bytes: b, Valid: true}, nil
|
||||
}
|
||||
|
||||
func ParseSandboxID(s string) (pgtype.UUID, error) { return parseUUID(PrefixSandbox, s) }
|
||||
func ParseUserID(s string) (pgtype.UUID, error) { return parseUUID(PrefixUser, s) }
|
||||
func ParseTeamID(s string) (pgtype.UUID, error) { return parseUUID(PrefixTeam, s) }
|
||||
func ParseAPIKeyID(s string) (pgtype.UUID, error) { return parseUUID(PrefixAPIKey, s) }
|
||||
func ParseHostID(s string) (pgtype.UUID, error) { return parseUUID(PrefixHost, s) }
|
||||
func ParseHostTokenID(s string) (pgtype.UUID, error) { return parseUUID(PrefixHostToken, s) }
|
||||
func ParseAuditLogID(s string) (pgtype.UUID, error) { return parseUUID(PrefixAuditLog, s) }
|
||||
func ParseBuildID(s string) (pgtype.UUID, error) { return parseUUID(PrefixBuild, s) }
|
||||
func ParseChannelID(s string) (pgtype.UUID, error) { return parseUUID(PrefixChannel, s) }
|
||||
|
||||
// --- Well-known IDs ---
|
||||
|
||||
// PlatformTeamID is the all-zeros UUID reserved for platform-owned resources
|
||||
// (e.g. base templates, shared infrastructure).
|
||||
var PlatformTeamID = pgtype.UUID{Bytes: [16]byte{}, Valid: true}
|
||||
|
||||
// MinimalTemplateID is the all-zeros UUID sentinel for the built-in "minimal"
|
||||
// template. When both team_id and template_id are zero, the host agent uses
|
||||
// the minimal rootfs at WRENN_DIR/images/minimal/.
|
||||
var MinimalTemplateID = pgtype.UUID{Bytes: [16]byte{}, Valid: true}
|
||||
|
||||
// UUIDString converts a pgtype.UUID to a standard hyphenated UUID string
|
||||
// (e.g., "6ba7b810-9dad-11d1-80b4-00c04fd430c8"). Used for RPC wire format.
|
||||
func UUIDString(id pgtype.UUID) string {
|
||||
return uuid.UUID(id.Bytes).String()
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func hex8() string {
|
||||
b := make([]byte, 4)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand failed: %v", err))
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func hexToken(nBytes int) string {
|
||||
b := make([]byte, nBytes)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand failed: %v", err))
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
@ -1,118 +0,0 @@
|
||||
package id
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
func TestBase36RoundTrip(t *testing.T) {
|
||||
for i := 0; i < 1000; i++ {
|
||||
orig := uuid.New()
|
||||
encoded := UUIDToBase36(orig)
|
||||
|
||||
if len(encoded) != base36IDLen {
|
||||
t.Fatalf("expected %d chars, got %d: %s", base36IDLen, len(encoded), encoded)
|
||||
}
|
||||
|
||||
decoded, err := base36ToUUID(encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("decode failed: %v", err)
|
||||
}
|
||||
|
||||
if decoded != orig {
|
||||
t.Fatalf("round-trip failed: %v → %s → %v", orig, encoded, decoded)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBase36ZeroUUID(t *testing.T) {
|
||||
var zero [16]byte
|
||||
encoded := UUIDToBase36(zero)
|
||||
if encoded != "0000000000000000000000000" {
|
||||
t.Fatalf("zero UUID should encode to all zeros, got %s", encoded)
|
||||
}
|
||||
decoded, err := base36ToUUID(encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("decode failed: %v", err)
|
||||
}
|
||||
if decoded != zero {
|
||||
t.Fatalf("round-trip failed for zero UUID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatParseRoundTrip(t *testing.T) {
|
||||
id := NewSandboxID()
|
||||
formatted := FormatSandboxID(id)
|
||||
|
||||
if formatted[:3] != "cl-" {
|
||||
t.Fatalf("expected cl- prefix, got %s", formatted)
|
||||
}
|
||||
if len(formatted) != 3+base36IDLen {
|
||||
t.Fatalf("expected %d chars total, got %d: %s", 3+base36IDLen, len(formatted), formatted)
|
||||
}
|
||||
|
||||
parsed, err := ParseSandboxID(formatted)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
if parsed != id {
|
||||
t.Fatalf("round-trip failed: %v → %s → %v", id, formatted, parsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBase36InvalidInput(t *testing.T) {
|
||||
// Wrong length.
|
||||
if _, err := base36ToUUID("abc"); err == nil {
|
||||
t.Fatal("expected error for short input")
|
||||
}
|
||||
// Invalid character.
|
||||
if _, err := base36ToUUID("000000000000000000000000!"); err == nil {
|
||||
t.Fatal("expected error for invalid character")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlatformTeamIDFormats(t *testing.T) {
|
||||
formatted := FormatTeamID(PlatformTeamID)
|
||||
parsed, err := ParseTeamID(formatted)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
if parsed != PlatformTeamID {
|
||||
t.Fatalf("platform team ID round-trip failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxUUID(t *testing.T) {
|
||||
max := [16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
|
||||
encoded := UUIDToBase36(max)
|
||||
if len(encoded) != base36IDLen {
|
||||
t.Fatalf("max UUID encoding wrong length: %d", len(encoded))
|
||||
}
|
||||
decoded, err := base36ToUUID(encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("decode failed: %v", err)
|
||||
}
|
||||
if decoded != max {
|
||||
t.Fatalf("round-trip failed for max UUID")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFormatSandboxID(b *testing.B) {
|
||||
id := pgtype.UUID{Bytes: uuid.New(), Valid: true}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
FormatSandboxID(id)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseSandboxID(b *testing.B) {
|
||||
id := pgtype.UUID{Bytes: uuid.New(), Valid: true}
|
||||
s := FormatSandboxID(id)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = ParseSandboxID(s)
|
||||
}
|
||||
}
|
||||
@ -1,11 +1,15 @@
|
||||
package layout
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// IsMinimal reports whether the given team and template IDs represent the
|
||||
@ -47,6 +51,75 @@ func KernelPath(wrennDir string) string {
|
||||
return filepath.Join(wrennDir, "kernels", "vmlinux")
|
||||
}
|
||||
|
||||
// KernelPathVersioned returns the path to a specific kernel version.
|
||||
func KernelPathVersioned(wrennDir, version string) string {
|
||||
return filepath.Join(wrennDir, "kernels", "vmlinux-"+version)
|
||||
}
|
||||
|
||||
// LatestKernel scans the kernels directory for files matching vmlinux-{semver}
|
||||
// and returns the path and version of the latest one (by semver sort).
|
||||
func LatestKernel(wrennDir string) (path, version string, err error) {
|
||||
dir := filepath.Join(wrennDir, "kernels")
|
||||
return latestVersionedFile(dir, "vmlinux-")
|
||||
}
|
||||
|
||||
// latestVersionedFile scans dir for files with the given prefix, extracts the
|
||||
// version suffix, sorts by semver, and returns the path and version of the latest.
|
||||
func latestVersionedFile(dir, prefix string) (path, version string, err error) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("read directory %s: %w", dir, err)
|
||||
}
|
||||
|
||||
var versions []string
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := e.Name()
|
||||
if v, ok := strings.CutPrefix(name, prefix); ok && v != "" {
|
||||
versions = append(versions, v)
|
||||
}
|
||||
}
|
||||
|
||||
if len(versions) == 0 {
|
||||
return "", "", fmt.Errorf("no %s* files found in %s", prefix, dir)
|
||||
}
|
||||
|
||||
sort.Slice(versions, func(i, j int) bool {
|
||||
return compareSemver(versions[i], versions[j]) < 0
|
||||
})
|
||||
|
||||
latest := versions[len(versions)-1]
|
||||
return filepath.Join(dir, prefix+latest), latest, nil
|
||||
}
|
||||
|
||||
// compareSemver compares two dotted-numeric version strings.
|
||||
// Returns -1 if a < b, 0 if equal, 1 if a > b.
|
||||
func compareSemver(a, b string) int {
|
||||
aParts := strings.Split(a, ".")
|
||||
bParts := strings.Split(b, ".")
|
||||
|
||||
maxLen := max(len(aParts), len(bParts))
|
||||
|
||||
for i := 0; i < maxLen; i++ {
|
||||
var av, bv int
|
||||
if i < len(aParts) {
|
||||
_, _ = fmt.Sscanf(aParts[i], "%d", &av)
|
||||
}
|
||||
if i < len(bParts) {
|
||||
_, _ = fmt.Sscanf(bParts[i], "%d", &bv)
|
||||
}
|
||||
if av < bv {
|
||||
return -1
|
||||
}
|
||||
if av > bv {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// ImagesRoot returns the root images directory.
|
||||
func ImagesRoot(wrennDir string) string {
|
||||
return filepath.Join(wrennDir, "images")
|
||||
|
||||
@ -6,7 +6,7 @@ import (
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
func TestIsMinimal(t *testing.T) {
|
||||
|
||||
@ -1,125 +0,0 @@
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
// HostClientPool maintains a cache of Connect RPC clients keyed by host ID.
|
||||
// Clients are created lazily on first access and evicted when a host is removed
|
||||
// or goes unreachable. The pool is safe for concurrent use.
|
||||
type HostClientPool struct {
|
||||
mu sync.RWMutex
|
||||
clients map[string]hostagentv1connect.HostAgentServiceClient
|
||||
httpClient *http.Client
|
||||
scheme string // "http://" or "https://"
|
||||
}
|
||||
|
||||
// NewHostClientPool creates a pool that connects to agents over plain HTTP.
|
||||
// Use NewHostClientPoolTLS when mTLS is required.
|
||||
func NewHostClientPool() *HostClientPool {
|
||||
return &HostClientPool{
|
||||
clients: make(map[string]hostagentv1connect.HostAgentServiceClient),
|
||||
httpClient: &http.Client{Timeout: 10 * time.Minute},
|
||||
scheme: "http://",
|
||||
}
|
||||
}
|
||||
|
||||
// NewHostClientPoolTLS creates a pool that connects to agents over mTLS.
|
||||
// tlsCfg should already carry the CP client cert and CA trust anchor
|
||||
// (use auth.CPClientTLSConfig to construct it).
|
||||
func NewHostClientPoolTLS(tlsCfg *tls.Config) *HostClientPool {
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: tlsCfg,
|
||||
}
|
||||
return &HostClientPool{
|
||||
clients: make(map[string]hostagentv1connect.HostAgentServiceClient),
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Minute,
|
||||
Transport: transport,
|
||||
},
|
||||
scheme: "https://",
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns a Connect RPC client for the given host, creating one if necessary.
|
||||
// address is the host agent address (ip:port or full URL). The scheme is added if absent.
|
||||
func (p *HostClientPool) Get(hostID, address string) hostagentv1connect.HostAgentServiceClient {
|
||||
p.mu.RLock()
|
||||
c, ok := p.clients[hostID]
|
||||
p.mu.RUnlock()
|
||||
if ok {
|
||||
return c
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
// Double-check after acquiring write lock.
|
||||
if c, ok = p.clients[hostID]; ok {
|
||||
return c
|
||||
}
|
||||
c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, p.ensureScheme(address))
|
||||
p.clients[hostID] = c
|
||||
return c
|
||||
}
|
||||
|
||||
// GetForHost is a convenience wrapper that extracts the address from a db.Host
|
||||
// and returns an error if the host has no address recorded yet.
|
||||
func (p *HostClientPool) GetForHost(h db.Host) (hostagentv1connect.HostAgentServiceClient, error) {
|
||||
if h.Address == "" {
|
||||
return nil, fmt.Errorf("host %s has no address", id.FormatHostID(h.ID))
|
||||
}
|
||||
return p.Get(id.FormatHostID(h.ID), h.Address), nil
|
||||
}
|
||||
|
||||
// Evict removes the cached client for the given host, forcing a new client to be
|
||||
// created on the next call to Get. Call this when a host's address changes or when
|
||||
// a host is deleted.
|
||||
func (p *HostClientPool) Evict(hostID string) {
|
||||
p.mu.Lock()
|
||||
delete(p.clients, hostID)
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
// ensureScheme prepends the pool's configured scheme if the address has none.
|
||||
func (p *HostClientPool) ensureScheme(addr string) string {
|
||||
if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
|
||||
return addr
|
||||
}
|
||||
return p.scheme + addr
|
||||
}
|
||||
|
||||
// Transport returns the http.RoundTripper used by this pool. Use this when you
|
||||
// need to make raw HTTP requests to agent addresses with the same TLS settings
|
||||
// as the pool's Connect RPC clients (e.g., the sandbox reverse proxy).
|
||||
func (p *HostClientPool) Transport() http.RoundTripper {
|
||||
if p.httpClient.Transport != nil {
|
||||
return p.httpClient.Transport
|
||||
}
|
||||
return http.DefaultTransport
|
||||
}
|
||||
|
||||
// ResolveAddr prepends the pool's configured scheme to addr if it has none.
|
||||
// Use this when constructing URLs that must use the same transport as the pool
|
||||
// (e.g., the sandbox proxy handler). Calling Get/GetForHost internally does
|
||||
// the same thing, but ResolveAddr exposes it for callers that only need the URL.
|
||||
func (p *HostClientPool) ResolveAddr(addr string) string {
|
||||
return p.ensureScheme(addr)
|
||||
}
|
||||
|
||||
// EnsureScheme adds "http://" if the address has no scheme.
|
||||
// Deprecated: use pool.ResolveAddr which respects the pool's TLS setting.
|
||||
func EnsureScheme(addr string) string {
|
||||
if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
|
||||
return addr
|
||||
}
|
||||
return "http://" + addr
|
||||
}
|
||||
@ -1 +0,0 @@
|
||||
package lifecycle
|
||||
@ -30,4 +30,5 @@ type Sandbox struct {
|
||||
RootfsPath string
|
||||
CreatedAt time.Time
|
||||
LastActiveAt time.Time
|
||||
Metadata map[string]string
|
||||
}
|
||||
|
||||
@ -24,7 +24,7 @@ func (a *SlotAllocator) Allocate() (int, error) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
for i := 1; i <= 65534; i++ {
|
||||
for i := 1; i <= 32767; i++ {
|
||||
if !a.inUse[i] {
|
||||
a.inUse[i] = true
|
||||
return i, nil
|
||||
|
||||
@ -131,26 +131,31 @@ type Slot struct {
|
||||
}
|
||||
|
||||
// NewSlot computes the addressing for the given slot index (1-based).
|
||||
// Index must be in [1, 32767] so that veth offset (index*2) fits in 16 bits.
|
||||
func NewSlot(index int) *Slot {
|
||||
if index < 1 || index > 32767 {
|
||||
panic(fmt.Sprintf("slot index %d out of range [1, 32767]", index))
|
||||
}
|
||||
|
||||
hostBaseIP := net.ParseIP(hostBase).To4()
|
||||
vrtBaseIP := net.ParseIP(vrtBase).To4()
|
||||
|
||||
hostIP := make(net.IP, 4)
|
||||
copy(hostIP, hostBaseIP)
|
||||
hostIP[2] += byte(index >> 8)
|
||||
hostIP[3] += byte(index & 0xFF)
|
||||
hostIP[2] = hostBaseIP[2] + byte(index>>8)
|
||||
hostIP[3] = hostBaseIP[3] + byte(index&0xFF)
|
||||
|
||||
vethOffset := index * vrtAddressesPerSlot
|
||||
vethIP := make(net.IP, 4)
|
||||
copy(vethIP, vrtBaseIP)
|
||||
vethIP[2] += byte(vethOffset >> 8)
|
||||
vethIP[3] += byte(vethOffset & 0xFF)
|
||||
vethIP[2] = vrtBaseIP[2] + byte(vethOffset>>8)
|
||||
vethIP[3] = vrtBaseIP[3] + byte(vethOffset&0xFF)
|
||||
|
||||
vpeerOffset := vethOffset + 1
|
||||
vpeerIP := make(net.IP, 4)
|
||||
copy(vpeerIP, vrtBaseIP)
|
||||
vpeerIP[2] += byte(vpeerOffset >> 8)
|
||||
vpeerIP[3] += byte(vpeerOffset & 0xFF)
|
||||
vpeerIP[2] = vrtBaseIP[2] + byte(vpeerOffset>>8)
|
||||
vpeerIP[3] = vrtBaseIP[3] + byte(vpeerOffset&0xFF)
|
||||
|
||||
return &Slot{
|
||||
Index: index,
|
||||
|
||||
@ -7,10 +7,11 @@ import (
|
||||
)
|
||||
|
||||
// ExecContext holds mutable state that persists across recipe steps.
|
||||
// It is initialized empty and updated by ENV and WORKDIR steps.
|
||||
// It is initialized empty and updated by ENV, WORKDIR, and USER steps.
|
||||
type ExecContext struct {
|
||||
WorkDir string
|
||||
EnvVars map[string]string
|
||||
User string // Current unix user for command execution.
|
||||
}
|
||||
|
||||
// This regex matches:
|
||||
@ -25,7 +26,20 @@ var envRegex = regexp.MustCompile(`\$\$|\$\{([a-zA-Z0-9_]*)\}|\$([a-zA-Z0-9_]+)`
|
||||
// If WORKDIR and/or ENV are set, they are prepended as a shell preamble:
|
||||
//
|
||||
// cd '/the/dir' && KEY='val' /bin/sh -c 'original command'
|
||||
//
|
||||
// If USER is set to a non-root user, the entire command is wrapped with su:
|
||||
//
|
||||
// su <user> -s /bin/sh -c '<preamble + command>'
|
||||
func (c *ExecContext) WrappedCommand(cmd string) string {
|
||||
inner := c.innerCommand(cmd)
|
||||
if c.User != "" && c.User != "root" {
|
||||
return "su " + shellescape(c.User) + " -s /bin/sh -c " + shellescape(inner)
|
||||
}
|
||||
return inner
|
||||
}
|
||||
|
||||
// innerCommand builds the command with workdir/env preamble but without user wrapping.
|
||||
func (c *ExecContext) innerCommand(cmd string) string {
|
||||
prefix := c.shellPrefix()
|
||||
if prefix == "" {
|
||||
return cmd
|
||||
@ -42,7 +56,11 @@ func (c *ExecContext) WrappedCommand(cmd string) string {
|
||||
// simultaneously before a healthcheck is evaluated.
|
||||
func (c *ExecContext) StartCommand(cmd string) string {
|
||||
prefix := c.shellPrefix()
|
||||
return prefix + "nohup /bin/sh -c " + shellescape(cmd) + " >/dev/null 2>&1 &"
|
||||
inner := prefix + "nohup /bin/sh -c " + shellescape(cmd) + " >/dev/null 2>&1 &"
|
||||
if c.User != "" && c.User != "root" {
|
||||
return "su " + shellescape(c.User) + " -s /bin/sh -c " + shellescape(inner)
|
||||
}
|
||||
return inner
|
||||
}
|
||||
|
||||
// shellPrefix builds the "cd ... && KEY=val " preamble for a shell command.
|
||||
@ -97,8 +115,11 @@ func expandEnv(s string, vars map[string]string) string {
|
||||
})
|
||||
}
|
||||
|
||||
// shellescape wraps s in single quotes, escaping any embedded single quotes.
|
||||
// Shellescape wraps s in single quotes, escaping any embedded single quotes.
|
||||
// This is POSIX-safe for paths, env values, and shell commands.
|
||||
func shellescape(s string) string {
|
||||
func Shellescape(s string) string {
|
||||
return "'" + strings.ReplaceAll(s, "'", `'\''`) + "'"
|
||||
}
|
||||
|
||||
// shellescape is the package-internal alias for Shellescape.
|
||||
func shellescape(s string) string { return Shellescape(s) }
|
||||
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -16,6 +17,10 @@ import (
|
||||
// explicit --timeout flag.
|
||||
const DefaultStepTimeout = 30 * time.Second
|
||||
|
||||
// BuildFilesDir is the directory inside the sandbox where uploaded build
|
||||
// archives are extracted. COPY instructions reference paths relative to this.
|
||||
const BuildFilesDir = "/tmp/build-files"
|
||||
|
||||
// BuildLogEntry is the per-step record stored in template_builds.logs (JSONB).
|
||||
type BuildLogEntry struct {
|
||||
Step int `json:"step"`
|
||||
@ -32,13 +37,18 @@ type BuildLogEntry struct {
|
||||
// the method on the hostagent Connect RPC client.
|
||||
type ExecFunc func(ctx context.Context, req *connect.Request[pb.ExecRequest]) (*connect.Response[pb.ExecResponse], error)
|
||||
|
||||
// ProgressFunc is called after each step with the current step counter and
|
||||
// accumulated log entries. Used for per-step DB progress updates.
|
||||
type ProgressFunc func(step int, entries []BuildLogEntry)
|
||||
|
||||
// Execute runs steps sequentially against sandboxID using execFn.
|
||||
//
|
||||
// - phase labels the log entries (e.g., "pre-build", "recipe", "post-build").
|
||||
// - startStep is the 1-based offset so entries are globally numbered across phases.
|
||||
// - defaultTimeout applies to RUN steps with no per-step --timeout; 0 → 10 minutes.
|
||||
// - bctx is mutated in place as ENV/WORKDIR steps execute, and carries forward
|
||||
// - bctx is mutated in place as ENV/WORKDIR/USER steps execute, and carries forward
|
||||
// into subsequent phases when the caller passes the same pointer.
|
||||
// - onProgress is called after each step for live progress updates (may be nil).
|
||||
//
|
||||
// Returns all log entries appended during this call, the next step counter
|
||||
// value, and whether all steps succeeded. On false the last entry contains
|
||||
@ -53,6 +63,7 @@ func Execute(
|
||||
defaultTimeout time.Duration,
|
||||
bctx *ExecContext,
|
||||
execFn ExecFunc,
|
||||
onProgress ProgressFunc,
|
||||
) (entries []BuildLogEntry, nextStep int, ok bool) {
|
||||
if defaultTimeout <= 0 {
|
||||
defaultTimeout = 10 * time.Minute
|
||||
@ -72,19 +83,30 @@ func Execute(
|
||||
entries = append(entries, BuildLogEntry{Step: step, Phase: phase, Cmd: st.Raw, Ok: true})
|
||||
|
||||
case KindWORKDIR:
|
||||
// Create the directory if it doesn't exist.
|
||||
mkdirEntry := execRawShell(ctx, st.Raw, sandboxID, phase, step, 10*time.Second, execFn,
|
||||
"mkdir -p "+shellescape(st.Path))
|
||||
if !mkdirEntry.Ok {
|
||||
entries = append(entries, mkdirEntry)
|
||||
return entries, step, false
|
||||
}
|
||||
bctx.WorkDir = st.Path
|
||||
entries = append(entries, BuildLogEntry{Step: step, Phase: phase, Cmd: st.Raw, Ok: true})
|
||||
mkdirEntry.Ok = true
|
||||
entries = append(entries, mkdirEntry)
|
||||
|
||||
case KindUSER, KindCOPY:
|
||||
verb := strings.ToUpper(strings.Fields(st.Raw)[0])
|
||||
entries = append(entries, BuildLogEntry{
|
||||
Step: step,
|
||||
Phase: phase,
|
||||
Cmd: st.Raw,
|
||||
Stderr: verb + " is not yet supported",
|
||||
Ok: false,
|
||||
})
|
||||
return entries, step, false
|
||||
case KindUSER:
|
||||
entry, succeeded := execUser(ctx, st, sandboxID, phase, step, bctx, execFn)
|
||||
entries = append(entries, entry)
|
||||
if !succeeded {
|
||||
return entries, step, false
|
||||
}
|
||||
|
||||
case KindCOPY:
|
||||
entry, succeeded := execCopy(ctx, st, sandboxID, phase, step, bctx, execFn)
|
||||
entries = append(entries, entry)
|
||||
if !succeeded {
|
||||
return entries, step, false
|
||||
}
|
||||
|
||||
case KindSTART:
|
||||
entry, succeeded := execStart(ctx, st, sandboxID, phase, step, bctx, execFn)
|
||||
@ -104,6 +126,10 @@ func Execute(
|
||||
return entries, step, false
|
||||
}
|
||||
}
|
||||
|
||||
if onProgress != nil {
|
||||
onProgress(step, entries)
|
||||
}
|
||||
}
|
||||
return entries, step, true
|
||||
}
|
||||
@ -145,6 +171,123 @@ func execRun(
|
||||
return entry, entry.Ok
|
||||
}
|
||||
|
||||
// execUser creates a unix user (if not exists), grants passwordless sudo,
|
||||
// and updates bctx.User for subsequent steps.
|
||||
func execUser(
|
||||
ctx context.Context,
|
||||
st Step,
|
||||
sandboxID, phase string,
|
||||
step int,
|
||||
bctx *ExecContext,
|
||||
execFn ExecFunc,
|
||||
) (BuildLogEntry, bool) {
|
||||
username := st.Key
|
||||
// Create user if not exists, with home directory and bash shell.
|
||||
// Grant passwordless sudo access (E2B convention).
|
||||
// Uses printf %s to avoid shell injection in the sudoers line.
|
||||
script := fmt.Sprintf(
|
||||
"id %s >/dev/null 2>&1 || (adduser --disabled-password --gecos '' --shell /bin/bash %s && printf '%%s ALL=(ALL) NOPASSWD:ALL\\n' %s >> /etc/sudoers)",
|
||||
shellescape(username), shellescape(username), shellescape(username),
|
||||
)
|
||||
|
||||
entry := execRawShell(ctx, st.Raw, sandboxID, phase, step, 30*time.Second, execFn, script)
|
||||
if entry.Ok {
|
||||
bctx.User = username
|
||||
// Update HOME so ~ expands correctly in subsequent RUN/WORKDIR steps.
|
||||
if bctx.EnvVars == nil {
|
||||
bctx.EnvVars = make(map[string]string)
|
||||
}
|
||||
if username == "root" {
|
||||
bctx.EnvVars["HOME"] = "/root"
|
||||
} else {
|
||||
bctx.EnvVars["HOME"] = "/home/" + username
|
||||
}
|
||||
}
|
||||
return entry, entry.Ok
|
||||
}
|
||||
|
||||
// execCopy copies a file or directory from the build archive (extracted at
|
||||
// BuildFilesDir) to the destination path inside the sandbox. Ownership is
|
||||
// set to the current user from bctx.
|
||||
func execCopy(
|
||||
ctx context.Context,
|
||||
st Step,
|
||||
sandboxID, phase string,
|
||||
step int,
|
||||
bctx *ExecContext,
|
||||
execFn ExecFunc,
|
||||
) (BuildLogEntry, bool) {
|
||||
// Validate all source paths: must be relative and not escape the archive directory.
|
||||
var srcPaths []string
|
||||
for _, s := range st.Srcs {
|
||||
cleaned := path.Clean(s)
|
||||
if strings.HasPrefix(cleaned, "..") || strings.HasPrefix(cleaned, "/") {
|
||||
return BuildLogEntry{
|
||||
Step: step,
|
||||
Phase: phase,
|
||||
Cmd: st.Raw,
|
||||
Stderr: fmt.Sprintf("COPY source must be a relative path within the archive: %q", s),
|
||||
}, false
|
||||
}
|
||||
srcPaths = append(srcPaths, shellescape(BuildFilesDir+"/"+cleaned))
|
||||
}
|
||||
|
||||
dst := st.Dst
|
||||
// Resolve relative destination against the current WORKDIR.
|
||||
if dst != "" && dst[0] != '/' && bctx.WorkDir != "" {
|
||||
dst = bctx.WorkDir + "/" + dst
|
||||
}
|
||||
owner := "root"
|
||||
if bctx.User != "" {
|
||||
owner = bctx.User
|
||||
}
|
||||
script := fmt.Sprintf(
|
||||
"cp -r %s %s && chown -R %s:%s %s",
|
||||
strings.Join(srcPaths, " "), shellescape(dst), shellescape(owner), shellescape(owner), shellescape(dst),
|
||||
)
|
||||
|
||||
entry := execRawShell(ctx, st.Raw, sandboxID, phase, step, 60*time.Second, execFn, script)
|
||||
return entry, entry.Ok
|
||||
}
|
||||
|
||||
// execRawShell runs a shell command directly (as root) without ExecContext
|
||||
// wrapping. Used for internal operations like user creation and file copy.
|
||||
func execRawShell(
|
||||
ctx context.Context,
|
||||
raw, sandboxID, phase string,
|
||||
step int,
|
||||
timeout time.Duration,
|
||||
execFn ExecFunc,
|
||||
shellCmd string,
|
||||
) BuildLogEntry {
|
||||
execCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
resp, err := execFn(execCtx, connect.NewRequest(&pb.ExecRequest{
|
||||
SandboxId: sandboxID,
|
||||
Cmd: "/bin/sh",
|
||||
Args: []string{"-c", shellCmd},
|
||||
TimeoutSec: int32(timeout.Seconds()),
|
||||
}))
|
||||
|
||||
entry := BuildLogEntry{
|
||||
Step: step,
|
||||
Phase: phase,
|
||||
Cmd: raw,
|
||||
Elapsed: time.Since(start).Milliseconds(),
|
||||
}
|
||||
if err != nil {
|
||||
entry.Stderr = fmt.Sprintf("exec error: %v", err)
|
||||
return entry
|
||||
}
|
||||
entry.Stdout = string(resp.Msg.Stdout)
|
||||
entry.Stderr = string(resp.Msg.Stderr)
|
||||
entry.Exit = resp.Msg.ExitCode
|
||||
entry.Ok = resp.Msg.ExitCode == 0
|
||||
return entry
|
||||
}
|
||||
|
||||
func execStart(
|
||||
ctx context.Context,
|
||||
st Step,
|
||||
|
||||
@ -24,9 +24,11 @@ type Step struct {
|
||||
Raw string // original string, preserved for logging
|
||||
Shell string // KindRUN, KindSTART: the shell command text
|
||||
Timeout time.Duration // KindRUN: 0 means use caller's default
|
||||
Key string // KindENV: variable name
|
||||
Key string // KindENV: variable name; KindUSER: username
|
||||
Value string // KindENV: variable value
|
||||
Path string // KindWORKDIR: directory path
|
||||
Srcs []string // KindCOPY: source paths (relative to build archive)
|
||||
Dst string // KindCOPY: destination path inside sandbox
|
||||
}
|
||||
|
||||
// ParseStep parses a single recipe instruction string into a Step.
|
||||
@ -61,9 +63,9 @@ func ParseStep(s string) (Step, error) {
|
||||
case "WORKDIR":
|
||||
return parseWORKDIR(s, rest)
|
||||
case "USER":
|
||||
return Step{Kind: KindUSER, Raw: s}, nil
|
||||
return parseUSER(s, rest)
|
||||
case "COPY":
|
||||
return Step{Kind: KindCOPY, Raw: s}, nil
|
||||
return parseCOPY(s, rest)
|
||||
default:
|
||||
return Step{}, fmt.Errorf("unknown instruction %q (expected RUN, START, ENV, WORKDIR, USER, or COPY)", keyword)
|
||||
}
|
||||
@ -127,3 +129,33 @@ func parseWORKDIR(raw, path string) (Step, error) {
|
||||
}
|
||||
return Step{Kind: KindWORKDIR, Raw: raw, Path: path}, nil
|
||||
}
|
||||
|
||||
func parseUSER(raw, username string) (Step, error) {
|
||||
if username == "" {
|
||||
return Step{}, fmt.Errorf("USER requires a username: %q", raw)
|
||||
}
|
||||
// Validate: alphanumeric, hyphens, underscores only; must start with a letter or underscore.
|
||||
for i, c := range username {
|
||||
if i == 0 && !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_') {
|
||||
return Step{}, fmt.Errorf("USER username must start with a letter or underscore: %q", raw)
|
||||
}
|
||||
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' || c == '-') {
|
||||
return Step{}, fmt.Errorf("USER username contains invalid character %q: %q", string(c), raw)
|
||||
}
|
||||
}
|
||||
return Step{Kind: KindUSER, Raw: raw, Key: username}, nil
|
||||
}
|
||||
|
||||
func parseCOPY(raw, rest string) (Step, error) {
|
||||
if rest == "" {
|
||||
return Step{}, fmt.Errorf("COPY requires <src>... <dst>: %q", raw)
|
||||
}
|
||||
parts := strings.Fields(rest)
|
||||
if len(parts) < 2 {
|
||||
return Step{}, fmt.Errorf("COPY requires <src>... <dst>: %q", raw)
|
||||
}
|
||||
// Last argument is the destination, everything before is sources.
|
||||
dst := parts[len(parts)-1]
|
||||
srcs := parts[:len(parts)-1]
|
||||
return Step{Kind: KindCOPY, Raw: raw, Srcs: srcs, Dst: dst}, nil
|
||||
}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package recipe
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@ -111,16 +112,42 @@ func TestParseStep(t *testing.T) {
|
||||
input: "WORKDIR",
|
||||
wantErr: true,
|
||||
},
|
||||
// USER and COPY stubs
|
||||
// USER
|
||||
{
|
||||
name: "USER stub",
|
||||
name: "USER basic",
|
||||
input: "USER www-data",
|
||||
want: Step{Kind: KindUSER, Raw: "USER www-data"},
|
||||
want: Step{Kind: KindUSER, Raw: "USER www-data", Key: "www-data"},
|
||||
},
|
||||
{
|
||||
name: "COPY stub",
|
||||
name: "USER empty",
|
||||
input: "USER",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "USER invalid chars",
|
||||
input: "USER bad user",
|
||||
wantErr: true,
|
||||
},
|
||||
// COPY
|
||||
{
|
||||
name: "COPY basic",
|
||||
input: "COPY config.yaml /etc/app/config.yaml",
|
||||
want: Step{Kind: KindCOPY, Raw: "COPY config.yaml /etc/app/config.yaml"},
|
||||
want: Step{Kind: KindCOPY, Raw: "COPY config.yaml /etc/app/config.yaml", Srcs: []string{"config.yaml"}, Dst: "/etc/app/config.yaml"},
|
||||
},
|
||||
{
|
||||
name: "COPY multiple sources",
|
||||
input: "COPY a.txt b.txt /dest/",
|
||||
want: Step{Kind: KindCOPY, Raw: "COPY a.txt b.txt /dest/", Srcs: []string{"a.txt", "b.txt"}, Dst: "/dest/"},
|
||||
},
|
||||
{
|
||||
name: "COPY missing dst",
|
||||
input: "COPY config.yaml",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "COPY empty",
|
||||
input: "COPY",
|
||||
wantErr: true,
|
||||
},
|
||||
// Unknown keyword
|
||||
{
|
||||
@ -148,7 +175,7 @@ func TestParseStep(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("ParseStep(%q) unexpected error: %v", tc.input, err)
|
||||
}
|
||||
if got != tc.want {
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Errorf("ParseStep(%q)\n got %+v\n want %+v", tc.input, got, tc.want)
|
||||
}
|
||||
})
|
||||
|
||||
30
internal/sandbox/fcversion.go
Normal file
30
internal/sandbox/fcversion.go
Normal file
@ -0,0 +1,30 @@
|
||||
package sandbox
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// DetectFirecrackerVersion runs the firecracker binary with --version and
|
||||
// parses the semver from the output (e.g. "Firecracker v1.14.1" → "1.14.1").
|
||||
func DetectFirecrackerVersion(binaryPath string) (string, error) {
|
||||
out, err := exec.Command(binaryPath, "--version").Output()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("run %s --version: %w", binaryPath, err)
|
||||
}
|
||||
|
||||
// Output is typically "Firecracker v1.14.1\n" or similar.
|
||||
line := strings.TrimSpace(string(out))
|
||||
for _, field := range strings.Fields(line) {
|
||||
v := strings.TrimPrefix(field, "v")
|
||||
if v != field || strings.Contains(field, ".") {
|
||||
// Either had a "v" prefix or contains a dot — likely the version.
|
||||
if strings.Count(v, ".") >= 1 {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("could not parse version from firecracker output: %q", line)
|
||||
}
|
||||
@ -6,9 +6,11 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/layout"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// DefaultDiskSizeMB is the standard disk size for base images. Images smaller
|
||||
@ -66,6 +68,73 @@ func EnsureImageSizes(wrennDir string, targetMB int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseSizeToMB parses a human-readable size string into megabytes.
|
||||
// Supported suffixes: G, Gi (gibibytes), M, Mi (mebibytes).
|
||||
// Examples: "5G" → 5120, "2Gi" → 2048, "1000M" → 1000, "512Mi" → 512.
|
||||
func ParseSizeToMB(s string) (int, error) {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return 0, fmt.Errorf("empty size string")
|
||||
}
|
||||
|
||||
// Find where the numeric part ends.
|
||||
i := 0
|
||||
for i < len(s) && (s[i] == '.' || (s[i] >= '0' && s[i] <= '9')) {
|
||||
i++
|
||||
}
|
||||
if i == 0 {
|
||||
return 0, fmt.Errorf("invalid size %q: no numeric value", s)
|
||||
}
|
||||
|
||||
numStr := s[:i]
|
||||
suffix := strings.TrimSpace(s[i:])
|
||||
|
||||
num, err := strconv.ParseFloat(numStr, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid size %q: %w", s, err)
|
||||
}
|
||||
|
||||
switch suffix {
|
||||
case "G", "Gi":
|
||||
return int(num * 1024), nil
|
||||
case "M", "Mi", "":
|
||||
return int(num), nil
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid size %q: unknown suffix %q (use G, Gi, M, or Mi)", s, suffix)
|
||||
}
|
||||
}
|
||||
|
||||
// ShrinkMinimalImage shrinks the built-in minimal rootfs back to its minimum
|
||||
// size using resize2fs -M. This is the inverse of EnsureImageSizes and should
|
||||
// be called during graceful shutdown so the image is stored compactly on disk.
|
||||
func ShrinkMinimalImage(wrennDir string) {
|
||||
minimalRootfs := layout.TemplateRootfs(wrennDir, id.PlatformTeamID, id.MinimalTemplateID)
|
||||
shrinkImage(minimalRootfs)
|
||||
}
|
||||
|
||||
// shrinkImage shrinks a single rootfs image to its minimum size.
|
||||
func shrinkImage(rootfs string) {
|
||||
if _, err := os.Stat(rootfs); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("shrinking base image", "path", rootfs)
|
||||
|
||||
if out, err := exec.Command("e2fsck", "-fy", rootfs).CombinedOutput(); err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() > 1 {
|
||||
slog.Warn("e2fsck before shrink failed", "path", rootfs, "output", string(out), "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if out, err := exec.Command("resize2fs", "-M", rootfs).CombinedOutput(); err != nil {
|
||||
slog.Warn("resize2fs -M failed", "path", rootfs, "output", string(out), "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("base image shrunk", "path", rootfs)
|
||||
}
|
||||
|
||||
// expandImage expands a single rootfs image if it is smaller than targetBytes.
|
||||
func expandImage(rootfs string, targetBytes int64, targetMB int) error {
|
||||
info, err := os.Stat(rootfs)
|
||||
|
||||
@ -17,19 +17,28 @@ import (
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/devicemapper"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/envdclient"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/layout"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/models"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/network"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/snapshot"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/uffd"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/vm"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
envdpb "git.omukk.dev/wrenn/wrenn/proto/envd/gen"
|
||||
)
|
||||
|
||||
// Config holds the paths and defaults for the sandbox manager.
|
||||
type Config struct {
|
||||
WrennDir string // root directory (e.g. /var/lib/wrenn); all sub-paths derived via layout package
|
||||
EnvdTimeout time.Duration
|
||||
WrennDir string // root directory (e.g. /var/lib/wrenn); all sub-paths derived via layout package
|
||||
EnvdTimeout time.Duration
|
||||
DefaultRootfsSizeMB int // target size for template rootfs images; 0 → DefaultDiskSizeMB
|
||||
|
||||
// Resolved at startup by the host agent.
|
||||
KernelPath string // path to the latest vmlinux-x.y.z
|
||||
KernelVersion string // semver extracted from filename
|
||||
FirecrackerBin string // path to the firecracker binary
|
||||
FirecrackerVersion string // semver from firecracker --version
|
||||
AgentVersion string // host agent version (injected via ldflags)
|
||||
}
|
||||
|
||||
// Manager orchestrates sandbox lifecycle: VM, network, filesystem, envd.
|
||||
@ -84,6 +93,35 @@ type snapshotParent struct {
|
||||
// preventing the crash.
|
||||
const maxDiffGenerations = 8
|
||||
|
||||
// buildMetadata constructs the metadata map with version information.
|
||||
func (m *Manager) buildMetadata(envdVersion string) map[string]string {
|
||||
meta := map[string]string{
|
||||
"kernel_version": m.cfg.KernelVersion,
|
||||
"firecracker_version": m.cfg.FirecrackerVersion,
|
||||
"agent_version": m.cfg.AgentVersion,
|
||||
}
|
||||
if envdVersion != "" {
|
||||
meta["envd_version"] = envdVersion
|
||||
}
|
||||
return meta
|
||||
}
|
||||
|
||||
// resolveKernelPath returns the kernel path for the given version hint.
|
||||
// If the exact version exists on disk, it is used. Otherwise, falls back to
|
||||
// the latest kernel (m.cfg.KernelPath).
|
||||
func (m *Manager) resolveKernelPath(versionHint string) string {
|
||||
if versionHint == "" {
|
||||
return m.cfg.KernelPath
|
||||
}
|
||||
exact := layout.KernelPathVersioned(m.cfg.WrennDir, versionHint)
|
||||
if _, err := os.Stat(exact); err == nil {
|
||||
return exact
|
||||
}
|
||||
slog.Warn("requested kernel version not found, using latest",
|
||||
"requested", versionHint, "latest", m.cfg.KernelVersion)
|
||||
return m.cfg.KernelPath
|
||||
}
|
||||
|
||||
// New creates a new sandbox manager.
|
||||
func New(cfg Config) *Manager {
|
||||
if cfg.EnvdTimeout == 0 {
|
||||
@ -173,7 +211,7 @@ func (m *Manager) Create(ctx context.Context, sandboxID string, teamID, template
|
||||
vmCfg := vm.VMConfig{
|
||||
SandboxID: sandboxID,
|
||||
TemplateID: id.UUIDString(templateID),
|
||||
KernelPath: layout.KernelPath(m.cfg.WrennDir),
|
||||
KernelPath: m.cfg.KernelPath,
|
||||
RootfsPath: dmDev.DevicePath,
|
||||
VCPUs: vcpus,
|
||||
MemoryMB: memoryMB,
|
||||
@ -183,6 +221,7 @@ func (m *Manager) Create(ctx context.Context, sandboxID string, teamID, template
|
||||
GuestIP: slot.GuestIP,
|
||||
GatewayIP: slot.TapIP,
|
||||
NetMask: slot.GuestNetMask,
|
||||
FirecrackerBin: m.cfg.FirecrackerBin,
|
||||
}
|
||||
|
||||
if _, err := m.vm.Create(ctx, vmCfg); err != nil {
|
||||
@ -209,6 +248,9 @@ func (m *Manager) Create(ctx context.Context, sandboxID string, teamID, template
|
||||
return nil, fmt.Errorf("wait for envd: %w", err)
|
||||
}
|
||||
|
||||
// Fetch envd version (best-effort).
|
||||
envdVersion, _ := client.FetchVersion(ctx)
|
||||
|
||||
now := time.Now()
|
||||
sb := &sandboxState{
|
||||
Sandbox: models.Sandbox{
|
||||
@ -224,6 +266,7 @@ func (m *Manager) Create(ctx context.Context, sandboxID string, teamID, template
|
||||
RootfsPath: dmDev.DevicePath,
|
||||
CreatedAt: now,
|
||||
LastActiveAt: now,
|
||||
Metadata: m.buildMetadata(envdVersion),
|
||||
},
|
||||
slot: slot,
|
||||
client: client,
|
||||
@ -326,6 +369,20 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
|
||||
sb.connTracker.Drain(2 * time.Second)
|
||||
slog.Debug("pause: proxy connections drained", "id", sandboxID)
|
||||
|
||||
// Step 0b: Signal envd to quiesce continuous goroutines (port scanner,
|
||||
// forwarder) and run GC before freezing vCPUs. This prevents Go runtime
|
||||
// page allocator corruption ("bad summary data") on snapshot restore.
|
||||
// Best-effort: a failure is logged but does not abort the pause.
|
||||
func() {
|
||||
prepCtx, prepCancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer prepCancel()
|
||||
if err := sb.client.PrepareSnapshot(prepCtx); err != nil {
|
||||
slog.Warn("pause: pre-snapshot quiesce failed (best-effort)", "id", sandboxID, "error", err)
|
||||
} else {
|
||||
slog.Debug("pause: envd goroutines quiesced", "id", sandboxID)
|
||||
}
|
||||
}()
|
||||
|
||||
pauseStart := time.Now()
|
||||
|
||||
// Step 1: Pause the VM (freeze vCPUs).
|
||||
@ -542,7 +599,7 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
|
||||
|
||||
// Resume restores a paused sandbox from its snapshot using UFFD for
|
||||
// lazy memory loading. The sandbox gets a new network slot.
|
||||
func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) (*models.Sandbox, error) {
|
||||
func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int, kernelVersion string) (*models.Sandbox, error) {
|
||||
pauseDir := layout.PauseSnapshotDir(m.cfg.WrennDir, sandboxID)
|
||||
if _, err := os.Stat(pauseDir); err != nil {
|
||||
return nil, fmt.Errorf("no snapshot found for sandbox %s", sandboxID)
|
||||
@ -656,7 +713,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int)
|
||||
// Restore VM from snapshot.
|
||||
vmCfg := vm.VMConfig{
|
||||
SandboxID: sandboxID,
|
||||
KernelPath: layout.KernelPath(m.cfg.WrennDir),
|
||||
KernelPath: m.resolveKernelPath(kernelVersion),
|
||||
RootfsPath: dmDev.DevicePath,
|
||||
VCPUs: 1, // Placeholder; overridden by snapshot.
|
||||
MemoryMB: int(header.Metadata.Size / (1024 * 1024)), // Placeholder; overridden by snapshot.
|
||||
@ -666,6 +723,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int)
|
||||
GuestIP: slot.GuestIP,
|
||||
GatewayIP: slot.TapIP,
|
||||
NetMask: slot.GuestNetMask,
|
||||
FirecrackerBin: m.cfg.FirecrackerBin,
|
||||
}
|
||||
|
||||
resumeSnapPath := filepath.Join(pauseDir, snapshot.SnapFileName)
|
||||
@ -697,6 +755,14 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int)
|
||||
return nil, fmt.Errorf("wait for envd: %w", err)
|
||||
}
|
||||
|
||||
// Trigger envd to re-read MMDS so it picks up the new sandbox/template IDs.
|
||||
if err := client.PostInit(waitCtx); err != nil {
|
||||
slog.Warn("post-init failed after resume, metadata files may be stale", "sandbox", sandboxID, "error", err)
|
||||
}
|
||||
|
||||
// Fetch envd version (best-effort).
|
||||
envdVersion, _ := client.FetchVersion(ctx)
|
||||
|
||||
now := time.Now()
|
||||
sb := &sandboxState{
|
||||
Sandbox: models.Sandbox{
|
||||
@ -710,6 +776,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int)
|
||||
RootfsPath: dmDev.DevicePath,
|
||||
CreatedAt: now,
|
||||
LastActiveAt: now,
|
||||
Metadata: m.buildMetadata(envdVersion),
|
||||
},
|
||||
slot: slot,
|
||||
client: client,
|
||||
@ -880,6 +947,18 @@ func (m *Manager) FlattenRootfs(ctx context.Context, sandboxID string, teamID, t
|
||||
return 0, fmt.Errorf("sandbox %s not found", sandboxID)
|
||||
}
|
||||
|
||||
// Flush guest page cache to disk before stopping the VM. Without this,
|
||||
// files written by the build (e.g. pip-installed packages) may exist in the
|
||||
// guest's page cache but not yet on the dm block device — flatten would then
|
||||
// capture 0-byte files.
|
||||
func() {
|
||||
syncCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
if _, err := sb.client.Exec(syncCtx, "/bin/sync"); err != nil {
|
||||
slog.Warn("flatten: guest sync failed (non-fatal)", "id", sb.ID, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Stop the VM but keep the dm device alive for flattening.
|
||||
m.stopSampler(sb)
|
||||
if err := m.vm.Destroy(ctx, sb.ID); err != nil {
|
||||
@ -919,8 +998,8 @@ func (m *Manager) FlattenRootfs(ctx context.Context, sandboxID string, teamID, t
|
||||
// Clean up dm device and loop device now that flatten is complete.
|
||||
m.cleanupDM(sb)
|
||||
|
||||
// Shrink the flattened image to its minimum size so stored templates are
|
||||
// compact. EnsureImageSizes will re-expand them on the next agent startup.
|
||||
// Shrink the flattened image to its minimum size, then re-expand to the
|
||||
// configured default rootfs size so sandboxes see the full disk from boot.
|
||||
if out, err := exec.Command("e2fsck", "-fy", outputPath).CombinedOutput(); err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() > 1 {
|
||||
slog.Warn("e2fsck before shrink failed (non-fatal)", "output", string(out), "error", err)
|
||||
@ -930,6 +1009,15 @@ func (m *Manager) FlattenRootfs(ctx context.Context, sandboxID string, teamID, t
|
||||
slog.Warn("resize2fs -M failed (non-fatal)", "output", string(out), "error", err)
|
||||
}
|
||||
|
||||
// Re-expand to default rootfs size.
|
||||
targetMB := m.cfg.DefaultRootfsSizeMB
|
||||
if targetMB <= 0 {
|
||||
targetMB = DefaultDiskSizeMB
|
||||
}
|
||||
if err := expandImage(outputPath, int64(targetMB)*1024*1024, targetMB); err != nil {
|
||||
slog.Warn("failed to expand template to default size (non-fatal)", "error", err)
|
||||
}
|
||||
|
||||
sizeBytes, err := snapshot.DirSize(flattenDstDir, "")
|
||||
if err != nil {
|
||||
slog.Warn("failed to calculate template size", "error", err)
|
||||
@ -1057,7 +1145,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID string, team
|
||||
vmCfg := vm.VMConfig{
|
||||
SandboxID: sandboxID,
|
||||
TemplateID: id.UUIDString(templateID),
|
||||
KernelPath: layout.KernelPath(m.cfg.WrennDir),
|
||||
KernelPath: m.cfg.KernelPath,
|
||||
RootfsPath: dmDev.DevicePath,
|
||||
VCPUs: vcpus,
|
||||
MemoryMB: memoryMB,
|
||||
@ -1067,6 +1155,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID string, team
|
||||
GuestIP: slot.GuestIP,
|
||||
GatewayIP: slot.TapIP,
|
||||
NetMask: slot.GuestNetMask,
|
||||
FirecrackerBin: m.cfg.FirecrackerBin,
|
||||
}
|
||||
|
||||
snapPath := filepath.Join(tmplDir, snapshot.SnapFileName)
|
||||
@ -1098,6 +1187,14 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID string, team
|
||||
return nil, fmt.Errorf("wait for envd: %w", err)
|
||||
}
|
||||
|
||||
// Trigger envd to re-read MMDS so it picks up the new sandbox/template IDs.
|
||||
if err := client.PostInit(waitCtx); err != nil {
|
||||
slog.Warn("post-init failed after template restore, metadata files may be stale", "sandbox", sandboxID, "error", err)
|
||||
}
|
||||
|
||||
// Fetch envd version (best-effort).
|
||||
envdVersion, _ := client.FetchVersion(ctx)
|
||||
|
||||
now := time.Now()
|
||||
sb := &sandboxState{
|
||||
Sandbox: models.Sandbox{
|
||||
@ -1113,6 +1210,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID string, team
|
||||
RootfsPath: dmDev.DevicePath,
|
||||
CreatedAt: now,
|
||||
LastActiveAt: now,
|
||||
Metadata: m.buildMetadata(envdVersion),
|
||||
},
|
||||
slot: slot,
|
||||
client: client,
|
||||
@ -1213,6 +1311,155 @@ func (m *Manager) GetClient(sandboxID string) (*envdclient.Client, error) {
|
||||
return sb.client, nil
|
||||
}
|
||||
|
||||
// SetDefaults calls envd's PostInit to configure the default user and
|
||||
// environment variables for a running sandbox. This is called by the host
|
||||
// agent after sandbox creation or resume when the template specifies defaults.
|
||||
func (m *Manager) SetDefaults(ctx context.Context, sandboxID, defaultUser string, defaultEnv map[string]string) error {
|
||||
if defaultUser == "" && len(defaultEnv) == 0 {
|
||||
return nil
|
||||
}
|
||||
sb, err := m.get(sandboxID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sb.Status != models.StatusRunning {
|
||||
return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
|
||||
}
|
||||
return sb.client.PostInitWithDefaults(ctx, defaultUser, defaultEnv)
|
||||
}
|
||||
|
||||
// PtyAttach starts a new PTY process or reconnects to an existing one.
|
||||
// If cmd is non-empty, starts a new process. If empty, reconnects using tag.
|
||||
func (m *Manager) PtyAttach(ctx context.Context, sandboxID, tag, cmd string, args []string, cols, rows uint32, envs map[string]string, cwd string) (<-chan envdclient.PtyEvent, error) {
|
||||
sb, err := m.get(sandboxID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sb.Status != models.StatusRunning {
|
||||
return nil, fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
sb.LastActiveAt = time.Now()
|
||||
m.mu.Unlock()
|
||||
|
||||
if cmd != "" {
|
||||
return sb.client.PtyStart(ctx, tag, cmd, args, cols, rows, envs, cwd)
|
||||
}
|
||||
return sb.client.PtyConnect(ctx, tag)
|
||||
}
|
||||
|
||||
// PtySendInput sends raw bytes to a PTY process in a sandbox.
|
||||
func (m *Manager) PtySendInput(ctx context.Context, sandboxID, tag string, data []byte) error {
|
||||
sb, err := m.get(sandboxID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sb.Status != models.StatusRunning {
|
||||
return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
sb.LastActiveAt = time.Now()
|
||||
m.mu.Unlock()
|
||||
|
||||
return sb.client.PtySendInput(ctx, tag, data)
|
||||
}
|
||||
|
||||
// PtyResize updates the terminal dimensions for a PTY process in a sandbox.
|
||||
func (m *Manager) PtyResize(ctx context.Context, sandboxID, tag string, cols, rows uint32) error {
|
||||
sb, err := m.get(sandboxID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sb.Status != models.StatusRunning {
|
||||
return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
|
||||
}
|
||||
|
||||
return sb.client.PtyResize(ctx, tag, cols, rows)
|
||||
}
|
||||
|
||||
// PtyKill sends SIGKILL to a PTY process in a sandbox.
|
||||
func (m *Manager) PtyKill(ctx context.Context, sandboxID, tag string) error {
|
||||
sb, err := m.get(sandboxID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sb.Status != models.StatusRunning {
|
||||
return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
|
||||
}
|
||||
|
||||
return sb.client.PtyKill(ctx, tag)
|
||||
}
|
||||
|
||||
// StartBackground starts a background process inside a sandbox.
|
||||
func (m *Manager) StartBackground(ctx context.Context, sandboxID, tag, cmd string, args []string, envs map[string]string, cwd string) (uint32, error) {
|
||||
sb, err := m.get(sandboxID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if sb.Status != models.StatusRunning {
|
||||
return 0, fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
sb.LastActiveAt = time.Now()
|
||||
m.mu.Unlock()
|
||||
|
||||
return sb.client.StartBackground(ctx, tag, cmd, args, envs, cwd)
|
||||
}
|
||||
|
||||
// ConnectProcess re-attaches to a running process inside a sandbox.
|
||||
func (m *Manager) ConnectProcess(ctx context.Context, sandboxID string, pid uint32, tag string) (<-chan envdclient.ExecStreamEvent, error) {
|
||||
sb, err := m.get(sandboxID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sb.Status != models.StatusRunning {
|
||||
return nil, fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
sb.LastActiveAt = time.Now()
|
||||
m.mu.Unlock()
|
||||
|
||||
return sb.client.ConnectProcess(ctx, pid, tag)
|
||||
}
|
||||
|
||||
// ListProcesses returns all running processes inside a sandbox.
|
||||
func (m *Manager) ListProcesses(ctx context.Context, sandboxID string) ([]envdclient.ProcessInfo, error) {
|
||||
sb, err := m.get(sandboxID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sb.Status != models.StatusRunning {
|
||||
return nil, fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
sb.LastActiveAt = time.Now()
|
||||
m.mu.Unlock()
|
||||
|
||||
return sb.client.ListProcesses(ctx)
|
||||
}
|
||||
|
||||
// KillProcess sends a signal to a process inside a sandbox.
|
||||
func (m *Manager) KillProcess(ctx context.Context, sandboxID string, pid uint32, tag string, signal envdpb.Signal) error {
|
||||
sb, err := m.get(sandboxID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sb.Status != models.StatusRunning {
|
||||
return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
sb.LastActiveAt = time.Now()
|
||||
m.mu.Unlock()
|
||||
|
||||
return sb.client.KillProcess(ctx, pid, tag, signal)
|
||||
}
|
||||
|
||||
// AcquireProxyConn atomically looks up a sandbox by ID and registers an
|
||||
// in-flight proxy connection. Returns the sandbox's host-reachable IP, the
|
||||
// connection tracker, and true on success. The caller must call
|
||||
|
||||
@ -1 +0,0 @@
|
||||
package scheduler
|
||||
@ -1,71 +0,0 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
)
|
||||
|
||||
// HostScheduler selects a host for a new sandbox. Implementations may use
|
||||
// different strategies (round-robin, least-loaded, tag-based, etc.).
|
||||
type HostScheduler interface {
|
||||
// SelectHost returns a host that can accept a new sandbox.
|
||||
// For BYOC teams (isByoc=true), only online BYOC hosts belonging to teamID
|
||||
// are considered. For non-BYOC teams, only online regular (platform) hosts
|
||||
// are considered. Returns an error if no suitable host is available.
|
||||
SelectHost(ctx context.Context, teamID pgtype.UUID, isByoc bool) (db.Host, error)
|
||||
}
|
||||
|
||||
// RoundRobinScheduler cycles through eligible online hosts in round-robin order.
|
||||
// It re-fetches the host list on every call so that newly registered or
|
||||
// recovered hosts are considered immediately.
|
||||
type RoundRobinScheduler struct {
|
||||
db *db.Queries
|
||||
counter atomic.Int64
|
||||
}
|
||||
|
||||
// NewRoundRobinScheduler creates a RoundRobinScheduler backed by the given DB.
|
||||
func NewRoundRobinScheduler(queries *db.Queries) *RoundRobinScheduler {
|
||||
return &RoundRobinScheduler{db: queries}
|
||||
}
|
||||
|
||||
// SelectHost returns the next eligible online host in round-robin order.
|
||||
func (s *RoundRobinScheduler) SelectHost(ctx context.Context, teamID pgtype.UUID, isByoc bool) (db.Host, error) {
|
||||
hosts, err := s.db.ListActiveHosts(ctx)
|
||||
if err != nil {
|
||||
return db.Host{}, fmt.Errorf("list hosts: %w", err)
|
||||
}
|
||||
|
||||
var eligible []db.Host
|
||||
for _, h := range hosts {
|
||||
if h.Status != "online" || h.Address == "" {
|
||||
continue
|
||||
}
|
||||
if isByoc {
|
||||
// BYOC team: only use hosts belonging to this team.
|
||||
if h.Type != "byoc" || !h.TeamID.Valid || h.TeamID != teamID {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
// Non-BYOC team: only use platform (regular) hosts.
|
||||
if h.Type != "regular" {
|
||||
continue
|
||||
}
|
||||
}
|
||||
eligible = append(eligible, h)
|
||||
}
|
||||
|
||||
if len(eligible) == 0 {
|
||||
if isByoc {
|
||||
return db.Host{}, fmt.Errorf("no online BYOC hosts available for team")
|
||||
}
|
||||
return db.Host{}, fmt.Errorf("no online platform hosts available")
|
||||
}
|
||||
|
||||
idx := s.counter.Add(1) - 1
|
||||
return eligible[int(idx%int64(len(eligible)))], nil
|
||||
}
|
||||
@ -1 +0,0 @@
|
||||
package scheduler
|
||||
@ -1 +0,0 @@
|
||||
package scheduler
|
||||
@ -1,65 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
)
|
||||
|
||||
// APIKeyService provides API key operations shared between the REST API and the dashboard.
|
||||
type APIKeyService struct {
|
||||
DB *db.Queries
|
||||
}
|
||||
|
||||
// APIKeyCreateResult holds the result of creating an API key, including the
|
||||
// plaintext key which is only available at creation time.
|
||||
type APIKeyCreateResult struct {
|
||||
Row db.TeamApiKey
|
||||
Plaintext string
|
||||
}
|
||||
|
||||
// Create generates a new API key for the given team.
|
||||
func (s *APIKeyService) Create(ctx context.Context, teamID, userID pgtype.UUID, name string) (APIKeyCreateResult, error) {
|
||||
if name == "" {
|
||||
name = "Unnamed API Key"
|
||||
}
|
||||
|
||||
plaintext, hash, err := auth.GenerateAPIKey()
|
||||
if err != nil {
|
||||
return APIKeyCreateResult{}, fmt.Errorf("generate key: %w", err)
|
||||
}
|
||||
|
||||
row, err := s.DB.InsertAPIKey(ctx, db.InsertAPIKeyParams{
|
||||
ID: id.NewAPIKeyID(),
|
||||
TeamID: teamID,
|
||||
Name: name,
|
||||
KeyHash: hash,
|
||||
KeyPrefix: auth.APIKeyPrefix(plaintext),
|
||||
CreatedBy: userID,
|
||||
})
|
||||
if err != nil {
|
||||
return APIKeyCreateResult{}, fmt.Errorf("insert key: %w", err)
|
||||
}
|
||||
|
||||
return APIKeyCreateResult{Row: row, Plaintext: plaintext}, nil
|
||||
}
|
||||
|
||||
// List returns all API keys belonging to the given team.
|
||||
func (s *APIKeyService) List(ctx context.Context, teamID pgtype.UUID) ([]db.TeamApiKey, error) {
|
||||
return s.DB.ListAPIKeysByTeam(ctx, teamID)
|
||||
}
|
||||
|
||||
// ListWithCreator returns all API keys for the team, joined with the creator's email.
|
||||
func (s *APIKeyService) ListWithCreator(ctx context.Context, teamID pgtype.UUID) ([]db.ListAPIKeysByTeamWithCreatorRow, error) {
|
||||
return s.DB.ListAPIKeysByTeamWithCreator(ctx, teamID)
|
||||
}
|
||||
|
||||
// Delete removes an API key by ID, scoped to the given team.
|
||||
func (s *APIKeyService) Delete(ctx context.Context, keyID, teamID pgtype.UUID) error {
|
||||
return s.DB.DeleteAPIKey(ctx, db.DeleteAPIKeyParams{ID: keyID, TeamID: teamID})
|
||||
}
|
||||
@ -1,113 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
)
|
||||
|
||||
const auditMaxLimit = 200
|
||||
|
||||
// AuditEntry is a single audit log record returned by List.
|
||||
type AuditEntry struct {
|
||||
ID string
|
||||
TeamID string
|
||||
ActorType string
|
||||
ActorID string // empty for system
|
||||
ActorName string // empty for system
|
||||
ResourceType string
|
||||
ResourceID string // empty when not applicable
|
||||
Action string
|
||||
Scope string
|
||||
Status string // 'success', 'info', 'warning', 'error'
|
||||
Metadata map[string]any
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// AuditListParams controls the ListAuditLogs query.
|
||||
type AuditListParams struct {
|
||||
TeamID pgtype.UUID
|
||||
AdminScoped bool // true → include admin-scoped events; false → team-scoped only
|
||||
ResourceTypes []string // empty = no filter; multiple values = OR match
|
||||
Actions []string // empty = no filter; multiple values = OR match
|
||||
Before time.Time // zero = no cursor (start from latest)
|
||||
BeforeID pgtype.UUID // tie-breaker: id of the last item at the Before timestamp; zero = no tie-break
|
||||
Limit int // clamped to auditMaxLimit by the handler
|
||||
}
|
||||
|
||||
// AuditService provides the read side of the audit log.
|
||||
type AuditService struct {
|
||||
DB *db.Queries
|
||||
}
|
||||
|
||||
// List returns a page of audit log entries for the given team.
|
||||
func (s *AuditService) List(ctx context.Context, p AuditListParams) ([]AuditEntry, error) {
|
||||
limit := p.Limit
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
if limit > auditMaxLimit {
|
||||
limit = auditMaxLimit
|
||||
}
|
||||
|
||||
scopes := []string{"team"}
|
||||
if p.AdminScoped {
|
||||
scopes = append(scopes, "admin")
|
||||
}
|
||||
|
||||
var before pgtype.Timestamptz
|
||||
if !p.Before.IsZero() {
|
||||
before = pgtype.Timestamptz{Time: p.Before, Valid: true}
|
||||
}
|
||||
|
||||
resourceTypes := p.ResourceTypes
|
||||
if resourceTypes == nil {
|
||||
resourceTypes = []string{}
|
||||
}
|
||||
actions := p.Actions
|
||||
if actions == nil {
|
||||
actions = []string{}
|
||||
}
|
||||
|
||||
rows, err := s.DB.ListAuditLogs(ctx, db.ListAuditLogsParams{
|
||||
TeamID: p.TeamID,
|
||||
Column2: scopes,
|
||||
Column3: resourceTypes,
|
||||
Column4: actions,
|
||||
Column5: before,
|
||||
ID: p.BeforeID,
|
||||
Limit: int32(limit),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list audit logs: %w", err)
|
||||
}
|
||||
|
||||
entries := make([]AuditEntry, len(rows))
|
||||
for i, row := range rows {
|
||||
var meta map[string]any
|
||||
if len(row.Metadata) > 0 {
|
||||
_ = json.Unmarshal(row.Metadata, &meta)
|
||||
}
|
||||
entries[i] = AuditEntry{
|
||||
ID: id.FormatAuditLogID(row.ID),
|
||||
TeamID: id.FormatTeamID(row.TeamID),
|
||||
ActorType: row.ActorType,
|
||||
ActorID: row.ActorID.String,
|
||||
ActorName: row.ActorName,
|
||||
ResourceType: row.ResourceType,
|
||||
ResourceID: row.ResourceID.String,
|
||||
Action: row.Action,
|
||||
Scope: row.Scope,
|
||||
Status: row.Status,
|
||||
Metadata: meta,
|
||||
CreatedAt: row.CreatedAt.Time,
|
||||
}
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
@ -1,605 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/recipe"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/scheduler"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
const (
|
||||
buildQueueKey = "wrenn:build_queue"
|
||||
buildCommandTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// preBuildCmds run before the user recipe to prepare the build environment.
|
||||
var preBuildCmds = []string{
|
||||
"RUN apt update",
|
||||
}
|
||||
|
||||
// postBuildCmds run after the user recipe to clean up caches and reduce image size.
|
||||
var postBuildCmds = []string{
|
||||
"RUN apt clean",
|
||||
"RUN apt autoremove -y",
|
||||
"RUN rm -rf /var/lib/apt/lists/*",
|
||||
}
|
||||
|
||||
// buildAgentClient is the subset of the host agent client used by the build worker.
|
||||
type buildAgentClient interface {
|
||||
CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error)
|
||||
DestroySandbox(ctx context.Context, req *connect.Request[pb.DestroySandboxRequest]) (*connect.Response[pb.DestroySandboxResponse], error)
|
||||
Exec(ctx context.Context, req *connect.Request[pb.ExecRequest]) (*connect.Response[pb.ExecResponse], error)
|
||||
CreateSnapshot(ctx context.Context, req *connect.Request[pb.CreateSnapshotRequest]) (*connect.Response[pb.CreateSnapshotResponse], error)
|
||||
FlattenRootfs(ctx context.Context, req *connect.Request[pb.FlattenRootfsRequest]) (*connect.Response[pb.FlattenRootfsResponse], error)
|
||||
}
|
||||
|
||||
// BuildService handles template build orchestration.
|
||||
type BuildService struct {
|
||||
DB *db.Queries
|
||||
Redis *redis.Client
|
||||
Pool *lifecycle.HostClientPool
|
||||
Scheduler scheduler.HostScheduler
|
||||
|
||||
mu sync.Mutex
|
||||
cancelMap map[string]context.CancelFunc // buildID → per-build cancel func
|
||||
}
|
||||
|
||||
// BuildCreateParams holds the parameters for creating a template build.
|
||||
type BuildCreateParams struct {
|
||||
Name string
|
||||
BaseTemplate string
|
||||
Recipe []string
|
||||
Healthcheck string
|
||||
VCPUs int32
|
||||
MemoryMB int32
|
||||
SkipPrePost bool
|
||||
}
|
||||
|
||||
// Create inserts a new build record and enqueues it to Redis.
|
||||
func (s *BuildService) Create(ctx context.Context, p BuildCreateParams) (db.TemplateBuild, error) {
|
||||
if p.BaseTemplate == "" {
|
||||
p.BaseTemplate = "minimal"
|
||||
}
|
||||
if p.VCPUs <= 0 {
|
||||
p.VCPUs = 1
|
||||
}
|
||||
if p.MemoryMB <= 0 {
|
||||
p.MemoryMB = 512
|
||||
}
|
||||
|
||||
recipeJSON, err := json.Marshal(p.Recipe)
|
||||
if err != nil {
|
||||
return db.TemplateBuild{}, fmt.Errorf("marshal recipe: %w", err)
|
||||
}
|
||||
|
||||
buildID := id.NewBuildID()
|
||||
buildIDStr := id.FormatBuildID(buildID)
|
||||
newTemplateID := id.NewTemplateID()
|
||||
|
||||
defaultSteps := len(preBuildCmds) + len(postBuildCmds)
|
||||
if p.SkipPrePost {
|
||||
defaultSteps = 0
|
||||
}
|
||||
|
||||
build, err := s.DB.InsertTemplateBuild(ctx, db.InsertTemplateBuildParams{
|
||||
ID: buildID,
|
||||
Name: p.Name,
|
||||
BaseTemplate: p.BaseTemplate,
|
||||
Recipe: recipeJSON,
|
||||
Healthcheck: p.Healthcheck,
|
||||
Vcpus: p.VCPUs,
|
||||
MemoryMb: p.MemoryMB,
|
||||
TotalSteps: int32(len(p.Recipe) + defaultSteps),
|
||||
TemplateID: newTemplateID,
|
||||
TeamID: id.PlatformTeamID,
|
||||
SkipPrePost: p.SkipPrePost,
|
||||
})
|
||||
if err != nil {
|
||||
return db.TemplateBuild{}, fmt.Errorf("insert build: %w", err)
|
||||
}
|
||||
|
||||
// Enqueue build ID (as formatted string) to Redis for workers to pick up.
|
||||
if err := s.Redis.RPush(ctx, buildQueueKey, buildIDStr).Err(); err != nil {
|
||||
return db.TemplateBuild{}, fmt.Errorf("enqueue build: %w", err)
|
||||
}
|
||||
|
||||
return build, nil
|
||||
}
|
||||
|
||||
// Get returns a single build by ID.
|
||||
func (s *BuildService) Get(ctx context.Context, buildID pgtype.UUID) (db.TemplateBuild, error) {
|
||||
return s.DB.GetTemplateBuild(ctx, buildID)
|
||||
}
|
||||
|
||||
// List returns all builds ordered by creation time.
|
||||
func (s *BuildService) List(ctx context.Context) ([]db.TemplateBuild, error) {
|
||||
return s.DB.ListTemplateBuilds(ctx)
|
||||
}
|
||||
|
||||
// Cancel cancels a pending or running build. For pending builds the status is
|
||||
// updated in the DB and the worker skips it when dequeued. For running builds
|
||||
// the per-build context is cancelled, which causes the current exec step to
|
||||
// abort; executeBuild then detects the cancellation and records the status.
|
||||
func (s *BuildService) Cancel(ctx context.Context, buildID pgtype.UUID) error {
|
||||
build, err := s.DB.GetTemplateBuild(ctx, buildID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get build: %w", err)
|
||||
}
|
||||
switch build.Status {
|
||||
case "success", "failed", "cancelled":
|
||||
return fmt.Errorf("build is already %s", build.Status)
|
||||
}
|
||||
|
||||
// Mark cancelled in DB first. This handles both pending builds (which haven't
|
||||
// been picked up yet) and acts as a flag for executeBuild to check on start.
|
||||
if _, err := s.DB.UpdateBuildStatus(ctx, db.UpdateBuildStatusParams{
|
||||
ID: buildID, Status: "cancelled",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("update build status: %w", err)
|
||||
}
|
||||
|
||||
// If the build is currently running, signal its context.
|
||||
buildIDStr := id.FormatBuildID(buildID)
|
||||
s.mu.Lock()
|
||||
cancel, running := s.cancelMap[buildIDStr]
|
||||
s.mu.Unlock()
|
||||
if running {
|
||||
cancel()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartWorkers launches n goroutines that consume from the Redis build queue.
|
||||
// The returned cancel function stops all workers.
|
||||
func (s *BuildService) StartWorkers(ctx context.Context, n int) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
for i := range n {
|
||||
go s.worker(ctx, i)
|
||||
}
|
||||
slog.Info("build workers started", "count", n)
|
||||
return cancel
|
||||
}
|
||||
|
||||
func (s *BuildService) worker(ctx context.Context, workerID int) {
|
||||
log := slog.With("worker", workerID)
|
||||
for {
|
||||
// BLPOP blocks until a build ID is available or context is cancelled.
|
||||
result, err := s.Redis.BLPop(ctx, 0, buildQueueKey).Result()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
log.Info("build worker shutting down")
|
||||
return
|
||||
}
|
||||
log.Error("redis BLPOP error", "error", err)
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
// result[0] is the key, result[1] is the build ID (formatted string).
|
||||
buildIDStr := result[1]
|
||||
log.Info("picked up build", "build_id", buildIDStr)
|
||||
s.executeBuild(ctx, buildIDStr)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) {
|
||||
log := slog.With("build_id", buildIDStr)
|
||||
|
||||
buildID, err := id.ParseBuildID(buildIDStr)
|
||||
if err != nil {
|
||||
log.Error("invalid build ID from queue", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a per-build context so this build can be cancelled independently of
|
||||
// the worker. Register in cancelMap before fetching the build so that a
|
||||
// concurrent Cancel call can always find and signal it.
|
||||
buildCtx, buildCancel := context.WithCancel(ctx)
|
||||
defer buildCancel()
|
||||
|
||||
s.mu.Lock()
|
||||
if s.cancelMap == nil {
|
||||
s.cancelMap = make(map[string]context.CancelFunc)
|
||||
}
|
||||
s.cancelMap[buildIDStr] = buildCancel
|
||||
s.mu.Unlock()
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
delete(s.cancelMap, buildIDStr)
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
build, err := s.DB.GetTemplateBuild(buildCtx, buildID)
|
||||
if err != nil {
|
||||
log.Error("failed to fetch build", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Skip if already cancelled (Cancel was called before we dequeued).
|
||||
if build.Status == "cancelled" {
|
||||
log.Info("build already cancelled, skipping")
|
||||
return
|
||||
}
|
||||
|
||||
// Mark as running.
|
||||
if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{
|
||||
ID: buildID, Status: "running",
|
||||
}); err != nil {
|
||||
log.Error("failed to update build status", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse user recipe.
|
||||
var userRecipe []string
|
||||
if err := json.Unmarshal(build.Recipe, &userRecipe); err != nil {
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("invalid recipe JSON: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Pick a platform host and create a sandbox.
|
||||
host, err := s.Scheduler.SelectHost(buildCtx, id.PlatformTeamID, false)
|
||||
if err != nil {
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("no host available: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := s.Pool.GetForHost(host)
|
||||
if err != nil {
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("agent client error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
sandboxID := id.NewSandboxID()
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
log = log.With("sandbox_id", sandboxIDStr, "host_id", id.FormatHostID(host.ID))
|
||||
|
||||
// Resolve the base template to UUIDs. "minimal" is the zero sentinel.
|
||||
baseTeamID := id.PlatformTeamID
|
||||
baseTemplateID := id.MinimalTemplateID
|
||||
if build.BaseTemplate != "minimal" {
|
||||
baseTmpl, err := s.DB.GetPlatformTemplateByName(buildCtx, build.BaseTemplate)
|
||||
if err != nil {
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("base template %q not found: %v", build.BaseTemplate, err))
|
||||
return
|
||||
}
|
||||
baseTeamID = baseTmpl.TeamID
|
||||
baseTemplateID = baseTmpl.ID
|
||||
}
|
||||
|
||||
resp, err := agent.CreateSandbox(buildCtx, connect.NewRequest(&pb.CreateSandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Template: build.BaseTemplate,
|
||||
TeamId: id.UUIDString(baseTeamID),
|
||||
TemplateId: id.UUIDString(baseTemplateID),
|
||||
Vcpus: build.Vcpus,
|
||||
MemoryMb: build.MemoryMb,
|
||||
TimeoutSec: 0, // no auto-pause for builds
|
||||
DiskSizeMb: 5120, // 5 GB for template builds
|
||||
}))
|
||||
if err != nil {
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("create sandbox failed: %v", err))
|
||||
return
|
||||
}
|
||||
_ = resp
|
||||
|
||||
// Record sandbox/host association.
|
||||
_ = s.DB.UpdateBuildSandbox(buildCtx, db.UpdateBuildSandboxParams{
|
||||
ID: buildID,
|
||||
SandboxID: sandboxID,
|
||||
HostID: host.ID,
|
||||
})
|
||||
|
||||
// Parse recipe steps. preBuildCmds and postBuildCmds are hardcoded and always
|
||||
// valid; panic on error is appropriate here since it would be a programmer mistake.
|
||||
preBuildSteps, err := recipe.ParseRecipe(preBuildCmds)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("invalid pre-build recipe: %v", err))
|
||||
}
|
||||
userRecipeSteps, err := recipe.ParseRecipe(userRecipe)
|
||||
if err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("recipe parse error: %v", err))
|
||||
return
|
||||
}
|
||||
postBuildSteps, err := recipe.ParseRecipe(postBuildCmds)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("invalid post-build recipe: %v", err))
|
||||
}
|
||||
|
||||
var logs []recipe.BuildLogEntry
|
||||
step := 0
|
||||
|
||||
envVars, err := s.fetchSandboxEnv(buildCtx, agent, sandboxIDStr)
|
||||
if err != nil {
|
||||
log.Warn("failed to fetch sandbox env, using defaults", "error", err)
|
||||
envVars = map[string]string{
|
||||
"PATH": "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
|
||||
"HOME": "/root",
|
||||
}
|
||||
}
|
||||
bctx := &recipe.ExecContext{EnvVars: envVars}
|
||||
|
||||
runPhase := func(phase string, steps []recipe.Step, defaultTimeout time.Duration) bool {
|
||||
newEntries, nextStep, ok := recipe.Execute(buildCtx, phase, steps, sandboxIDStr, step, defaultTimeout, bctx, agent.Exec)
|
||||
logs = append(logs, newEntries...)
|
||||
step = nextStep
|
||||
s.updateLogs(buildCtx, buildID, step, logs)
|
||||
if !ok {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
// If the build was cancelled, status is already set — don't overwrite with "failed".
|
||||
if buildCtx.Err() != nil {
|
||||
return false
|
||||
}
|
||||
last := newEntries[len(newEntries)-1]
|
||||
reason := last.Stderr
|
||||
if reason == "" {
|
||||
reason = fmt.Sprintf("exit code %d", last.Exit)
|
||||
}
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("%s step %d failed: %s", phase, step, reason))
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
if !build.SkipPrePost {
|
||||
if !runPhase("pre-build", preBuildSteps, 0) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if !runPhase("recipe", userRecipeSteps, buildCommandTimeout) {
|
||||
return
|
||||
}
|
||||
if !build.SkipPrePost {
|
||||
if !runPhase("post-build", postBuildSteps, 0) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Healthcheck or direct snapshot.
|
||||
var sizeBytes int64
|
||||
if build.Healthcheck != "" {
|
||||
hc, err := recipe.ParseHealthcheck(build.Healthcheck)
|
||||
if err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("invalid healthcheck: %v", err))
|
||||
return
|
||||
}
|
||||
log.Info("running healthcheck", "cmd", hc.Cmd, "interval", hc.Interval, "timeout", hc.Timeout, "start_period", hc.StartPeriod, "retries", hc.Retries)
|
||||
if err := s.waitForHealthcheck(buildCtx, agent, sandboxIDStr, hc); err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
if buildCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("healthcheck failed: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Healthcheck passed → full snapshot (with memory/CPU state).
|
||||
log.Info("healthcheck passed, creating snapshot")
|
||||
snapResp, err := agent.CreateSnapshot(buildCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Name: build.Name,
|
||||
TeamId: id.UUIDString(build.TeamID),
|
||||
TemplateId: id.UUIDString(build.TemplateID),
|
||||
}))
|
||||
if err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
if buildCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("create snapshot failed: %v", err))
|
||||
return
|
||||
}
|
||||
sizeBytes = snapResp.Msg.SizeBytes
|
||||
} else {
|
||||
// No healthcheck → image-only template (rootfs only).
|
||||
log.Info("no healthcheck, flattening rootfs")
|
||||
flatResp, err := agent.FlattenRootfs(buildCtx, connect.NewRequest(&pb.FlattenRootfsRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Name: build.Name,
|
||||
TeamId: id.UUIDString(build.TeamID),
|
||||
TemplateId: id.UUIDString(build.TemplateID),
|
||||
}))
|
||||
if err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
if buildCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("flatten rootfs failed: %v", err))
|
||||
return
|
||||
}
|
||||
sizeBytes = flatResp.Msg.SizeBytes
|
||||
}
|
||||
|
||||
// Insert into templates table as a global (platform) template.
|
||||
templateType := "base"
|
||||
if build.Healthcheck != "" {
|
||||
templateType = "snapshot"
|
||||
}
|
||||
|
||||
if _, err := s.DB.InsertTemplate(buildCtx, db.InsertTemplateParams{
|
||||
ID: build.TemplateID,
|
||||
Name: build.Name,
|
||||
Type: templateType,
|
||||
Vcpus: build.Vcpus,
|
||||
MemoryMb: build.MemoryMb,
|
||||
SizeBytes: sizeBytes,
|
||||
TeamID: id.PlatformTeamID,
|
||||
}); err != nil {
|
||||
log.Error("failed to insert template record", "error", err)
|
||||
// Build succeeded on disk, just DB record failed — don't mark as failed.
|
||||
}
|
||||
|
||||
// For CreateSnapshot, the sandbox is already destroyed by the snapshot process.
|
||||
// For FlattenRootfs, the sandbox is already destroyed by the flatten process.
|
||||
// No additional destroy needed.
|
||||
|
||||
// Mark build as success.
|
||||
if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{
|
||||
ID: buildID, Status: "success",
|
||||
}); err != nil {
|
||||
log.Error("failed to mark build as success", "error", err)
|
||||
}
|
||||
|
||||
log.Info("template build completed successfully", "name", build.Name)
|
||||
}
|
||||
|
||||
// waitForHealthcheck repeatedly executes the healthcheck command inside the
|
||||
// sandbox according to the config's interval, timeout, start-period, and
|
||||
// retries.
|
||||
// During the start period, failures are not counted toward the retry budget.
|
||||
// Returns nil on the first successful check, or an error if retries are
|
||||
// exhausted, the deadline passes, or the context is cancelled.
|
||||
func (s *BuildService) waitForHealthcheck(ctx context.Context, agent buildAgentClient, sandboxIDStr string, hc recipe.HealthcheckConfig) error {
|
||||
ticker := time.NewTicker(hc.Interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// When retries > 0, set a deadline based on the retry budget.
|
||||
// When retries == 0 (unlimited), rely solely on the parent context deadline.
|
||||
var deadlineCh <-chan time.Time
|
||||
if hc.Retries > 0 {
|
||||
deadline := time.NewTimer(hc.StartPeriod + time.Duration(hc.Retries+1)*hc.Interval)
|
||||
defer deadline.Stop()
|
||||
deadlineCh = deadline.C
|
||||
}
|
||||
|
||||
startedAt := time.Now()
|
||||
failCount := 0
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-deadlineCh:
|
||||
return fmt.Errorf("healthcheck timed out: exceeded %d attempts over %s", failCount, time.Since(startedAt))
|
||||
case <-ticker.C:
|
||||
execCtx, cancel := context.WithTimeout(ctx, hc.Timeout)
|
||||
resp, err := agent.Exec(execCtx, connect.NewRequest(&pb.ExecRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Cmd: "/bin/sh",
|
||||
Args: []string{"-c", hc.Cmd},
|
||||
TimeoutSec: int32(hc.Timeout.Seconds()),
|
||||
}))
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
slog.Debug("healthcheck exec error (retrying)", "error", err)
|
||||
if time.Since(startedAt) >= hc.StartPeriod {
|
||||
failCount++
|
||||
if hc.Retries > 0 && failCount >= hc.Retries {
|
||||
return fmt.Errorf("healthcheck failed after %d retries: exec error: %w", failCount, err)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if resp.Msg.ExitCode == 0 {
|
||||
return nil
|
||||
}
|
||||
slog.Debug("healthcheck failed (retrying)", "exit_code", resp.Msg.ExitCode)
|
||||
if time.Since(startedAt) >= hc.StartPeriod {
|
||||
failCount++
|
||||
if hc.Retries > 0 && failCount >= hc.Retries {
|
||||
return fmt.Errorf("healthcheck failed after %d retries: exit code %d", failCount, resp.Msg.ExitCode)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BuildService) updateLogs(ctx context.Context, buildID pgtype.UUID, step int, logs []recipe.BuildLogEntry) {
|
||||
logsJSON, err := json.Marshal(logs)
|
||||
if err != nil {
|
||||
slog.Warn("failed to marshal build logs", "error", err)
|
||||
return
|
||||
}
|
||||
if err := s.DB.UpdateBuildProgress(ctx, db.UpdateBuildProgressParams{
|
||||
ID: buildID,
|
||||
CurrentStep: int32(step),
|
||||
Logs: logsJSON,
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update build progress", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BuildService) failBuild(_ context.Context, buildID pgtype.UUID, errMsg string) {
|
||||
slog.Error("build failed", "build_id", id.FormatBuildID(buildID), "error", errMsg)
|
||||
// Use a detached context so DB writes survive parent context cancellation (e.g. shutdown).
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := s.DB.UpdateBuildError(ctx, db.UpdateBuildErrorParams{
|
||||
ID: buildID,
|
||||
Error: errMsg,
|
||||
}); err != nil {
|
||||
slog.Error("failed to update build error", "build_id", id.FormatBuildID(buildID), "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BuildService) destroySandbox(_ context.Context, agent buildAgentClient, sandboxIDStr string) {
|
||||
// Use a detached context so cleanup succeeds even during shutdown.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
})); err != nil {
|
||||
slog.Warn("failed to destroy build sandbox", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// fetchSandboxEnv executes the 'env' command inside the specified sandbox via
|
||||
// the build agent and returns environment variables
|
||||
func (s *BuildService) fetchSandboxEnv(ctx context.Context,
|
||||
agent buildAgentClient, sandboxIDStr string) (map[string]string, error) {
|
||||
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Cmd: "/bin/sh",
|
||||
Args: []string{"-c", "env"},
|
||||
TimeoutSec: 10,
|
||||
}))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch env: %w", err)
|
||||
}
|
||||
|
||||
if resp.Msg.ExitCode != 0 {
|
||||
return nil, fmt.Errorf("fetch env: command exited with code %d",
|
||||
resp.Msg.ExitCode)
|
||||
}
|
||||
|
||||
return parseSandboxEnv(string(resp.Msg.Stdout)), nil
|
||||
}
|
||||
|
||||
// parseSandboxEnv converts the raw newline-separated output of an 'env'
|
||||
// command into a map.
|
||||
// It skips empty lines and malformed entries, and correctly handles values
|
||||
// containing '='.
|
||||
func parseSandboxEnv(raw string) map[string]string {
|
||||
envVars := make(map[string]string)
|
||||
|
||||
for line := range strings.SplitSeq(raw, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
envVars[parts[0]] = parts[1]
|
||||
}
|
||||
|
||||
return envVars
|
||||
}
|
||||
@ -1,628 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
// HostService provides host management operations.
|
||||
type HostService struct {
|
||||
DB *db.Queries
|
||||
Redis *redis.Client
|
||||
JWT []byte
|
||||
Pool *lifecycle.HostClientPool
|
||||
CA *auth.CA // nil disables mTLS cert issuance (dev/test environments)
|
||||
}
|
||||
|
||||
// HostCreateParams holds the parameters for creating a host.
|
||||
type HostCreateParams struct {
|
||||
Type string
|
||||
TeamID pgtype.UUID // required for BYOC, zero value for regular
|
||||
Provider string
|
||||
AvailabilityZone string
|
||||
RequestingUserID pgtype.UUID
|
||||
IsRequestorAdmin bool
|
||||
}
|
||||
|
||||
// HostCreateResult holds the created host and the one-time registration token.
|
||||
type HostCreateResult struct {
|
||||
Host db.Host
|
||||
RegistrationToken string
|
||||
}
|
||||
|
||||
// HostRegisterParams holds the parameters for host agent registration.
|
||||
type HostRegisterParams struct {
|
||||
Token string
|
||||
Arch string
|
||||
CPUCores int32
|
||||
MemoryMB int32
|
||||
DiskGB int32
|
||||
Address string
|
||||
}
|
||||
|
||||
// HostRegisterResult holds the registered host, its short-lived JWT, a long-lived
|
||||
// refresh token, and optionally the host's mTLS certificate material.
|
||||
type HostRegisterResult struct {
|
||||
Host db.Host
|
||||
JWT string
|
||||
RefreshToken string
|
||||
// mTLS cert material — empty when CA is not configured.
|
||||
CertPEM string
|
||||
KeyPEM string
|
||||
CACertPEM string
|
||||
}
|
||||
|
||||
// HostRefreshResult holds a new JWT and rotated refresh token after a successful
|
||||
// refresh, plus refreshed mTLS certificate material when CA is configured.
|
||||
type HostRefreshResult struct {
|
||||
Host db.Host
|
||||
JWT string
|
||||
RefreshToken string
|
||||
// mTLS cert material — empty when CA is not configured.
|
||||
CertPEM string
|
||||
KeyPEM string
|
||||
CACertPEM string
|
||||
}
|
||||
|
||||
// HostDeletePreview describes what will be affected by deleting a host.
|
||||
type HostDeletePreview struct {
|
||||
Host db.Host
|
||||
SandboxIDs []string
|
||||
}
|
||||
|
||||
// regTokenPayload is the JSON stored in Redis for registration tokens.
|
||||
type regTokenPayload struct {
|
||||
HostID string `json:"host_id"`
|
||||
TokenID string `json:"token_id"`
|
||||
}
|
||||
|
||||
const regTokenTTL = time.Hour
|
||||
|
||||
// requireAdminOrOwner returns nil iff the role is "owner" or "admin".
|
||||
func requireAdminOrOwner(role string) error {
|
||||
if role == "owner" || role == "admin" {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("forbidden: only team owners and admins can manage BYOC hosts")
|
||||
}
|
||||
|
||||
// Create creates a new host record and generates a one-time registration token.
|
||||
func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreateResult, error) {
|
||||
if p.Type != "regular" && p.Type != "byoc" {
|
||||
return HostCreateResult{}, fmt.Errorf("invalid host type: must be 'regular' or 'byoc'")
|
||||
}
|
||||
|
||||
if p.Type == "regular" {
|
||||
if !p.IsRequestorAdmin {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can create regular hosts")
|
||||
}
|
||||
} else {
|
||||
// BYOC: platform admin, or team owner/admin.
|
||||
if !p.TeamID.Valid {
|
||||
return HostCreateResult{}, fmt.Errorf("invalid request: team_id is required for BYOC hosts")
|
||||
}
|
||||
if !p.IsRequestorAdmin {
|
||||
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: p.RequestingUserID,
|
||||
TeamID: p.TeamID,
|
||||
})
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: not a member of the specified team")
|
||||
}
|
||||
if err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
|
||||
}
|
||||
if err := requireAdminOrOwner(membership.Role); err != nil {
|
||||
return HostCreateResult{}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate team exists, is not deleted, and has BYOC enabled.
|
||||
if p.TeamID.Valid {
|
||||
team, err := s.DB.GetTeam(ctx, p.TeamID)
|
||||
if err != nil || team.DeletedAt.Valid {
|
||||
return HostCreateResult{}, fmt.Errorf("invalid request: team not found")
|
||||
}
|
||||
if !team.IsByoc {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: BYOC is not enabled for this team")
|
||||
}
|
||||
}
|
||||
|
||||
hostID := id.NewHostID()
|
||||
|
||||
host, err := s.DB.InsertHost(ctx, db.InsertHostParams{
|
||||
ID: hostID,
|
||||
Type: p.Type,
|
||||
TeamID: p.TeamID,
|
||||
Provider: p.Provider,
|
||||
AvailabilityZone: p.AvailabilityZone,
|
||||
CreatedBy: p.RequestingUserID,
|
||||
})
|
||||
if err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("insert host: %w", err)
|
||||
}
|
||||
|
||||
// Generate registration token and store in Redis + Postgres audit trail.
|
||||
token := id.NewRegistrationToken()
|
||||
tokenID := id.NewHostTokenID()
|
||||
|
||||
payload, _ := json.Marshal(regTokenPayload{
|
||||
HostID: id.FormatHostID(hostID),
|
||||
TokenID: id.FormatHostTokenID(tokenID),
|
||||
})
|
||||
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{
|
||||
ID: tokenID,
|
||||
HostID: hostID,
|
||||
CreatedBy: p.RequestingUserID,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
|
||||
}
|
||||
|
||||
return HostCreateResult{Host: host, RegistrationToken: token}, nil
|
||||
}
|
||||
|
||||
// RegenerateToken issues a new registration token for a host still in "pending"
|
||||
// status. This allows retry when a previous registration attempt failed after
|
||||
// the original token was consumed.
|
||||
func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin bool) (HostCreateResult, error) {
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
if host.Status != "pending" {
|
||||
return HostCreateResult{}, fmt.Errorf("invalid state: can only regenerate token for pending hosts (status: %s)", host.Status)
|
||||
}
|
||||
|
||||
if !isAdmin {
|
||||
if host.Type != "byoc" {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can manage regular hosts")
|
||||
}
|
||||
if !host.TeamID.Valid || host.TeamID != teamID {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: host does not belong to your team")
|
||||
}
|
||||
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: not a member of the specified team")
|
||||
}
|
||||
if err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
|
||||
}
|
||||
if err := requireAdminOrOwner(membership.Role); err != nil {
|
||||
return HostCreateResult{}, err
|
||||
}
|
||||
}
|
||||
|
||||
token := id.NewRegistrationToken()
|
||||
tokenID := id.NewHostTokenID()
|
||||
|
||||
payload, _ := json.Marshal(regTokenPayload{
|
||||
HostID: id.FormatHostID(hostID),
|
||||
TokenID: id.FormatHostTokenID(tokenID),
|
||||
})
|
||||
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{
|
||||
ID: tokenID,
|
||||
HostID: hostID,
|
||||
CreatedBy: userID,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
|
||||
}
|
||||
|
||||
return HostCreateResult{Host: host, RegistrationToken: token}, nil
|
||||
}
|
||||
|
||||
// Register validates a one-time registration token, updates the host with
|
||||
// machine specs, and returns a short-lived host JWT plus a long-lived refresh token.
|
||||
func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostRegisterResult, error) {
|
||||
// Atomic consume: GetDel returns the value and deletes in one operation,
|
||||
// preventing concurrent requests from consuming the same token.
|
||||
raw, err := s.Redis.GetDel(ctx, "host:reg:"+p.Token).Bytes()
|
||||
if err == redis.Nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("invalid or expired registration token")
|
||||
}
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("token lookup: %w", err)
|
||||
}
|
||||
|
||||
var payload regTokenPayload
|
||||
if err := json.Unmarshal(raw, &payload); err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("corrupted registration token")
|
||||
}
|
||||
|
||||
hostID, err := id.ParseHostID(payload.HostID)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("corrupted registration token: %w", err)
|
||||
}
|
||||
tokenID, err := id.ParseHostTokenID(payload.TokenID)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("corrupted registration token: %w", err)
|
||||
}
|
||||
|
||||
if _, err := s.DB.GetHost(ctx, hostID); err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
|
||||
// Sign JWT before mutating DB — if signing fails, the host stays pending.
|
||||
hostJWT, err := auth.SignHostJWT(s.JWT, hostID)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("sign host token: %w", err)
|
||||
}
|
||||
|
||||
// Issue mTLS certificate if CA is configured.
|
||||
var hc auth.HostCert
|
||||
if s.CA != nil {
|
||||
hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(hostID), p.Address)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("issue host cert: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Atomically update only if still pending (defense-in-depth against races).
|
||||
rowsAffected, err := s.DB.RegisterHost(ctx, db.RegisterHostParams{
|
||||
ID: hostID,
|
||||
Arch: p.Arch,
|
||||
CpuCores: p.CPUCores,
|
||||
MemoryMb: p.MemoryMB,
|
||||
DiskGb: p.DiskGB,
|
||||
Address: p.Address,
|
||||
CertFingerprint: hc.Fingerprint,
|
||||
CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: s.CA != nil},
|
||||
})
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("register host: %w", err)
|
||||
}
|
||||
if rowsAffected == 0 {
|
||||
return HostRegisterResult{}, fmt.Errorf("host already registered or not found")
|
||||
}
|
||||
|
||||
// Mark audit trail.
|
||||
if err := s.DB.MarkHostTokenUsed(ctx, tokenID); err != nil {
|
||||
slog.Warn("failed to mark host token used", "token_id", payload.TokenID, "error", err)
|
||||
}
|
||||
|
||||
// Issue a long-lived refresh token.
|
||||
refreshToken, err := s.issueRefreshToken(ctx, hostID)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("issue refresh token: %w", err)
|
||||
}
|
||||
|
||||
// Re-fetch the host to get the updated state.
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err)
|
||||
}
|
||||
|
||||
result := HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken}
|
||||
if s.CA != nil {
|
||||
result.CertPEM = hc.CertPEM
|
||||
result.KeyPEM = hc.KeyPEM
|
||||
result.CACertPEM = s.CA.PEM
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Refresh validates a refresh token, rotates it (revokes old, issues new),
|
||||
// and returns a fresh JWT plus the new refresh token.
|
||||
func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRefreshResult, error) {
|
||||
hash := hashToken(refreshToken)
|
||||
|
||||
row, err := s.DB.GetHostRefreshTokenByHash(ctx, hash)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return HostRefreshResult{}, fmt.Errorf("invalid or expired refresh token")
|
||||
}
|
||||
if err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("lookup refresh token: %w", err)
|
||||
}
|
||||
|
||||
host, err := s.DB.GetHost(ctx, row.HostID)
|
||||
if err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
|
||||
// Sign new JWT.
|
||||
hostJWT, err := auth.SignHostJWT(s.JWT, host.ID)
|
||||
if err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("sign host JWT: %w", err)
|
||||
}
|
||||
|
||||
// Renew mTLS certificate if CA is configured.
|
||||
var hc auth.HostCert
|
||||
if s.CA != nil {
|
||||
hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(host.ID), host.Address)
|
||||
if err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("renew host cert: %w", err)
|
||||
}
|
||||
if err := s.DB.UpdateHostCert(ctx, db.UpdateHostCertParams{
|
||||
ID: host.ID,
|
||||
CertFingerprint: hc.Fingerprint,
|
||||
CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: true},
|
||||
}); err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("update host cert: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Issue-then-revoke rotation: insert new token first so a crash between
|
||||
// the two DB calls leaves the host with two valid tokens rather than zero.
|
||||
newRefreshToken, err := s.issueRefreshToken(ctx, host.ID)
|
||||
if err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("issue new refresh token: %w", err)
|
||||
}
|
||||
|
||||
// Revoke old refresh token after the new one is safely persisted.
|
||||
if err := s.DB.RevokeHostRefreshToken(ctx, row.ID); err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("revoke old refresh token: %w", err)
|
||||
}
|
||||
|
||||
result := HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken}
|
||||
if s.CA != nil {
|
||||
result.CertPEM = hc.CertPEM
|
||||
result.KeyPEM = hc.KeyPEM
|
||||
result.CACertPEM = s.CA.PEM
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// issueRefreshToken creates a new refresh token record in the DB and returns
|
||||
// the opaque token string.
|
||||
func (s *HostService) issueRefreshToken(ctx context.Context, hostID pgtype.UUID) (string, error) {
|
||||
token := id.NewRefreshToken()
|
||||
hash := hashToken(token)
|
||||
now := time.Now()
|
||||
|
||||
if _, err := s.DB.InsertHostRefreshToken(ctx, db.InsertHostRefreshTokenParams{
|
||||
ID: id.NewRefreshTokenID(),
|
||||
HostID: hostID,
|
||||
TokenHash: hash,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: now.Add(auth.HostRefreshTokenExpiry), Valid: true},
|
||||
}); err != nil {
|
||||
return "", fmt.Errorf("insert refresh token: %w", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// hashToken returns the hex-encoded SHA-256 hash of the token.
|
||||
func hashToken(token string) string {
|
||||
h := sha256.Sum256([]byte(token))
|
||||
return fmt.Sprintf("%x", h)
|
||||
}
|
||||
|
||||
// Heartbeat updates the last heartbeat timestamp for a host and transitions
|
||||
// any 'unreachable' host back to 'online'. Returns a "host not found" error
|
||||
// (which becomes 404) if the host record no longer exists (e.g., was deleted).
|
||||
func (s *HostService) Heartbeat(ctx context.Context, hostID pgtype.UUID) error {
|
||||
n, err := s.DB.UpdateHostHeartbeatAndStatus(ctx, hostID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
return fmt.Errorf("host not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// List returns hosts visible to the caller.
|
||||
// Admins see all hosts; non-admins see only BYOC hosts belonging to their team.
|
||||
func (s *HostService) List(ctx context.Context, teamID pgtype.UUID, isAdmin bool) ([]db.Host, error) {
|
||||
if isAdmin {
|
||||
return s.DB.ListHosts(ctx)
|
||||
}
|
||||
return s.DB.ListHostsByTeam(ctx, teamID)
|
||||
}
|
||||
|
||||
// Get returns a single host, enforcing access control.
|
||||
func (s *HostService) Get(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) {
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return db.Host{}, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
if !isAdmin {
|
||||
if !host.TeamID.Valid || host.TeamID != teamID {
|
||||
return db.Host{}, fmt.Errorf("host not found")
|
||||
}
|
||||
}
|
||||
return host, nil
|
||||
}
|
||||
|
||||
// DeletePreview returns what would be affected by deleting the host, without
|
||||
// making any changes. Use this to show the user a confirmation prompt.
|
||||
func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (HostDeletePreview, error) {
|
||||
host, err := s.checkDeletePermission(ctx, hostID, pgtype.UUID{}, teamID, isAdmin)
|
||||
if err != nil {
|
||||
return HostDeletePreview{}, err
|
||||
}
|
||||
|
||||
sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||
HostID: hostID,
|
||||
Column2: []string{"pending", "starting", "running", "missing"},
|
||||
})
|
||||
if err != nil {
|
||||
return HostDeletePreview{}, fmt.Errorf("list sandboxes: %w", err)
|
||||
}
|
||||
|
||||
ids := make([]string, len(sandboxes))
|
||||
for i, sb := range sandboxes {
|
||||
ids[i] = id.FormatSandboxID(sb.ID)
|
||||
}
|
||||
|
||||
return HostDeletePreview{Host: host, SandboxIDs: ids}, nil
|
||||
}
|
||||
|
||||
// Delete removes a host. Without force it returns an error listing active
|
||||
// sandboxes so the caller can present a confirmation. With force it gracefully
|
||||
// destroys all running sandboxes before deleting the host record.
|
||||
func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin, force bool) error {
|
||||
host, err := s.checkDeletePermission(ctx, hostID, userID, teamID, isAdmin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||
HostID: hostID,
|
||||
Column2: []string{"pending", "starting", "running", "missing"},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("list sandboxes: %w", err)
|
||||
}
|
||||
|
||||
if len(sandboxes) > 0 && !force {
|
||||
ids := make([]string, len(sandboxes))
|
||||
for i, sb := range sandboxes {
|
||||
ids[i] = id.FormatSandboxID(sb.ID)
|
||||
}
|
||||
return &HostHasSandboxesError{SandboxIDs: ids}
|
||||
}
|
||||
|
||||
hostIDStr := id.FormatHostID(hostID)
|
||||
|
||||
// Gracefully destroy running sandboxes and terminate the agent (best-effort).
|
||||
if host.Address != "" {
|
||||
agent, err := s.Pool.GetForHost(host)
|
||||
if err == nil {
|
||||
for _, sb := range sandboxes {
|
||||
if sb.Status == "running" || sb.Status == "starting" {
|
||||
_, rpcErr := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
||||
SandboxId: id.FormatSandboxID(sb.ID),
|
||||
}))
|
||||
if rpcErr != nil && connect.CodeOf(rpcErr) != connect.CodeNotFound {
|
||||
slog.Warn("delete host: failed to destroy sandbox on agent", "sandbox_id", id.FormatSandboxID(sb.ID), "error", rpcErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Tell the agent to shut itself down immediately.
|
||||
if _, rpcErr := agent.Terminate(ctx, connect.NewRequest(&pb.TerminateRequest{})); rpcErr != nil {
|
||||
slog.Warn("delete host: failed to send Terminate to agent", "host_id", hostIDStr, "error", rpcErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mark all affected sandboxes as stopped in DB.
|
||||
if len(sandboxes) > 0 {
|
||||
sbIDs := make([]pgtype.UUID, len(sandboxes))
|
||||
for i, sb := range sandboxes {
|
||||
sbIDs[i] = sb.ID
|
||||
}
|
||||
if err := s.DB.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: sbIDs,
|
||||
Status: "stopped",
|
||||
}); err != nil {
|
||||
slog.Warn("delete host: failed to mark sandboxes stopped", "host_id", hostIDStr, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Revoke all refresh tokens for this host.
|
||||
if err := s.DB.RevokeHostRefreshTokensByHost(ctx, hostID); err != nil {
|
||||
slog.Warn("delete host: failed to revoke refresh tokens", "host_id", hostIDStr, "error", err)
|
||||
}
|
||||
|
||||
// Evict the client from the pool so no further RPCs are sent.
|
||||
if s.Pool != nil {
|
||||
s.Pool.Evict(id.FormatHostID(hostID))
|
||||
}
|
||||
|
||||
return s.DB.DeleteHost(ctx, hostID)
|
||||
}
|
||||
|
||||
// checkDeletePermission verifies the caller has permission to delete the given
|
||||
// host and returns the host record on success.
|
||||
func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) {
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return db.Host{}, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
|
||||
if isAdmin {
|
||||
return host, nil
|
||||
}
|
||||
|
||||
if host.Type != "byoc" {
|
||||
return db.Host{}, fmt.Errorf("forbidden: only admins can delete regular hosts")
|
||||
}
|
||||
if !host.TeamID.Valid || host.TeamID != teamID {
|
||||
return db.Host{}, fmt.Errorf("forbidden: host does not belong to your team")
|
||||
}
|
||||
|
||||
if userID.Valid {
|
||||
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return db.Host{}, fmt.Errorf("forbidden: not a member of the specified team")
|
||||
}
|
||||
if err != nil {
|
||||
return db.Host{}, fmt.Errorf("check team membership: %w", err)
|
||||
}
|
||||
if err := requireAdminOrOwner(membership.Role); err != nil {
|
||||
return db.Host{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return host, nil
|
||||
}
|
||||
|
||||
// AddTag adds a tag to a host.
|
||||
func (s *HostService) AddTag(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool, tag string) error {
|
||||
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.DB.AddHostTag(ctx, db.AddHostTagParams{HostID: hostID, Tag: tag})
|
||||
}
|
||||
|
||||
// RemoveTag removes a tag from a host.
|
||||
func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool, tag string) error {
|
||||
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.DB.RemoveHostTag(ctx, db.RemoveHostTagParams{HostID: hostID, Tag: tag})
|
||||
}
|
||||
|
||||
// ListTags returns all tags for a host.
|
||||
func (s *HostService) ListTags(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) ([]string, error) {
|
||||
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.DB.GetHostTags(ctx, hostID)
|
||||
}
|
||||
|
||||
// HostHasSandboxesError is returned by Delete when the host has active sandboxes
|
||||
// and force was not set. The caller should present the list to the user and
|
||||
// re-call Delete with force=true if they confirm.
|
||||
type HostHasSandboxesError struct {
|
||||
SandboxIDs []string
|
||||
}
|
||||
|
||||
func (e *HostHasSandboxesError) Error() string {
|
||||
return fmt.Sprintf("host has %d active sandbox(es): %v", len(e.SandboxIDs), e.SandboxIDs)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user