1
0
forked from wrenn/wrenn

Fix concurrency, security, and correctness issues across backend and frontend

- C1: Add sync.RWMutex to vm.Manager to protect concurrent vms map access
- H1: Fix IP arithmetic overflow in network slot addressing (byte truncation)
- H5: Fix MultiplexedChannel.Fork() TOCTOU race (move exited check inside lock)
- H8: Remove snapshot overwrite — return template_name_taken conflict instead
- H9: Wrap DeleteAccount DB ops in a transaction, make team deletion fatal
- H10: Sanitize serviceErrToHTTP to stop leaking internal error messages
- H11: Add deleted_at IS NULL to GetUserByEmail/GetUserByID queries
- H12: Add id DESC to audit log composite index for cursor pagination
- H15: Delete dead AuthModal.svelte component
- H17: Move JWT from WebSocket URL query param to first WS message
- H18: Fix $derived to $derived.by in FilesTab breadcrumbs
This commit is contained in:
2026-04-16 06:11:42 +06:00
parent ed2222c80c
commit 9ea847923c
39 changed files with 532 additions and 380 deletions

View File

@ -20,12 +20,13 @@ import (
)
type execStreamHandler struct {
db *db.Queries
pool *lifecycle.HostClientPool
db *db.Queries
pool *lifecycle.HostClientPool
jwtSecret []byte
}
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execStreamHandler {
return &execStreamHandler{db: db, pool: pool}
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *execStreamHandler {
return &execStreamHandler{db: db, pool: pool, jwtSecret: jwtSecret}
}
var upgrader = websocket.Upgrader{
@ -51,7 +52,6 @@ type wsOutMsg struct {
func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
@ -59,13 +59,31 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
// Authenticate: use context from middleware (API key) or WS first message (JWT).
ac, hasAuth := auth.FromContext(ctx)
if !hasAuth {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("websocket upgrade failed", "error", err)
return
}
defer conn.Close()
var wsAC auth.AuthContext
var authErr error
if isAdminWSRoute(ctx) {
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
} else {
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
}
if authErr != nil {
sendWSError(conn, "authentication failed")
return
}
ac = wsAC
h.runExecStream(ctx, conn, ac, sandboxID, sandboxIDStr)
return
}
@ -76,6 +94,20 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
}
defer conn.Close()
h.runExecStream(ctx, conn, ac, sandboxID, sandboxIDStr)
}
func (h *execStreamHandler) runExecStream(ctx context.Context, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr string) {
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
sendWSError(conn, "sandbox not found")
return
}
if sb.Status != "running" {
sendWSError(conn, "sandbox is not running (status: "+sb.Status+")")
return
}
// Read the start message.
var startMsg wsStartMsg
if err := conn.ReadJSON(&startMsg); err != nil {

View File

@ -512,6 +512,9 @@ func (h *meHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
}
// Delete all teams the user solely owns (no other members).
// Team deletion involves RPC calls (sandbox destruction) that cannot be
// transactional, so we do those first as best-effort, then wrap the
// DB-only cleanup in a transaction.
soleTeams, err := h.db.ListSoleOwnedTeams(ctx, ac.UserID)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to list owned teams")
@ -519,20 +522,36 @@ func (h *meHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
}
for _, teamID := range soleTeams {
if err := h.teamSvc.DeleteTeamInternal(ctx, teamID); err != nil {
slog.Warn("account delete: failed to delete sole-owned team",
"team_id", id.FormatTeamID(teamID), "error", err)
writeError(w, http.StatusInternalServerError, "db_error",
fmt.Sprintf("failed to delete sole-owned team %s", id.FormatTeamID(teamID)))
return
}
}
if err := h.db.DeleteAPIKeysByCreator(ctx, ac.UserID); err != nil {
slog.Warn("account delete: failed to delete user's API keys", "error", err)
tx, err := h.pool.Begin(ctx)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to start transaction")
return
}
defer tx.Rollback(ctx)
qtx := h.db.WithTx(tx)
if err := qtx.DeleteAPIKeysByCreator(ctx, ac.UserID); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete user's API keys")
return
}
if err := h.db.SoftDeleteUser(ctx, ac.UserID); err != nil {
if err := qtx.SoftDeleteUser(ctx, ac.UserID); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete account")
return
}
if err := tx.Commit(ctx); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to commit account deletion")
return
}
slog.Info("account soft-deleted", "user_id", id.FormatUserID(ac.UserID), "email", user.Email)
go func() {

View File

@ -20,12 +20,13 @@ import (
)
type processHandler struct {
db *db.Queries
pool *lifecycle.HostClientPool
db *db.Queries
pool *lifecycle.HostClientPool
jwtSecret []byte
}
func newProcessHandler(db *db.Queries, pool *lifecycle.HostClientPool) *processHandler {
return &processHandler{db: db, pool: pool}
func newProcessHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *processHandler {
return &processHandler{db: db, pool: pool, jwtSecret: jwtSecret}
}
// processResponse is a single entry in the process list.
@ -158,7 +159,6 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request)
sandboxIDStr := chi.URLParam(r, "id")
selectorStr := chi.URLParam(r, "selector")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
@ -166,19 +166,31 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request)
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
return
}
// Authenticate: use context from middleware (API key) or WS first message (JWT).
ac, hasAuth := auth.FromContext(ctx)
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
if !hasAuth {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("process stream websocket upgrade failed", "error", err)
return
}
defer conn.Close()
var wsAC auth.AuthContext
var authErr error
if isAdminWSRoute(ctx) {
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
} else {
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
}
if authErr != nil {
sendProcessWSError(conn, "authentication failed")
return
}
ac = wsAC
h.runConnectProcess(ctx, conn, ac, sandboxID, sandboxIDStr, selectorStr)
return
}
@ -189,6 +201,26 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request)
}
defer conn.Close()
h.runConnectProcess(ctx, conn, ac, sandboxID, sandboxIDStr, selectorStr)
}
func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr, selectorStr string) {
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
sendProcessWSError(conn, "sandbox not found")
return
}
if sb.Status != "running" {
sendProcessWSError(conn, "sandbox is not running (status: "+sb.Status+")")
return
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
sendProcessWSError(conn, "sandbox host is not reachable")
return
}
// Build the connect request with PID or tag selector.
connectReq := &pb.ConnectProcessRequest{
SandboxId: sandboxIDStr,

View File

@ -30,12 +30,13 @@ const (
)
type ptyHandler struct {
db *db.Queries
pool *lifecycle.HostClientPool
db *db.Queries
pool *lifecycle.HostClientPool
jwtSecret []byte
}
func newPtyHandler(db *db.Queries, pool *lifecycle.HostClientPool) *ptyHandler {
return &ptyHandler{db: db, pool: pool}
func newPtyHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *ptyHandler {
return &ptyHandler{db: db, pool: pool, jwtSecret: jwtSecret}
}
// --- WebSocket message types ---
@ -82,7 +83,6 @@ func (w *wsWriter) writeJSON(v any) {
func (h *ptyHandler) PtySession(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
@ -90,13 +90,34 @@ func (h *ptyHandler) PtySession(w http.ResponseWriter, r *http.Request) {
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
// API key auth is handled by middleware (sets context).
// For browser JWT auth, we authenticate after upgrade via first WS message.
ac, hasAuth := auth.FromContext(ctx)
if !hasAuth {
// No pre-upgrade auth — upgrade first, then authenticate via WS message.
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("pty websocket upgrade failed", "error", err)
return
}
defer conn.Close()
ws := &wsWriter{conn: conn}
var wsAC auth.AuthContext
if isAdminWSRoute(ctx) {
wsAC, err = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
} else {
wsAC, err = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
}
if err != nil {
ws.writeJSON(wsPtyOut{Type: "error", Data: "authentication failed", Fatal: true})
return
}
ac = wsAC
h.runPtySession(ctx, ws, conn, ac, sandboxID, sandboxIDStr)
return
}
@ -108,6 +129,19 @@ func (h *ptyHandler) PtySession(w http.ResponseWriter, r *http.Request) {
defer conn.Close()
ws := &wsWriter{conn: conn}
h.runPtySession(ctx, ws, conn, ac, sandboxID, sandboxIDStr)
}
func (h *ptyHandler) runPtySession(ctx context.Context, ws *wsWriter, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr string) {
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
ws.writeJSON(wsPtyOut{Type: "error", Data: "sandbox not found", Fatal: true})
return
}
if sb.Status != "running" {
ws.writeJSON(wsPtyOut{Type: "error", Data: "sandbox is not running (status: " + sb.Status + ")", Fatal: true})
return
}
// Read the first message to determine start vs connect.
var firstMsg wsPtyIn

View File

@ -133,7 +133,6 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ac := auth.MustFromContext(ctx)
overwrite := r.URL.Query().Get("overwrite") == "true"
// Check for global name collision.
if _, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
@ -142,20 +141,10 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
}
// Check if name already exists for this team.
if existing, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
if !overwrite {
writeError(w, http.StatusConflict, "already_exists", "snapshot name already exists; use ?overwrite=true to replace")
return
}
// Delete old snapshot files from all hosts before removing the DB record.
if err := deleteSnapshotBroadcast(ctx, h.db, h.pool, existing.TeamID, existing.ID); err != nil {
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files")
return
}
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to remove existing template record")
return
}
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
writeError(w, http.StatusConflict, "template_name_taken",
"snapshot name already exists; delete the existing snapshot first to reuse this name")
return
}
// Verify sandbox exists, belongs to team, and is running or paused.

109
internal/api/helpers_ws.go Normal file
View File

@ -0,0 +1,109 @@
package api
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/gorilla/websocket"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// isWebSocketUpgrade returns true if the request is a WebSocket upgrade.
func isWebSocketUpgrade(r *http.Request) bool {
return strings.EqualFold(r.Header.Get("Upgrade"), "websocket")
}
// ctxKeyAdminWS is a context key for flagging admin WS routes.
type ctxKeyAdminWS struct{}
// setAdminWSFlag marks the context as an admin WebSocket route.
func setAdminWSFlag(ctx context.Context) context.Context {
return context.WithValue(ctx, ctxKeyAdminWS{}, true)
}
// isAdminWSRoute checks if the request context was marked as admin WS.
func isAdminWSRoute(ctx context.Context) bool {
v, _ := ctx.Value(ctxKeyAdminWS{}).(bool)
return v
}
// wsAuthMsg is the first message a browser WS client sends to authenticate.
type wsAuthMsg struct {
Type string `json:"type"`
Token string `json:"token"`
}
// wsAuthenticate reads a JWT auth message from the WebSocket and returns the
// authenticated context. The caller must send this as the first message after
// connecting.
func wsAuthenticate(ctx context.Context, conn *websocket.Conn, jwtSecret []byte, queries *db.Queries) (auth.AuthContext, error) {
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
var msg wsAuthMsg
if err := conn.ReadJSON(&msg); err != nil {
return auth.AuthContext{}, fmt.Errorf("read auth message: %w", err)
}
conn.SetReadDeadline(time.Time{}) // clear deadline
if msg.Type != "auth" || msg.Token == "" {
return auth.AuthContext{}, fmt.Errorf("first message must be type 'auth' with a token")
}
claims, err := auth.VerifyJWT(jwtSecret, msg.Token)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("invalid or expired token: %w", err)
}
teamID, err := id.ParseTeamID(claims.TeamID)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("invalid team ID in token: %w", err)
}
userID, err := id.ParseUserID(claims.Subject)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("invalid user ID in token: %w", err)
}
user, err := queries.GetUserByID(ctx, userID)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("user not found")
}
if user.Status != "active" {
return auth.AuthContext{}, fmt.Errorf("account deactivated")
}
return auth.AuthContext{
TeamID: teamID,
UserID: userID,
Email: claims.Email,
Name: claims.Name,
Role: claims.Role,
}, nil
}
// wsAuthenticateAdmin performs WS-based auth and verifies admin status,
// returning an AuthContext with the platform team ID.
func wsAuthenticateAdmin(ctx context.Context, conn *websocket.Conn, jwtSecret []byte, queries *db.Queries) (auth.AuthContext, error) {
ac, err := wsAuthenticate(ctx, conn, jwtSecret, queries)
if err != nil {
return auth.AuthContext{}, err
}
user, err := queries.GetUserByID(ctx, ac.UserID)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("user not found")
}
if !user.IsAdmin {
return auth.AuthContext{}, fmt.Errorf("admin access required")
}
ac.TeamID = id.PlatformTeamID
return ac, nil
}

