forked from wrenn/wrenn
v0.1.0 (#17)
This commit is contained in:
@ -6,8 +6,8 @@ import (
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
|
||||
@ -17,10 +17,9 @@ import (
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
)
|
||||
|
||||
// Sentinel errors returned by proxyTarget, used to map to HTTP status codes
|
||||
@ -44,7 +43,7 @@ func (e errProxySandboxNotRunning) Error() string {
|
||||
return fmt.Sprintf("sandbox is not running (status: %s)", e.status)
|
||||
}
|
||||
|
||||
// proxyCacheEntry caches the resolved agent URL for a (sandbox, team) pair.
|
||||
// proxyCacheEntry caches the resolved agent URL for a sandbox.
|
||||
// The *httputil.ReverseProxy is built per-request (cheap) so the Director closure
|
||||
// can capture the correct port without the cache key needing to include it.
|
||||
type proxyCacheEntry struct {
|
||||
@ -52,23 +51,13 @@ type proxyCacheEntry struct {
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// proxyCacheKey is a fixed-size key from two UUIDs, avoids string allocation.
|
||||
type proxyCacheKey [32]byte
|
||||
|
||||
func makeProxyCacheKey(sandboxID, teamID pgtype.UUID) proxyCacheKey {
|
||||
var k proxyCacheKey
|
||||
copy(k[:16], sandboxID.Bytes[:])
|
||||
copy(k[16:], teamID.Bytes[:])
|
||||
return k
|
||||
}
|
||||
|
||||
// SandboxProxyWrapper wraps an existing HTTP handler and intercepts requests
|
||||
// whose Host header matches the {port}-{sandbox_id}.{domain} pattern. Matching
|
||||
// requests are reverse-proxied through the host agent that owns the sandbox.
|
||||
// All other requests are passed through to the inner handler.
|
||||
//
|
||||
// Authentication is via X-API-Key header only (no JWT). The API key's team
|
||||
// must own the sandbox.
|
||||
// No authentication is required — sandbox URLs are unguessable and access is
|
||||
// scoped to the sandbox ID embedded in the hostname.
|
||||
type SandboxProxyWrapper struct {
|
||||
inner http.Handler
|
||||
db *db.Queries
|
||||
@ -76,7 +65,7 @@ type SandboxProxyWrapper struct {
|
||||
transport http.RoundTripper
|
||||
|
||||
cacheMu sync.Mutex
|
||||
cache map[proxyCacheKey]proxyCacheEntry
|
||||
cache map[pgtype.UUID]proxyCacheEntry
|
||||
}
|
||||
|
||||
// NewSandboxProxyWrapper creates a new proxy wrapper.
|
||||
@ -86,19 +75,15 @@ func NewSandboxProxyWrapper(inner http.Handler, queries *db.Queries, pool *lifec
|
||||
db: queries,
|
||||
pool: pool,
|
||||
transport: pool.Transport(),
|
||||
cache: make(map[proxyCacheKey]proxyCacheEntry),
|
||||
cache: make(map[pgtype.UUID]proxyCacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
// proxyTarget looks up the cached agent URL for (sandboxID, teamID).
|
||||
// proxyTarget looks up the cached agent URL for sandboxID.
|
||||
// On a miss it queries the DB, resolves the address, and populates the cache.
|
||||
// The *httputil.ReverseProxy is built by the caller so the Director closure
|
||||
// captures the correct port without the cache key needing to include it.
|
||||
func (h *SandboxProxyWrapper) proxyTarget(ctx context.Context, sandboxID, teamID pgtype.UUID) (*url.URL, error) {
|
||||
cacheKey := makeProxyCacheKey(sandboxID, teamID)
|
||||
|
||||
func (h *SandboxProxyWrapper) proxyTarget(ctx context.Context, sandboxID pgtype.UUID) (*url.URL, error) {
|
||||
h.cacheMu.Lock()
|
||||
entry, ok := h.cache[cacheKey]
|
||||
entry, ok := h.cache[sandboxID]
|
||||
h.cacheMu.Unlock()
|
||||
|
||||
if ok && time.Now().Before(entry.expiresAt) {
|
||||
@ -106,10 +91,7 @@ func (h *SandboxProxyWrapper) proxyTarget(ctx context.Context, sandboxID, teamID
|
||||
}
|
||||
|
||||
// Cache miss or expired — query DB.
|
||||
target, err := h.db.GetSandboxProxyTarget(ctx, db.GetSandboxProxyTargetParams{
|
||||
ID: sandboxID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
target, err := h.db.GetSandboxProxyTarget(ctx, sandboxID)
|
||||
if err != nil {
|
||||
return nil, errProxySandboxNotFound
|
||||
}
|
||||
@ -126,7 +108,7 @@ func (h *SandboxProxyWrapper) proxyTarget(ctx context.Context, sandboxID, teamID
|
||||
}
|
||||
|
||||
h.cacheMu.Lock()
|
||||
h.cache[cacheKey] = proxyCacheEntry{
|
||||
h.cache[sandboxID] = proxyCacheEntry{
|
||||
agentURL: agentURL,
|
||||
expiresAt: time.Now().Add(proxyCacheTTL),
|
||||
}
|
||||
@ -135,11 +117,11 @@ func (h *SandboxProxyWrapper) proxyTarget(ctx context.Context, sandboxID, teamID
|
||||
return agentURL, nil
|
||||
}
|
||||
|
||||
// evictProxyCache removes the cached entry for a (sandbox, team) pair.
|
||||
// evictProxyCache removes the cached entry for a sandbox.
|
||||
// Called on 502 so a stopped/moved sandbox is re-resolved on the next request.
|
||||
func (h *SandboxProxyWrapper) evictProxyCache(sandboxID, teamID pgtype.UUID) {
|
||||
func (h *SandboxProxyWrapper) evictProxyCache(sandboxID pgtype.UUID) {
|
||||
h.cacheMu.Lock()
|
||||
delete(h.cache, makeProxyCacheKey(sandboxID, teamID))
|
||||
delete(h.cache, sandboxID)
|
||||
h.cacheMu.Unlock()
|
||||
}
|
||||
|
||||
@ -166,20 +148,13 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
// Authenticate: require API key or JWT, extract team ID.
|
||||
teamID, err := h.authenticateRequest(r)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid sandbox ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
agentURL, err := h.proxyTarget(r.Context(), sandboxID, teamID)
|
||||
agentURL, err := h.proxyTarget(r.Context(), sandboxID)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, errProxySandboxNotFound):
|
||||
@ -206,25 +181,9 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request)
|
||||
"port", port,
|
||||
"error", err,
|
||||
)
|
||||
h.evictProxyCache(sandboxID, teamID)
|
||||
h.evictProxyCache(sandboxID)
|
||||
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
|
||||
},
|
||||
}
|
||||
proxy.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// authenticateRequest validates the request's API key and returns the team ID.
|
||||
// Only API key authentication is supported for sandbox proxy requests (not JWT).
|
||||
func (h *SandboxProxyWrapper) authenticateRequest(r *http.Request) (pgtype.UUID, error) {
|
||||
key := r.Header.Get("X-API-Key")
|
||||
if key == "" {
|
||||
return pgtype.UUID{}, fmt.Errorf("X-API-Key header required")
|
||||
}
|
||||
|
||||
hash := auth.HashAPIKey(key)
|
||||
row, err := h.db.GetAPIKeyByHash(r.Context(), hash)
|
||||
if err != nil {
|
||||
return pgtype.UUID{}, fmt.Errorf("invalid API key")
|
||||
}
|
||||
return row.TeamID, nil
|
||||
}
|
||||
|
||||
248
internal/api/handlers_admin_capsules.go
Normal file
248
internal/api/handlers_admin_capsules.go
Normal file
@ -0,0 +1,248 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/validate"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
type adminCapsuleHandler struct {
|
||||
svc *service.SandboxService
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
audit *audit.AuditLogger
|
||||
}
|
||||
|
||||
func newAdminCapsuleHandler(svc *service.SandboxService, db *db.Queries, pool *lifecycle.HostClientPool, al *audit.AuditLogger) *adminCapsuleHandler {
|
||||
return &adminCapsuleHandler{svc: svc, db: db, pool: pool, audit: al}
|
||||
}
|
||||
|
||||
// Create handles POST /v1/admin/capsules.
|
||||
func (h *adminCapsuleHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
var req createSandboxRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sb, err := h.svc.Create(r.Context(), service.SandboxCreateParams{
|
||||
TeamID: id.PlatformTeamID,
|
||||
Template: req.Template,
|
||||
VCPUs: req.VCPUs,
|
||||
MemoryMB: req.MemoryMB,
|
||||
TimeoutSec: req.TimeoutSec,
|
||||
})
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSandboxCreate(r.Context(), ac, sb.ID, sb.Template)
|
||||
writeJSON(w, http.StatusCreated, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// List handles GET /v1/admin/capsules.
|
||||
func (h *adminCapsuleHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxes, err := h.svc.List(r.Context(), id.PlatformTeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list sandboxes")
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]sandboxResponse, len(sandboxes))
|
||||
for i, sb := range sandboxes {
|
||||
resp[i] = sandboxToResponse(sb)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Get handles GET /v1/admin/capsules/{id}.
|
||||
func (h *adminCapsuleHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.svc.Get(r.Context(), sandboxID, id.PlatformTeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// Destroy handles DELETE /v1/admin/capsules/{id}.
|
||||
func (h *adminCapsuleHandler) Destroy(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.Destroy(r.Context(), sandboxID, id.PlatformTeamID); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSandboxDestroy(r.Context(), ac, sandboxID)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
type adminSnapshotRequest struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Snapshot handles POST /v1/admin/capsules/{id}/snapshot.
|
||||
// Pauses the capsule, takes a snapshot as a platform template, then destroys the capsule.
|
||||
func (h *adminCapsuleHandler) Snapshot(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req adminSnapshotRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
req.Name = id.NewSnapshotName()
|
||||
}
|
||||
if err := validate.SafeName(req.Name); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid snapshot name: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
// Verify sandbox exists and belongs to platform team BEFORE any
|
||||
// destructive operations (template overwrite).
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: id.PlatformTeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" && sb.Status != "paused" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox must be running or paused")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if name already exists as a platform template.
|
||||
if existing, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
|
||||
// Delete old snapshot files from all hosts before removing the DB record.
|
||||
if err := deleteSnapshotBroadcast(ctx, h.db, h.pool, existing.TeamID, existing.ID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files")
|
||||
return
|
||||
}
|
||||
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: id.PlatformTeamID}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to remove existing template record")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
// Pre-mark sandbox as "paused" to prevent the reconciler from racing.
|
||||
if sb.Status == "running" {
|
||||
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "paused",
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to update sandbox status")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Use a detached context so the snapshot completes even if the client disconnects.
|
||||
snapCtx, snapCancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer snapCancel()
|
||||
|
||||
newTemplateID := id.NewTemplateID()
|
||||
|
||||
resp, err := agent.CreateSnapshot(snapCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Name: req.Name,
|
||||
TeamId: formatUUIDForRPC(id.PlatformTeamID),
|
||||
TemplateId: formatUUIDForRPC(newTemplateID),
|
||||
}))
|
||||
if err != nil {
|
||||
// Snapshot failed — revert status.
|
||||
if sb.Status == "running" {
|
||||
if _, dbErr := h.db.UpdateSandboxStatus(snapCtx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "running",
|
||||
}); dbErr != nil {
|
||||
slog.Error("failed to revert sandbox status after snapshot error", "sandbox_id", sandboxIDStr, "error", dbErr)
|
||||
}
|
||||
}
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{
|
||||
ID: newTemplateID,
|
||||
Name: req.Name,
|
||||
Type: "snapshot",
|
||||
Vcpus: sb.Vcpus,
|
||||
MemoryMb: sb.MemoryMb,
|
||||
SizeBytes: resp.Msg.SizeBytes,
|
||||
TeamID: id.PlatformTeamID,
|
||||
DefaultUser: "root",
|
||||
DefaultEnv: []byte("{}"),
|
||||
Metadata: sb.Metadata,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("failed to insert template record", "name", req.Name, "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "snapshot created but failed to record in database")
|
||||
return
|
||||
}
|
||||
|
||||
// Destroy the ephemeral capsule after successful snapshot.
|
||||
if err := h.svc.Destroy(snapCtx, sandboxID, id.PlatformTeamID); err != nil {
|
||||
slog.Error("failed to destroy capsule after snapshot", "sandbox_id", sandboxIDStr, "error", err)
|
||||
// Don't fail the response — the snapshot was created successfully.
|
||||
}
|
||||
|
||||
h.audit.LogSnapshotCreate(snapCtx, ac, req.Name)
|
||||
|
||||
if ctx.Err() != nil {
|
||||
slog.Info("snapshot created but client disconnected before response", "name", req.Name)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, templateToResponse(tmpl))
|
||||
}
|
||||
@ -6,11 +6,11 @@ import (
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/service"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
)
|
||||
|
||||
type apiKeyHandler struct {
|
||||
@ -63,7 +63,7 @@ func apiKeyWithCreatorToResponse(k db.ListAPIKeysByTeamWithCreatorRow) apiKeyRes
|
||||
Name: k.Name,
|
||||
KeyPrefix: k.KeyPrefix,
|
||||
CreatedBy: id.FormatUserID(k.CreatedBy),
|
||||
CreatorEmail: k.CreatorEmail,
|
||||
CreatorEmail: k.CreatorEmail.String,
|
||||
}
|
||||
if k.CreatedAt.Valid {
|
||||
resp.CreatedAt = k.CreatedAt.Time.Format(time.RFC3339)
|
||||
|
||||
@ -8,9 +8,9 @@ import (
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/service"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
)
|
||||
|
||||
type auditHandler struct {
|
||||
|
||||
@ -2,19 +2,32 @@ package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/email"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
const (
|
||||
activationKeyPrefix = "wrenn:activation:"
|
||||
activationTTL = 30 * time.Minute
|
||||
signupCooldown = 30 * time.Minute
|
||||
)
|
||||
|
||||
// loginTeam returns the team and role to stamp into a login JWT.
|
||||
@ -52,18 +65,89 @@ func loginTeam(ctx context.Context, q *db.Queries, userID pgtype.UUID) (db.Team,
|
||||
}, first.Role, nil
|
||||
}
|
||||
|
||||
// ensureDefaultTeam creates a default team for a user if they have none.
|
||||
// This happens on first login after activation or for edge cases where a user
|
||||
// has no teams. Returns the team, role, and whether the user was set as admin.
|
||||
func ensureDefaultTeam(ctx context.Context, qtx *db.Queries, pool *pgxpool.Pool, userID pgtype.UUID, userName string) (db.Team, string, bool, error) {
|
||||
// Try existing teams first.
|
||||
team, role, err := loginTeam(ctx, qtx, userID)
|
||||
if err == nil {
|
||||
return team, role, false, nil
|
||||
}
|
||||
if !errors.Is(err, pgx.ErrNoRows) {
|
||||
return db.Team{}, "", false, err
|
||||
}
|
||||
|
||||
// No teams — create default team in a transaction.
|
||||
tx, err := pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return db.Team{}, "", false, fmt.Errorf("begin tx: %w", err)
|
||||
}
|
||||
defer tx.Rollback(ctx) //nolint:errcheck
|
||||
|
||||
txq := qtx.WithTx(tx)
|
||||
|
||||
// First active user to have a team becomes admin.
|
||||
activeCount, err := txq.CountActiveUsers(ctx)
|
||||
if err != nil {
|
||||
return db.Team{}, "", false, fmt.Errorf("count active users: %w", err)
|
||||
}
|
||||
isFirstUser := activeCount == 1 // only this user is active
|
||||
|
||||
teamID := id.NewTeamID()
|
||||
teamRow, err := txq.InsertTeam(ctx, db.InsertTeamParams{
|
||||
ID: teamID,
|
||||
Name: userName + "'s Team",
|
||||
Slug: id.NewTeamSlug(),
|
||||
})
|
||||
if err != nil {
|
||||
return db.Team{}, "", false, fmt.Errorf("insert team: %w", err)
|
||||
}
|
||||
|
||||
if err := txq.InsertTeamMember(ctx, db.InsertTeamMemberParams{
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
IsDefault: true,
|
||||
Role: "owner",
|
||||
}); err != nil {
|
||||
return db.Team{}, "", false, fmt.Errorf("insert team member: %w", err)
|
||||
}
|
||||
|
||||
if isFirstUser {
|
||||
if err := txq.SetUserAdmin(ctx, db.SetUserAdminParams{ID: userID, IsAdmin: true}); err != nil {
|
||||
return db.Team{}, "", false, fmt.Errorf("set admin: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return db.Team{}, "", false, fmt.Errorf("commit: %w", err)
|
||||
}
|
||||
|
||||
return db.Team{
|
||||
ID: teamRow.ID,
|
||||
Name: teamRow.Name,
|
||||
Slug: teamRow.Slug,
|
||||
IsByoc: teamRow.IsByoc,
|
||||
CreatedAt: teamRow.CreatedAt,
|
||||
DeletedAt: teamRow.DeletedAt,
|
||||
}, "owner", isFirstUser, nil
|
||||
}
|
||||
|
||||
type switchTeamRequest struct {
|
||||
TeamID string `json:"team_id"`
|
||||
}
|
||||
|
||||
type authHandler struct {
|
||||
db *db.Queries
|
||||
pool *pgxpool.Pool
|
||||
jwtSecret []byte
|
||||
db *db.Queries
|
||||
pool *pgxpool.Pool
|
||||
jwtSecret []byte
|
||||
mailer email.Mailer
|
||||
rdb *redis.Client
|
||||
redirectURL string
|
||||
}
|
||||
|
||||
func newAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte) *authHandler {
|
||||
return &authHandler{db: db, pool: pool, jwtSecret: jwtSecret}
|
||||
func newAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte, mailer email.Mailer, rdb *redis.Client, redirectURL string) *authHandler {
|
||||
return &authHandler{db: db, pool: pool, jwtSecret: jwtSecret, mailer: mailer, rdb: rdb, redirectURL: strings.TrimRight(redirectURL, "/")}
|
||||
}
|
||||
|
||||
type signupRequest struct {
|
||||
@ -77,6 +161,10 @@ type loginRequest struct {
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type activateRequest struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type authResponse struct {
|
||||
Token string `json:"token"`
|
||||
UserID string `json:"user_id"`
|
||||
@ -85,6 +173,10 @@ type authResponse struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type signupResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Signup handles POST /v1/auth/signup.
|
||||
func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
var req signupRequest
|
||||
@ -110,24 +202,41 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
// Check for existing user with this email.
|
||||
existing, err := h.db.GetUserByEmail(ctx, req.Email)
|
||||
if err == nil {
|
||||
// User exists — decide what to do based on status.
|
||||
switch existing.Status {
|
||||
case "inactive":
|
||||
// Unactivated user — allow re-signup after cooldown.
|
||||
if time.Since(existing.CreatedAt.Time) < signupCooldown {
|
||||
writeError(w, http.StatusConflict, "signup_cooldown",
|
||||
"an activation email was recently sent to this address — please check your inbox or try again later")
|
||||
return
|
||||
}
|
||||
// Cooldown passed — delete the old row and proceed with fresh signup.
|
||||
if err := h.db.HardDeleteUser(ctx, existing.ID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to clean up previous signup")
|
||||
return
|
||||
}
|
||||
default:
|
||||
// active, disabled, deleted — email is taken.
|
||||
writeError(w, http.StatusConflict, "email_taken", "an account with this email already exists")
|
||||
return
|
||||
}
|
||||
} else if !errors.Is(err, pgx.ErrNoRows) {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up user")
|
||||
return
|
||||
}
|
||||
|
||||
passwordHash, err := auth.HashPassword(req.Password)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to hash password")
|
||||
return
|
||||
}
|
||||
|
||||
// Use a transaction to atomically create user + team + membership.
|
||||
tx, err := h.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to begin transaction")
|
||||
return
|
||||
}
|
||||
defer tx.Rollback(ctx) //nolint:errcheck
|
||||
|
||||
qtx := h.db.WithTx(tx)
|
||||
|
||||
userID := id.NewUserID()
|
||||
_, err = qtx.InsertUser(ctx, db.InsertUserParams{
|
||||
_, err = h.db.InsertUserInactive(ctx, db.InsertUserInactiveParams{
|
||||
ID: userID,
|
||||
Email: req.Email,
|
||||
PasswordHash: pgtype.Text{String: passwordHash, Valid: true},
|
||||
@ -143,44 +252,111 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Create default team.
|
||||
teamID := id.NewTeamID()
|
||||
if _, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
|
||||
ID: teamID,
|
||||
Name: req.Name + "'s Team",
|
||||
Slug: id.NewTeamSlug(),
|
||||
// Generate activation token and store in Redis.
|
||||
rawToken := generateActivationToken()
|
||||
tokenHash := hashActivationToken(rawToken)
|
||||
redisKey := activationKeyPrefix + tokenHash
|
||||
|
||||
if err := h.rdb.Set(ctx, redisKey, id.FormatUserID(userID), activationTTL).Err(); err != nil {
|
||||
slog.Error("signup: failed to store activation token in redis", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to create activation token")
|
||||
return
|
||||
}
|
||||
|
||||
activateURL := h.redirectURL + "/activate?token=" + rawToken
|
||||
go func() {
|
||||
sendCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := h.mailer.Send(sendCtx, req.Email, "Activate your Wrenn account", email.EmailData{
|
||||
RecipientName: req.Name,
|
||||
Message: "Welcome to Wrenn! Click the button below to activate your account. This link expires in 30 minutes.",
|
||||
Button: &email.Button{Text: "Activate Account", URL: activateURL},
|
||||
Closing: "If you didn't create this account, you can safely ignore this email.",
|
||||
}); err != nil {
|
||||
slog.Warn("signup: failed to send activation email", "email", req.Email, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
writeJSON(w, http.StatusCreated, signupResponse{
|
||||
Message: "Account created. Please check your email to activate your account.",
|
||||
})
|
||||
}
|
||||
|
||||
// Activate handles POST /v1/auth/activate.
|
||||
func (h *authHandler) Activate(w http.ResponseWriter, r *http.Request) {
|
||||
var req activateRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Token == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "token is required")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
tokenHash := hashActivationToken(req.Token)
|
||||
redisKey := activationKeyPrefix + tokenHash
|
||||
|
||||
userIDStr, err := h.rdb.GetDel(ctx, redisKey).Result()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
writeError(w, http.StatusBadRequest, "invalid_token", "activation link is invalid or has expired")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to verify token")
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := id.ParseUserID(userIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "invalid stored user ID")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.db.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
|
||||
return
|
||||
}
|
||||
|
||||
if user.Status != "inactive" {
|
||||
writeError(w, http.StatusBadRequest, "already_activated", "this account has already been activated")
|
||||
return
|
||||
}
|
||||
|
||||
// Activate the user.
|
||||
if err := h.db.SetUserStatus(ctx, db.SetUserStatusParams{
|
||||
ID: userID,
|
||||
Status: "active",
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to create team")
|
||||
slog.Error("activate: failed to set user status", "user_id", id.FormatUserID(userID), "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to activate user")
|
||||
return
|
||||
}
|
||||
|
||||
if err := qtx.InsertTeamMember(ctx, db.InsertTeamMemberParams{
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
IsDefault: true,
|
||||
Role: "owner",
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to add user to team")
|
||||
// Create default team and log them in.
|
||||
team, role, isFirstUser, err := ensureDefaultTeam(ctx, h.db, h.pool, userID, user.Name)
|
||||
if err != nil {
|
||||
slog.Error("activate: failed to create default team", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to set up account")
|
||||
return
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to commit signup")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, req.Email, req.Name, "owner", false)
|
||||
isAdmin := user.IsAdmin || isFirstUser
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, team.ID, user.Email, user.Name, role, isAdmin)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, authResponse{
|
||||
writeJSON(w, http.StatusOK, authResponse{
|
||||
Token: token,
|
||||
UserID: id.FormatUserID(userID),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Email: req.Email,
|
||||
Name: req.Name,
|
||||
TeamID: id.FormatTeamID(team.ID),
|
||||
Email: user.Email,
|
||||
Name: user.Name,
|
||||
})
|
||||
}
|
||||
|
||||
@ -222,17 +398,36 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
team, role, err := loginTeam(ctx, h.db, user.ID)
|
||||
switch user.Status {
|
||||
case "active":
|
||||
// OK — proceed.
|
||||
case "inactive":
|
||||
slog.Warn("login failed: account not activated", "email", req.Email, "ip", r.RemoteAddr)
|
||||
writeError(w, http.StatusForbidden, "account_not_activated", "please check your email and activate your account before signing in")
|
||||
return
|
||||
case "disabled":
|
||||
slog.Warn("login failed: account disabled", "email", req.Email, "ip", r.RemoteAddr)
|
||||
writeError(w, http.StatusForbidden, "account_disabled", "your account has been deactivated — contact your administrator to regain access")
|
||||
return
|
||||
case "deleted":
|
||||
slog.Warn("login failed: account deleted", "email", req.Email, "ip", r.RemoteAddr)
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
|
||||
return
|
||||
default:
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure user has a default team (creates one on first login after activation).
|
||||
team, role, isFirstUser, err := ensureDefaultTeam(ctx, h.db, h.pool, user.ID, user.Name)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
writeError(w, http.StatusForbidden, "no_team", "user is not a member of any team")
|
||||
return
|
||||
}
|
||||
slog.Error("login: failed to ensure default team", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up team")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
|
||||
isAdmin := user.IsAdmin || isFirstUser
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
return
|
||||
@ -322,3 +517,18 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) {
|
||||
Name: user.Name,
|
||||
})
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func generateActivationToken() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand failed: %v", err))
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func hashActivationToken(raw string) string {
|
||||
h := sha256.Sum256([]byte(raw))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
@ -3,19 +3,21 @@ package api
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/layout"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/service"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/validate"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/validate"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
@ -54,6 +56,8 @@ type buildResponse struct {
|
||||
Error *string `json:"error,omitempty"`
|
||||
SandboxID *string `json:"sandbox_id,omitempty"`
|
||||
HostID *string `json:"host_id,omitempty"`
|
||||
DefaultUser string `json:"default_user"`
|
||||
DefaultEnv json.RawMessage `json:"default_env"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
StartedAt *string `json:"started_at,omitempty"`
|
||||
CompletedAt *string `json:"completed_at,omitempty"`
|
||||
@ -71,6 +75,8 @@ func buildToResponse(b db.TemplateBuild) buildResponse {
|
||||
CurrentStep: b.CurrentStep,
|
||||
TotalSteps: b.TotalSteps,
|
||||
Logs: b.Logs,
|
||||
DefaultUser: b.DefaultUser,
|
||||
DefaultEnv: b.DefaultEnv,
|
||||
}
|
||||
if b.Healthcheck != "" {
|
||||
resp.Healthcheck = &b.Healthcheck
|
||||
@ -101,11 +107,54 @@ func buildToResponse(b db.TemplateBuild) buildResponse {
|
||||
}
|
||||
|
||||
// Create handles POST /v1/admin/builds.
|
||||
// Accepts either JSON body or multipart/form-data with a "config" JSON part
|
||||
// and an optional "archive" file part (tar/tar.gz/zip for COPY commands).
|
||||
func (h *buildHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
var req createBuildRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
var archive []byte
|
||||
var archiveName string
|
||||
|
||||
ct := r.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(ct, "multipart/") {
|
||||
// 100 MB max for multipart (archive + JSON config).
|
||||
if err := r.ParseMultipartForm(100 << 20); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "failed to parse multipart form")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse JSON config from "config" field.
|
||||
configStr := r.FormValue("config")
|
||||
if configStr == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "multipart form requires a 'config' JSON field")
|
||||
return
|
||||
}
|
||||
if err := json.Unmarshal([]byte(configStr), &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid config JSON in multipart form")
|
||||
return
|
||||
}
|
||||
|
||||
// Read optional archive file (max 100 MB).
|
||||
file, header, err := r.FormFile("archive")
|
||||
if err == nil {
|
||||
defer file.Close()
|
||||
const maxArchiveSize = 100 << 20 // 100 MB
|
||||
lr := io.LimitReader(file, maxArchiveSize+1)
|
||||
archive, err = io.ReadAll(lr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "failed to read archive file")
|
||||
return
|
||||
}
|
||||
if int64(len(archive)) > maxArchiveSize {
|
||||
writeError(w, http.StatusRequestEntityTooLarge, "invalid_request", "archive exceeds 100 MB limit")
|
||||
return
|
||||
}
|
||||
archiveName = header.Filename
|
||||
}
|
||||
} else {
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
@ -129,6 +178,8 @@ func (h *buildHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
VCPUs: req.VCPUs,
|
||||
MemoryMB: req.MemoryMB,
|
||||
SkipPrePost: req.SkipPrePost,
|
||||
Archive: archive,
|
||||
ArchiveName: archiveName,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("failed to create build", "error", err)
|
||||
|
||||
@ -8,11 +8,11 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/channels"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/channels"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
type channelHandler struct {
|
||||
|
||||
@ -12,10 +12,10 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
@ -29,9 +29,13 @@ func newExecHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execHandler
|
||||
}
|
||||
|
||||
type execRequest struct {
|
||||
Cmd string `json:"cmd"`
|
||||
Args []string `json:"args"`
|
||||
TimeoutSec int32 `json:"timeout_sec"`
|
||||
Cmd string `json:"cmd"`
|
||||
Args []string `json:"args"`
|
||||
TimeoutSec int32 `json:"timeout_sec"`
|
||||
Background bool `json:"background"`
|
||||
Tag string `json:"tag"`
|
||||
Envs map[string]string `json:"envs"`
|
||||
Cwd string `json:"cwd"`
|
||||
}
|
||||
|
||||
type execResponse struct {
|
||||
@ -45,7 +49,14 @@ type execResponse struct {
|
||||
Encoding string `json:"encoding"`
|
||||
}
|
||||
|
||||
// Exec handles POST /v1/sandboxes/{id}/exec.
|
||||
type backgroundExecResponse struct {
|
||||
SandboxID string `json:"sandbox_id"`
|
||||
Cmd string `json:"cmd"`
|
||||
PID uint32 `json:"pid"`
|
||||
Tag string `json:"tag"`
|
||||
}
|
||||
|
||||
// Exec handles POST /v1/capsules/{id}/exec.
|
||||
func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
@ -78,14 +89,54 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
// Background mode: start process and return immediately.
|
||||
if req.Background {
|
||||
tag := req.Tag
|
||||
if tag == "" {
|
||||
tag = "proc-" + id.NewPtyTag()
|
||||
}
|
||||
|
||||
bgResp, err := agent.StartBackground(ctx, connect.NewRequest(&pb.StartBackgroundRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Tag: tag,
|
||||
Cmd: req.Cmd,
|
||||
Args: req.Args,
|
||||
Envs: req.Envs,
|
||||
Cwd: req.Cwd,
|
||||
}))
|
||||
if err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.UpdateLastActive(ctx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusAccepted, backgroundExecResponse{
|
||||
SandboxID: sandboxIDStr,
|
||||
Cmd: req.Cmd,
|
||||
PID: bgResp.Msg.Pid,
|
||||
Tag: bgResp.Msg.Tag,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Cmd: req.Cmd,
|
||||
|
||||
@ -12,20 +12,21 @@ import (
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
type execStreamHandler struct {
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
jwtSecret []byte
|
||||
}
|
||||
|
||||
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execStreamHandler {
|
||||
return &execStreamHandler{db: db, pool: pool}
|
||||
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *execStreamHandler {
|
||||
return &execStreamHandler{db: db, pool: pool, jwtSecret: jwtSecret}
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
@ -47,11 +48,10 @@ type wsOutMsg struct {
|
||||
ExitCode *int32 `json:"exit_code,omitempty"` // only for "exit"
|
||||
}
|
||||
|
||||
// ExecStream handles WS /v1/sandboxes/{id}/exec/stream.
|
||||
// ExecStream handles WS /v1/capsules/{id}/exec/stream.
|
||||
func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
@ -59,13 +59,31 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
// Authenticate: use context from middleware (API key) or WS first message (JWT).
|
||||
ac, hasAuth := auth.FromContext(ctx)
|
||||
|
||||
if !hasAuth {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
var wsAC auth.AuthContext
|
||||
var authErr error
|
||||
if isAdminWSRoute(ctx) {
|
||||
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
|
||||
} else {
|
||||
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
|
||||
}
|
||||
if authErr != nil {
|
||||
sendWSError(conn, "authentication failed")
|
||||
return
|
||||
}
|
||||
ac = wsAC
|
||||
|
||||
h.runExecStream(ctx, conn, ac, sandboxID, sandboxIDStr)
|
||||
return
|
||||
}
|
||||
|
||||
@ -76,6 +94,20 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
h.runExecStream(ctx, conn, ac, sandboxID, sandboxIDStr)
|
||||
}
|
||||
|
||||
func (h *execStreamHandler) runExecStream(ctx context.Context, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr string) {
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
sendWSError(conn, "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
sendWSError(conn, "sandbox is not running (status: "+sb.Status+")")
|
||||
return
|
||||
}
|
||||
|
||||
// Read the start message.
|
||||
var startMsg wsStartMsg
|
||||
if err := conn.ReadJSON(&startMsg); err != nil {
|
||||
|
||||
@ -9,10 +9,10 @@ import (
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
@ -25,7 +25,7 @@ func newFilesHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesHandl
|
||||
return &filesHandler{db: db, pool: pool}
|
||||
}
|
||||
|
||||
// Upload handles POST /v1/sandboxes/{id}/files/write.
|
||||
// Upload handles POST /v1/capsules/{id}/files/write.
|
||||
// Expects multipart/form-data with:
|
||||
// - "path" text field: absolute destination path inside the sandbox
|
||||
// - "file" file field: binary content to write
|
||||
@ -105,7 +105,7 @@ type readFileRequest struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// Download handles POST /v1/sandboxes/{id}/files/read.
|
||||
// Download handles POST /v1/capsules/{id}/files/read.
|
||||
// Accepts JSON body with path, returns raw file content with Content-Disposition.
|
||||
func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
|
||||
@ -10,10 +10,10 @@ import (
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
@ -26,7 +26,7 @@ func newFilesStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *file
|
||||
return &filesStreamHandler{db: db, pool: pool}
|
||||
}
|
||||
|
||||
// StreamUpload handles POST /v1/sandboxes/{id}/files/stream/write.
|
||||
// StreamUpload handles POST /v1/capsules/{id}/files/stream/write.
|
||||
// Expects multipart/form-data with "path" text field and "file" file field.
|
||||
// Streams file content directly from the request body to the host agent without buffering.
|
||||
func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) {
|
||||
@ -150,7 +150,7 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// StreamDownload handles POST /v1/sandboxes/{id}/files/stream/read.
|
||||
// StreamDownload handles POST /v1/capsules/{id}/files/stream/read.
|
||||
// Accepts JSON body with path, streams file content back without buffering.
|
||||
func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
|
||||
236
internal/api/handlers_fs.go
Normal file
236
internal/api/handlers_fs.go
Normal file
@ -0,0 +1,236 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
type fsHandler struct {
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
}
|
||||
|
||||
func newFSHandler(db *db.Queries, pool *lifecycle.HostClientPool) *fsHandler {
|
||||
return &fsHandler{db: db, pool: pool}
|
||||
}
|
||||
|
||||
type listDirRequest struct {
|
||||
Path string `json:"path"`
|
||||
Depth uint32 `json:"depth"`
|
||||
}
|
||||
|
||||
type fileEntryResponse struct {
|
||||
Name string `json:"name"`
|
||||
Path string `json:"path"`
|
||||
Type string `json:"type"`
|
||||
Size int64 `json:"size"`
|
||||
Mode uint32 `json:"mode"`
|
||||
Permissions string `json:"permissions"`
|
||||
Owner string `json:"owner"`
|
||||
Group string `json:"group"`
|
||||
ModifiedAt int64 `json:"modified_at"`
|
||||
SymlinkTarget *string `json:"symlink_target,omitempty"`
|
||||
}
|
||||
|
||||
type listDirResponse struct {
|
||||
Entries []fileEntryResponse `json:"entries"`
|
||||
}
|
||||
|
||||
type makeDirRequest struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
type makeDirResponse struct {
|
||||
Entry fileEntryResponse `json:"entry"`
|
||||
}
|
||||
|
||||
type removeRequest struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// ListDir handles POST /v1/capsules/{id}/files/list.
|
||||
func (h *fsHandler) ListDir(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
return
|
||||
}
|
||||
|
||||
var req listDirRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
if req.Path == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "path is required")
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := agent.ListDir(ctx, connect.NewRequest(&pb.ListDirRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Path: req.Path,
|
||||
Depth: req.Depth,
|
||||
}))
|
||||
if err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
entries := make([]fileEntryResponse, 0, len(resp.Msg.Entries))
|
||||
for _, e := range resp.Msg.Entries {
|
||||
entries = append(entries, fileEntryFromPB(e))
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, listDirResponse{Entries: entries})
|
||||
}
|
||||
|
||||
// MakeDir handles POST /v1/capsules/{id}/files/mkdir.
|
||||
func (h *fsHandler) MakeDir(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
return
|
||||
}
|
||||
|
||||
var req makeDirRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
if req.Path == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "path is required")
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := agent.MakeDir(ctx, connect.NewRequest(&pb.MakeDirRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Path: req.Path,
|
||||
}))
|
||||
if err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, makeDirResponse{Entry: fileEntryFromPB(resp.Msg.Entry)})
|
||||
}
|
||||
|
||||
// Remove handles POST /v1/capsules/{id}/files/remove.
|
||||
func (h *fsHandler) Remove(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
return
|
||||
}
|
||||
|
||||
var req removeRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
if req.Path == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "path is required")
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := agent.RemovePath(ctx, connect.NewRequest(&pb.RemovePathRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Path: req.Path,
|
||||
})); err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func fileEntryFromPB(e *pb.FileEntry) fileEntryResponse {
|
||||
if e == nil {
|
||||
return fileEntryResponse{}
|
||||
}
|
||||
resp := fileEntryResponse{
|
||||
Name: e.Name,
|
||||
Path: e.Path,
|
||||
Type: e.Type,
|
||||
Size: e.Size,
|
||||
Mode: e.Mode,
|
||||
Permissions: e.Permissions,
|
||||
Owner: e.Owner,
|
||||
Group: e.Group,
|
||||
ModifiedAt: e.ModifiedAt,
|
||||
}
|
||||
if e.SymlinkTarget != nil {
|
||||
resp.SymlinkTarget = e.SymlinkTarget
|
||||
}
|
||||
return resp
|
||||
}
|
||||
@ -10,11 +10,11 @@ import (
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/service"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
)
|
||||
|
||||
type hostHandler struct {
|
||||
|
||||
585
internal/api/handlers_me.go
Normal file
585
internal/api/handlers_me.go
Normal file
@ -0,0 +1,585 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/email"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/oauth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
)
|
||||
|
||||
const (
|
||||
passwordResetKeyPrefix = "wrenn:password_reset:"
|
||||
passwordResetTTL = 15 * time.Minute
|
||||
)
|
||||
|
||||
type meHandler struct {
|
||||
db *db.Queries
|
||||
pool *pgxpool.Pool
|
||||
rdb *redis.Client
|
||||
jwtSecret []byte
|
||||
mailer email.Mailer
|
||||
oauthRegistry *oauth.Registry
|
||||
redirectURL string
|
||||
teamSvc *service.TeamService
|
||||
}
|
||||
|
||||
func newMeHandler(
|
||||
db *db.Queries,
|
||||
pool *pgxpool.Pool,
|
||||
rdb *redis.Client,
|
||||
jwtSecret []byte,
|
||||
mailer email.Mailer,
|
||||
registry *oauth.Registry,
|
||||
redirectURL string,
|
||||
teamSvc *service.TeamService,
|
||||
) *meHandler {
|
||||
return &meHandler{
|
||||
db: db,
|
||||
pool: pool,
|
||||
rdb: rdb,
|
||||
jwtSecret: jwtSecret,
|
||||
mailer: mailer,
|
||||
oauthRegistry: registry,
|
||||
redirectURL: strings.TrimRight(redirectURL, "/"),
|
||||
teamSvc: teamSvc,
|
||||
}
|
||||
}
|
||||
|
||||
type meResponse struct {
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
HasPassword bool `json:"has_password"`
|
||||
Providers []string `json:"providers"`
|
||||
}
|
||||
|
||||
type updateNameRequest struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type changePasswordRequest struct {
|
||||
CurrentPassword string `json:"current_password"`
|
||||
NewPassword string `json:"new_password"`
|
||||
ConfirmPassword string `json:"confirm_password"`
|
||||
}
|
||||
|
||||
type requestPasswordResetRequest struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
type confirmPasswordResetRequest struct {
|
||||
Token string `json:"token"`
|
||||
NewPassword string `json:"new_password"`
|
||||
}
|
||||
|
||||
type deleteAccountRequest struct {
|
||||
Confirmation string `json:"confirmation"`
|
||||
}
|
||||
|
||||
// GetMe handles GET /v1/me.
|
||||
func (h *meHandler) GetMe(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
ctx := r.Context()
|
||||
|
||||
user, err := h.db.GetUserByID(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
|
||||
return
|
||||
}
|
||||
|
||||
providers, err := h.db.GetOAuthProvidersByUserID(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get providers")
|
||||
return
|
||||
}
|
||||
|
||||
providerNames := make([]string, 0, len(providers))
|
||||
for _, p := range providers {
|
||||
providerNames = append(providerNames, p.Provider)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, meResponse{
|
||||
Name: user.Name,
|
||||
Email: user.Email,
|
||||
HasPassword: user.PasswordHash.Valid,
|
||||
Providers: providerNames,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateName handles PATCH /v1/me — updates the user's name and re-issues a JWT.
|
||||
func (h *meHandler) UpdateName(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
ctx := r.Context()
|
||||
|
||||
var req updateNameRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
req.Name = strings.TrimSpace(req.Name)
|
||||
if req.Name == "" || len(req.Name) > 100 {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "name must be between 1 and 100 characters")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.UpdateUserName(ctx, db.UpdateUserNameParams{
|
||||
ID: ac.UserID,
|
||||
Name: req.Name,
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to update name")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.db.GetUserByID(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
|
||||
return
|
||||
}
|
||||
|
||||
team, role, err := loginTeam(ctx, h.db, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get team")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, ac.UserID, team.ID, user.Email, req.Name, role, user.IsAdmin)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, authResponse{
|
||||
Token: token,
|
||||
UserID: id.FormatUserID(ac.UserID),
|
||||
TeamID: id.FormatTeamID(team.ID),
|
||||
Email: user.Email,
|
||||
Name: req.Name,
|
||||
})
|
||||
}
|
||||
|
||||
// ChangePassword handles POST /v1/me/password.
|
||||
// For users with a password: requires current_password + new_password.
|
||||
// For OAuth-only users: requires new_password + confirm_password.
|
||||
func (h *meHandler) ChangePassword(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
ctx := r.Context()
|
||||
|
||||
var req changePasswordRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.db.GetUserByID(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
|
||||
return
|
||||
}
|
||||
|
||||
if user.PasswordHash.Valid {
|
||||
// Changing existing password — verify current.
|
||||
if req.CurrentPassword == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "current_password is required")
|
||||
return
|
||||
}
|
||||
if err := auth.CheckPassword(user.PasswordHash.String, req.CurrentPassword); err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "wrong_password", "current password is incorrect")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// OAuth user adding a password — confirm must match.
|
||||
if req.ConfirmPassword == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "confirm_password is required")
|
||||
return
|
||||
}
|
||||
if req.NewPassword != req.ConfirmPassword {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "passwords do not match")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if len(req.NewPassword) < 8 {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "password must be at least 8 characters")
|
||||
return
|
||||
}
|
||||
|
||||
hash, err := auth.HashPassword(req.NewPassword)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to hash password")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.UpdateUserPassword(ctx, db.UpdateUserPasswordParams{
|
||||
ID: ac.UserID,
|
||||
PasswordHash: pgtype.Text{String: hash, Valid: true},
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to update password")
|
||||
return
|
||||
}
|
||||
|
||||
isAdding := !user.PasswordHash.Valid
|
||||
go func() {
|
||||
sendCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
subject, message := "Your Wrenn password was changed", "Your account password was successfully updated. If you did not make this change, reset your password immediately."
|
||||
if isAdding {
|
||||
subject = "Password added to your Wrenn account"
|
||||
message = "A password has been added to your Wrenn account. You can now sign in with your email and password in addition to any connected OAuth providers."
|
||||
}
|
||||
if err := h.mailer.Send(sendCtx, user.Email, subject, email.EmailData{
|
||||
RecipientName: user.Name,
|
||||
Message: message,
|
||||
Closing: "If you didn't make this change, contact support immediately.",
|
||||
}); err != nil {
|
||||
slog.Warn("change password: failed to send notification", "email", user.Email, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// RequestPasswordReset handles POST /v1/me/password/reset (unauthenticated).
|
||||
// Always returns 200 to avoid leaking account existence.
|
||||
func (h *meHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Request) {
|
||||
var req requestPasswordResetRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
req.Email = strings.TrimSpace(strings.ToLower(req.Email))
|
||||
if req.Email == "" {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
user, err := h.db.GetUserByEmail(ctx, req.Email)
|
||||
if err != nil {
|
||||
// Don't leak whether the email exists.
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
if user.Status != "active" {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
rawToken := generateResetToken()
|
||||
tokenHash := hashResetToken(rawToken)
|
||||
redisKey := passwordResetKeyPrefix + tokenHash
|
||||
|
||||
if err := h.rdb.Set(ctx, redisKey, id.FormatUserID(user.ID), passwordResetTTL).Err(); err != nil {
|
||||
slog.Error("password reset: failed to store token in redis", "error", err)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
resetURL := h.redirectURL + "/reset-password?token=" + rawToken
|
||||
go func() {
|
||||
sendCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := h.mailer.Send(sendCtx, user.Email, "Reset your Wrenn password", email.EmailData{
|
||||
RecipientName: user.Name,
|
||||
Message: "We received a request to reset your password. Click the button below to set a new password. This link expires in 15 minutes.",
|
||||
Button: &email.Button{Text: "Reset Password", URL: resetURL},
|
||||
Closing: "If you didn't request a password reset, you can safely ignore this email.",
|
||||
}); err != nil {
|
||||
slog.Error("password reset: failed to send email", "email", user.Email, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// ConfirmPasswordReset handles POST /v1/me/password/reset/confirm (unauthenticated).
|
||||
func (h *meHandler) ConfirmPasswordReset(w http.ResponseWriter, r *http.Request) {
|
||||
var req confirmPasswordResetRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Token == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "token is required")
|
||||
return
|
||||
}
|
||||
if len(req.NewPassword) < 8 {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "password must be at least 8 characters")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
tokenHash := hashResetToken(req.Token)
|
||||
redisKey := passwordResetKeyPrefix + tokenHash
|
||||
|
||||
// GetDel atomically retrieves and removes the token in a single round-trip,
|
||||
// preventing concurrent requests from both consuming the same token.
|
||||
userIDStr, err := h.rdb.GetDel(ctx, redisKey).Result()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
writeError(w, http.StatusBadRequest, "invalid_token", "reset token is invalid or has expired")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to verify token")
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := id.ParseUserID(userIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "invalid stored user ID")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.db.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
|
||||
return
|
||||
}
|
||||
|
||||
hash, err := auth.HashPassword(req.NewPassword)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to hash password")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.UpdateUserPassword(ctx, db.UpdateUserPasswordParams{
|
||||
ID: userID,
|
||||
PasswordHash: pgtype.Text{String: hash, Valid: true},
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to update password")
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
sendCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := h.mailer.Send(sendCtx, user.Email, "Your Wrenn password was reset", email.EmailData{
|
||||
RecipientName: user.Name,
|
||||
Message: "Your password has been successfully reset. You can now sign in with your new password.",
|
||||
Closing: "If you didn't request this change, contact support immediately.",
|
||||
}); err != nil {
|
||||
slog.Warn("confirm password reset: failed to send notification", "email", user.Email, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// ConnectProvider handles GET /v1/me/providers/{provider}/connect.
|
||||
// Sets OAuth state + link cookies and returns the provider auth URL.
|
||||
func (h *meHandler) ConnectProvider(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
provider := chi.URLParam(r, "provider")
|
||||
|
||||
p, ok := h.oauthRegistry.Get(provider)
|
||||
if !ok {
|
||||
writeError(w, http.StatusNotFound, "provider_not_found", "unsupported OAuth provider")
|
||||
return
|
||||
}
|
||||
|
||||
state, err := generateState()
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate state")
|
||||
return
|
||||
}
|
||||
|
||||
mac := computeHMAC(h.jwtSecret, state)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: state + ":" + mac,
|
||||
Path: "/",
|
||||
MaxAge: 600,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: isSecure(r),
|
||||
})
|
||||
|
||||
userIDStr := id.FormatUserID(ac.UserID)
|
||||
linkMac := computeHMAC(h.jwtSecret, userIDStr)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_link_user_id",
|
||||
Value: userIDStr + ":" + linkMac,
|
||||
Path: "/",
|
||||
MaxAge: 600,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: isSecure(r),
|
||||
})
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]string{"auth_url": p.AuthCodeURL(state)})
|
||||
}
|
||||
|
||||
// DisconnectProvider handles DELETE /v1/me/providers/{provider}.
|
||||
func (h *meHandler) DisconnectProvider(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
provider := chi.URLParam(r, "provider")
|
||||
ctx := r.Context()
|
||||
|
||||
user, err := h.db.GetUserByID(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
|
||||
return
|
||||
}
|
||||
|
||||
providers, err := h.db.GetOAuthProvidersByUserID(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get providers")
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure the user will still have at least one login method after disconnecting.
|
||||
if !user.PasswordHash.Valid && len(providers) <= 1 {
|
||||
writeError(w, http.StatusBadRequest, "last_login_method", "cannot disconnect your only login method — add a password first")
|
||||
return
|
||||
}
|
||||
|
||||
// Check the provider is actually linked to this user.
|
||||
found := false
|
||||
for _, p := range providers {
|
||||
if p.Provider == provider {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
writeError(w, http.StatusNotFound, "not_found", "provider not connected")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.DeleteOAuthProvider(ctx, db.DeleteOAuthProviderParams{
|
||||
UserID: ac.UserID,
|
||||
Provider: provider,
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to disconnect provider")
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// DeleteAccount handles DELETE /v1/me — soft-deletes the user's account.
|
||||
func (h *meHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
ctx := r.Context()
|
||||
|
||||
var req deleteAccountRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.db.GetUserByID(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.EqualFold(strings.TrimSpace(req.Confirmation), user.Email) {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "confirmation does not match your email address")
|
||||
return
|
||||
}
|
||||
|
||||
teamsBlocking, err := h.db.CountUserOwnedTeamsWithOtherMembers(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to check team ownership")
|
||||
return
|
||||
}
|
||||
if teamsBlocking > 0 {
|
||||
writeError(w, http.StatusConflict, "owns_team_with_members",
|
||||
fmt.Sprintf("you own %d team(s) with other members — transfer ownership or remove members before deleting your account", teamsBlocking))
|
||||
return
|
||||
}
|
||||
|
||||
// Delete all teams the user solely owns (no other members).
|
||||
// Team deletion involves RPC calls (sandbox destruction) that cannot be
|
||||
// transactional, so we do those first as best-effort, then wrap the
|
||||
// DB-only cleanup in a transaction.
|
||||
soleTeams, err := h.db.ListSoleOwnedTeams(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list owned teams")
|
||||
return
|
||||
}
|
||||
for _, teamID := range soleTeams {
|
||||
if err := h.teamSvc.DeleteTeamInternal(ctx, teamID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error",
|
||||
fmt.Sprintf("failed to delete sole-owned team %s", id.FormatTeamID(teamID)))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tx, err := h.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to start transaction")
|
||||
return
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
qtx := h.db.WithTx(tx)
|
||||
|
||||
if err := qtx.DeleteAPIKeysByCreator(ctx, ac.UserID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete user's API keys")
|
||||
return
|
||||
}
|
||||
|
||||
if err := qtx.SoftDeleteUser(ctx, ac.UserID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete account")
|
||||
return
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to commit account deletion")
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("account soft-deleted", "user_id", id.FormatUserID(ac.UserID), "email", user.Email)
|
||||
|
||||
go func() {
|
||||
sendCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := h.mailer.Send(sendCtx, user.Email, "Your Wrenn account has been deleted", email.EmailData{
|
||||
RecipientName: user.Name,
|
||||
Message: "Your Wrenn account has been deactivated and is scheduled for permanent deletion in 15 days. If this was a mistake, contact support before then to recover your account.",
|
||||
Closing: "Thank you for using Wrenn.",
|
||||
}); err != nil {
|
||||
slog.Warn("delete account: failed to send notification", "email", user.Email, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func generateResetToken() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand failed: %v", err))
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func hashResetToken(raw string) string {
|
||||
h := sha256.Sum256([]byte(raw))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
@ -9,10 +9,10 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
@ -38,7 +38,7 @@ type metricsResponse struct {
|
||||
Points []metricPointResponse `json:"points"`
|
||||
}
|
||||
|
||||
// GetMetrics handles GET /v1/sandboxes/{id}/metrics?range=10m|2h|24h.
|
||||
// GetMetrics handles GET /v1/capsules/{id}/metrics?range=10m|2h|24h.
|
||||
func (h *sandboxMetricsHandler) GetMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
|
||||
@ -16,10 +16,10 @@ import (
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth/oauth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/oauth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
type oauthHandler struct {
|
||||
@ -137,6 +137,73 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
email := strings.TrimSpace(strings.ToLower(profile.Email))
|
||||
|
||||
// Check for a link operation initiated from the settings page.
|
||||
if linkCookie, err := r.Cookie("oauth_link_user_id"); err == nil && linkCookie.Value != "" {
|
||||
// Clear the link cookie immediately.
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_link_user_id",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: isSecure(r),
|
||||
})
|
||||
|
||||
settingsBase := h.redirectURL + "/dashboard/settings"
|
||||
|
||||
// Verify the HMAC to prevent cookie forgery.
|
||||
linkParts := strings.SplitN(linkCookie.Value, ":", 2)
|
||||
if len(linkParts) != 2 || !hmac.Equal([]byte(computeHMAC(h.jwtSecret, linkParts[0])), []byte(linkParts[1])) {
|
||||
slog.Warn("oauth link: invalid or tampered link cookie")
|
||||
http.Redirect(w, r, settingsBase+"?connect_error=invalid_state", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
userID, parseErr := id.ParseUserID(linkParts[0])
|
||||
if parseErr != nil {
|
||||
slog.Error("oauth link: invalid user ID in cookie", "error", parseErr)
|
||||
http.Redirect(w, r, settingsBase+"?connect_error=invalid_state", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure the GitHub account isn't already linked to a different user.
|
||||
existing, lookupErr := h.db.GetOAuthProvider(ctx, db.GetOAuthProviderParams{
|
||||
Provider: provider,
|
||||
ProviderID: profile.ProviderID,
|
||||
})
|
||||
if lookupErr == nil && existing.UserID != userID {
|
||||
slog.Warn("oauth link: provider already linked to another account", "provider", provider)
|
||||
http.Redirect(w, r, settingsBase+"?connect_error=already_linked", http.StatusFound)
|
||||
return
|
||||
}
|
||||
if lookupErr == nil && existing.UserID == userID {
|
||||
// Already linked to this user — treat as success.
|
||||
http.Redirect(w, r, settingsBase+"?connected="+provider, http.StatusFound)
|
||||
return
|
||||
}
|
||||
if !errors.Is(lookupErr, pgx.ErrNoRows) {
|
||||
slog.Error("oauth link: db lookup failed", "error", lookupErr)
|
||||
http.Redirect(w, r, settingsBase+"?connect_error=db_error", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
if insertErr := h.db.InsertOAuthProvider(ctx, db.InsertOAuthProviderParams{
|
||||
Provider: provider,
|
||||
ProviderID: profile.ProviderID,
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
}); insertErr != nil {
|
||||
slog.Error("oauth link: failed to insert provider", "error", insertErr)
|
||||
http.Redirect(w, r, settingsBase+"?connect_error=db_error", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("oauth link: provider linked", "provider", provider, "user_id", id.FormatUserID(userID))
|
||||
http.Redirect(w, r, settingsBase+"?connected="+provider, http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this OAuth identity already exists.
|
||||
existing, err := h.db.GetOAuthProvider(ctx, db.GetOAuthProviderParams{
|
||||
Provider: provider,
|
||||
@ -145,18 +212,29 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
if err == nil {
|
||||
// Existing OAuth user — log them in.
|
||||
user, err := h.db.GetUserByID(ctx, existing.UserID)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Warn("oauth login: user no longer exists", "user_id", existing.UserID)
|
||||
redirectWithError(w, r, redirectBase, "account_deactivated")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to get user", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
team, role, err := loginTeam(ctx, h.db, user.ID)
|
||||
if user.Status != "active" {
|
||||
slog.Warn("oauth login: account not active", "email", user.Email, "status", user.Status)
|
||||
redirectWithError(w, r, redirectBase, "account_deactivated")
|
||||
return
|
||||
}
|
||||
team, role, isFirstUser, err := ensureDefaultTeam(ctx, h.db, h.pool, user.ID, user.Name)
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to get team", "error", err)
|
||||
slog.Error("oauth login: failed to ensure team", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
|
||||
isAdmin := user.IsAdmin || isFirstUser
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
@ -172,13 +250,21 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// New OAuth identity — check for email collision.
|
||||
_, err = h.db.GetUserByEmail(ctx, email)
|
||||
existingUser, err := h.db.GetUserByEmail(ctx, email)
|
||||
if err == nil {
|
||||
// Email already taken by another account.
|
||||
redirectWithError(w, r, redirectBase, "email_taken")
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, pgx.ErrNoRows) {
|
||||
if existingUser.Status == "inactive" {
|
||||
// Unactivated email signup — delete and let OAuth take over.
|
||||
if delErr := h.db.HardDeleteUser(ctx, existingUser.ID); delErr != nil {
|
||||
slog.Error("oauth: failed to delete inactive user", "error", delErr)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Email already taken by an active/disabled/deleted account.
|
||||
redirectWithError(w, r, redirectBase, "email_taken")
|
||||
return
|
||||
}
|
||||
} else if !errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Error("oauth: email check failed", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
@ -195,6 +281,15 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
qtx := h.db.WithTx(tx)
|
||||
|
||||
// The first user to sign up becomes a platform admin.
|
||||
userCount, err := qtx.CountUsers(ctx)
|
||||
if err != nil {
|
||||
slog.Error("oauth: failed to count users", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
isFirstUser := userCount == 0
|
||||
|
||||
userID := id.NewUserID()
|
||||
_, err = qtx.InsertUserOAuth(ctx, db.InsertUserOAuthParams{
|
||||
ID: userID,
|
||||
@ -238,6 +333,14 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if isFirstUser {
|
||||
if err := qtx.SetUserAdmin(ctx, db.SetUserAdminParams{ID: userID, IsAdmin: true}); err != nil {
|
||||
slog.Error("oauth: failed to set admin status", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := qtx.InsertOAuthProvider(ctx, db.InsertOAuthProviderParams{
|
||||
Provider: provider,
|
||||
ProviderID: profile.ProviderID,
|
||||
@ -255,7 +358,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email, profile.Name, "owner", false)
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email, profile.Name, "owner", isFirstUser)
|
||||
if err != nil {
|
||||
slog.Error("oauth: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
@ -279,18 +382,29 @@ func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, prov
|
||||
return
|
||||
}
|
||||
user, err := h.db.GetUserByID(ctx, existing.UserID)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Warn("oauth: retry login: user no longer exists", "user_id", existing.UserID)
|
||||
redirectWithError(w, r, redirectBase, "account_deactivated")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to get user", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
team, role, err := loginTeam(ctx, h.db, user.ID)
|
||||
if user.Status != "active" {
|
||||
slog.Warn("oauth: retry login: account not active", "email", user.Email, "status", user.Status)
|
||||
redirectWithError(w, r, redirectBase, "account_deactivated")
|
||||
return
|
||||
}
|
||||
team, role, isFirstUser, err := ensureDefaultTeam(ctx, h.db, h.pool, user.ID, user.Name)
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to get team", "error", err)
|
||||
slog.Error("oauth: retry login: failed to ensure team", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
|
||||
isAdmin := user.IsAdmin || isFirstUser
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
|
||||
298
internal/api/handlers_process.go
Normal file
298
internal/api/handlers_process.go
Normal file
@ -0,0 +1,298 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
type processHandler struct {
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
jwtSecret []byte
|
||||
}
|
||||
|
||||
func newProcessHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *processHandler {
|
||||
return &processHandler{db: db, pool: pool, jwtSecret: jwtSecret}
|
||||
}
|
||||
|
||||
// processResponse is a single entry in the process list.
|
||||
type processResponse struct {
|
||||
PID uint32 `json:"pid"`
|
||||
Tag string `json:"tag,omitempty"`
|
||||
Cmd string `json:"cmd"`
|
||||
Args []string `json:"args,omitempty"`
|
||||
}
|
||||
|
||||
// processListResponse wraps the list of processes.
|
||||
type processListResponse struct {
|
||||
Processes []processResponse `json:"processes"`
|
||||
}
|
||||
|
||||
// ListProcesses handles GET /v1/capsules/{id}/processes.
|
||||
func (h *processHandler) ListProcesses(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := agent.ListProcesses(ctx, connect.NewRequest(&pb.ListProcessesRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
}))
|
||||
if err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
procs := make([]processResponse, 0, len(resp.Msg.Processes))
|
||||
for _, p := range resp.Msg.Processes {
|
||||
procs = append(procs, processResponse{
|
||||
PID: p.Pid,
|
||||
Tag: p.Tag,
|
||||
Cmd: p.Cmd,
|
||||
Args: p.Args,
|
||||
})
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, processListResponse{Processes: procs})
|
||||
}
|
||||
|
||||
// KillProcess handles DELETE /v1/capsules/{id}/processes/{selector}.
|
||||
// The selector can be a numeric PID or a string tag.
|
||||
func (h *processHandler) KillProcess(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
selectorStr := chi.URLParam(r, "selector")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
// Build the kill request with PID or tag selector.
|
||||
killReq := &pb.KillProcessRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Signal: "SIGKILL",
|
||||
}
|
||||
if sig := r.URL.Query().Get("signal"); sig == "SIGTERM" {
|
||||
killReq.Signal = "SIGTERM"
|
||||
}
|
||||
|
||||
if pid, err := strconv.ParseUint(selectorStr, 10, 32); err == nil {
|
||||
killReq.Selector = &pb.KillProcessRequest_Pid{Pid: uint32(pid)}
|
||||
} else {
|
||||
killReq.Selector = &pb.KillProcessRequest_Tag{Tag: selectorStr}
|
||||
}
|
||||
|
||||
if _, err := agent.KillProcess(ctx, connect.NewRequest(killReq)); err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// wsProcessOut is the JSON message sent to the WebSocket client.
|
||||
type wsProcessOut struct {
|
||||
Type string `json:"type"` // "start", "stdout", "stderr", "exit", "error"
|
||||
PID uint32 `json:"pid,omitempty"` // only for "start"
|
||||
Data string `json:"data,omitempty"` // only for "stdout", "stderr", "error"
|
||||
ExitCode *int32 `json:"exit_code,omitempty"` // only for "exit"
|
||||
}
|
||||
|
||||
// ConnectProcess handles WS /v1/capsules/{id}/processes/{selector}/stream.
|
||||
func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
selectorStr := chi.URLParam(r, "selector")
|
||||
ctx := r.Context()
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Authenticate: use context from middleware (API key) or WS first message (JWT).
|
||||
ac, hasAuth := auth.FromContext(ctx)
|
||||
|
||||
if !hasAuth {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("process stream websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
var wsAC auth.AuthContext
|
||||
var authErr error
|
||||
if isAdminWSRoute(ctx) {
|
||||
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
|
||||
} else {
|
||||
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
|
||||
}
|
||||
if authErr != nil {
|
||||
sendProcessWSError(conn, "authentication failed")
|
||||
return
|
||||
}
|
||||
ac = wsAC
|
||||
|
||||
h.runConnectProcess(ctx, conn, ac, sandboxID, sandboxIDStr, selectorStr)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("process stream websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
h.runConnectProcess(ctx, conn, ac, sandboxID, sandboxIDStr, selectorStr)
|
||||
}
|
||||
|
||||
func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr, selectorStr string) {
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
sendProcessWSError(conn, "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
sendProcessWSError(conn, "sandbox is not running (status: "+sb.Status+")")
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
sendProcessWSError(conn, "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
// Build the connect request with PID or tag selector.
|
||||
connectReq := &pb.ConnectProcessRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
}
|
||||
if pid, err := strconv.ParseUint(selectorStr, 10, 32); err == nil {
|
||||
connectReq.Selector = &pb.ConnectProcessRequest_Pid{Pid: uint32(pid)}
|
||||
} else {
|
||||
connectReq.Selector = &pb.ConnectProcessRequest_Tag{Tag: selectorStr}
|
||||
}
|
||||
|
||||
streamCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
stream, err := agent.ConnectProcess(streamCtx, connect.NewRequest(connectReq))
|
||||
if err != nil {
|
||||
sendProcessWSError(conn, "failed to connect to process: "+err.Error())
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
// Listen for client disconnect in a goroutine.
|
||||
go func() {
|
||||
for {
|
||||
_, _, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Forward stream events to WebSocket.
|
||||
for stream.Receive() {
|
||||
resp := stream.Msg()
|
||||
switch ev := resp.Event.(type) {
|
||||
case *pb.ConnectProcessResponse_Start:
|
||||
writeWSJSON(conn, wsProcessOut{Type: "start", PID: ev.Start.Pid})
|
||||
|
||||
case *pb.ConnectProcessResponse_Data:
|
||||
switch o := ev.Data.Output.(type) {
|
||||
case *pb.ExecStreamData_Stdout:
|
||||
writeWSJSON(conn, wsProcessOut{Type: "stdout", Data: string(o.Stdout)})
|
||||
case *pb.ExecStreamData_Stderr:
|
||||
writeWSJSON(conn, wsProcessOut{Type: "stderr", Data: string(o.Stderr)})
|
||||
}
|
||||
|
||||
case *pb.ConnectProcessResponse_End:
|
||||
exitCode := ev.End.ExitCode
|
||||
writeWSJSON(conn, wsProcessOut{Type: "exit", ExitCode: &exitCode})
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil {
|
||||
if streamCtx.Err() == nil {
|
||||
sendProcessWSError(conn, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Update last active using a fresh context.
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last active after process stream", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func sendProcessWSError(conn *websocket.Conn, msg string) {
|
||||
writeWSJSON(conn, wsProcessOut{Type: "error", Data: msg})
|
||||
}
|
||||
400
internal/api/handlers_pty.go
Normal file
400
internal/api/handlers_pty.go
Normal file
@ -0,0 +1,400 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
const (
|
||||
ptyKeepaliveInterval = 30 * time.Second
|
||||
ptyDefaultCmd = "/bin/bash"
|
||||
ptyDefaultCols = 80
|
||||
ptyDefaultRows = 24
|
||||
)
|
||||
|
||||
type ptyHandler struct {
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
jwtSecret []byte
|
||||
}
|
||||
|
||||
func newPtyHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *ptyHandler {
|
||||
return &ptyHandler{db: db, pool: pool, jwtSecret: jwtSecret}
|
||||
}
|
||||
|
||||
// --- WebSocket message types ---
|
||||
|
||||
// wsPtyIn is the inbound message from the client.
|
||||
type wsPtyIn struct {
|
||||
Type string `json:"type"` // "start", "connect", "input", "resize", "kill"
|
||||
Cmd string `json:"cmd,omitempty"` // for "start"
|
||||
Args []string `json:"args,omitempty"` // for "start"
|
||||
Cols uint32 `json:"cols,omitempty"` // for "start", "resize"
|
||||
Rows uint32 `json:"rows,omitempty"` // for "start", "resize"
|
||||
Envs map[string]string `json:"envs,omitempty"` // for "start"
|
||||
Cwd string `json:"cwd,omitempty"` // for "start"
|
||||
User string `json:"user,omitempty"` // for "start"
|
||||
Tag string `json:"tag,omitempty"` // for "connect"
|
||||
Data string `json:"data,omitempty"` // for "input" (base64)
|
||||
}
|
||||
|
||||
// wsPtyOut is the outbound message to the client.
|
||||
type wsPtyOut struct {
|
||||
Type string `json:"type"` // "started", "output", "exit", "error"
|
||||
Tag string `json:"tag,omitempty"` // for "started"
|
||||
PID uint32 `json:"pid,omitempty"` // for "started"
|
||||
Data string `json:"data,omitempty"` // for "output" (base64), "error"
|
||||
ExitCode *int32 `json:"exit_code,omitempty"` // for "exit"
|
||||
Fatal bool `json:"fatal,omitempty"` // for "error"
|
||||
}
|
||||
|
||||
// wsWriter wraps a websocket.Conn with a mutex for concurrent writes.
|
||||
type wsWriter struct {
|
||||
conn *websocket.Conn
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (w *wsWriter) writeJSON(v any) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if err := w.conn.WriteJSON(v); err != nil {
|
||||
slog.Debug("pty websocket write error", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// PtySession handles WS /v1/capsules/{id}/pty.
|
||||
func (h *ptyHandler) PtySession(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
// API key auth is handled by middleware (sets context).
|
||||
// For browser JWT auth, we authenticate after upgrade via first WS message.
|
||||
ac, hasAuth := auth.FromContext(ctx)
|
||||
|
||||
if !hasAuth {
|
||||
// No pre-upgrade auth — upgrade first, then authenticate via WS message.
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("pty websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ws := &wsWriter{conn: conn}
|
||||
|
||||
var wsAC auth.AuthContext
|
||||
if isAdminWSRoute(ctx) {
|
||||
wsAC, err = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
|
||||
} else {
|
||||
wsAC, err = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
|
||||
}
|
||||
if err != nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "authentication failed", Fatal: true})
|
||||
return
|
||||
}
|
||||
ac = wsAC
|
||||
|
||||
h.runPtySession(ctx, ws, conn, ac, sandboxID, sandboxIDStr)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("pty websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ws := &wsWriter{conn: conn}
|
||||
h.runPtySession(ctx, ws, conn, ac, sandboxID, sandboxIDStr)
|
||||
}
|
||||
|
||||
func (h *ptyHandler) runPtySession(ctx context.Context, ws *wsWriter, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr string) {
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "sandbox not found", Fatal: true})
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "sandbox is not running (status: " + sb.Status + ")", Fatal: true})
|
||||
return
|
||||
}
|
||||
|
||||
// Read the first message to determine start vs connect.
|
||||
var firstMsg wsPtyIn
|
||||
if err := conn.ReadJSON(&firstMsg); err != nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "failed to read first message: " + err.Error(), Fatal: true})
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "sandbox host is not reachable", Fatal: true})
|
||||
return
|
||||
}
|
||||
|
||||
streamCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
switch firstMsg.Type {
|
||||
case "start":
|
||||
h.handleStart(streamCtx, cancel, ws, agent, sandboxIDStr, firstMsg)
|
||||
case "connect":
|
||||
h.handleConnect(streamCtx, cancel, ws, agent, sandboxIDStr, firstMsg)
|
||||
default:
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "first message must be type 'start' or 'connect'", Fatal: true})
|
||||
}
|
||||
|
||||
// Update last active using a fresh context.
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last active after pty session", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ptyHandler) handleStart(
|
||||
ctx context.Context,
|
||||
cancel context.CancelFunc,
|
||||
ws *wsWriter,
|
||||
agent hostagentv1connect.HostAgentServiceClient,
|
||||
sandboxIDStr string,
|
||||
msg wsPtyIn,
|
||||
) {
|
||||
cmd := msg.Cmd
|
||||
if cmd == "" {
|
||||
cmd = ptyDefaultCmd
|
||||
}
|
||||
cols := msg.Cols
|
||||
if cols == 0 {
|
||||
cols = ptyDefaultCols
|
||||
}
|
||||
rows := msg.Rows
|
||||
if rows == 0 {
|
||||
rows = ptyDefaultRows
|
||||
}
|
||||
|
||||
tag := newPtyTag()
|
||||
|
||||
stream, err := agent.PtyAttach(ctx, connect.NewRequest(&pb.PtyAttachRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Tag: tag,
|
||||
Cmd: cmd,
|
||||
Args: msg.Args,
|
||||
Cols: cols,
|
||||
Rows: rows,
|
||||
Envs: msg.Envs,
|
||||
Cwd: msg.Cwd,
|
||||
User: msg.User,
|
||||
}))
|
||||
if err != nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "failed to start pty: " + err.Error(), Fatal: true})
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
// Wait for the started event and forward it.
|
||||
if !stream.Receive() {
|
||||
if err := stream.Err(); err != nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "pty stream failed: " + err.Error(), Fatal: true})
|
||||
}
|
||||
return
|
||||
}
|
||||
resp := stream.Msg()
|
||||
started, ok := resp.Event.(*pb.PtyAttachResponse_Started)
|
||||
if !ok {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "expected started event from host agent", Fatal: true})
|
||||
return
|
||||
}
|
||||
ws.writeJSON(wsPtyOut{Type: "started", Tag: started.Started.Tag, PID: started.Started.Pid})
|
||||
|
||||
runPtyLoop(ctx, cancel, ws, stream, agent, sandboxIDStr, tag)
|
||||
}
|
||||
|
||||
func (h *ptyHandler) handleConnect(
|
||||
ctx context.Context,
|
||||
cancel context.CancelFunc,
|
||||
ws *wsWriter,
|
||||
agent hostagentv1connect.HostAgentServiceClient,
|
||||
sandboxIDStr string,
|
||||
msg wsPtyIn,
|
||||
) {
|
||||
if msg.Tag == "" {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "connect requires a 'tag' field", Fatal: true})
|
||||
return
|
||||
}
|
||||
|
||||
stream, err := agent.PtyAttach(ctx, connect.NewRequest(&pb.PtyAttachRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Tag: msg.Tag,
|
||||
}))
|
||||
if err != nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "failed to connect to pty: " + err.Error(), Fatal: true})
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
runPtyLoop(ctx, cancel, ws, stream, agent, sandboxIDStr, msg.Tag)
|
||||
}
|
||||
|
||||
// runPtyLoop drives the bidirectional communication between the WebSocket
|
||||
// and the host agent PTY stream.
|
||||
func runPtyLoop(
|
||||
ctx context.Context,
|
||||
cancel context.CancelFunc,
|
||||
ws *wsWriter,
|
||||
stream *connect.ServerStreamForClient[pb.PtyAttachResponse],
|
||||
agent hostagentv1connect.HostAgentServiceClient,
|
||||
sandboxID string,
|
||||
tag string,
|
||||
) {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Output pump: read from Connect stream, write to WebSocket.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer cancel()
|
||||
|
||||
for stream.Receive() {
|
||||
resp := stream.Msg()
|
||||
switch ev := resp.Event.(type) {
|
||||
case *pb.PtyAttachResponse_Started:
|
||||
// Already handled before the loop for "start" mode.
|
||||
// For "connect" mode this won't appear.
|
||||
ws.writeJSON(wsPtyOut{Type: "started", Tag: ev.Started.Tag, PID: ev.Started.Pid})
|
||||
|
||||
case *pb.PtyAttachResponse_Output:
|
||||
ws.writeJSON(wsPtyOut{
|
||||
Type: "output",
|
||||
Data: base64.StdEncoding.EncodeToString(ev.Output.Data),
|
||||
})
|
||||
|
||||
case *pb.PtyAttachResponse_Exited:
|
||||
exitCode := ev.Exited.ExitCode
|
||||
ws.writeJSON(wsPtyOut{Type: "exit", ExitCode: &exitCode})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil && ctx.Err() == nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: err.Error()})
|
||||
}
|
||||
}()
|
||||
|
||||
// Input pump: read from WebSocket, dispatch to host agent.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
_, raw, err := ws.conn.ReadMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var msg wsPtyIn
|
||||
if json.Unmarshal(raw, &msg) != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Use a background context for unary RPCs so they complete
|
||||
// even if the stream context is being cancelled.
|
||||
rpcCtx, rpcCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
|
||||
switch msg.Type {
|
||||
case "input":
|
||||
data, err := base64.StdEncoding.DecodeString(msg.Data)
|
||||
if err != nil {
|
||||
rpcCancel()
|
||||
continue
|
||||
}
|
||||
if _, err := agent.PtySendInput(rpcCtx, connect.NewRequest(&pb.PtySendInputRequest{
|
||||
SandboxId: sandboxID,
|
||||
Tag: tag,
|
||||
Data: data,
|
||||
})); err != nil {
|
||||
slog.Debug("pty send input error", "error", err)
|
||||
}
|
||||
|
||||
case "resize":
|
||||
cols := msg.Cols
|
||||
rows := msg.Rows
|
||||
if cols > 0 && rows > 0 {
|
||||
if _, err := agent.PtyResize(rpcCtx, connect.NewRequest(&pb.PtyResizeRequest{
|
||||
SandboxId: sandboxID,
|
||||
Tag: tag,
|
||||
Cols: cols,
|
||||
Rows: rows,
|
||||
})); err != nil {
|
||||
slog.Debug("pty resize error", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
case "kill":
|
||||
if _, err := agent.PtyKill(rpcCtx, connect.NewRequest(&pb.PtyKillRequest{
|
||||
SandboxId: sandboxID,
|
||||
Tag: tag,
|
||||
})); err != nil {
|
||||
slog.Debug("pty kill error", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
rpcCancel()
|
||||
}
|
||||
}()
|
||||
|
||||
// Keepalive pump: send periodic pings to prevent idle WS closure.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ticker := time.NewTicker(ptyKeepaliveInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
ws.writeJSON(wsPtyOut{Type: "ping"})
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// newPtyTag returns a PTY session tag: "pty-" + 8 random hex chars.
|
||||
func newPtyTag() string {
|
||||
return "pty-" + id.NewPtyTag()
|
||||
}
|
||||
@ -7,11 +7,11 @@ import (
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/service"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
)
|
||||
|
||||
type sandboxHandler struct {
|
||||
@ -31,18 +31,19 @@ type createSandboxRequest struct {
|
||||
}
|
||||
|
||||
type sandboxResponse struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Template string `json:"template"`
|
||||
VCPUs int32 `json:"vcpus"`
|
||||
MemoryMB int32 `json:"memory_mb"`
|
||||
TimeoutSec int32 `json:"timeout_sec"`
|
||||
GuestIP string `json:"guest_ip,omitempty"`
|
||||
HostIP string `json:"host_ip,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
StartedAt *string `json:"started_at,omitempty"`
|
||||
LastActiveAt *string `json:"last_active_at,omitempty"`
|
||||
LastUpdated string `json:"last_updated"`
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Template string `json:"template"`
|
||||
VCPUs int32 `json:"vcpus"`
|
||||
MemoryMB int32 `json:"memory_mb"`
|
||||
TimeoutSec int32 `json:"timeout_sec"`
|
||||
GuestIP string `json:"guest_ip,omitempty"`
|
||||
HostIP string `json:"host_ip,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
StartedAt *string `json:"started_at,omitempty"`
|
||||
LastActiveAt *string `json:"last_active_at,omitempty"`
|
||||
LastUpdated string `json:"last_updated"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
func sandboxToResponse(sb db.Sandbox) sandboxResponse {
|
||||
@ -56,6 +57,12 @@ func sandboxToResponse(sb db.Sandbox) sandboxResponse {
|
||||
GuestIP: sb.GuestIp,
|
||||
HostIP: sb.HostIp,
|
||||
}
|
||||
if len(sb.Metadata) > 0 {
|
||||
var meta map[string]string
|
||||
if err := json.Unmarshal(sb.Metadata, &meta); err == nil && len(meta) > 0 {
|
||||
resp.Metadata = meta
|
||||
}
|
||||
}
|
||||
if sb.CreatedAt.Valid {
|
||||
resp.CreatedAt = sb.CreatedAt.Time.Format(time.RFC3339)
|
||||
}
|
||||
@ -73,7 +80,7 @@ func sandboxToResponse(sb db.Sandbox) sandboxResponse {
|
||||
return resp
|
||||
}
|
||||
|
||||
// Create handles POST /v1/sandboxes.
|
||||
// Create handles POST /v1/capsules.
|
||||
func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
var req createSandboxRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
@ -104,7 +111,7 @@ func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusCreated, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// List handles GET /v1/sandboxes.
|
||||
// List handles GET /v1/capsules.
|
||||
func (h *sandboxHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
sandboxes, err := h.svc.List(r.Context(), ac.TeamID)
|
||||
@ -121,7 +128,7 @@ func (h *sandboxHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Get handles GET /v1/sandboxes/{id}.
|
||||
// Get handles GET /v1/capsules/{id}.
|
||||
func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
@ -141,7 +148,7 @@ func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// Pause handles POST /v1/sandboxes/{id}/pause.
|
||||
// Pause handles POST /v1/capsules/{id}/pause.
|
||||
func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
@ -163,7 +170,7 @@ func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// Resume handles POST /v1/sandboxes/{id}/resume.
|
||||
// Resume handles POST /v1/capsules/{id}/resume.
|
||||
func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
@ -185,7 +192,7 @@ func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// Ping handles POST /v1/sandboxes/{id}/ping.
|
||||
// Ping handles POST /v1/capsules/{id}/ping.
|
||||
func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
@ -205,7 +212,7 @@ func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// Destroy handles DELETE /v1/sandboxes/{id}.
|
||||
// Destroy handles DELETE /v1/capsules/{id}.
|
||||
func (h *sandboxHandler) Destroy(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
@ -13,14 +13,14 @@ import (
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/layout"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/service"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/validate"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/validate"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
@ -38,8 +38,8 @@ func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, pool *life
|
||||
// deleteSnapshotBroadcast attempts to delete snapshot files on all online hosts.
|
||||
// Snapshots aren't currently host-tracked in the DB, so we broadcast to all hosts
|
||||
// and ignore NotFound errors.
|
||||
func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, teamID, templateID pgtype.UUID) error {
|
||||
hosts, err := h.db.ListActiveHosts(ctx)
|
||||
func deleteSnapshotBroadcast(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, teamID, templateID pgtype.UUID) error {
|
||||
hosts, err := queries.ListActiveHosts(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list hosts: %w", err)
|
||||
}
|
||||
@ -47,7 +47,7 @@ func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, teamID, t
|
||||
if host.Status != "online" {
|
||||
continue
|
||||
}
|
||||
agent, err := h.pool.GetForHost(host)
|
||||
agent, err := pool.GetForHost(host)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@ -69,13 +69,14 @@ type createSnapshotRequest struct {
|
||||
}
|
||||
|
||||
type snapshotResponse struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
VCPUs *int32 `json:"vcpus,omitempty"`
|
||||
MemoryMB *int32 `json:"memory_mb,omitempty"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Platform bool `json:"platform"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
VCPUs *int32 `json:"vcpus,omitempty"`
|
||||
MemoryMB *int32 `json:"memory_mb,omitempty"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Platform bool `json:"platform"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
func templateToResponse(t db.Template) snapshotResponse {
|
||||
@ -94,6 +95,12 @@ func templateToResponse(t db.Template) snapshotResponse {
|
||||
if t.CreatedAt.Valid {
|
||||
resp.CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
|
||||
}
|
||||
if len(t.Metadata) > 0 {
|
||||
var meta map[string]string
|
||||
if err := json.Unmarshal(t.Metadata, &meta); err == nil && len(meta) > 0 {
|
||||
resp.Metadata = meta
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
@ -126,7 +133,6 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
overwrite := r.URL.Query().Get("overwrite") == "true"
|
||||
|
||||
// Check for global name collision.
|
||||
if _, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
|
||||
@ -135,20 +141,10 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Check if name already exists for this team.
|
||||
if existing, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
|
||||
if !overwrite {
|
||||
writeError(w, http.StatusConflict, "already_exists", "snapshot name already exists; use ?overwrite=true to replace")
|
||||
return
|
||||
}
|
||||
// Delete old snapshot files from all hosts before removing the DB record.
|
||||
if err := h.deleteSnapshotBroadcast(ctx, existing.TeamID, existing.ID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files")
|
||||
return
|
||||
}
|
||||
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to remove existing template record")
|
||||
return
|
||||
}
|
||||
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
|
||||
writeError(w, http.StatusConflict, "template_name_taken",
|
||||
"snapshot name already exists; delete the existing snapshot first to reuse this name")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify sandbox exists, belongs to team, and is running or paused.
|
||||
@ -210,13 +206,16 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{
|
||||
ID: newTemplateID,
|
||||
Name: req.Name,
|
||||
Type: "snapshot",
|
||||
Vcpus: sb.Vcpus,
|
||||
MemoryMb: sb.MemoryMb,
|
||||
SizeBytes: resp.Msg.SizeBytes,
|
||||
TeamID: ac.TeamID,
|
||||
ID: newTemplateID,
|
||||
Name: req.Name,
|
||||
Type: "snapshot",
|
||||
Vcpus: sb.Vcpus,
|
||||
MemoryMb: sb.MemoryMb,
|
||||
SizeBytes: resp.Msg.SizeBytes,
|
||||
TeamID: ac.TeamID,
|
||||
DefaultUser: "root",
|
||||
DefaultEnv: []byte("{}"),
|
||||
Metadata: sb.Metadata,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("failed to insert template record", "name", req.Name, "error", err)
|
||||
@ -277,7 +276,7 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.deleteSnapshotBroadcast(ctx, tmpl.TeamID, tmpl.ID); err != nil {
|
||||
if err := deleteSnapshotBroadcast(ctx, h.db, h.pool, tmpl.TeamID, tmpl.ID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete snapshot files")
|
||||
return
|
||||
}
|
||||
|
||||
@ -5,8 +5,8 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/service"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
)
|
||||
|
||||
type statsHandler struct {
|
||||
@ -43,7 +43,7 @@ type statsResponse struct {
|
||||
Series statsSeriesResponse `json:"series"`
|
||||
}
|
||||
|
||||
// GetStats handles GET /v1/sandboxes/stats?range=5m|1h|6h|24h|30d
|
||||
// GetStats handles GET /v1/capsules/stats?range=5m|1h|6h|24h|30d
|
||||
func (h *statsHandler) GetStats(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
@ -9,20 +11,22 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/service"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/email"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
)
|
||||
|
||||
type teamHandler struct {
|
||||
svc *service.TeamService
|
||||
audit *audit.AuditLogger
|
||||
svc *service.TeamService
|
||||
audit *audit.AuditLogger
|
||||
mailer email.Mailer
|
||||
}
|
||||
|
||||
func newTeamHandler(svc *service.TeamService, al *audit.AuditLogger) *teamHandler {
|
||||
return &teamHandler{svc: svc, audit: al}
|
||||
func newTeamHandler(svc *service.TeamService, al *audit.AuditLogger, mailer email.Mailer) *teamHandler {
|
||||
return &teamHandler{svc: svc, audit: al, mailer: mailer}
|
||||
}
|
||||
|
||||
// teamResponse is the JSON shape for a team.
|
||||
@ -131,6 +135,15 @@ func (h *teamHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := h.mailer.Send(context.Background(), ac.Email, "Your team has been created", email.EmailData{
|
||||
RecipientName: ac.Name,
|
||||
Message: fmt.Sprintf("Your team \"%s\" has been created on Wrenn. You can now invite members and start creating sandboxes under this team.", req.Name),
|
||||
}); err != nil {
|
||||
slog.Warn("failed to send team created email", "email", ac.Email, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
writeJSON(w, http.StatusCreated, teamWithRoleResponse{
|
||||
teamResponse: teamToResponse(team.Team),
|
||||
Role: team.Role,
|
||||
@ -279,6 +292,21 @@ func (h *teamHandler) AddMember(w http.ResponseWriter, r *http.Request) {
|
||||
if parseErr == nil {
|
||||
h.audit.LogMemberAdd(r.Context(), ac, targetUserID, member.Email, member.Role)
|
||||
}
|
||||
|
||||
go func() {
|
||||
team, err := h.svc.GetTeam(context.Background(), teamID)
|
||||
teamName := "a team"
|
||||
if err == nil {
|
||||
teamName = team.Name
|
||||
}
|
||||
if err := h.mailer.Send(context.Background(), member.Email, "You've been added to a team on Wrenn", email.EmailData{
|
||||
RecipientName: member.Name,
|
||||
Message: fmt.Sprintf("%s has added you to the team \"%s\" on Wrenn.", ac.Name, teamName),
|
||||
}); err != nil {
|
||||
slog.Warn("failed to send team invitation email", "email", member.Email, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
writeJSON(w, http.StatusCreated, memberInfoToResponse(member))
|
||||
}
|
||||
|
||||
@ -388,3 +416,87 @@ func (h *teamHandler) SetBYOC(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// AdminListTeams handles GET /v1/admin/teams?page=1
|
||||
// Returns a paginated list of all teams with member counts, owner info, and active sandbox counts.
|
||||
func (h *teamHandler) AdminListTeams(w http.ResponseWriter, r *http.Request) {
|
||||
page := 1
|
||||
if p := r.URL.Query().Get("page"); p != "" {
|
||||
if _, err := fmt.Sscanf(p, "%d", &page); err != nil || page < 1 {
|
||||
page = 1
|
||||
}
|
||||
}
|
||||
const perPage = 100
|
||||
offset := int32((page - 1) * perPage)
|
||||
|
||||
teams, total, err := h.svc.AdminListTeams(r.Context(), perPage, offset)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
type adminTeamResponse struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
IsByoc bool `json:"is_byoc"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
DeletedAt *string `json:"deleted_at"`
|
||||
MemberCount int32 `json:"member_count"`
|
||||
OwnerName string `json:"owner_name"`
|
||||
OwnerEmail string `json:"owner_email"`
|
||||
ActiveSandboxCount int32 `json:"active_sandbox_count"`
|
||||
ChannelCount int32 `json:"channel_count"`
|
||||
}
|
||||
|
||||
resp := make([]adminTeamResponse, len(teams))
|
||||
for i, t := range teams {
|
||||
r := adminTeamResponse{
|
||||
ID: id.FormatTeamID(t.ID),
|
||||
Name: t.Name,
|
||||
Slug: t.Slug,
|
||||
IsByoc: t.IsByoc,
|
||||
CreatedAt: t.CreatedAt.Format(time.RFC3339),
|
||||
MemberCount: t.MemberCount,
|
||||
OwnerName: t.OwnerName,
|
||||
OwnerEmail: t.OwnerEmail,
|
||||
ActiveSandboxCount: t.ActiveSandboxCount,
|
||||
ChannelCount: t.ChannelCount,
|
||||
}
|
||||
if t.DeletedAt != nil {
|
||||
s := t.DeletedAt.Format(time.RFC3339)
|
||||
r.DeletedAt = &s
|
||||
}
|
||||
resp[i] = r
|
||||
}
|
||||
|
||||
totalPages := (total + perPage - 1) / perPage
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"teams": resp,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": perPage,
|
||||
"total_pages": totalPages,
|
||||
})
|
||||
}
|
||||
|
||||
// AdminDeleteTeam handles DELETE /v1/admin/teams/{id}
|
||||
// Soft-deletes a team and destroys all its active sandboxes.
|
||||
func (h *teamHandler) AdminDeleteTeam(w http.ResponseWriter, r *http.Request) {
|
||||
teamIDStr := chi.URLParam(r, "id")
|
||||
|
||||
teamID, err := id.ParseTeamID(teamIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.AdminDeleteTeam(r.Context(), teamID); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
@ -1,22 +1,27 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
)
|
||||
|
||||
type usersHandler struct {
|
||||
db *db.Queries
|
||||
db *db.Queries
|
||||
svc *service.UserService
|
||||
}
|
||||
|
||||
func newUsersHandler(db *db.Queries) *usersHandler {
|
||||
return &usersHandler{db: db}
|
||||
func newUsersHandler(db *db.Queries, svc *service.UserService) *usersHandler {
|
||||
return &usersHandler{db: db, svc: svc}
|
||||
}
|
||||
|
||||
// Search handles GET /v1/users/search?email=<prefix>
|
||||
@ -50,3 +55,96 @@ func (h *usersHandler) Search(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// AdminListUsers handles GET /v1/admin/users?page=1
|
||||
// Returns a paginated list of all users with team counts.
|
||||
func (h *usersHandler) AdminListUsers(w http.ResponseWriter, r *http.Request) {
|
||||
page := 1
|
||||
if p := r.URL.Query().Get("page"); p != "" {
|
||||
if _, err := fmt.Sscanf(p, "%d", &page); err != nil || page < 1 {
|
||||
page = 1
|
||||
}
|
||||
}
|
||||
const perPage = 100
|
||||
offset := int32((page - 1) * perPage)
|
||||
|
||||
users, total, err := h.svc.AdminListUsers(r.Context(), perPage, offset)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
type adminUserResponse struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
TeamsJoined int32 `json:"teams_joined"`
|
||||
TeamsOwned int32 `json:"teams_owned"`
|
||||
}
|
||||
|
||||
resp := make([]adminUserResponse, len(users))
|
||||
for i, u := range users {
|
||||
resp[i] = adminUserResponse{
|
||||
ID: id.FormatUserID(u.ID),
|
||||
Email: u.Email,
|
||||
Name: u.Name,
|
||||
IsAdmin: u.IsAdmin,
|
||||
Status: u.Status,
|
||||
CreatedAt: u.CreatedAt.Format(time.RFC3339),
|
||||
TeamsJoined: u.TeamsJoined,
|
||||
TeamsOwned: u.TeamsOwned,
|
||||
}
|
||||
}
|
||||
|
||||
totalPages := (total + perPage - 1) / perPage
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"users": resp,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": perPage,
|
||||
"total_pages": totalPages,
|
||||
})
|
||||
}
|
||||
|
||||
// SetUserActive handles PUT /v1/admin/users/{id}/active
|
||||
// Enables or disables a user account. Admins cannot deactivate themselves.
|
||||
func (h *usersHandler) SetUserActive(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
userIDStr := chi.URLParam(r, "id")
|
||||
|
||||
userID, err := id.ParseUserID(userIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if ac.UserID == userID && !req.Active {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "cannot deactivate your own account")
|
||||
return
|
||||
}
|
||||
|
||||
newStatus := "active"
|
||||
if !req.Active {
|
||||
newStatus = "disabled"
|
||||
}
|
||||
|
||||
if err := h.svc.SetUserStatus(r.Context(), userID, newStatus); err != nil {
|
||||
httpStatus, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, httpStatus, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
109
internal/api/helpers_ws.go
Normal file
109
internal/api/helpers_ws.go
Normal file
@ -0,0 +1,109 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// isWebSocketUpgrade returns true if the request is a WebSocket upgrade.
|
||||
func isWebSocketUpgrade(r *http.Request) bool {
|
||||
return strings.EqualFold(r.Header.Get("Upgrade"), "websocket")
|
||||
}
|
||||
|
||||
// ctxKeyAdminWS is a context key for flagging admin WS routes.
|
||||
type ctxKeyAdminWS struct{}
|
||||
|
||||
// setAdminWSFlag marks the context as an admin WebSocket route.
|
||||
func setAdminWSFlag(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, ctxKeyAdminWS{}, true)
|
||||
}
|
||||
|
||||
// isAdminWSRoute checks if the request context was marked as admin WS.
|
||||
func isAdminWSRoute(ctx context.Context) bool {
|
||||
v, _ := ctx.Value(ctxKeyAdminWS{}).(bool)
|
||||
return v
|
||||
}
|
||||
|
||||
// wsAuthMsg is the first message a browser WS client sends to authenticate.
|
||||
type wsAuthMsg struct {
|
||||
Type string `json:"type"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
// wsAuthenticate reads a JWT auth message from the WebSocket and returns the
|
||||
// authenticated context. The caller must send this as the first message after
|
||||
// connecting.
|
||||
func wsAuthenticate(ctx context.Context, conn *websocket.Conn, jwtSecret []byte, queries *db.Queries) (auth.AuthContext, error) {
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
var msg wsAuthMsg
|
||||
if err := conn.ReadJSON(&msg); err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("read auth message: %w", err)
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Time{}) // clear deadline
|
||||
|
||||
if msg.Type != "auth" || msg.Token == "" {
|
||||
return auth.AuthContext{}, fmt.Errorf("first message must be type 'auth' with a token")
|
||||
}
|
||||
|
||||
claims, err := auth.VerifyJWT(jwtSecret, msg.Token)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("invalid or expired token: %w", err)
|
||||
}
|
||||
|
||||
teamID, err := id.ParseTeamID(claims.TeamID)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("invalid team ID in token: %w", err)
|
||||
}
|
||||
|
||||
userID, err := id.ParseUserID(claims.Subject)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("invalid user ID in token: %w", err)
|
||||
}
|
||||
|
||||
user, err := queries.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("user not found")
|
||||
}
|
||||
if user.Status != "active" {
|
||||
return auth.AuthContext{}, fmt.Errorf("account deactivated")
|
||||
}
|
||||
|
||||
return auth.AuthContext{
|
||||
TeamID: teamID,
|
||||
UserID: userID,
|
||||
Email: claims.Email,
|
||||
Name: claims.Name,
|
||||
Role: claims.Role,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// wsAuthenticateAdmin performs WS-based auth and verifies admin status,
|
||||
// returning an AuthContext with the platform team ID.
|
||||
func wsAuthenticateAdmin(ctx context.Context, conn *websocket.Conn, jwtSecret []byte, queries *db.Queries) (auth.AuthContext, error) {
|
||||
ac, err := wsAuthenticate(ctx, conn, jwtSecret, queries)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, err
|
||||
}
|
||||
|
||||
user, err := queries.GetUserByID(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("user not found")
|
||||
}
|
||||
if !user.IsAdmin {
|
||||
return auth.AuthContext{}, fmt.Errorf("admin access required")
|
||||
}
|
||||
|
||||
ac.TeamID = id.PlatformTeamID
|
||||
return ac, nil
|
||||
}
|
||||
@ -8,10 +8,10 @@ import (
|
||||
"connectrpc.com/connect"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ import (
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
)
|
||||
|
||||
// MetricsSampler records per-team sandbox resource usage to
|
||||
|
||||
@ -14,7 +14,7 @@ import (
|
||||
"connectrpc.com/connect"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
type errorResponse struct {
|
||||
@ -50,8 +50,12 @@ func agentErrToHTTP(err error) (int, string, string) {
|
||||
return http.StatusNotFound, "not_found", err.Error()
|
||||
case connect.CodeInvalidArgument:
|
||||
return http.StatusBadRequest, "invalid_request", err.Error()
|
||||
case connect.CodeFailedPrecondition:
|
||||
case connect.CodeFailedPrecondition, connect.CodeAlreadyExists:
|
||||
return http.StatusConflict, "conflict", err.Error()
|
||||
case connect.CodePermissionDenied:
|
||||
return http.StatusForbidden, "forbidden", err.Error()
|
||||
case connect.CodeUnimplemented:
|
||||
return http.StatusNotImplemented, "agent_error", err.Error()
|
||||
default:
|
||||
return http.StatusBadGateway, "agent_error", err.Error()
|
||||
}
|
||||
@ -90,21 +94,25 @@ func serviceErrToHTTP(err error) (int, string, string) {
|
||||
}
|
||||
|
||||
// Map well-known service error patterns.
|
||||
// Return generic messages for most cases to avoid leaking internal details.
|
||||
switch {
|
||||
case strings.Contains(msg, "not found"):
|
||||
return http.StatusNotFound, "not_found", msg
|
||||
case strings.Contains(msg, "not running"), strings.Contains(msg, "not paused"):
|
||||
return http.StatusConflict, "invalid_state", msg
|
||||
return http.StatusNotFound, "not_found", "resource not found"
|
||||
case strings.Contains(msg, "not running"):
|
||||
return http.StatusConflict, "invalid_state", "resource is not running"
|
||||
case strings.Contains(msg, "not paused"):
|
||||
return http.StatusConflict, "invalid_state", "resource is not paused"
|
||||
case strings.Contains(msg, "conflict:"):
|
||||
return http.StatusConflict, "conflict", msg
|
||||
return http.StatusConflict, "conflict", strings.TrimPrefix(msg, "conflict: ")
|
||||
case strings.Contains(msg, "forbidden"):
|
||||
return http.StatusForbidden, "forbidden", msg
|
||||
return http.StatusForbidden, "forbidden", "forbidden"
|
||||
case strings.Contains(msg, "invalid or expired"):
|
||||
return http.StatusUnauthorized, "unauthorized", msg
|
||||
return http.StatusUnauthorized, "unauthorized", "invalid or expired credentials"
|
||||
case strings.Contains(msg, "invalid"):
|
||||
return http.StatusBadRequest, "invalid_request", msg
|
||||
return http.StatusBadRequest, "invalid_request", "invalid request"
|
||||
default:
|
||||
return http.StatusInternalServerError, "internal_error", msg
|
||||
slog.Error("unhandled service error", "error", err)
|
||||
return http.StatusInternalServerError, "internal_error", "an internal error occurred"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -3,19 +3,47 @@ package api
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// injectPlatformTeam overwrites the AuthContext's TeamID with the platform
|
||||
// sentinel UUID. This lets existing team-scoped handlers (exec, files, pty,
|
||||
// metrics) work unchanged under admin routes. Must run after requireAdmin.
|
||||
func injectPlatformTeam() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if _, ok := auth.FromContext(r.Context()); !ok {
|
||||
// No auth context yet (WS upgrade); handler will inject platform team after WS auth.
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
ac.TeamID = id.PlatformTeamID
|
||||
ctx := auth.WithAuthContext(r.Context(), ac)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// requireAdmin validates that the authenticated user is a platform admin.
|
||||
// Must run after requireJWT (depends on AuthContext being present).
|
||||
// Re-validates against the DB — the JWT is_admin claim is for UI only;
|
||||
// the DB is the source of truth for admin access.
|
||||
// WebSocket upgrade requests without auth context are passed through —
|
||||
// admin WS handlers verify admin status after upgrade via wsAuthenticateAdmin.
|
||||
func requireAdmin(queries *db.Queries) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ac, ok := auth.FromContext(r.Context())
|
||||
if !ok {
|
||||
if isWebSocketUpgrade(r) {
|
||||
ctx := r.Context()
|
||||
ctx = setAdminWSFlag(ctx)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "authentication required")
|
||||
return
|
||||
}
|
||||
|
||||
@ -5,9 +5,9 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// requireAPIKeyOrJWT accepts either X-API-Key header or Authorization: Bearer JWT.
|
||||
@ -38,9 +38,12 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
|
||||
return
|
||||
}
|
||||
|
||||
// Try JWT bearer token.
|
||||
// Try JWT bearer token from Authorization header.
|
||||
tokenStr := ""
|
||||
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
|
||||
tokenStr := strings.TrimPrefix(header, "Bearer ")
|
||||
tokenStr = strings.TrimPrefix(header, "Bearer ")
|
||||
}
|
||||
if tokenStr != "" {
|
||||
claims, err := auth.VerifyJWT(jwtSecret, tokenStr)
|
||||
if err != nil {
|
||||
slog.Warn("jwt auth failed", "error", err, "ip", r.RemoteAddr)
|
||||
@ -59,6 +62,18 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
|
||||
return
|
||||
}
|
||||
|
||||
// Verify user is still active in the database.
|
||||
user, err := queries.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
slog.Warn("jwt auth: failed to look up user", "user_id", claims.Subject, "error", err)
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "user not found")
|
||||
return
|
||||
}
|
||||
if user.Status != "active" {
|
||||
writeError(w, http.StatusForbidden, "account_deactivated", "your account has been deactivated — contact your administrator to regain access")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: teamID,
|
||||
UserID: userID,
|
||||
@ -70,7 +85,15 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
|
||||
return
|
||||
}
|
||||
|
||||
// WebSocket upgrade requests may not carry auth headers (browsers
|
||||
// cannot set custom headers on WS connections). Pass through —
|
||||
// the WS handler authenticates via the first message after upgrade.
|
||||
if isWebSocketUpgrade(r) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key or Authorization: Bearer <token> required")
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3,8 +3,8 @@ package api
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// requireHostToken validates the X-Host-Token header containing a host JWT,
|
||||
|
||||
@ -1,25 +1,37 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// requireJWT validates the Authorization: Bearer <token> header, verifies the JWT
|
||||
// signature and expiry, and stamps UserID + TeamID + Email into the request context.
|
||||
func requireJWT(secret []byte) func(http.Handler) http.Handler {
|
||||
// requireJWT validates a JWT from the Authorization: Bearer header.
|
||||
// It also verifies the user is still active in the database.
|
||||
// WebSocket upgrade requests without an Authorization header are passed through
|
||||
// — WS handlers authenticate via the first message after upgrade.
|
||||
func requireJWT(secret []byte, queries *db.Queries) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
header := r.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(header, "Bearer ") {
|
||||
var tokenStr string
|
||||
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
|
||||
tokenStr = strings.TrimPrefix(header, "Bearer ")
|
||||
}
|
||||
if tokenStr == "" {
|
||||
// WebSocket upgrade requests may not have an Authorization header
|
||||
// (browsers cannot set custom headers on WS connections). Let them
|
||||
// through — the handler authenticates via the first WS message.
|
||||
if isWebSocketUpgrade(r) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "Authorization: Bearer <token> required")
|
||||
return
|
||||
}
|
||||
|
||||
tokenStr := strings.TrimPrefix(header, "Bearer ")
|
||||
claims, err := auth.VerifyJWT(secret, tokenStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired token")
|
||||
@ -37,6 +49,18 @@ func requireJWT(secret []byte) func(http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
// Verify user is still active in the database.
|
||||
user, err := queries.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
slog.Warn("jwt auth: failed to look up user", "user_id", claims.Subject, "error", err)
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "user not found")
|
||||
return
|
||||
}
|
||||
if user.Status != "active" {
|
||||
writeError(w, http.StatusForbidden, "account_deactivated", "your account has been deactivated — contact your administrator to regain access")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: teamID,
|
||||
UserID: userID,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -9,14 +9,16 @@ import (
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth/oauth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/channels"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/scheduler"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/service"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/email"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/oauth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/channels"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/scheduler"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
)
|
||||
|
||||
//go:embed openapi.yaml
|
||||
@ -26,9 +28,12 @@ var openapiYAML []byte
|
||||
type Server struct {
|
||||
router chi.Router
|
||||
BuildSvc *service.BuildService
|
||||
version string
|
||||
}
|
||||
|
||||
// New constructs the chi router and registers all routes.
|
||||
// Extensions are called after core routes are registered, allowing enterprise
|
||||
// or third-party code to add routes and middleware.
|
||||
func New(
|
||||
queries *db.Queries,
|
||||
pool *lifecycle.HostClientPool,
|
||||
@ -41,6 +46,10 @@ func New(
|
||||
ca *auth.CA,
|
||||
al *audit.AuditLogger,
|
||||
channelSvc *channels.Service,
|
||||
mailer email.Mailer,
|
||||
extensions []cpextension.Extension,
|
||||
sctx cpextension.ServerContext,
|
||||
version string,
|
||||
) *Server {
|
||||
r := chi.NewRouter()
|
||||
r.Use(requestLogger())
|
||||
@ -51,27 +60,39 @@ func New(
|
||||
templateSvc := &service.TemplateService{DB: queries}
|
||||
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool, CA: ca}
|
||||
teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool}
|
||||
userSvc := &service.UserService{DB: queries, SandboxSvc: sandboxSvc}
|
||||
auditSvc := &service.AuditService{DB: queries}
|
||||
statsSvc := &service.StatsService{DB: queries, Pool: pgPool}
|
||||
buildSvc := &service.BuildService{DB: queries, Redis: rdb, Pool: pool, Scheduler: sched}
|
||||
|
||||
sandbox := newSandboxHandler(sandboxSvc, al)
|
||||
exec := newExecHandler(queries, pool)
|
||||
execStream := newExecStreamHandler(queries, pool)
|
||||
execStream := newExecStreamHandler(queries, pool, jwtSecret)
|
||||
files := newFilesHandler(queries, pool)
|
||||
filesStream := newFilesStreamHandler(queries, pool)
|
||||
fsH := newFSHandler(queries, pool)
|
||||
snapshots := newSnapshotHandler(templateSvc, queries, pool, al)
|
||||
authH := newAuthHandler(queries, pgPool, jwtSecret)
|
||||
authH := newAuthHandler(queries, pgPool, jwtSecret, mailer, rdb, oauthRedirectURL)
|
||||
oauthH := newOAuthHandler(queries, pgPool, jwtSecret, oauthRegistry, oauthRedirectURL)
|
||||
apiKeys := newAPIKeyHandler(apiKeySvc, al)
|
||||
hostH := newHostHandler(hostSvc, queries, al)
|
||||
teamH := newTeamHandler(teamSvc, al)
|
||||
usersH := newUsersHandler(queries)
|
||||
teamH := newTeamHandler(teamSvc, al, mailer)
|
||||
usersH := newUsersHandler(queries, userSvc)
|
||||
auditH := newAuditHandler(auditSvc)
|
||||
statsH := newStatsHandler(statsSvc)
|
||||
metricsH := newSandboxMetricsHandler(queries, pool)
|
||||
buildH := newBuildHandler(buildSvc, queries, pool)
|
||||
channelH := newChannelHandler(channelSvc, al)
|
||||
ptyH := newPtyHandler(queries, pool, jwtSecret)
|
||||
processH := newProcessHandler(queries, pool, jwtSecret)
|
||||
adminCapsules := newAdminCapsuleHandler(sandboxSvc, queries, pool, al)
|
||||
meH := newMeHandler(queries, pgPool, rdb, jwtSecret, mailer, oauthRegistry, oauthRedirectURL, teamSvc)
|
||||
|
||||
// Health check.
|
||||
r.Get("/health", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprintf(w, `{"status":"ok","version":%q}`, version)
|
||||
})
|
||||
|
||||
// OpenAPI spec and docs.
|
||||
r.Get("/openapi.yaml", serveOpenAPI)
|
||||
@ -80,15 +101,31 @@ func New(
|
||||
// Unauthenticated auth endpoints.
|
||||
r.Post("/v1/auth/signup", authH.Signup)
|
||||
r.Post("/v1/auth/login", authH.Login)
|
||||
r.Post("/v1/auth/activate", authH.Activate)
|
||||
r.Get("/auth/oauth/{provider}", oauthH.Redirect)
|
||||
r.Get("/auth/oauth/{provider}/callback", oauthH.Callback)
|
||||
|
||||
// Unauthenticated: password reset request and confirmation.
|
||||
r.Post("/v1/me/password/reset", meH.RequestPasswordReset)
|
||||
r.Post("/v1/me/password/reset/confirm", meH.ConfirmPasswordReset)
|
||||
|
||||
// JWT-authenticated: self-service account management.
|
||||
r.Route("/v1/me", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Get("/", meH.GetMe)
|
||||
r.Patch("/", meH.UpdateName)
|
||||
r.Post("/password", meH.ChangePassword)
|
||||
r.Get("/providers/{provider}/connect", meH.ConnectProvider)
|
||||
r.Delete("/providers/{provider}", meH.DisconnectProvider)
|
||||
r.Delete("/", meH.DeleteAccount)
|
||||
})
|
||||
|
||||
// JWT-authenticated: switch active team.
|
||||
r.With(requireJWT(jwtSecret)).Post("/v1/auth/switch-team", authH.SwitchTeam)
|
||||
r.With(requireJWT(jwtSecret, queries)).Post("/v1/auth/switch-team", authH.SwitchTeam)
|
||||
|
||||
// JWT-authenticated: API key management.
|
||||
r.Route("/v1/api-keys", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret))
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Post("/", apiKeys.Create)
|
||||
r.Get("/", apiKeys.List)
|
||||
r.Delete("/{id}", apiKeys.Delete)
|
||||
@ -96,7 +133,7 @@ func New(
|
||||
|
||||
// JWT-authenticated: team management.
|
||||
r.Route("/v1/teams", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret))
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Get("/", teamH.List)
|
||||
r.Post("/", teamH.Create)
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
@ -112,10 +149,12 @@ func New(
|
||||
})
|
||||
|
||||
// JWT-authenticated: user search (for add-member UI).
|
||||
r.With(requireJWT(jwtSecret)).Get("/v1/users/search", usersH.Search)
|
||||
r.With(requireJWT(jwtSecret, queries)).Get("/v1/users/search", usersH.Search)
|
||||
|
||||
// Sandbox lifecycle: accepts API key or JWT bearer token.
|
||||
r.Route("/v1/sandboxes", func(r chi.Router) {
|
||||
// Capsule lifecycle: accepts API key or JWT bearer token.
|
||||
// WebSocket upgrade requests without auth headers are passed through by
|
||||
// requireAPIKeyOrJWT — the WS handlers authenticate via first message.
|
||||
r.Route("/v1/capsules", func(r chi.Router) {
|
||||
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
|
||||
r.Post("/", sandbox.Create)
|
||||
r.Get("/", sandbox.List)
|
||||
@ -133,7 +172,14 @@ func New(
|
||||
r.Post("/files/read", files.Download)
|
||||
r.Post("/files/stream/write", filesStream.StreamUpload)
|
||||
r.Post("/files/stream/read", filesStream.StreamDownload)
|
||||
r.Post("/files/list", fsH.ListDir)
|
||||
r.Post("/files/mkdir", fsH.MakeDir)
|
||||
r.Post("/files/remove", fsH.Remove)
|
||||
r.Get("/metrics", metricsH.GetMetrics)
|
||||
r.Get("/pty", ptyH.PtySession)
|
||||
r.Get("/processes", processH.ListProcesses)
|
||||
r.Delete("/processes/{selector}", processH.KillProcess)
|
||||
r.Get("/processes/{selector}/stream", processH.ConnectProcess)
|
||||
})
|
||||
})
|
||||
|
||||
@ -158,7 +204,7 @@ func New(
|
||||
|
||||
// JWT-authenticated: host CRUD and tags.
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret))
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Post("/", hostH.Create)
|
||||
r.Get("/", hostH.List)
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
@ -175,7 +221,7 @@ func New(
|
||||
|
||||
// JWT-authenticated: notification channels.
|
||||
r.Route("/v1/channels", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret))
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Post("/", channelH.Create)
|
||||
r.Get("/", channelH.List)
|
||||
r.Post("/test", channelH.Test)
|
||||
@ -188,22 +234,51 @@ func New(
|
||||
})
|
||||
|
||||
// JWT-authenticated: audit log.
|
||||
r.With(requireJWT(jwtSecret)).Get("/v1/audit-logs", auditH.List)
|
||||
r.With(requireJWT(jwtSecret, queries)).Get("/v1/audit-logs", auditH.List)
|
||||
|
||||
// Platform admin routes — require JWT + DB-validated admin status.
|
||||
r.Route("/v1/admin", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret))
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Use(requireAdmin(queries))
|
||||
r.Get("/teams", teamH.AdminListTeams)
|
||||
r.Put("/teams/{id}/byoc", teamH.SetBYOC)
|
||||
r.Delete("/teams/{id}", teamH.AdminDeleteTeam)
|
||||
r.Get("/users", usersH.AdminListUsers)
|
||||
r.Put("/users/{id}/active", usersH.SetUserActive)
|
||||
r.Get("/templates", buildH.ListTemplates)
|
||||
r.Delete("/templates/{name}", buildH.DeleteTemplate)
|
||||
r.Post("/builds", buildH.Create)
|
||||
r.Get("/builds", buildH.List)
|
||||
r.Get("/builds/{id}", buildH.Get)
|
||||
r.Post("/builds/{id}/cancel", buildH.Cancel)
|
||||
r.Post("/capsules", adminCapsules.Create)
|
||||
r.Get("/capsules", adminCapsules.List)
|
||||
r.Route("/capsules/{id}", func(r chi.Router) {
|
||||
r.Use(injectPlatformTeam())
|
||||
r.Get("/", adminCapsules.Get)
|
||||
r.Delete("/", adminCapsules.Destroy)
|
||||
r.Post("/snapshot", adminCapsules.Snapshot)
|
||||
r.Post("/exec", exec.Exec)
|
||||
r.Get("/exec/stream", execStream.ExecStream)
|
||||
r.Post("/files/write", files.Upload)
|
||||
r.Post("/files/read", files.Download)
|
||||
r.Post("/files/list", fsH.ListDir)
|
||||
r.Post("/files/mkdir", fsH.MakeDir)
|
||||
r.Post("/files/remove", fsH.Remove)
|
||||
r.Get("/metrics", metricsH.GetMetrics)
|
||||
r.Get("/pty", ptyH.PtySession)
|
||||
r.Get("/processes", processH.ListProcesses)
|
||||
r.Delete("/processes/{selector}", processH.KillProcess)
|
||||
r.Get("/processes/{selector}/stream", processH.ConnectProcess)
|
||||
})
|
||||
})
|
||||
|
||||
return &Server{router: r, BuildSvc: buildSvc}
|
||||
// Let extensions register their routes after all core routes.
|
||||
for _, ext := range extensions {
|
||||
ext.RegisterRoutes(r, sctx)
|
||||
}
|
||||
|
||||
return &Server{router: r, BuildSvc: buildSvc, version: version}
|
||||
}
|
||||
|
||||
// Handler returns the HTTP handler.
|
||||
@ -211,6 +286,11 @@ func (s *Server) Handler() http.Handler {
|
||||
return s.router
|
||||
}
|
||||
|
||||
// Router returns the underlying chi.Router for direct access.
|
||||
func (s *Server) Router() chi.Router {
|
||||
return s.router
|
||||
}
|
||||
|
||||
func serveOpenAPI(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/yaml")
|
||||
_, _ = w.Write(openapiYAML)
|
||||
@ -223,7 +303,7 @@ func serveDocs(w http.ResponseWriter, r *http.Request) {
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Wrenn Sandbox API</title>
|
||||
<title>Wrenn API</title>
|
||||
<link rel="stylesheet" href="https://unpkg.com/swagger-ui-dist@5.18.2/swagger-ui.css" integrity="sha384-rcbEi6xgdPk0iWkAQzT2F3FeBJXdG+ydrawGlfHAFIZG7wU6aKbQaRewysYpmrlW" crossorigin="anonymous">
|
||||
<style>
|
||||
body { margin: 0; background: #fafafa; }
|
||||
|
||||
Reference in New Issue
Block a user