forked from wrenn/wrenn
refactor: eliminate DRY violations across control plane and host agent
Extract shared helpers to consolidate repeated patterns: - requireRunningSandbox: sandbox lookup + running check (10 call sites) - upgradeAndAuthenticate: WS upgrade + JWT/API-key auth (3 handlers) - updateLastActive: last_active_at update with background context (5 sites) - attachCowAndCreate: cow loop attach + dmsetup create (devicemapper) - issueRegistrationToken: token gen + Redis + audit (host service) - ErrNotFound sentinel: replaces string matching in hostagent server Also merges duplicate wsProcessOut/wsOutMsg types into one. Net: -208 lines, zero behavior change.
This commit is contained in:
@ -3,10 +3,17 @@ package api
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
"github.com/jackc/pgx/v5/pgtype"
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||||
|
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||||
"git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect"
|
"git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect"
|
||||||
)
|
)
|
||||||
@ -20,3 +27,82 @@ func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.Host
|
|||||||
}
|
}
|
||||||
return pool.GetForHost(host)
|
return pool.GetForHost(host)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// requireRunningSandbox parses the sandbox ID from the URL, looks it up by team,
|
||||||
|
// and verifies it is running. On failure it writes the appropriate HTTP error and
|
||||||
|
// returns false.
|
||||||
|
func requireRunningSandbox(w http.ResponseWriter, r *http.Request, queries *db.Queries, teamID pgtype.UUID) (db.Sandbox, pgtype.UUID, string, bool) {
|
||||||
|
sandboxIDStr := chi.URLParam(r, "id")
|
||||||
|
ctx := r.Context()
|
||||||
|
|
||||||
|
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||||
|
return db.Sandbox{}, pgtype.UUID{}, "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
sb, err := queries.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||||
|
return db.Sandbox{}, pgtype.UUID{}, "", false
|
||||||
|
}
|
||||||
|
if sb.Status != "running" {
|
||||||
|
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||||
|
return db.Sandbox{}, pgtype.UUID{}, "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb, sandboxID, sandboxIDStr, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// upgradeAndAuthenticate upgrades the HTTP connection to WebSocket and resolves
|
||||||
|
// the auth context — either from middleware (API key) or from the first WS message (JWT).
|
||||||
|
// Returns the connection and auth context, or an error if authentication fails.
|
||||||
|
// The caller is responsible for closing the returned connection.
|
||||||
|
func upgradeAndAuthenticate(w http.ResponseWriter, r *http.Request, jwtSecret []byte, queries *db.Queries) (*websocket.Conn, auth.AuthContext, error) {
|
||||||
|
ctx := r.Context()
|
||||||
|
ac, hasAuth := auth.FromContext(ctx)
|
||||||
|
|
||||||
|
if hasAuth {
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, auth.AuthContext{}, fmt.Errorf("websocket upgrade: %w", err)
|
||||||
|
}
|
||||||
|
return conn, ac, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, auth.AuthContext{}, fmt.Errorf("websocket upgrade: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var wsAC auth.AuthContext
|
||||||
|
var authErr error
|
||||||
|
if isAdminWSRoute(ctx) {
|
||||||
|
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, jwtSecret, queries)
|
||||||
|
} else {
|
||||||
|
wsAC, authErr = wsAuthenticate(ctx, conn, jwtSecret, queries)
|
||||||
|
}
|
||||||
|
if authErr != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, auth.AuthContext{}, fmt.Errorf("authentication failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, wsAC, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateLastActive updates the sandbox last_active_at timestamp.
|
||||||
|
// Uses a background context with timeout for streaming handlers where
|
||||||
|
// the request context may already be cancelled.
|
||||||
|
func updateLastActive(queries *db.Queries, sandboxID pgtype.UUID, sandboxIDStr string) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := queries.UpdateLastActive(ctx, db.UpdateLastActiveParams{
|
||||||
|
ID: sandboxID,
|
||||||
|
LastActiveAt: pgtype.Timestamptz{
|
||||||
|
Time: time.Now(),
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -3,14 +3,11 @@ package api
|
|||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"log/slog"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
"connectrpc.com/connect"
|
"connectrpc.com/connect"
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
"github.com/jackc/pgx/v5/pgtype"
|
|
||||||
|
|
||||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||||
@ -58,23 +55,11 @@ type backgroundExecResponse struct {
|
|||||||
|
|
||||||
// Exec handles POST /v1/capsules/{id}/exec.
|
// Exec handles POST /v1/capsules/{id}/exec.
|
||||||
func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||||
sandboxIDStr := chi.URLParam(r, "id")
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
ac := auth.MustFromContext(ctx)
|
ac := auth.MustFromContext(ctx)
|
||||||
|
|
||||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
sb, sandboxID, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||||
if err != nil {
|
if !ok {
|
||||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
|
||||||
if err != nil {
|
|
||||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if sb.Status != "running" {
|
|
||||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -116,15 +101,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.UpdateLastActive(ctx, db.UpdateLastActiveParams{
|
updateLastActive(h.db, sandboxID, sandboxIDStr)
|
||||||
ID: sandboxID,
|
|
||||||
LastActiveAt: pgtype.Timestamptz{
|
|
||||||
Time: time.Now(),
|
|
||||||
Valid: true,
|
|
||||||
},
|
|
||||||
}); err != nil {
|
|
||||||
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
writeJSON(w, http.StatusAccepted, backgroundExecResponse{
|
writeJSON(w, http.StatusAccepted, backgroundExecResponse{
|
||||||
SandboxID: sandboxIDStr,
|
SandboxID: sandboxIDStr,
|
||||||
@ -151,16 +128,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
duration := time.Since(start)
|
duration := time.Since(start)
|
||||||
|
|
||||||
// Update last active.
|
updateLastActive(h.db, sandboxID, sandboxIDStr)
|
||||||
if err := h.db.UpdateLastActive(ctx, db.UpdateLastActiveParams{
|
|
||||||
ID: sandboxID,
|
|
||||||
LastActiveAt: pgtype.Timestamptz{
|
|
||||||
Time: time.Now(),
|
|
||||||
Valid: true,
|
|
||||||
},
|
|
||||||
}); err != nil {
|
|
||||||
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use base64 encoding if output contains non-UTF-8 bytes.
|
// Use base64 encoding if output contains non-UTF-8 bytes.
|
||||||
stdout := resp.Msg.Stdout
|
stdout := resp.Msg.Stdout
|
||||||
|
|||||||
@ -5,7 +5,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
|
|
||||||
"connectrpc.com/connect"
|
"connectrpc.com/connect"
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
@ -59,37 +58,9 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authenticate: use context from middleware (API key) or WS first message (JWT).
|
conn, ac, err := upgradeAndAuthenticate(w, r, h.jwtSecret, h.db)
|
||||||
ac, hasAuth := auth.FromContext(ctx)
|
|
||||||
|
|
||||||
if !hasAuth {
|
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("websocket upgrade failed", "error", err)
|
slog.Error("websocket upgrade/auth failed", "error", err)
|
||||||
return
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
var wsAC auth.AuthContext
|
|
||||||
var authErr error
|
|
||||||
if isAdminWSRoute(ctx) {
|
|
||||||
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
|
|
||||||
} else {
|
|
||||||
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
|
|
||||||
}
|
|
||||||
if authErr != nil {
|
|
||||||
sendWSError(conn, "authentication failed")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ac = wsAC
|
|
||||||
|
|
||||||
h.runExecStream(ctx, conn, ac, sandboxID, sandboxIDStr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("websocket upgrade failed", "error", err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
@ -186,18 +157,7 @@ func (h *execStreamHandler) runExecStream(ctx context.Context, conn *websocket.C
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update last active using a fresh context (the request context may be cancelled).
|
updateLastActive(h.db, sandboxID, sandboxIDStr)
|
||||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer updateCancel()
|
|
||||||
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
|
|
||||||
ID: sandboxID,
|
|
||||||
LastActiveAt: pgtype.Timestamptz{
|
|
||||||
Time: time.Now(),
|
|
||||||
Valid: true,
|
|
||||||
},
|
|
||||||
}); err != nil {
|
|
||||||
slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxIDStr, "error", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func sendWSError(conn *websocket.Conn, msg string) {
|
func sendWSError(conn *websocket.Conn, msg string) {
|
||||||
|
|||||||
@ -7,11 +7,9 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"connectrpc.com/connect"
|
"connectrpc.com/connect"
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
|
|
||||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
|
||||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||||
)
|
)
|
||||||
@ -30,23 +28,11 @@ func newFilesHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesHandl
|
|||||||
// - "path" text field: absolute destination path inside the sandbox
|
// - "path" text field: absolute destination path inside the sandbox
|
||||||
// - "file" file field: binary content to write
|
// - "file" file field: binary content to write
|
||||||
func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
|
func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
|
||||||
sandboxIDStr := chi.URLParam(r, "id")
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
ac := auth.MustFromContext(ctx)
|
ac := auth.MustFromContext(ctx)
|
||||||
|
|
||||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||||
if err != nil {
|
if !ok {
|
||||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
|
||||||
if err != nil {
|
|
||||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if sb.Status != "running" {
|
|
||||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -108,23 +94,11 @@ type readFileRequest struct {
|
|||||||
// Download handles POST /v1/capsules/{id}/files/read.
|
// Download handles POST /v1/capsules/{id}/files/read.
|
||||||
// Accepts JSON body with path, returns raw file content with Content-Disposition.
|
// Accepts JSON body with path, returns raw file content with Content-Disposition.
|
||||||
func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
|
func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
|
||||||
sandboxIDStr := chi.URLParam(r, "id")
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
ac := auth.MustFromContext(ctx)
|
ac := auth.MustFromContext(ctx)
|
||||||
|
|
||||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||||
if err != nil {
|
if !ok {
|
||||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
|
||||||
if err != nil {
|
|
||||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if sb.Status != "running" {
|
|
||||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -8,11 +8,9 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"connectrpc.com/connect"
|
"connectrpc.com/connect"
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
|
|
||||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
|
||||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||||
)
|
)
|
||||||
@ -30,23 +28,11 @@ func newFilesStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *file
|
|||||||
// Expects multipart/form-data with "path" text field and "file" file field.
|
// Expects multipart/form-data with "path" text field and "file" file field.
|
||||||
// Streams file content directly from the request body to the host agent without buffering.
|
// Streams file content directly from the request body to the host agent without buffering.
|
||||||
func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) {
|
func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) {
|
||||||
sandboxIDStr := chi.URLParam(r, "id")
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
ac := auth.MustFromContext(ctx)
|
ac := auth.MustFromContext(ctx)
|
||||||
|
|
||||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||||
if err != nil {
|
if !ok {
|
||||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
|
||||||
if err != nil {
|
|
||||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if sb.Status != "running" {
|
|
||||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -153,23 +139,11 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
|
|||||||
// StreamDownload handles POST /v1/capsules/{id}/files/stream/read.
|
// StreamDownload handles POST /v1/capsules/{id}/files/stream/read.
|
||||||
// Accepts JSON body with path, streams file content back without buffering.
|
// Accepts JSON body with path, streams file content back without buffering.
|
||||||
func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) {
|
func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) {
|
||||||
sandboxIDStr := chi.URLParam(r, "id")
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
ac := auth.MustFromContext(ctx)
|
ac := auth.MustFromContext(ctx)
|
||||||
|
|
||||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||||
if err != nil {
|
if !ok {
|
||||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
|
||||||
if err != nil {
|
|
||||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if sb.Status != "running" {
|
|
||||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -4,11 +4,9 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"connectrpc.com/connect"
|
"connectrpc.com/connect"
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
|
|
||||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
|
||||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||||
)
|
)
|
||||||
@ -58,23 +56,11 @@ type removeRequest struct {
|
|||||||
|
|
||||||
// ListDir handles POST /v1/capsules/{id}/files/list.
|
// ListDir handles POST /v1/capsules/{id}/files/list.
|
||||||
func (h *fsHandler) ListDir(w http.ResponseWriter, r *http.Request) {
|
func (h *fsHandler) ListDir(w http.ResponseWriter, r *http.Request) {
|
||||||
sandboxIDStr := chi.URLParam(r, "id")
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
ac := auth.MustFromContext(ctx)
|
ac := auth.MustFromContext(ctx)
|
||||||
|
|
||||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||||
if err != nil {
|
if !ok {
|
||||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
|
||||||
if err != nil {
|
|
||||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if sb.Status != "running" {
|
|
||||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -115,23 +101,11 @@ func (h *fsHandler) ListDir(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// MakeDir handles POST /v1/capsules/{id}/files/mkdir.
|
// MakeDir handles POST /v1/capsules/{id}/files/mkdir.
|
||||||
func (h *fsHandler) MakeDir(w http.ResponseWriter, r *http.Request) {
|
func (h *fsHandler) MakeDir(w http.ResponseWriter, r *http.Request) {
|
||||||
sandboxIDStr := chi.URLParam(r, "id")
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
ac := auth.MustFromContext(ctx)
|
ac := auth.MustFromContext(ctx)
|
||||||
|
|
||||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||||
if err != nil {
|
if !ok {
|
||||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
|
||||||
if err != nil {
|
|
||||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if sb.Status != "running" {
|
|
||||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -166,23 +140,11 @@ func (h *fsHandler) MakeDir(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// Remove handles POST /v1/capsules/{id}/files/remove.
|
// Remove handles POST /v1/capsules/{id}/files/remove.
|
||||||
func (h *fsHandler) Remove(w http.ResponseWriter, r *http.Request) {
|
func (h *fsHandler) Remove(w http.ResponseWriter, r *http.Request) {
|
||||||
sandboxIDStr := chi.URLParam(r, "id")
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
ac := auth.MustFromContext(ctx)
|
ac := auth.MustFromContext(ctx)
|
||||||
|
|
||||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||||
if err != nil {
|
if !ok {
|
||||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
|
||||||
if err != nil {
|
|
||||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if sb.Status != "running" {
|
|
||||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,6 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
|
||||||
|
|
||||||
"connectrpc.com/connect"
|
"connectrpc.com/connect"
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
@ -44,23 +43,11 @@ type processListResponse struct {
|
|||||||
|
|
||||||
// ListProcesses handles GET /v1/capsules/{id}/processes.
|
// ListProcesses handles GET /v1/capsules/{id}/processes.
|
||||||
func (h *processHandler) ListProcesses(w http.ResponseWriter, r *http.Request) {
|
func (h *processHandler) ListProcesses(w http.ResponseWriter, r *http.Request) {
|
||||||
sandboxIDStr := chi.URLParam(r, "id")
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
ac := auth.MustFromContext(ctx)
|
ac := auth.MustFromContext(ctx)
|
||||||
|
|
||||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||||
if err != nil {
|
if !ok {
|
||||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
|
||||||
if err != nil {
|
|
||||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if sb.Status != "running" {
|
|
||||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -95,24 +82,12 @@ func (h *processHandler) ListProcesses(w http.ResponseWriter, r *http.Request) {
|
|||||||
// KillProcess handles DELETE /v1/capsules/{id}/processes/{selector}.
|
// KillProcess handles DELETE /v1/capsules/{id}/processes/{selector}.
|
||||||
// The selector can be a numeric PID or a string tag.
|
// The selector can be a numeric PID or a string tag.
|
||||||
func (h *processHandler) KillProcess(w http.ResponseWriter, r *http.Request) {
|
func (h *processHandler) KillProcess(w http.ResponseWriter, r *http.Request) {
|
||||||
sandboxIDStr := chi.URLParam(r, "id")
|
|
||||||
selectorStr := chi.URLParam(r, "selector")
|
selectorStr := chi.URLParam(r, "selector")
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
ac := auth.MustFromContext(ctx)
|
ac := auth.MustFromContext(ctx)
|
||||||
|
|
||||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||||
if err != nil {
|
if !ok {
|
||||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
|
||||||
if err != nil {
|
|
||||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if sb.Status != "running" {
|
|
||||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -146,14 +121,6 @@ func (h *processHandler) KillProcess(w http.ResponseWriter, r *http.Request) {
|
|||||||
w.WriteHeader(http.StatusNoContent)
|
w.WriteHeader(http.StatusNoContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
// wsProcessOut is the JSON message sent to the WebSocket client.
|
|
||||||
type wsProcessOut struct {
|
|
||||||
Type string `json:"type"` // "start", "stdout", "stderr", "exit", "error"
|
|
||||||
PID uint32 `json:"pid,omitempty"` // only for "start"
|
|
||||||
Data string `json:"data,omitempty"` // only for "stdout", "stderr", "error"
|
|
||||||
ExitCode *int32 `json:"exit_code,omitempty"` // only for "exit"
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConnectProcess handles WS /v1/capsules/{id}/processes/{selector}/stream.
|
// ConnectProcess handles WS /v1/capsules/{id}/processes/{selector}/stream.
|
||||||
func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request) {
|
func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request) {
|
||||||
sandboxIDStr := chi.URLParam(r, "id")
|
sandboxIDStr := chi.URLParam(r, "id")
|
||||||
@ -166,37 +133,9 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authenticate: use context from middleware (API key) or WS first message (JWT).
|
conn, ac, err := upgradeAndAuthenticate(w, r, h.jwtSecret, h.db)
|
||||||
ac, hasAuth := auth.FromContext(ctx)
|
|
||||||
|
|
||||||
if !hasAuth {
|
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("process stream websocket upgrade failed", "error", err)
|
slog.Error("process stream websocket upgrade/auth failed", "error", err)
|
||||||
return
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
var wsAC auth.AuthContext
|
|
||||||
var authErr error
|
|
||||||
if isAdminWSRoute(ctx) {
|
|
||||||
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
|
|
||||||
} else {
|
|
||||||
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
|
|
||||||
}
|
|
||||||
if authErr != nil {
|
|
||||||
sendProcessWSError(conn, "authentication failed")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ac = wsAC
|
|
||||||
|
|
||||||
h.runConnectProcess(ctx, conn, ac, sandboxID, sandboxIDStr, selectorStr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("process stream websocket upgrade failed", "error", err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
@ -207,17 +146,17 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request)
|
|||||||
func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr, selectorStr string) {
|
func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr, selectorStr string) {
|
||||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sendProcessWSError(conn, "sandbox not found")
|
sendWSError(conn, "sandbox not found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if sb.Status != "running" {
|
if sb.Status != "running" {
|
||||||
sendProcessWSError(conn, "sandbox is not running (status: "+sb.Status+")")
|
sendWSError(conn, "sandbox is not running (status: "+sb.Status+")")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sendProcessWSError(conn, "sandbox host is not reachable")
|
sendWSError(conn, "sandbox host is not reachable")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -236,7 +175,7 @@ func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.
|
|||||||
|
|
||||||
stream, err := agent.ConnectProcess(streamCtx, connect.NewRequest(connectReq))
|
stream, err := agent.ConnectProcess(streamCtx, connect.NewRequest(connectReq))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sendProcessWSError(conn, "failed to connect to process: "+err.Error())
|
sendWSError(conn, "failed to connect to process: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
@ -257,42 +196,27 @@ func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.
|
|||||||
resp := stream.Msg()
|
resp := stream.Msg()
|
||||||
switch ev := resp.Event.(type) {
|
switch ev := resp.Event.(type) {
|
||||||
case *pb.ConnectProcessResponse_Start:
|
case *pb.ConnectProcessResponse_Start:
|
||||||
writeWSJSON(conn, wsProcessOut{Type: "start", PID: ev.Start.Pid})
|
writeWSJSON(conn, wsOutMsg{Type: "start", PID: ev.Start.Pid})
|
||||||
|
|
||||||
case *pb.ConnectProcessResponse_Data:
|
case *pb.ConnectProcessResponse_Data:
|
||||||
switch o := ev.Data.Output.(type) {
|
switch o := ev.Data.Output.(type) {
|
||||||
case *pb.ExecStreamData_Stdout:
|
case *pb.ExecStreamData_Stdout:
|
||||||
writeWSJSON(conn, wsProcessOut{Type: "stdout", Data: string(o.Stdout)})
|
writeWSJSON(conn, wsOutMsg{Type: "stdout", Data: string(o.Stdout)})
|
||||||
case *pb.ExecStreamData_Stderr:
|
case *pb.ExecStreamData_Stderr:
|
||||||
writeWSJSON(conn, wsProcessOut{Type: "stderr", Data: string(o.Stderr)})
|
writeWSJSON(conn, wsOutMsg{Type: "stderr", Data: string(o.Stderr)})
|
||||||
}
|
}
|
||||||
|
|
||||||
case *pb.ConnectProcessResponse_End:
|
case *pb.ConnectProcessResponse_End:
|
||||||
exitCode := ev.End.ExitCode
|
exitCode := ev.End.ExitCode
|
||||||
writeWSJSON(conn, wsProcessOut{Type: "exit", ExitCode: &exitCode})
|
writeWSJSON(conn, wsOutMsg{Type: "exit", ExitCode: &exitCode})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := stream.Err(); err != nil {
|
if err := stream.Err(); err != nil {
|
||||||
if streamCtx.Err() == nil {
|
if streamCtx.Err() == nil {
|
||||||
sendProcessWSError(conn, err.Error())
|
sendWSError(conn, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update last active using a fresh context.
|
updateLastActive(h.db, sandboxID, sandboxIDStr)
|
||||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer updateCancel()
|
|
||||||
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
|
|
||||||
ID: sandboxID,
|
|
||||||
LastActiveAt: pgtype.Timestamptz{
|
|
||||||
Time: time.Now(),
|
|
||||||
Valid: true,
|
|
||||||
},
|
|
||||||
}); err != nil {
|
|
||||||
slog.Warn("failed to update last active after process stream", "sandbox_id", sandboxIDStr, "error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func sendProcessWSError(conn *websocket.Conn, msg string) {
|
|
||||||
writeWSJSON(conn, wsProcessOut{Type: "error", Data: msg})
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -90,40 +90,9 @@ func (h *ptyHandler) PtySession(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// API key auth is handled by middleware (sets context).
|
conn, ac, err := upgradeAndAuthenticate(w, r, h.jwtSecret, h.db)
|
||||||
// For browser JWT auth, we authenticate after upgrade via first WS message.
|
|
||||||
ac, hasAuth := auth.FromContext(ctx)
|
|
||||||
|
|
||||||
if !hasAuth {
|
|
||||||
// No pre-upgrade auth — upgrade first, then authenticate via WS message.
|
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("pty websocket upgrade failed", "error", err)
|
slog.Error("pty websocket upgrade/auth failed", "error", err)
|
||||||
return
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
ws := &wsWriter{conn: conn}
|
|
||||||
|
|
||||||
var wsAC auth.AuthContext
|
|
||||||
if isAdminWSRoute(ctx) {
|
|
||||||
wsAC, err = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
|
|
||||||
} else {
|
|
||||||
wsAC, err = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "authentication failed", Fatal: true})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ac = wsAC
|
|
||||||
|
|
||||||
h.runPtySession(ctx, ws, conn, ac, sandboxID, sandboxIDStr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("pty websocket upgrade failed", "error", err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
@ -168,18 +137,7 @@ func (h *ptyHandler) runPtySession(ctx context.Context, ws *wsWriter, conn *webs
|
|||||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "first message must be type 'start' or 'connect'", Fatal: true})
|
ws.writeJSON(wsPtyOut{Type: "error", Data: "first message must be type 'start' or 'connect'", Fatal: true})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update last active using a fresh context.
|
updateLastActive(h.db, sandboxID, sandboxIDStr)
|
||||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer updateCancel()
|
|
||||||
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
|
|
||||||
ID: sandboxID,
|
|
||||||
LastActiveAt: pgtype.Timestamptz{
|
|
||||||
Time: time.Now(),
|
|
||||||
Valid: true,
|
|
||||||
},
|
|
||||||
}); err != nil {
|
|
||||||
slog.Warn("failed to update last active after pty session", "sandbox_id", sandboxIDStr, "error", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ptyHandler) handleStart(
|
func (h *ptyHandler) handleStart(
|
||||||
|
|||||||
@ -109,6 +109,31 @@ type SnapshotDevice struct {
|
|||||||
CowLoopDev string // loop device for the CoW file
|
CowLoopDev string // loop device for the CoW file
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// attachCowAndCreate attaches a CoW file as a loop device, creates the
|
||||||
|
// dm-snapshot target, and returns the assembled SnapshotDevice. On failure
|
||||||
|
// it detaches the CoW loop device before returning.
|
||||||
|
func attachCowAndCreate(name, originLoopDev, cowPath string, originSizeBytes int64) (*SnapshotDevice, error) {
|
||||||
|
cowLoopDev, err := losetupCreateRW(cowPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("losetup cow: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sectors := originSizeBytes / 512
|
||||||
|
if err := dmsetupCreate(name, originLoopDev, cowLoopDev, sectors); err != nil {
|
||||||
|
if detachErr := losetupDetachRetry(cowLoopDev); detachErr != nil {
|
||||||
|
slog.Error("cow losetup detach failed during cleanup, loop device leaked", "device", cowLoopDev, "error", detachErr)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("dmsetup create: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &SnapshotDevice{
|
||||||
|
Name: name,
|
||||||
|
DevicePath: "/dev/mapper/" + name,
|
||||||
|
CowPath: cowPath,
|
||||||
|
CowLoopDev: cowLoopDev,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// CreateSnapshot sets up a new dm-snapshot device.
|
// CreateSnapshot sets up a new dm-snapshot device.
|
||||||
//
|
//
|
||||||
// It creates a sparse CoW file, attaches it as a loop device, and creates
|
// It creates a sparse CoW file, attaches it as a loop device, and creates
|
||||||
@ -117,45 +142,24 @@ type SnapshotDevice struct {
|
|||||||
//
|
//
|
||||||
// The origin loop device must already exist (from LoopRegistry.Acquire).
|
// The origin loop device must already exist (from LoopRegistry.Acquire).
|
||||||
func CreateSnapshot(name, originLoopDev, cowPath string, originSizeBytes, cowSizeBytes int64) (*SnapshotDevice, error) {
|
func CreateSnapshot(name, originLoopDev, cowPath string, originSizeBytes, cowSizeBytes int64) (*SnapshotDevice, error) {
|
||||||
// Create sparse CoW file. The logical size limits how many blocks can be
|
|
||||||
// modified; because the file is sparse, only written blocks use real disk.
|
|
||||||
if err := createSparseFile(cowPath, cowSizeBytes); err != nil {
|
if err := createSparseFile(cowPath, cowSizeBytes); err != nil {
|
||||||
return nil, fmt.Errorf("create cow file: %w", err)
|
return nil, fmt.Errorf("create cow file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cowLoopDev, err := losetupCreateRW(cowPath)
|
dev, err := attachCowAndCreate(name, originLoopDev, cowPath, originSizeBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
os.Remove(cowPath)
|
os.Remove(cowPath)
|
||||||
return nil, fmt.Errorf("losetup cow: %w", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// The dm-snapshot virtual device size must match the origin — the snapshot
|
|
||||||
// target maps 1:1 onto origin sectors. The CoW file just needs enough
|
|
||||||
// space to store all modified blocks (it's sparse, so 20GB costs nothing).
|
|
||||||
sectors := originSizeBytes / 512
|
|
||||||
if err := dmsetupCreate(name, originLoopDev, cowLoopDev, sectors); err != nil {
|
|
||||||
if detachErr := losetupDetachRetry(cowLoopDev); detachErr != nil {
|
|
||||||
slog.Error("cow losetup detach failed during cleanup, loop device leaked", "device", cowLoopDev, "error", detachErr)
|
|
||||||
}
|
|
||||||
os.Remove(cowPath)
|
|
||||||
return nil, fmt.Errorf("dmsetup create: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
devPath := "/dev/mapper/" + name
|
|
||||||
|
|
||||||
slog.Info("dm-snapshot created",
|
slog.Info("dm-snapshot created",
|
||||||
"name", name,
|
"name", name,
|
||||||
"device", devPath,
|
"device", dev.DevicePath,
|
||||||
"origin", originLoopDev,
|
"origin", originLoopDev,
|
||||||
"cow", cowPath,
|
"cow", cowPath,
|
||||||
)
|
)
|
||||||
|
|
||||||
return &SnapshotDevice{
|
return dev, nil
|
||||||
Name: name,
|
|
||||||
DevicePath: devPath,
|
|
||||||
CowPath: cowPath,
|
|
||||||
CowLoopDev: cowLoopDev,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RestoreSnapshot re-attaches a dm-snapshot from an existing persistent CoW file.
|
// RestoreSnapshot re-attaches a dm-snapshot from an existing persistent CoW file.
|
||||||
@ -171,34 +175,19 @@ func RestoreSnapshot(ctx context.Context, name, originLoopDev, cowPath string, o
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cowLoopDev, err := losetupCreateRW(cowPath)
|
dev, err := attachCowAndCreate(name, originLoopDev, cowPath, originSizeBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("losetup cow: %w", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
sectors := originSizeBytes / 512
|
|
||||||
if err := dmsetupCreate(name, originLoopDev, cowLoopDev, sectors); err != nil {
|
|
||||||
if detachErr := losetupDetachRetry(cowLoopDev); detachErr != nil {
|
|
||||||
slog.Error("cow losetup detach failed during cleanup, loop device leaked", "device", cowLoopDev, "error", detachErr)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("dmsetup create: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
devPath := "/dev/mapper/" + name
|
|
||||||
|
|
||||||
slog.Info("dm-snapshot restored",
|
slog.Info("dm-snapshot restored",
|
||||||
"name", name,
|
"name", name,
|
||||||
"device", devPath,
|
"device", dev.DevicePath,
|
||||||
"origin", originLoopDev,
|
"origin", originLoopDev,
|
||||||
"cow", cowPath,
|
"cow", cowPath,
|
||||||
)
|
)
|
||||||
|
|
||||||
return &SnapshotDevice{
|
return dev, nil
|
||||||
Name: name,
|
|
||||||
DevicePath: devPath,
|
|
||||||
CowPath: cowPath,
|
|
||||||
CowLoopDev: cowLoopDev,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveSnapshot tears down a dm-snapshot device and its CoW loop device.
|
// RemoveSnapshot tears down a dm-snapshot device and its CoW loop device.
|
||||||
|
|||||||
@ -2,6 +2,7 @@ package hostagent
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@ -193,7 +194,7 @@ func (s *Server) PingSandbox(
|
|||||||
req *connect.Request[pb.PingSandboxRequest],
|
req *connect.Request[pb.PingSandboxRequest],
|
||||||
) (*connect.Response[pb.PingSandboxResponse], error) {
|
) (*connect.Response[pb.PingSandboxResponse], error) {
|
||||||
if err := s.mgr.Ping(req.Msg.SandboxId); err != nil {
|
if err := s.mgr.Ping(req.Msg.SandboxId); err != nil {
|
||||||
if strings.Contains(err.Error(), "not found") {
|
if errors.Is(err, sandbox.ErrNotFound) {
|
||||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||||
}
|
}
|
||||||
return nil, connect.NewError(connect.CodeFailedPrecondition, err)
|
return nil, connect.NewError(connect.CodeFailedPrecondition, err)
|
||||||
@ -590,7 +591,7 @@ func (s *Server) GetSandboxMetrics(
|
|||||||
|
|
||||||
points, err := s.mgr.GetMetrics(msg.SandboxId, msg.Range)
|
points, err := s.mgr.GetMetrics(msg.SandboxId, msg.Range)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "not found") {
|
if errors.Is(err, sandbox.ErrNotFound) {
|
||||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||||
}
|
}
|
||||||
if strings.Contains(err.Error(), "invalid range") {
|
if strings.Contains(err.Error(), "invalid range") {
|
||||||
@ -608,7 +609,7 @@ func (s *Server) FlushSandboxMetrics(
|
|||||||
) (*connect.Response[pb.FlushSandboxMetricsResponse], error) {
|
) (*connect.Response[pb.FlushSandboxMetricsResponse], error) {
|
||||||
pts10m, pts2h, pts24h, err := s.mgr.FlushMetrics(req.Msg.SandboxId)
|
pts10m, pts2h, pts24h, err := s.mgr.FlushMetrics(req.Msg.SandboxId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "not found") {
|
if errors.Is(err, sandbox.ErrNotFound) {
|
||||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||||
}
|
}
|
||||||
return nil, connect.NewError(connect.CodeInternal, err)
|
return nil, connect.NewError(connect.CodeInternal, err)
|
||||||
@ -761,7 +762,7 @@ func (s *Server) StartBackground(
|
|||||||
|
|
||||||
pid, err := s.mgr.StartBackground(ctx, msg.SandboxId, msg.Tag, msg.Cmd, msg.Args, msg.Envs, msg.Cwd)
|
pid, err := s.mgr.StartBackground(ctx, msg.SandboxId, msg.Tag, msg.Cmd, msg.Args, msg.Envs, msg.Cwd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "not found") {
|
if errors.Is(err, sandbox.ErrNotFound) {
|
||||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||||
}
|
}
|
||||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("start background: %w", err))
|
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("start background: %w", err))
|
||||||
@ -779,7 +780,7 @@ func (s *Server) ListProcesses(
|
|||||||
) (*connect.Response[pb.ListProcessesResponse], error) {
|
) (*connect.Response[pb.ListProcessesResponse], error) {
|
||||||
procs, err := s.mgr.ListProcesses(ctx, req.Msg.SandboxId)
|
procs, err := s.mgr.ListProcesses(ctx, req.Msg.SandboxId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "not found") {
|
if errors.Is(err, sandbox.ErrNotFound) {
|
||||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||||
}
|
}
|
||||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("list processes: %w", err))
|
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("list processes: %w", err))
|
||||||
@ -830,7 +831,7 @@ func (s *Server) KillProcess(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := s.mgr.KillProcess(ctx, msg.SandboxId, pid, tag, signal); err != nil {
|
if err := s.mgr.KillProcess(ctx, msg.SandboxId, pid, tag, signal); err != nil {
|
||||||
if strings.Contains(err.Error(), "not found") {
|
if errors.Is(err, sandbox.ErrNotFound) {
|
||||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||||
}
|
}
|
||||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("kill process: %w", err))
|
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("kill process: %w", err))
|
||||||
@ -859,7 +860,7 @@ func (s *Server) ConnectProcess(
|
|||||||
|
|
||||||
events, err := s.mgr.ConnectProcess(ctx, msg.SandboxId, pid, tag)
|
events, err := s.mgr.ConnectProcess(ctx, msg.SandboxId, pid, tag)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "not found") {
|
if errors.Is(err, sandbox.ErrNotFound) {
|
||||||
return connect.NewError(connect.CodeNotFound, err)
|
return connect.NewError(connect.CodeNotFound, err)
|
||||||
}
|
}
|
||||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("connect process: %w", err))
|
return connect.NewError(connect.CodeInternal, fmt.Errorf("connect process: %w", err))
|
||||||
|
|||||||
@ -2,6 +2,7 @@ package sandbox
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
@ -24,6 +25,9 @@ import (
|
|||||||
envdpb "git.omukk.dev/wrenn/wrenn/proto/envd/gen"
|
envdpb "git.omukk.dev/wrenn/wrenn/proto/envd/gen"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrNotFound is returned when a sandbox is not present in the in-memory map.
|
||||||
|
var ErrNotFound = errors.New("sandbox not found")
|
||||||
|
|
||||||
// Config holds the paths and defaults for the sandbox manager.
|
// Config holds the paths and defaults for the sandbox manager.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
WrennDir string // root directory (e.g. /var/lib/wrenn); all sub-paths derived via layout package
|
WrennDir string // root directory (e.g. /var/lib/wrenn); all sub-paths derived via layout package
|
||||||
@ -904,7 +908,7 @@ func (m *Manager) FlattenRootfs(ctx context.Context, sandboxID string, teamID, t
|
|||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, fmt.Errorf("sandbox %s not found", sandboxID)
|
return 0, fmt.Errorf("%w: %s", ErrNotFound, sandboxID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush guest page cache to disk before stopping the VM. Without this,
|
// Flush guest page cache to disk before stopping the VM. Without this,
|
||||||
@ -1395,7 +1399,7 @@ func (m *Manager) Ping(sandboxID string) error {
|
|||||||
|
|
||||||
sb, ok := m.boxes[sandboxID]
|
sb, ok := m.boxes[sandboxID]
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("sandbox not found: %s", sandboxID)
|
return fmt.Errorf("%w: %s", ErrNotFound, sandboxID)
|
||||||
}
|
}
|
||||||
if sb.Status != models.StatusRunning {
|
if sb.Status != models.StatusRunning {
|
||||||
return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
|
return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
|
||||||
@ -1421,7 +1425,7 @@ func (m *Manager) get(sandboxID string) (*sandboxState, error) {
|
|||||||
|
|
||||||
sb, ok := m.boxes[sandboxID]
|
sb, ok := m.boxes[sandboxID]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("sandbox not found: %s", sandboxID)
|
return nil, fmt.Errorf("%w: %s", ErrNotFound, sandboxID)
|
||||||
}
|
}
|
||||||
return sb, nil
|
return sb, nil
|
||||||
}
|
}
|
||||||
@ -1731,7 +1735,7 @@ func (m *Manager) GetMetrics(sandboxID, rangeTier string) ([]MetricPoint, error)
|
|||||||
sb, ok := m.boxes[sandboxID]
|
sb, ok := m.boxes[sandboxID]
|
||||||
m.mu.RUnlock()
|
m.mu.RUnlock()
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("sandbox not found: %s", sandboxID)
|
return nil, fmt.Errorf("%w: %s", ErrNotFound, sandboxID)
|
||||||
}
|
}
|
||||||
if sb.ring == nil {
|
if sb.ring == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -1784,7 +1788,7 @@ func (m *Manager) FlushMetrics(sandboxID string) (pts10m, pts2h, pts24h []Metric
|
|||||||
sb, ok := m.boxes[sandboxID]
|
sb, ok := m.boxes[sandboxID]
|
||||||
m.mu.RUnlock()
|
m.mu.RUnlock()
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, nil, nil, fmt.Errorf("sandbox not found: %s", sandboxID)
|
return nil, nil, nil, fmt.Errorf("%w: %s", ErrNotFound, sandboxID)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.stopSampler(sb)
|
m.stopSampler(sb)
|
||||||
|
|||||||
@ -94,6 +94,31 @@ type regTokenPayload struct {
|
|||||||
|
|
||||||
const regTokenTTL = time.Hour
|
const regTokenTTL = time.Hour
|
||||||
|
|
||||||
|
func (s *HostService) issueRegistrationToken(ctx context.Context, hostID, createdBy pgtype.UUID) (string, error) {
|
||||||
|
token := id.NewRegistrationToken()
|
||||||
|
tokenID := id.NewHostTokenID()
|
||||||
|
|
||||||
|
payload, _ := json.Marshal(regTokenPayload{
|
||||||
|
HostID: id.FormatHostID(hostID),
|
||||||
|
TokenID: id.FormatHostTokenID(tokenID),
|
||||||
|
})
|
||||||
|
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
|
||||||
|
return "", fmt.Errorf("store registration token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{
|
||||||
|
ID: tokenID,
|
||||||
|
HostID: hostID,
|
||||||
|
CreatedBy: createdBy,
|
||||||
|
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
|
||||||
|
}); err != nil {
|
||||||
|
slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
// requireAdminOrOwner returns nil iff the role is "owner" or "admin".
|
// requireAdminOrOwner returns nil iff the role is "owner" or "admin".
|
||||||
func requireAdminOrOwner(role string) error {
|
func requireAdminOrOwner(role string) error {
|
||||||
if role == "owner" || role == "admin" {
|
if role == "owner" || role == "admin" {
|
||||||
@ -159,26 +184,9 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
|
|||||||
return HostCreateResult{}, fmt.Errorf("insert host: %w", err)
|
return HostCreateResult{}, fmt.Errorf("insert host: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate registration token and store in Redis + Postgres audit trail.
|
token, err := s.issueRegistrationToken(ctx, hostID, p.RequestingUserID)
|
||||||
token := id.NewRegistrationToken()
|
if err != nil {
|
||||||
tokenID := id.NewHostTokenID()
|
return HostCreateResult{}, err
|
||||||
|
|
||||||
payload, _ := json.Marshal(regTokenPayload{
|
|
||||||
HostID: id.FormatHostID(hostID),
|
|
||||||
TokenID: id.FormatHostTokenID(tokenID),
|
|
||||||
})
|
|
||||||
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
|
|
||||||
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{
|
|
||||||
ID: tokenID,
|
|
||||||
HostID: hostID,
|
|
||||||
CreatedBy: p.RequestingUserID,
|
|
||||||
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
|
|
||||||
}); err != nil {
|
|
||||||
slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return HostCreateResult{Host: host, RegistrationToken: token}, nil
|
return HostCreateResult{Host: host, RegistrationToken: token}, nil
|
||||||
@ -218,25 +226,9 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
token := id.NewRegistrationToken()
|
token, err := s.issueRegistrationToken(ctx, hostID, userID)
|
||||||
tokenID := id.NewHostTokenID()
|
if err != nil {
|
||||||
|
return HostCreateResult{}, err
|
||||||
payload, _ := json.Marshal(regTokenPayload{
|
|
||||||
HostID: id.FormatHostID(hostID),
|
|
||||||
TokenID: id.FormatHostTokenID(tokenID),
|
|
||||||
})
|
|
||||||
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
|
|
||||||
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{
|
|
||||||
ID: tokenID,
|
|
||||||
HostID: hostID,
|
|
||||||
CreatedBy: userID,
|
|
||||||
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
|
|
||||||
}); err != nil {
|
|
||||||
slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return HostCreateResult{Host: host, RegistrationToken: token}, nil
|
return HostCreateResult{Host: host, RegistrationToken: token}, nil
|
||||||
|
|||||||
Reference in New Issue
Block a user