View File

@ -94,21 +94,25 @@ func serviceErrToHTTP(err error) (int, string, string) {
}
// Map well-known service error patterns.
// Return generic messages for most cases to avoid leaking internal details.
switch {
case strings.Contains(msg, "not found"):
return http.StatusNotFound, "not_found", msg
case strings.Contains(msg, "not running"), strings.Contains(msg, "not paused"):
return http.StatusConflict, "invalid_state", msg
return http.StatusNotFound, "not_found", "resource not found"
case strings.Contains(msg, "not running"):
return http.StatusConflict, "invalid_state", "resource is not running"
case strings.Contains(msg, "not paused"):
return http.StatusConflict, "invalid_state", "resource is not paused"
case strings.Contains(msg, "conflict:"):
return http.StatusConflict, "conflict", msg
return http.StatusConflict, "conflict", strings.TrimPrefix(msg, "conflict: ")
case strings.Contains(msg, "forbidden"):
return http.StatusForbidden, "forbidden", msg
return http.StatusForbidden, "forbidden", "forbidden"
case strings.Contains(msg, "invalid or expired"):
return http.StatusUnauthorized, "unauthorized", msg
return http.StatusUnauthorized, "unauthorized", "invalid or expired credentials"
case strings.Contains(msg, "invalid"):
return http.StatusBadRequest, "invalid_request", msg
return http.StatusBadRequest, "invalid_request", "invalid request"
default:
return http.StatusInternalServerError, "internal_error", msg
slog.Error("unhandled service error", "error", err)
return http.StatusInternalServerError, "internal_error", "an internal error occurred"
}
}

