1
0
forked from wrenn/wrenn

Merge branch 'dev' into chore/hardening

This commit is contained in:
2026-04-16 12:58:48 +00:00
52 changed files with 795 additions and 708 deletions

View File

@ -20,12 +20,13 @@ import (
)
type execStreamHandler struct {
db *db.Queries
pool *lifecycle.HostClientPool
db *db.Queries
pool *lifecycle.HostClientPool
jwtSecret []byte
}
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execStreamHandler {
return &execStreamHandler{db: db, pool: pool}
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *execStreamHandler {
return &execStreamHandler{db: db, pool: pool, jwtSecret: jwtSecret}
}
var upgrader = websocket.Upgrader{
@ -51,7 +52,6 @@ type wsOutMsg struct {
func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
@ -59,13 +59,31 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
// Authenticate: use context from middleware (API key) or WS first message (JWT).
ac, hasAuth := auth.FromContext(ctx)
if !hasAuth {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("websocket upgrade failed", "error", err)
return
}
defer conn.Close()
var wsAC auth.AuthContext
var authErr error
if isAdminWSRoute(ctx) {
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
} else {
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
}
if authErr != nil {
sendWSError(conn, "authentication failed")
return
}
ac = wsAC
h.runExecStream(ctx, conn, ac, sandboxID, sandboxIDStr)
return
}
@ -76,6 +94,20 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
}
defer conn.Close()
h.runExecStream(ctx, conn, ac, sandboxID, sandboxIDStr)
}
func (h *execStreamHandler) runExecStream(ctx context.Context, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr string) {
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
sendWSError(conn, "sandbox not found")
return
}
if sb.Status != "running" {
sendWSError(conn, "sandbox is not running (status: "+sb.Status+")")
return
}
// Read the start message.
var startMsg wsStartMsg
if err := conn.ReadJSON(&startMsg); err != nil {

View File

@ -512,6 +512,9 @@ func (h *meHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
}
// Delete all teams the user solely owns (no other members).
// Team deletion involves RPC calls (sandbox destruction) that cannot be
// transactional, so we do those first as best-effort, then wrap the
// DB-only cleanup in a transaction.
soleTeams, err := h.db.ListSoleOwnedTeams(ctx, ac.UserID)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to list owned teams")
@ -519,16 +522,36 @@ func (h *meHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
}
for _, teamID := range soleTeams {
if err := h.teamSvc.DeleteTeamInternal(ctx, teamID); err != nil {
slog.Warn("account delete: failed to delete sole-owned team",
"team_id", id.FormatTeamID(teamID), "error", err)
writeError(w, http.StatusInternalServerError, "db_error",
fmt.Sprintf("failed to delete sole-owned team %s", id.FormatTeamID(teamID)))
return
}
}
if err := h.db.SoftDeleteUser(ctx, ac.UserID); err != nil {
tx, err := h.pool.Begin(ctx)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to start transaction")
return
}
defer tx.Rollback(ctx)
qtx := h.db.WithTx(tx)
if err := qtx.DeleteAPIKeysByCreator(ctx, ac.UserID); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete user's API keys")
return
}
if err := qtx.SoftDeleteUser(ctx, ac.UserID); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete account")
return
}
if err := tx.Commit(ctx); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to commit account deletion")
return
}
slog.Info("account soft-deleted", "user_id", id.FormatUserID(ac.UserID), "email", user.Email)
go func() {

View File

@ -212,6 +212,11 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
if err == nil {
// Existing OAuth user — log them in.
user, err := h.db.GetUserByID(ctx, existing.UserID)
if errors.Is(err, pgx.ErrNoRows) {
slog.Warn("oauth login: user no longer exists", "user_id", existing.UserID)
redirectWithError(w, r, redirectBase, "account_deactivated")
return
}
if err != nil {
slog.Error("oauth login: failed to get user", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
@ -222,13 +227,14 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
redirectWithError(w, r, redirectBase, "account_deactivated")
return
}
team, role, err := loginTeam(ctx, h.db, user.ID)
team, role, isFirstUser, err := ensureDefaultTeam(ctx, h.db, h.pool, user.ID, user.Name)
if err != nil {
slog.Error("oauth login: failed to get team", "error", err)
slog.Error("oauth login: failed to ensure team", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
}
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
isAdmin := user.IsAdmin || isFirstUser
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
if err != nil {
slog.Error("oauth login: failed to sign jwt", "error", err)
redirectWithError(w, r, redirectBase, "internal_error")
@ -376,6 +382,11 @@ func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, prov
return
}
user, err := h.db.GetUserByID(ctx, existing.UserID)
if errors.Is(err, pgx.ErrNoRows) {
slog.Warn("oauth: retry login: user no longer exists", "user_id", existing.UserID)
redirectWithError(w, r, redirectBase, "account_deactivated")
return
}
if err != nil {
slog.Error("oauth: retry login: failed to get user", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
@ -386,13 +397,14 @@ func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, prov
redirectWithError(w, r, redirectBase, "account_deactivated")
return
}
team, role, err := loginTeam(ctx, h.db, user.ID)
team, role, isFirstUser, err := ensureDefaultTeam(ctx, h.db, h.pool, user.ID, user.Name)
if err != nil {
slog.Error("oauth: retry login: failed to get team", "error", err)
slog.Error("oauth: retry login: failed to ensure team", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
}
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
isAdmin := user.IsAdmin || isFirstUser
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
if err != nil {
slog.Error("oauth: retry login: failed to sign jwt", "error", err)
redirectWithError(w, r, redirectBase, "internal_error")

View File

@ -20,12 +20,13 @@ import (
)
type processHandler struct {
db *db.Queries
pool *lifecycle.HostClientPool
db *db.Queries
pool *lifecycle.HostClientPool
jwtSecret []byte
}
func newProcessHandler(db *db.Queries, pool *lifecycle.HostClientPool) *processHandler {
return &processHandler{db: db, pool: pool}
func newProcessHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *processHandler {
return &processHandler{db: db, pool: pool, jwtSecret: jwtSecret}
}
// processResponse is a single entry in the process list.
@ -158,7 +159,6 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request)
sandboxIDStr := chi.URLParam(r, "id")
selectorStr := chi.URLParam(r, "selector")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
@ -166,19 +166,31 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request)
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
return
}
// Authenticate: use context from middleware (API key) or WS first message (JWT).
ac, hasAuth := auth.FromContext(ctx)
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
if !hasAuth {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("process stream websocket upgrade failed", "error", err)
return
}
defer conn.Close()
var wsAC auth.AuthContext
var authErr error
if isAdminWSRoute(ctx) {
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
} else {
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
}
if authErr != nil {
sendProcessWSError(conn, "authentication failed")
return
}
ac = wsAC
h.runConnectProcess(ctx, conn, ac, sandboxID, sandboxIDStr, selectorStr)
return
}
@ -189,6 +201,26 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request)
}
defer conn.Close()
h.runConnectProcess(ctx, conn, ac, sandboxID, sandboxIDStr, selectorStr)
}
func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr, selectorStr string) {
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
sendProcessWSError(conn, "sandbox not found")
return
}
if sb.Status != "running" {
sendProcessWSError(conn, "sandbox is not running (status: "+sb.Status+")")
return
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
sendProcessWSError(conn, "sandbox host is not reachable")
return
}
// Build the connect request with PID or tag selector.
connectReq := &pb.ConnectProcessRequest{
SandboxId: sandboxIDStr,

View File

@ -30,12 +30,13 @@ const (
)
type ptyHandler struct {
db *db.Queries
pool *lifecycle.HostClientPool
db *db.Queries
pool *lifecycle.HostClientPool
jwtSecret []byte
}
func newPtyHandler(db *db.Queries, pool *lifecycle.HostClientPool) *ptyHandler {
return &ptyHandler{db: db, pool: pool}
func newPtyHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *ptyHandler {
return &ptyHandler{db: db, pool: pool, jwtSecret: jwtSecret}
}
// --- WebSocket message types ---
@ -82,7 +83,6 @@ func (w *wsWriter) writeJSON(v any) {
func (h *ptyHandler) PtySession(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
@ -90,13 +90,34 @@ func (h *ptyHandler) PtySession(w http.ResponseWriter, r *http.Request) {
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
// API key auth is handled by middleware (sets context).
// For browser JWT auth, we authenticate after upgrade via first WS message.
ac, hasAuth := auth.FromContext(ctx)
if !hasAuth {
// No pre-upgrade auth — upgrade first, then authenticate via WS message.
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("pty websocket upgrade failed", "error", err)
return
}
defer conn.Close()
ws := &wsWriter{conn: conn}
var wsAC auth.AuthContext
if isAdminWSRoute(ctx) {
wsAC, err = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
} else {
wsAC, err = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
}
if err != nil {
ws.writeJSON(wsPtyOut{Type: "error", Data: "authentication failed", Fatal: true})
return
}
ac = wsAC
h.runPtySession(ctx, ws, conn, ac, sandboxID, sandboxIDStr)
return
}
@ -108,6 +129,19 @@ func (h *ptyHandler) PtySession(w http.ResponseWriter, r *http.Request) {
defer conn.Close()
ws := &wsWriter{conn: conn}
h.runPtySession(ctx, ws, conn, ac, sandboxID, sandboxIDStr)
}
func (h *ptyHandler) runPtySession(ctx context.Context, ws *wsWriter, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr string) {
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
ws.writeJSON(wsPtyOut{Type: "error", Data: "sandbox not found", Fatal: true})
return
}
if sb.Status != "running" {
ws.writeJSON(wsPtyOut{Type: "error", Data: "sandbox is not running (status: " + sb.Status + ")", Fatal: true})
return
}
// Read the first message to determine start vs connect.
var firstMsg wsPtyIn

View File

@ -133,7 +133,6 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ac := auth.MustFromContext(ctx)
overwrite := r.URL.Query().Get("overwrite") == "true"
// Check for global name collision.
if _, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
@ -142,20 +141,10 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
}
// Check if name already exists for this team.
if existing, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
if !overwrite {
writeError(w, http.StatusConflict, "already_exists", "snapshot name already exists; use ?overwrite=true to replace")
return
}
// Delete old snapshot files from all hosts before removing the DB record.
if err := deleteSnapshotBroadcast(ctx, h.db, h.pool, existing.TeamID, existing.ID); err != nil {
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files")
return
}
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to remove existing template record")
return
}
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
writeError(w, http.StatusConflict, "template_name_taken",
"snapshot name already exists; delete the existing snapshot first to reuse this name")
return
}
// Verify sandbox exists, belongs to team, and is running or paused.

109
internal/api/helpers_ws.go Normal file
View File

@ -0,0 +1,109 @@
package api
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/gorilla/websocket"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// isWebSocketUpgrade returns true if the request is a WebSocket upgrade.
func isWebSocketUpgrade(r *http.Request) bool {
return strings.EqualFold(r.Header.Get("Upgrade"), "websocket")
}
// ctxKeyAdminWS is a context key for flagging admin WS routes.
type ctxKeyAdminWS struct{}
// setAdminWSFlag marks the context as an admin WebSocket route.
func setAdminWSFlag(ctx context.Context) context.Context {
return context.WithValue(ctx, ctxKeyAdminWS{}, true)
}
// isAdminWSRoute checks if the request context was marked as admin WS.
func isAdminWSRoute(ctx context.Context) bool {
v, _ := ctx.Value(ctxKeyAdminWS{}).(bool)
return v
}
// wsAuthMsg is the first message a browser WS client sends to authenticate.
type wsAuthMsg struct {
Type string `json:"type"`
Token string `json:"token"`
}
// wsAuthenticate reads a JWT auth message from the WebSocket and returns the
// authenticated context. The caller must send this as the first message after
// connecting.
func wsAuthenticate(ctx context.Context, conn *websocket.Conn, jwtSecret []byte, queries *db.Queries) (auth.AuthContext, error) {
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
var msg wsAuthMsg
if err := conn.ReadJSON(&msg); err != nil {
return auth.AuthContext{}, fmt.Errorf("read auth message: %w", err)
}
conn.SetReadDeadline(time.Time{}) // clear deadline
if msg.Type != "auth" || msg.Token == "" {
return auth.AuthContext{}, fmt.Errorf("first message must be type 'auth' with a token")
}
claims, err := auth.VerifyJWT(jwtSecret, msg.Token)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("invalid or expired token: %w", err)
}
teamID, err := id.ParseTeamID(claims.TeamID)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("invalid team ID in token: %w", err)
}
userID, err := id.ParseUserID(claims.Subject)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("invalid user ID in token: %w", err)
}
user, err := queries.GetUserByID(ctx, userID)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("user not found")
}
if user.Status != "active" {
return auth.AuthContext{}, fmt.Errorf("account deactivated")
}
return auth.AuthContext{
TeamID: teamID,
UserID: userID,
Email: claims.Email,
Name: claims.Name,
Role: claims.Role,
}, nil
}
// wsAuthenticateAdmin performs WS-based auth and verifies admin status,
// returning an AuthContext with the platform team ID.
func wsAuthenticateAdmin(ctx context.Context, conn *websocket.Conn, jwtSecret []byte, queries *db.Queries) (auth.AuthContext, error) {
ac, err := wsAuthenticate(ctx, conn, jwtSecret, queries)
if err != nil {
return auth.AuthContext{}, err
}
user, err := queries.GetUserByID(ctx, ac.UserID)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("user not found")
}
if !user.IsAdmin {
return auth.AuthContext{}, fmt.Errorf("admin access required")
}
ac.TeamID = id.PlatformTeamID
return ac, nil
}

