forked from wrenn/wrenn
v0.1.0 (#17)
This commit is contained in:
@ -12,20 +12,21 @@ import (
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
type execStreamHandler struct {
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
jwtSecret []byte
|
||||
}
|
||||
|
||||
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execStreamHandler {
|
||||
return &execStreamHandler{db: db, pool: pool}
|
||||
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *execStreamHandler {
|
||||
return &execStreamHandler{db: db, pool: pool, jwtSecret: jwtSecret}
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
@ -47,11 +48,10 @@ type wsOutMsg struct {
|
||||
ExitCode *int32 `json:"exit_code,omitempty"` // only for "exit"
|
||||
}
|
||||
|
||||
// ExecStream handles WS /v1/sandboxes/{id}/exec/stream.
|
||||
// ExecStream handles WS /v1/capsules/{id}/exec/stream.
|
||||
func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
@ -59,13 +59,31 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
// Authenticate: use context from middleware (API key) or WS first message (JWT).
|
||||
ac, hasAuth := auth.FromContext(ctx)
|
||||
|
||||
if !hasAuth {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
var wsAC auth.AuthContext
|
||||
var authErr error
|
||||
if isAdminWSRoute(ctx) {
|
||||
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
|
||||
} else {
|
||||
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
|
||||
}
|
||||
if authErr != nil {
|
||||
sendWSError(conn, "authentication failed")
|
||||
return
|
||||
}
|
||||
ac = wsAC
|
||||
|
||||
h.runExecStream(ctx, conn, ac, sandboxID, sandboxIDStr)
|
||||
return
|
||||
}
|
||||
|
||||
@ -76,6 +94,20 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
h.runExecStream(ctx, conn, ac, sandboxID, sandboxIDStr)
|
||||
}
|
||||
|
||||
func (h *execStreamHandler) runExecStream(ctx context.Context, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr string) {
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
sendWSError(conn, "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
sendWSError(conn, "sandbox is not running (status: "+sb.Status+")")
|
||||
return
|
||||
}
|
||||
|
||||
// Read the start message.
|
||||
var startMsg wsStartMsg
|
||||
if err := conn.ReadJSON(&startMsg); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user