View File

@ -14,6 +14,11 @@ import (
func injectPlatformTeam() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, ok := auth.FromContext(r.Context()); !ok {
// No auth context yet (WS upgrade); handler will inject platform team after WS auth.
next.ServeHTTP(w, r)
return
}
ac := auth.MustFromContext(r.Context())
ac.TeamID = id.PlatformTeamID
ctx := auth.WithAuthContext(r.Context(), ac)
@ -26,11 +31,19 @@ func injectPlatformTeam() func(http.Handler) http.Handler {
// Must run after requireJWT (depends on AuthContext being present).
// Re-validates against the DB — the JWT is_admin claim is for UI only;
// the DB is the source of truth for admin access.
// WebSocket upgrade requests without auth context are passed through —
// admin WS handlers verify admin status after upgrade via wsAuthenticateAdmin.
func requireAdmin(queries *db.Queries) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ac, ok := auth.FromContext(r.Context())
if !ok {
if isWebSocketUpgrade(r) {
ctx := r.Context()
ctx = setAdminWSFlag(ctx)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
writeError(w, http.StatusUnauthorized, "unauthorized", "authentication required")
return
}

View File

@ -38,12 +38,10 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
return
}
// Try JWT bearer token (header or query param for WebSocket).
// Try JWT bearer token from Authorization header.
tokenStr := ""
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
tokenStr = strings.TrimPrefix(header, "Bearer ")
} else if t := r.URL.Query().Get("token"); t != "" {
tokenStr = t
}
if tokenStr != "" {
claims, err := auth.VerifyJWT(jwtSecret, tokenStr)
@ -87,7 +85,15 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
return
}
// WebSocket upgrade requests may not carry auth headers (browsers
// cannot set custom headers on WS connections). Pass through —
// the WS handler authenticates via the first message after upgrade.
if isWebSocketUpgrade(r) {
next.ServeHTTP(w, r)
return
}
writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key or Authorization: Bearer <token> required")
})
}
}
}