View File

@ -94,21 +94,25 @@ func serviceErrToHTTP(err error) (int, string, string) {
}
// Map well-known service error patterns.
// Return generic messages for most cases to avoid leaking internal details.
switch {
case strings.Contains(msg, "not found"):
return http.StatusNotFound, "not_found", msg
case strings.Contains(msg, "not running"), strings.Contains(msg, "not paused"):
return http.StatusConflict, "invalid_state", msg
return http.StatusNotFound, "not_found", "resource not found"
case strings.Contains(msg, "not running"):
return http.StatusConflict, "invalid_state", "resource is not running"
case strings.Contains(msg, "not paused"):
return http.StatusConflict, "invalid_state", "resource is not paused"
case strings.Contains(msg, "conflict:"):
return http.StatusConflict, "conflict", msg
return http.StatusConflict, "conflict", strings.TrimPrefix(msg, "conflict: ")
case strings.Contains(msg, "forbidden"):
return http.StatusForbidden, "forbidden", msg
return http.StatusForbidden, "forbidden", "forbidden"
case strings.Contains(msg, "invalid or expired"):
return http.StatusUnauthorized, "unauthorized", msg
return http.StatusUnauthorized, "unauthorized", "invalid or expired credentials"
case strings.Contains(msg, "invalid"):
return http.StatusBadRequest, "invalid_request", msg
return http.StatusBadRequest, "invalid_request", "invalid request"
default:
return http.StatusInternalServerError, "internal_error", msg
slog.Error("unhandled service error", "error", err)
return http.StatusInternalServerError, "internal_error", "an internal error occurred"
}
}

