forked from wrenn/wrenn
Merge branch 'dev' into chore/hardening
This commit is contained in:
@ -20,12 +20,13 @@ import (
|
||||
)
|
||||
|
||||
type execStreamHandler struct {
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
jwtSecret []byte
|
||||
}
|
||||
|
||||
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execStreamHandler {
|
||||
return &execStreamHandler{db: db, pool: pool}
|
||||
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *execStreamHandler {
|
||||
return &execStreamHandler{db: db, pool: pool, jwtSecret: jwtSecret}
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
@ -51,7 +52,6 @@ type wsOutMsg struct {
|
||||
func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
@ -59,13 +59,31 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
// Authenticate: use context from middleware (API key) or WS first message (JWT).
|
||||
ac, hasAuth := auth.FromContext(ctx)
|
||||
|
||||
if !hasAuth {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
var wsAC auth.AuthContext
|
||||
var authErr error
|
||||
if isAdminWSRoute(ctx) {
|
||||
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
|
||||
} else {
|
||||
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
|
||||
}
|
||||
if authErr != nil {
|
||||
sendWSError(conn, "authentication failed")
|
||||
return
|
||||
}
|
||||
ac = wsAC
|
||||
|
||||
h.runExecStream(ctx, conn, ac, sandboxID, sandboxIDStr)
|
||||
return
|
||||
}
|
||||
|
||||
@ -76,6 +94,20 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
h.runExecStream(ctx, conn, ac, sandboxID, sandboxIDStr)
|
||||
}
|
||||
|
||||
func (h *execStreamHandler) runExecStream(ctx context.Context, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr string) {
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
sendWSError(conn, "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
sendWSError(conn, "sandbox is not running (status: "+sb.Status+")")
|
||||
return
|
||||
}
|
||||
|
||||
// Read the start message.
|
||||
var startMsg wsStartMsg
|
||||
if err := conn.ReadJSON(&startMsg); err != nil {
|
||||
|
||||
@ -512,6 +512,9 @@ func (h *meHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Delete all teams the user solely owns (no other members).
|
||||
// Team deletion involves RPC calls (sandbox destruction) that cannot be
|
||||
// transactional, so we do those first as best-effort, then wrap the
|
||||
// DB-only cleanup in a transaction.
|
||||
soleTeams, err := h.db.ListSoleOwnedTeams(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list owned teams")
|
||||
@ -519,16 +522,36 @@ func (h *meHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
for _, teamID := range soleTeams {
|
||||
if err := h.teamSvc.DeleteTeamInternal(ctx, teamID); err != nil {
|
||||
slog.Warn("account delete: failed to delete sole-owned team",
|
||||
"team_id", id.FormatTeamID(teamID), "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "db_error",
|
||||
fmt.Sprintf("failed to delete sole-owned team %s", id.FormatTeamID(teamID)))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.db.SoftDeleteUser(ctx, ac.UserID); err != nil {
|
||||
tx, err := h.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to start transaction")
|
||||
return
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
qtx := h.db.WithTx(tx)
|
||||
|
||||
if err := qtx.DeleteAPIKeysByCreator(ctx, ac.UserID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete user's API keys")
|
||||
return
|
||||
}
|
||||
|
||||
if err := qtx.SoftDeleteUser(ctx, ac.UserID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete account")
|
||||
return
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to commit account deletion")
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("account soft-deleted", "user_id", id.FormatUserID(ac.UserID), "email", user.Email)
|
||||
|
||||
go func() {
|
||||
|
||||
@ -212,6 +212,11 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
if err == nil {
|
||||
// Existing OAuth user — log them in.
|
||||
user, err := h.db.GetUserByID(ctx, existing.UserID)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Warn("oauth login: user no longer exists", "user_id", existing.UserID)
|
||||
redirectWithError(w, r, redirectBase, "account_deactivated")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to get user", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
@ -222,13 +227,14 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
redirectWithError(w, r, redirectBase, "account_deactivated")
|
||||
return
|
||||
}
|
||||
team, role, err := loginTeam(ctx, h.db, user.ID)
|
||||
team, role, isFirstUser, err := ensureDefaultTeam(ctx, h.db, h.pool, user.ID, user.Name)
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to get team", "error", err)
|
||||
slog.Error("oauth login: failed to ensure team", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
|
||||
isAdmin := user.IsAdmin || isFirstUser
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
@ -376,6 +382,11 @@ func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, prov
|
||||
return
|
||||
}
|
||||
user, err := h.db.GetUserByID(ctx, existing.UserID)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Warn("oauth: retry login: user no longer exists", "user_id", existing.UserID)
|
||||
redirectWithError(w, r, redirectBase, "account_deactivated")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to get user", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
@ -386,13 +397,14 @@ func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, prov
|
||||
redirectWithError(w, r, redirectBase, "account_deactivated")
|
||||
return
|
||||
}
|
||||
team, role, err := loginTeam(ctx, h.db, user.ID)
|
||||
team, role, isFirstUser, err := ensureDefaultTeam(ctx, h.db, h.pool, user.ID, user.Name)
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to get team", "error", err)
|
||||
slog.Error("oauth: retry login: failed to ensure team", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
|
||||
isAdmin := user.IsAdmin || isFirstUser
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
|
||||
@ -20,12 +20,13 @@ import (
|
||||
)
|
||||
|
||||
type processHandler struct {
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
jwtSecret []byte
|
||||
}
|
||||
|
||||
func newProcessHandler(db *db.Queries, pool *lifecycle.HostClientPool) *processHandler {
|
||||
return &processHandler{db: db, pool: pool}
|
||||
func newProcessHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *processHandler {
|
||||
return &processHandler{db: db, pool: pool, jwtSecret: jwtSecret}
|
||||
}
|
||||
|
||||
// processResponse is a single entry in the process list.
|
||||
@ -158,7 +159,6 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request)
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
selectorStr := chi.URLParam(r, "selector")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
@ -166,19 +166,31 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
return
|
||||
}
|
||||
// Authenticate: use context from middleware (API key) or WS first message (JWT).
|
||||
ac, hasAuth := auth.FromContext(ctx)
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
if !hasAuth {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("process stream websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
var wsAC auth.AuthContext
|
||||
var authErr error
|
||||
if isAdminWSRoute(ctx) {
|
||||
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
|
||||
} else {
|
||||
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
|
||||
}
|
||||
if authErr != nil {
|
||||
sendProcessWSError(conn, "authentication failed")
|
||||
return
|
||||
}
|
||||
ac = wsAC
|
||||
|
||||
h.runConnectProcess(ctx, conn, ac, sandboxID, sandboxIDStr, selectorStr)
|
||||
return
|
||||
}
|
||||
|
||||
@ -189,6 +201,26 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
h.runConnectProcess(ctx, conn, ac, sandboxID, sandboxIDStr, selectorStr)
|
||||
}
|
||||
|
||||
func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr, selectorStr string) {
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
sendProcessWSError(conn, "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
sendProcessWSError(conn, "sandbox is not running (status: "+sb.Status+")")
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
sendProcessWSError(conn, "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
// Build the connect request with PID or tag selector.
|
||||
connectReq := &pb.ConnectProcessRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
|
||||
@ -30,12 +30,13 @@ const (
|
||||
)
|
||||
|
||||
type ptyHandler struct {
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
jwtSecret []byte
|
||||
}
|
||||
|
||||
func newPtyHandler(db *db.Queries, pool *lifecycle.HostClientPool) *ptyHandler {
|
||||
return &ptyHandler{db: db, pool: pool}
|
||||
func newPtyHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *ptyHandler {
|
||||
return &ptyHandler{db: db, pool: pool, jwtSecret: jwtSecret}
|
||||
}
|
||||
|
||||
// --- WebSocket message types ---
|
||||
@ -82,7 +83,6 @@ func (w *wsWriter) writeJSON(v any) {
|
||||
func (h *ptyHandler) PtySession(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
@ -90,13 +90,34 @@ func (h *ptyHandler) PtySession(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
// API key auth is handled by middleware (sets context).
|
||||
// For browser JWT auth, we authenticate after upgrade via first WS message.
|
||||
ac, hasAuth := auth.FromContext(ctx)
|
||||
|
||||
if !hasAuth {
|
||||
// No pre-upgrade auth — upgrade first, then authenticate via WS message.
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("pty websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ws := &wsWriter{conn: conn}
|
||||
|
||||
var wsAC auth.AuthContext
|
||||
if isAdminWSRoute(ctx) {
|
||||
wsAC, err = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
|
||||
} else {
|
||||
wsAC, err = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
|
||||
}
|
||||
if err != nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "authentication failed", Fatal: true})
|
||||
return
|
||||
}
|
||||
ac = wsAC
|
||||
|
||||
h.runPtySession(ctx, ws, conn, ac, sandboxID, sandboxIDStr)
|
||||
return
|
||||
}
|
||||
|
||||
@ -108,6 +129,19 @@ func (h *ptyHandler) PtySession(w http.ResponseWriter, r *http.Request) {
|
||||
defer conn.Close()
|
||||
|
||||
ws := &wsWriter{conn: conn}
|
||||
h.runPtySession(ctx, ws, conn, ac, sandboxID, sandboxIDStr)
|
||||
}
|
||||
|
||||
func (h *ptyHandler) runPtySession(ctx context.Context, ws *wsWriter, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr string) {
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "sandbox not found", Fatal: true})
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "sandbox is not running (status: " + sb.Status + ")", Fatal: true})
|
||||
return
|
||||
}
|
||||
|
||||
// Read the first message to determine start vs connect.
|
||||
var firstMsg wsPtyIn
|
||||
|
||||
@ -133,7 +133,6 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
overwrite := r.URL.Query().Get("overwrite") == "true"
|
||||
|
||||
// Check for global name collision.
|
||||
if _, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
|
||||
@ -142,20 +141,10 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Check if name already exists for this team.
|
||||
if existing, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
|
||||
if !overwrite {
|
||||
writeError(w, http.StatusConflict, "already_exists", "snapshot name already exists; use ?overwrite=true to replace")
|
||||
return
|
||||
}
|
||||
// Delete old snapshot files from all hosts before removing the DB record.
|
||||
if err := deleteSnapshotBroadcast(ctx, h.db, h.pool, existing.TeamID, existing.ID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files")
|
||||
return
|
||||
}
|
||||
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to remove existing template record")
|
||||
return
|
||||
}
|
||||
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
|
||||
writeError(w, http.StatusConflict, "template_name_taken",
|
||||
"snapshot name already exists; delete the existing snapshot first to reuse this name")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify sandbox exists, belongs to team, and is running or paused.
|
||||
|
||||
109
internal/api/helpers_ws.go
Normal file
109
internal/api/helpers_ws.go
Normal file
@ -0,0 +1,109 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// isWebSocketUpgrade returns true if the request is a WebSocket upgrade.
|
||||
func isWebSocketUpgrade(r *http.Request) bool {
|
||||
return strings.EqualFold(r.Header.Get("Upgrade"), "websocket")
|
||||
}
|
||||
|
||||
// ctxKeyAdminWS is a context key for flagging admin WS routes.
|
||||
type ctxKeyAdminWS struct{}
|
||||
|
||||
// setAdminWSFlag marks the context as an admin WebSocket route.
|
||||
func setAdminWSFlag(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, ctxKeyAdminWS{}, true)
|
||||
}
|
||||
|
||||
// isAdminWSRoute checks if the request context was marked as admin WS.
|
||||
func isAdminWSRoute(ctx context.Context) bool {
|
||||
v, _ := ctx.Value(ctxKeyAdminWS{}).(bool)
|
||||
return v
|
||||
}
|
||||
|
||||
// wsAuthMsg is the first message a browser WS client sends to authenticate.
|
||||
type wsAuthMsg struct {
|
||||
Type string `json:"type"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
// wsAuthenticate reads a JWT auth message from the WebSocket and returns the
|
||||
// authenticated context. The caller must send this as the first message after
|
||||
// connecting.
|
||||
func wsAuthenticate(ctx context.Context, conn *websocket.Conn, jwtSecret []byte, queries *db.Queries) (auth.AuthContext, error) {
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
var msg wsAuthMsg
|
||||
if err := conn.ReadJSON(&msg); err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("read auth message: %w", err)
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Time{}) // clear deadline
|
||||
|
||||
if msg.Type != "auth" || msg.Token == "" {
|
||||
return auth.AuthContext{}, fmt.Errorf("first message must be type 'auth' with a token")
|
||||
}
|
||||
|
||||
claims, err := auth.VerifyJWT(jwtSecret, msg.Token)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("invalid or expired token: %w", err)
|
||||
}
|
||||
|
||||
teamID, err := id.ParseTeamID(claims.TeamID)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("invalid team ID in token: %w", err)
|
||||
}
|
||||
|
||||
userID, err := id.ParseUserID(claims.Subject)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("invalid user ID in token: %w", err)
|
||||
}
|
||||
|
||||
user, err := queries.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("user not found")
|
||||
}
|
||||
if user.Status != "active" {
|
||||
return auth.AuthContext{}, fmt.Errorf("account deactivated")
|
||||
}
|
||||
|
||||
return auth.AuthContext{
|
||||
TeamID: teamID,
|
||||
UserID: userID,
|
||||
Email: claims.Email,
|
||||
Name: claims.Name,
|
||||
Role: claims.Role,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// wsAuthenticateAdmin performs WS-based auth and verifies admin status,
|
||||
// returning an AuthContext with the platform team ID.
|
||||
func wsAuthenticateAdmin(ctx context.Context, conn *websocket.Conn, jwtSecret []byte, queries *db.Queries) (auth.AuthContext, error) {
|
||||
ac, err := wsAuthenticate(ctx, conn, jwtSecret, queries)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, err
|
||||
}
|
||||
|
||||
user, err := queries.GetUserByID(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("user not found")
|
||||
}
|
||||
if !user.IsAdmin {
|
||||
return auth.AuthContext{}, fmt.Errorf("admin access required")
|
||||
}
|
||||
|
||||
ac.TeamID = id.PlatformTeamID
|
||||
return ac, nil
|
||||
}
|
||||
@ -94,21 +94,25 @@ func serviceErrToHTTP(err error) (int, string, string) {
|
||||
}
|
||||
|
||||
// Map well-known service error patterns.
|
||||
// Return generic messages for most cases to avoid leaking internal details.
|
||||
switch {
|
||||
case strings.Contains(msg, "not found"):
|
||||
return http.StatusNotFound, "not_found", msg
|
||||
case strings.Contains(msg, "not running"), strings.Contains(msg, "not paused"):
|
||||
return http.StatusConflict, "invalid_state", msg
|
||||
return http.StatusNotFound, "not_found", "resource not found"
|
||||
case strings.Contains(msg, "not running"):
|
||||
return http.StatusConflict, "invalid_state", "resource is not running"
|
||||
case strings.Contains(msg, "not paused"):
|
||||
return http.StatusConflict, "invalid_state", "resource is not paused"
|
||||
case strings.Contains(msg, "conflict:"):
|
||||
return http.StatusConflict, "conflict", msg
|
||||
return http.StatusConflict, "conflict", strings.TrimPrefix(msg, "conflict: ")
|
||||
case strings.Contains(msg, "forbidden"):
|
||||
return http.StatusForbidden, "forbidden", msg
|
||||
return http.StatusForbidden, "forbidden", "forbidden"
|
||||
case strings.Contains(msg, "invalid or expired"):
|
||||
return http.StatusUnauthorized, "unauthorized", msg
|
||||
return http.StatusUnauthorized, "unauthorized", "invalid or expired credentials"
|
||||
case strings.Contains(msg, "invalid"):
|
||||
return http.StatusBadRequest, "invalid_request", msg
|
||||
return http.StatusBadRequest, "invalid_request", "invalid request"
|
||||
default:
|
||||
return http.StatusInternalServerError, "internal_error", msg
|
||||
slog.Error("unhandled service error", "error", err)
|
||||
return http.StatusInternalServerError, "internal_error", "an internal error occurred"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -14,6 +14,11 @@ import (
|
||||
func injectPlatformTeam() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if _, ok := auth.FromContext(r.Context()); !ok {
|
||||
// No auth context yet (WS upgrade); handler will inject platform team after WS auth.
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
ac.TeamID = id.PlatformTeamID
|
||||
ctx := auth.WithAuthContext(r.Context(), ac)
|
||||
@ -26,11 +31,19 @@ func injectPlatformTeam() func(http.Handler) http.Handler {
|
||||
// Must run after requireJWT (depends on AuthContext being present).
|
||||
// Re-validates against the DB — the JWT is_admin claim is for UI only;
|
||||
// the DB is the source of truth for admin access.
|
||||
// WebSocket upgrade requests without auth context are passed through —
|
||||
// admin WS handlers verify admin status after upgrade via wsAuthenticateAdmin.
|
||||
func requireAdmin(queries *db.Queries) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ac, ok := auth.FromContext(r.Context())
|
||||
if !ok {
|
||||
if isWebSocketUpgrade(r) {
|
||||
ctx := r.Context()
|
||||
ctx = setAdminWSFlag(ctx)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "authentication required")
|
||||
return
|
||||
}
|
||||
|
||||
@ -38,12 +38,10 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
|
||||
return
|
||||
}
|
||||
|
||||
// Try JWT bearer token (header or query param for WebSocket).
|
||||
// Try JWT bearer token from Authorization header.
|
||||
tokenStr := ""
|
||||
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
|
||||
tokenStr = strings.TrimPrefix(header, "Bearer ")
|
||||
} else if t := r.URL.Query().Get("token"); t != "" {
|
||||
tokenStr = t
|
||||
}
|
||||
if tokenStr != "" {
|
||||
claims, err := auth.VerifyJWT(jwtSecret, tokenStr)
|
||||
@ -87,7 +85,15 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
|
||||
return
|
||||
}
|
||||
|
||||
// WebSocket upgrade requests may not carry auth headers (browsers
|
||||
// cannot set custom headers on WS connections). Pass through —
|
||||
// the WS handler authenticates via the first message after upgrade.
|
||||
if isWebSocketUpgrade(r) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key or Authorization: Bearer <token> required")
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -10,19 +10,25 @@ import (
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// requireJWT validates a JWT from the Authorization: Bearer header or the
|
||||
// ?token= query parameter (for WebSocket connections that cannot send headers).
|
||||
// requireJWT validates a JWT from the Authorization: Bearer header.
|
||||
// It also verifies the user is still active in the database.
|
||||
// WebSocket upgrade requests without an Authorization header are passed through
|
||||
// — WS handlers authenticate via the first message after upgrade.
|
||||
func requireJWT(secret []byte, queries *db.Queries) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var tokenStr string
|
||||
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
|
||||
tokenStr = strings.TrimPrefix(header, "Bearer ")
|
||||
} else if t := r.URL.Query().Get("token"); t != "" {
|
||||
tokenStr = t
|
||||
}
|
||||
if tokenStr == "" {
|
||||
// WebSocket upgrade requests may not have an Authorization header
|
||||
// (browsers cannot set custom headers on WS connections). Let them
|
||||
// through — the handler authenticates via the first WS message.
|
||||
if isWebSocketUpgrade(r) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "Authorization: Bearer <token> required")
|
||||
return
|
||||
}
|
||||
|
||||
@ -60,14 +60,14 @@ func New(
|
||||
templateSvc := &service.TemplateService{DB: queries}
|
||||
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool, CA: ca}
|
||||
teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool}
|
||||
userSvc := &service.UserService{DB: queries}
|
||||
userSvc := &service.UserService{DB: queries, SandboxSvc: sandboxSvc}
|
||||
auditSvc := &service.AuditService{DB: queries}
|
||||
statsSvc := &service.StatsService{DB: queries, Pool: pgPool}
|
||||
buildSvc := &service.BuildService{DB: queries, Redis: rdb, Pool: pool, Scheduler: sched}
|
||||
|
||||
sandbox := newSandboxHandler(sandboxSvc, al)
|
||||
exec := newExecHandler(queries, pool)
|
||||
execStream := newExecStreamHandler(queries, pool)
|
||||
execStream := newExecStreamHandler(queries, pool, jwtSecret)
|
||||
files := newFilesHandler(queries, pool)
|
||||
filesStream := newFilesStreamHandler(queries, pool)
|
||||
fsH := newFSHandler(queries, pool)
|
||||
@ -83,8 +83,8 @@ func New(
|
||||
metricsH := newSandboxMetricsHandler(queries, pool)
|
||||
buildH := newBuildHandler(buildSvc, queries, pool)
|
||||
channelH := newChannelHandler(channelSvc, al)
|
||||
ptyH := newPtyHandler(queries, pool)
|
||||
processH := newProcessHandler(queries, pool)
|
||||
ptyH := newPtyHandler(queries, pool, jwtSecret)
|
||||
processH := newProcessHandler(queries, pool, jwtSecret)
|
||||
adminCapsules := newAdminCapsuleHandler(sandboxSvc, queries, pool, al)
|
||||
meH := newMeHandler(queries, pgPool, rdb, jwtSecret, mailer, oauthRegistry, oauthRedirectURL, teamSvc)
|
||||
|
||||
@ -152,6 +152,8 @@ func New(
|
||||
r.With(requireJWT(jwtSecret, queries)).Get("/v1/users/search", usersH.Search)
|
||||
|
||||
// Capsule lifecycle: accepts API key or JWT bearer token.
|
||||
// WebSocket upgrade requests without auth headers are passed through by
|
||||
// requireAPIKeyOrJWT — the WS handlers authenticate via first message.
|
||||
r.Route("/v1/capsules", func(r chi.Router) {
|
||||
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
|
||||
r.Post("/", sandbox.Create)
|
||||
|
||||
@ -131,26 +131,31 @@ type Slot struct {
|
||||
}
|
||||
|
||||
// NewSlot computes the addressing for the given slot index (1-based).
|
||||
// Index must be in [1, 32767] so that veth offset (index*2) fits in 16 bits.
|
||||
func NewSlot(index int) *Slot {
|
||||
if index < 1 || index > 32767 {
|
||||
panic(fmt.Sprintf("slot index %d out of range [1, 32767]", index))
|
||||
}
|
||||
|
||||
hostBaseIP := net.ParseIP(hostBase).To4()
|
||||
vrtBaseIP := net.ParseIP(vrtBase).To4()
|
||||
|
||||
hostIP := make(net.IP, 4)
|
||||
copy(hostIP, hostBaseIP)
|
||||
hostIP[2] += byte(index >> 8)
|
||||
hostIP[3] += byte(index & 0xFF)
|
||||
hostIP[2] = hostBaseIP[2] + byte(index>>8)
|
||||
hostIP[3] = hostBaseIP[3] + byte(index&0xFF)
|
||||
|
||||
vethOffset := index * vrtAddressesPerSlot
|
||||
vethIP := make(net.IP, 4)
|
||||
copy(vethIP, vrtBaseIP)
|
||||
vethIP[2] += byte(vethOffset >> 8)
|
||||
vethIP[3] += byte(vethOffset & 0xFF)
|
||||
vethIP[2] = vrtBaseIP[2] + byte(vethOffset>>8)
|
||||
vethIP[3] = vrtBaseIP[3] + byte(vethOffset&0xFF)
|
||||
|
||||
vpeerOffset := vethOffset + 1
|
||||
vpeerIP := make(net.IP, 4)
|
||||
copy(vpeerIP, vrtBaseIP)
|
||||
vpeerIP[2] += byte(vpeerOffset >> 8)
|
||||
vpeerIP[3] += byte(vpeerOffset & 0xFF)
|
||||
vpeerIP[2] = vrtBaseIP[2] + byte(vpeerOffset>>8)
|
||||
vpeerIP[3] = vrtBaseIP[3] + byte(vpeerOffset&0xFF)
|
||||
|
||||
return &Slot{
|
||||
Index: index,
|
||||
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -17,6 +18,7 @@ type VM struct {
|
||||
|
||||
// Manager handles the lifecycle of Firecracker microVMs.
|
||||
type Manager struct {
|
||||
mu sync.RWMutex
|
||||
// vms tracks running VMs by sandbox ID.
|
||||
vms map[string]*VM
|
||||
}
|
||||
@ -84,7 +86,9 @@ func (m *Manager) Create(ctx context.Context, cfg VMConfig) (*VM, error) {
|
||||
client: client,
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.vms[cfg.SandboxID] = vm
|
||||
m.mu.Unlock()
|
||||
|
||||
slog.Info("VM started successfully", "sandbox", cfg.SandboxID)
|
||||
|
||||
@ -126,7 +130,9 @@ func configureVM(ctx context.Context, client *fcClient, cfg *VMConfig) error {
|
||||
|
||||
// Pause pauses a running VM.
|
||||
func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
|
||||
m.mu.RLock()
|
||||
vm, ok := m.vms[sandboxID]
|
||||
m.mu.RUnlock()
|
||||
if !ok {
|
||||
return fmt.Errorf("VM not found: %s", sandboxID)
|
||||
}
|
||||
@ -141,7 +147,9 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
|
||||
|
||||
// Resume resumes a paused VM.
|
||||
func (m *Manager) Resume(ctx context.Context, sandboxID string) error {
|
||||
m.mu.RLock()
|
||||
vm, ok := m.vms[sandboxID]
|
||||
m.mu.RUnlock()
|
||||
if !ok {
|
||||
return fmt.Errorf("VM not found: %s", sandboxID)
|
||||
}
|
||||
@ -156,10 +164,14 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string) error {
|
||||
|
||||
// Destroy stops and cleans up a VM.
|
||||
func (m *Manager) Destroy(ctx context.Context, sandboxID string) error {
|
||||
m.mu.Lock()
|
||||
vm, ok := m.vms[sandboxID]
|
||||
if !ok {
|
||||
m.mu.Unlock()
|
||||
return fmt.Errorf("VM not found: %s", sandboxID)
|
||||
}
|
||||
delete(m.vms, sandboxID)
|
||||
m.mu.Unlock()
|
||||
|
||||
slog.Info("destroying VM", "sandbox", sandboxID)
|
||||
|
||||
@ -171,8 +183,6 @@ func (m *Manager) Destroy(ctx context.Context, sandboxID string) error {
|
||||
// Clean up the API socket.
|
||||
os.Remove(vm.Config.SocketPath)
|
||||
|
||||
delete(m.vms, sandboxID)
|
||||
|
||||
slog.Info("VM destroyed", "sandbox", sandboxID)
|
||||
return nil
|
||||
}
|
||||
@ -180,7 +190,9 @@ func (m *Manager) Destroy(ctx context.Context, sandboxID string) error {
|
||||
// Snapshot creates a VM snapshot. The VM must already be paused.
|
||||
// snapshotType is "Full" (all memory) or "Diff" (only dirty pages since last resume).
|
||||
func (m *Manager) Snapshot(ctx context.Context, sandboxID, snapPath, memPath, snapshotType string) error {
|
||||
m.mu.RLock()
|
||||
vm, ok := m.vms[sandboxID]
|
||||
m.mu.RUnlock()
|
||||
if !ok {
|
||||
return fmt.Errorf("VM not found: %s", sandboxID)
|
||||
}
|
||||
@ -263,7 +275,9 @@ func (m *Manager) CreateFromSnapshot(ctx context.Context, cfg VMConfig, snapPath
|
||||
client: client,
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.vms[cfg.SandboxID] = vm
|
||||
m.mu.Unlock()
|
||||
|
||||
slog.Info("VM restored from snapshot", "sandbox", cfg.SandboxID)
|
||||
return vm, nil
|
||||
@ -277,7 +291,9 @@ func (v *VM) PID() int {
|
||||
|
||||
// Get returns a running VM by sandbox ID.
|
||||
func (m *Manager) Get(sandboxID string) (*VM, bool) {
|
||||
m.mu.RLock()
|
||||
vm, ok := m.vms[sandboxID]
|
||||
m.mu.RUnlock()
|
||||
return vm, ok
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user