View File

@ -10,19 +10,25 @@ import (
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// requireJWT validates a JWT from the Authorization: Bearer header or the
// ?token= query parameter (for WebSocket connections that cannot send headers).
// requireJWT validates a JWT from the Authorization: Bearer header.
// It also verifies the user is still active in the database.
// WebSocket upgrade requests without an Authorization header are passed through
// — WS handlers authenticate via the first message after upgrade.
func requireJWT(secret []byte, queries *db.Queries) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var tokenStr string
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
tokenStr = strings.TrimPrefix(header, "Bearer ")
} else if t := r.URL.Query().Get("token"); t != "" {
tokenStr = t
}
if tokenStr == "" {
// WebSocket upgrade requests may not have an Authorization header
// (browsers cannot set custom headers on WS connections). Let them
// through — the handler authenticates via the first WS message.
if isWebSocketUpgrade(r) {
next.ServeHTTP(w, r)
return
}
writeError(w, http.StatusUnauthorized, "unauthorized", "Authorization: Bearer <token> required")
return
}

View File

@ -65,7 +65,7 @@ func New(
sandbox := newSandboxHandler(sandboxSvc, al)
exec := newExecHandler(queries, pool)
execStream := newExecStreamHandler(queries, pool)
execStream := newExecStreamHandler(queries, pool, jwtSecret)
files := newFilesHandler(queries, pool)
filesStream := newFilesStreamHandler(queries, pool)
fsH := newFSHandler(queries, pool)
@ -81,8 +81,8 @@ func New(
metricsH := newSandboxMetricsHandler(queries, pool)
buildH := newBuildHandler(buildSvc, queries, pool)
channelH := newChannelHandler(channelSvc, al)
ptyH := newPtyHandler(queries, pool)
processH := newProcessHandler(queries, pool)
ptyH := newPtyHandler(queries, pool, jwtSecret)
processH := newProcessHandler(queries, pool, jwtSecret)
adminCapsules := newAdminCapsuleHandler(sandboxSvc, queries, pool, al)
meH := newMeHandler(queries, pgPool, rdb, jwtSecret, mailer, oauthRegistry, oauthRedirectURL, teamSvc)
@ -144,6 +144,8 @@ func New(
r.With(requireJWT(jwtSecret, queries)).Get("/v1/users/search", usersH.Search)
// Capsule lifecycle: accepts API key or JWT bearer token.
// WebSocket upgrade requests without auth headers are passed through by
// requireAPIKeyOrJWT — the WS handlers authenticate via first message.
r.Route("/v1/capsules", func(r chi.Router) {
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
r.Post("/", sandbox.Create)

View File

@ -131,26 +131,31 @@ type Slot struct {
}
// NewSlot computes the addressing for the given slot index (1-based).
// Index must be in [1, 32767] so that veth offset (index*2) fits in 16 bits.
func NewSlot(index int) *Slot {
if index < 1 || index > 32767 {
panic(fmt.Sprintf("slot index %d out of range [1, 32767]", index))
}
hostBaseIP := net.ParseIP(hostBase).To4()
vrtBaseIP := net.ParseIP(vrtBase).To4()
hostIP := make(net.IP, 4)
copy(hostIP, hostBaseIP)
hostIP[2] += byte(index >> 8)
hostIP[3] += byte(index & 0xFF)
hostIP[2] = hostBaseIP[2] + byte(index>>8)
hostIP[3] = hostBaseIP[3] + byte(index&0xFF)
vethOffset := index * vrtAddressesPerSlot
vethIP := make(net.IP, 4)
copy(vethIP, vrtBaseIP)
vethIP[2] += byte(vethOffset >> 8)
vethIP[3] += byte(vethOffset & 0xFF)
vethIP[2] = vrtBaseIP[2] + byte(vethOffset>>8)
vethIP[3] = vrtBaseIP[3] + byte(vethOffset&0xFF)
vpeerOffset := vethOffset + 1
vpeerIP := make(net.IP, 4)
copy(vpeerIP, vrtBaseIP)
vpeerIP[2] += byte(vpeerOffset >> 8)
vpeerIP[3] += byte(vpeerOffset & 0xFF)
vpeerIP[2] = vrtBaseIP[2] + byte(vpeerOffset>>8)
vpeerIP[3] = vrtBaseIP[3] + byte(vpeerOffset&0xFF)
return &Slot{
Index: index,

View File

@ -5,6 +5,7 @@ import (
"fmt"
"log/slog"
"os"
"sync"
"time"
)
@ -17,6 +18,7 @@ type VM struct {
// Manager handles the lifecycle of Firecracker microVMs.
type Manager struct {
mu sync.RWMutex
// vms tracks running VMs by sandbox ID.
vms map[string]*VM
}
@ -84,7 +86,9 @@ func (m *Manager) Create(ctx context.Context, cfg VMConfig) (*VM, error) {
client: client,
}
m.mu.Lock()
m.vms[cfg.SandboxID] = vm
m.mu.Unlock()
slog.Info("VM started successfully", "sandbox", cfg.SandboxID)
@ -126,7 +130,9 @@ func configureVM(ctx context.Context, client *fcClient, cfg *VMConfig) error {
// Pause pauses a running VM.
func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
m.mu.RLock()
vm, ok := m.vms[sandboxID]
m.mu.RUnlock()
if !ok {
return fmt.Errorf("VM not found: %s", sandboxID)
}
@ -141,7 +147,9 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
// Resume resumes a paused VM.
func (m *Manager) Resume(ctx context.Context, sandboxID string) error {
m.mu.RLock()
vm, ok := m.vms[sandboxID]
m.mu.RUnlock()
if !ok {
return fmt.Errorf("VM not found: %s", sandboxID)
}
@ -156,10 +164,14 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string) error {
// Destroy stops and cleans up a VM.
func (m *Manager) Destroy(ctx context.Context, sandboxID string) error {
m.mu.Lock()
vm, ok := m.vms[sandboxID]
if !ok {
m.mu.Unlock()
return fmt.Errorf("VM not found: %s", sandboxID)
}
delete(m.vms, sandboxID)
m.mu.Unlock()
slog.Info("destroying VM", "sandbox", sandboxID)
@ -171,8 +183,6 @@ func (m *Manager) Destroy(ctx context.Context, sandboxID string) error {
// Clean up the API socket.
os.Remove(vm.Config.SocketPath)
delete(m.vms, sandboxID)
slog.Info("VM destroyed", "sandbox", sandboxID)
return nil
}
@ -180,7 +190,9 @@ func (m *Manager) Destroy(ctx context.Context, sandboxID string) error {
// Snapshot creates a VM snapshot. The VM must already be paused.
// snapshotType is "Full" (all memory) or "Diff" (only dirty pages since last resume).
func (m *Manager) Snapshot(ctx context.Context, sandboxID, snapPath, memPath, snapshotType string) error {
m.mu.RLock()
vm, ok := m.vms[sandboxID]
m.mu.RUnlock()
if !ok {
return fmt.Errorf("VM not found: %s", sandboxID)
}
@ -263,7 +275,9 @@ func (m *Manager) CreateFromSnapshot(ctx context.Context, cfg VMConfig, snapPath
client: client,
}
m.mu.Lock()
m.vms[cfg.SandboxID] = vm
m.mu.Unlock()
slog.Info("VM restored from snapshot", "sandbox", cfg.SandboxID)
return vm, nil
@ -277,7 +291,9 @@ func (v *VM) PID() int {
// Get returns a running VM by sandbox ID.
func (m *Manager) Get(sandboxID string) (*VM, bool) {
m.mu.RLock()
vm, ok := m.vms[sandboxID]
m.mu.RUnlock()
return vm, ok
}