View File

@ -14,6 +14,11 @@ import (
func injectPlatformTeam() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, ok := auth.FromContext(r.Context()); !ok {
// No auth context yet (WS upgrade); handler will inject platform team after WS auth.
next.ServeHTTP(w, r)
return
}
ac := auth.MustFromContext(r.Context())
ac.TeamID = id.PlatformTeamID
ctx := auth.WithAuthContext(r.Context(), ac)
@ -26,11 +31,19 @@ func injectPlatformTeam() func(http.Handler) http.Handler {
// Must run after requireJWT (depends on AuthContext being present).
// Re-validates against the DB — the JWT is_admin claim is for UI only;
// the DB is the source of truth for admin access.
// WebSocket upgrade requests without auth context are passed through —
// admin WS handlers verify admin status after upgrade via wsAuthenticateAdmin.
func requireAdmin(queries *db.Queries) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ac, ok := auth.FromContext(r.Context())
if !ok {
if isWebSocketUpgrade(r) {
ctx := r.Context()
ctx = setAdminWSFlag(ctx)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
writeError(w, http.StatusUnauthorized, "unauthorized", "authentication required")
return
}

View File

@ -38,12 +38,10 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
return
}
// Try JWT bearer token (header or query param for WebSocket).
// Try JWT bearer token from Authorization header.
tokenStr := ""
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
tokenStr = strings.TrimPrefix(header, "Bearer ")
} else if t := r.URL.Query().Get("token"); t != "" {
tokenStr = t
}
if tokenStr != "" {
claims, err := auth.VerifyJWT(jwtSecret, tokenStr)
@ -87,7 +85,15 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
return
}
// WebSocket upgrade requests may not carry auth headers (browsers
// cannot set custom headers on WS connections). Pass through —
// the WS handler authenticates via the first message after upgrade.
if isWebSocketUpgrade(r) {
next.ServeHTTP(w, r)
return
}
writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key or Authorization: Bearer <token> required")
})
}
}
}

