Add streaming exec and file transfer endpoints
Add WebSocket-based streaming exec endpoint and streaming file upload/download endpoints to the control plane API. Includes new host agent RPC methods (ExecStream, StreamWriteFile, StreamReadFile), envd client streaming support, and OpenAPI spec updates.
This commit is contained in:
169
internal/api/handlers_exec_stream.go
Normal file
169
internal/api/handlers_exec_stream.go
Normal file
@ -0,0 +1,169 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
type execStreamHandler struct {
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
}
|
||||
|
||||
func newExecStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *execStreamHandler {
|
||||
return &execStreamHandler{db: db, agent: agent}
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
// wsStartMsg is the first message the client sends to start a process.
|
||||
type wsStartMsg struct {
|
||||
Type string `json:"type"` // "start"
|
||||
Cmd string `json:"cmd"`
|
||||
Args []string `json:"args"`
|
||||
}
|
||||
|
||||
// wsStopMsg is sent by the client to stop the process.
|
||||
type wsStopMsg struct {
|
||||
Type string `json:"type"` // "stop"
|
||||
}
|
||||
|
||||
// wsOutMsg is sent by the server for process events.
|
||||
type wsOutMsg struct {
|
||||
Type string `json:"type"` // "start", "stdout", "stderr", "exit", "error"
|
||||
PID uint32 `json:"pid,omitempty"` // only for "start"
|
||||
Data string `json:"data,omitempty"` // only for "stdout", "stderr", "error"
|
||||
ExitCode *int32 `json:"exit_code,omitempty"` // only for "exit"
|
||||
}
|
||||
|
||||
// ExecStream handles WS /v1/sandboxes/{id}/exec/stream.
|
||||
func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
|
||||
sb, err := h.db.GetSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Read the start message.
|
||||
var startMsg wsStartMsg
|
||||
if err := conn.ReadJSON(&startMsg); err != nil {
|
||||
sendWSError(conn, "failed to read start message: "+err.Error())
|
||||
return
|
||||
}
|
||||
if startMsg.Type != "start" || startMsg.Cmd == "" {
|
||||
sendWSError(conn, "first message must be type 'start' with a 'cmd' field")
|
||||
return
|
||||
}
|
||||
|
||||
// Open streaming exec to host agent.
|
||||
streamCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
stream, err := h.agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{
|
||||
SandboxId: sandboxID,
|
||||
Cmd: startMsg.Cmd,
|
||||
Args: startMsg.Args,
|
||||
}))
|
||||
if err != nil {
|
||||
sendWSError(conn, "failed to start exec stream: "+err.Error())
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
// Listen for stop messages from the client in a goroutine.
|
||||
go func() {
|
||||
for {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
var parsed struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if json.Unmarshal(msg, &parsed) == nil && parsed.Type == "stop" {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Forward stream events to WebSocket.
|
||||
for stream.Receive() {
|
||||
resp := stream.Msg()
|
||||
switch ev := resp.Event.(type) {
|
||||
case *pb.ExecStreamResponse_Start:
|
||||
writeWSJSON(conn, wsOutMsg{Type: "start", PID: ev.Start.Pid})
|
||||
|
||||
case *pb.ExecStreamResponse_Data:
|
||||
switch o := ev.Data.Output.(type) {
|
||||
case *pb.ExecStreamData_Stdout:
|
||||
writeWSJSON(conn, wsOutMsg{Type: "stdout", Data: string(o.Stdout)})
|
||||
case *pb.ExecStreamData_Stderr:
|
||||
writeWSJSON(conn, wsOutMsg{Type: "stderr", Data: string(o.Stderr)})
|
||||
}
|
||||
|
||||
case *pb.ExecStreamResponse_End:
|
||||
exitCode := ev.End.ExitCode
|
||||
writeWSJSON(conn, wsOutMsg{Type: "exit", ExitCode: &exitCode})
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil {
|
||||
// Only send if the connection is still alive (not a normal close).
|
||||
if streamCtx.Err() == nil {
|
||||
sendWSError(conn, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Update last active using a fresh context (the request context may be cancelled).
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func sendWSError(conn *websocket.Conn, msg string) {
|
||||
writeWSJSON(conn, wsOutMsg{Type: "error", Data: msg})
|
||||
}
|
||||
|
||||
func writeWSJSON(conn *websocket.Conn, v any) {
|
||||
if err := conn.WriteJSON(v); err != nil {
|
||||
slog.Debug("websocket write error", "error", err)
|
||||
}
|
||||
}
|
||||
194
internal/api/handlers_files_stream.go
Normal file
194
internal/api/handlers_files_stream.go
Normal file
@ -0,0 +1,194 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
type filesStreamHandler struct {
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
}
|
||||
|
||||
func newFilesStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *filesStreamHandler {
|
||||
return &filesStreamHandler{db: db, agent: agent}
|
||||
}
|
||||
|
||||
// StreamUpload handles POST /v1/sandboxes/{id}/files/stream/write.
|
||||
// Expects multipart/form-data with "path" text field and "file" file field.
|
||||
// Streams file content directly from the request body to the host agent without buffering.
|
||||
func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
|
||||
sb, err := h.db.GetSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse boundary from Content-Type without buffering the body.
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
_, params, err := mime.ParseMediaType(contentType)
|
||||
if err != nil || params["boundary"] == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "expected multipart/form-data with boundary")
|
||||
return
|
||||
}
|
||||
|
||||
// Read parts manually from the multipart stream.
|
||||
mr := multipart.NewReader(r.Body, params["boundary"])
|
||||
|
||||
var filePath string
|
||||
var filePart *multipart.Part
|
||||
|
||||
for {
|
||||
part, err := mr.NextPart()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "failed to parse multipart")
|
||||
return
|
||||
}
|
||||
switch part.FormName() {
|
||||
case "path":
|
||||
data, _ := io.ReadAll(part)
|
||||
filePath = string(data)
|
||||
case "file":
|
||||
filePart = part
|
||||
}
|
||||
if filePath != "" && filePart != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if filePath == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "path field is required")
|
||||
return
|
||||
}
|
||||
if filePart == nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "file field is required")
|
||||
return
|
||||
}
|
||||
defer filePart.Close()
|
||||
|
||||
// Open client-streaming RPC to host agent.
|
||||
stream := h.agent.WriteFileStream(ctx)
|
||||
|
||||
// Send metadata first.
|
||||
if err := stream.Send(&pb.WriteFileStreamRequest{
|
||||
Content: &pb.WriteFileStreamRequest_Meta{
|
||||
Meta: &pb.WriteFileStreamMeta{
|
||||
SandboxId: sandboxID,
|
||||
Path: filePath,
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusBadGateway, "agent_error", "failed to send file metadata")
|
||||
return
|
||||
}
|
||||
|
||||
// Stream file content in 64KB chunks directly from the multipart part.
|
||||
buf := make([]byte, 64*1024)
|
||||
for {
|
||||
n, err := filePart.Read(buf)
|
||||
if n > 0 {
|
||||
chunk := make([]byte, n)
|
||||
copy(chunk, buf[:n])
|
||||
if sendErr := stream.Send(&pb.WriteFileStreamRequest{
|
||||
Content: &pb.WriteFileStreamRequest_Chunk{Chunk: chunk},
|
||||
}); sendErr != nil {
|
||||
writeError(w, http.StatusBadGateway, "agent_error", "failed to stream file chunk")
|
||||
return
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "read_error", "failed to read uploaded file")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Close and receive response.
|
||||
if _, err := stream.CloseAndReceive(); err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// StreamDownload handles POST /v1/sandboxes/{id}/files/stream/read.
|
||||
// Accepts JSON body with path, streams file content back without buffering.
|
||||
func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
|
||||
sb, err := h.db.GetSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
return
|
||||
}
|
||||
|
||||
var req readFileRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
if req.Path == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "path is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Open server-streaming RPC to host agent.
|
||||
stream, err := h.agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{
|
||||
SandboxId: sandboxID,
|
||||
Path: req.Path,
|
||||
}))
|
||||
if err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
|
||||
flusher, canFlush := w.(http.Flusher)
|
||||
for stream.Receive() {
|
||||
chunk := stream.Msg().Chunk
|
||||
if len(chunk) > 0 {
|
||||
if _, err := w.Write(chunk); err != nil {
|
||||
return
|
||||
}
|
||||
if canFlush {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil {
|
||||
// Headers already sent, nothing we can do but log.
|
||||
// The client will see a truncated response.
|
||||
}
|
||||
}
|
||||
@ -1,8 +1,11 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@ -61,6 +64,10 @@ func requestLogger() func(http.Handler) http.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
func decodeJSON(r *http.Request, v any) error {
|
||||
return json.NewDecoder(r.Body).Decode(v)
|
||||
}
|
||||
|
||||
type statusWriter struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
@ -70,3 +77,18 @@ func (w *statusWriter) WriteHeader(status int) {
|
||||
w.status = status
|
||||
w.ResponseWriter.WriteHeader(status)
|
||||
}
|
||||
|
||||
// Hijack implements http.Hijacker, required for WebSocket upgrade.
|
||||
func (w *statusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hj, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||
return hj.Hijack()
|
||||
}
|
||||
return nil, nil, fmt.Errorf("underlying ResponseWriter does not implement http.Hijacker")
|
||||
}
|
||||
|
||||
// Flush implements http.Flusher, required for streaming responses.
|
||||
func (w *statusWriter) Flush() {
|
||||
if fl, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
@ -239,6 +239,143 @@ paths:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
/v1/sandboxes/{id}/exec/stream:
|
||||
parameters:
|
||||
- name: id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
||||
get:
|
||||
summary: Stream command execution via WebSocket
|
||||
operationId: execStream
|
||||
description: |
|
||||
Opens a WebSocket connection for streaming command execution.
|
||||
|
||||
**Client sends** (first message to start the process):
|
||||
```json
|
||||
{"type": "start", "cmd": "tail", "args": ["-f", "/var/log/syslog"]}
|
||||
```
|
||||
|
||||
**Client sends** (to stop the process):
|
||||
```json
|
||||
{"type": "stop"}
|
||||
```
|
||||
|
||||
**Server sends** (process events as they arrive):
|
||||
```json
|
||||
{"type": "start", "pid": 1234}
|
||||
{"type": "stdout", "data": "line of output\n"}
|
||||
{"type": "stderr", "data": "warning message\n"}
|
||||
{"type": "exit", "exit_code": 0}
|
||||
{"type": "error", "data": "description of error"}
|
||||
```
|
||||
|
||||
The connection closes automatically after the process exits.
|
||||
responses:
|
||||
"101":
|
||||
description: WebSocket upgrade
|
||||
"404":
|
||||
description: Sandbox not found
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"409":
|
||||
description: Sandbox not running
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
/v1/sandboxes/{id}/files/stream/write:
|
||||
parameters:
|
||||
- name: id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
||||
post:
|
||||
summary: Upload a file (streaming)
|
||||
operationId: streamUploadFile
|
||||
description: |
|
||||
Streams file content to the sandbox without buffering in memory.
|
||||
Suitable for large files. Uses the same multipart/form-data format
|
||||
as the non-streaming upload endpoint.
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
multipart/form-data:
|
||||
schema:
|
||||
type: object
|
||||
required: [path, file]
|
||||
properties:
|
||||
path:
|
||||
type: string
|
||||
description: Absolute destination path inside the sandbox
|
||||
file:
|
||||
type: string
|
||||
format: binary
|
||||
description: File content
|
||||
responses:
|
||||
"204":
|
||||
description: File uploaded
|
||||
"404":
|
||||
description: Sandbox not found
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"409":
|
||||
description: Sandbox not running
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
/v1/sandboxes/{id}/files/stream/read:
|
||||
parameters:
|
||||
- name: id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
||||
post:
|
||||
summary: Download a file (streaming)
|
||||
operationId: streamDownloadFile
|
||||
description: |
|
||||
Streams file content from the sandbox without buffering in memory.
|
||||
Suitable for large files. Returns raw bytes with chunked transfer encoding.
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ReadFileRequest"
|
||||
responses:
|
||||
"200":
|
||||
description: File content streamed in chunks
|
||||
content:
|
||||
application/octet-stream:
|
||||
schema:
|
||||
type: string
|
||||
format: binary
|
||||
"404":
|
||||
description: Sandbox or file not found
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"409":
|
||||
description: Sandbox not running
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
components:
|
||||
schemas:
|
||||
CreateSandboxRequest:
|
||||
|
||||
@ -26,7 +26,9 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *
|
||||
|
||||
sandbox := newSandboxHandler(queries, agent)
|
||||
exec := newExecHandler(queries, agent)
|
||||
execStream := newExecStreamHandler(queries, agent)
|
||||
files := newFilesHandler(queries, agent)
|
||||
filesStream := newFilesStreamHandler(queries, agent)
|
||||
|
||||
// OpenAPI spec and docs.
|
||||
r.Get("/openapi.yaml", serveOpenAPI)
|
||||
@ -41,10 +43,13 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *
|
||||
r.Get("/", sandbox.Get)
|
||||
r.Delete("/", sandbox.Destroy)
|
||||
r.Post("/exec", exec.Exec)
|
||||
r.Get("/exec/stream", execStream.ExecStream)
|
||||
r.Post("/pause", sandbox.Pause)
|
||||
r.Post("/resume", sandbox.Resume)
|
||||
r.Post("/files/write", files.Upload)
|
||||
r.Post("/files/read", files.Download)
|
||||
r.Post("/files/stream/write", filesStream.StreamUpload)
|
||||
r.Post("/files/stream/read", filesStream.StreamDownload)
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user