View File

@ -10,19 +10,25 @@ import (
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// requireJWT validates a JWT from the Authorization: Bearer header or the
// ?token= query parameter (for WebSocket connections that cannot send headers).
// requireJWT validates a JWT from the Authorization: Bearer header.
// It also verifies the user is still active in the database.
// WebSocket upgrade requests without an Authorization header are passed through
// — WS handlers authenticate via the first message after upgrade.
func requireJWT(secret []byte, queries *db.Queries) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var tokenStr string
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
tokenStr = strings.TrimPrefix(header, "Bearer ")
} else if t := r.URL.Query().Get("token"); t != "" {
tokenStr = t
}
if tokenStr == "" {
// WebSocket upgrade requests may not have an Authorization header
// (browsers cannot set custom headers on WS connections). Let them
// through — the handler authenticates via the first WS message.
if isWebSocketUpgrade(r) {
next.ServeHTTP(w, r)
return
}
writeError(w, http.StatusUnauthorized, "unauthorized", "Authorization: Bearer <token> required")
return
}

View File

@ -60,14 +60,14 @@ func New(
templateSvc := &service.TemplateService{DB: queries}
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool, CA: ca}
teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool}
userSvc := &service.UserService{DB: queries}
userSvc := &service.UserService{DB: queries, SandboxSvc: sandboxSvc}
auditSvc := &service.AuditService{DB: queries}
statsSvc := &service.StatsService{DB: queries, Pool: pgPool}
buildSvc := &service.BuildService{DB: queries, Redis: rdb, Pool: pool, Scheduler: sched}
sandbox := newSandboxHandler(sandboxSvc, al)
exec := newExecHandler(queries, pool)
execStream := newExecStreamHandler(queries, pool)
execStream := newExecStreamHandler(queries, pool, jwtSecret)
files := newFilesHandler(queries, pool)
filesStream := newFilesStreamHandler(queries, pool)
fsH := newFSHandler(queries, pool)
@ -83,8 +83,8 @@ func New(
metricsH := newSandboxMetricsHandler(queries, pool)
buildH := newBuildHandler(buildSvc, queries, pool)
channelH := newChannelHandler(channelSvc, al)
ptyH := newPtyHandler(queries, pool)
processH := newProcessHandler(queries, pool)
ptyH := newPtyHandler(queries, pool, jwtSecret)
processH := newProcessHandler(queries, pool, jwtSecret)
adminCapsules := newAdminCapsuleHandler(sandboxSvc, queries, pool, al)
meH := newMeHandler(queries, pgPool, rdb, jwtSecret, mailer, oauthRegistry, oauthRedirectURL, teamSvc)
@ -152,6 +152,8 @@ func New(
r.With(requireJWT(jwtSecret, queries)).Get("/v1/users/search", usersH.Search)
// Capsule lifecycle: accepts API key or JWT bearer token.
// WebSocket upgrade requests without auth headers are passed through by
// requireAPIKeyOrJWT — the WS handlers authenticate via first message.
r.Route("/v1/capsules", func(r chi.Router) {
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
r.Post("/", sandbox.Create)