Add streaming exec and file transfer endpoints
Add WebSocket-based streaming exec endpoint and streaming file upload/download endpoints to the control plane API. Includes new host agent RPC methods (ExecStream, StreamWriteFile, StreamReadFile), envd client streaming support, and OpenAPI spec updates.
This commit is contained in:
169
internal/api/handlers_exec_stream.go
Normal file
169
internal/api/handlers_exec_stream.go
Normal file
@ -0,0 +1,169 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
type execStreamHandler struct {
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
}
|
||||
|
||||
func newExecStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *execStreamHandler {
|
||||
return &execStreamHandler{db: db, agent: agent}
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
// wsStartMsg is the first message the client sends to start a process.
|
||||
type wsStartMsg struct {
|
||||
Type string `json:"type"` // "start"
|
||||
Cmd string `json:"cmd"`
|
||||
Args []string `json:"args"`
|
||||
}
|
||||
|
||||
// wsStopMsg is sent by the client to stop the process.
|
||||
type wsStopMsg struct {
|
||||
Type string `json:"type"` // "stop"
|
||||
}
|
||||
|
||||
// wsOutMsg is sent by the server for process events.
|
||||
type wsOutMsg struct {
|
||||
Type string `json:"type"` // "start", "stdout", "stderr", "exit", "error"
|
||||
PID uint32 `json:"pid,omitempty"` // only for "start"
|
||||
Data string `json:"data,omitempty"` // only for "stdout", "stderr", "error"
|
||||
ExitCode *int32 `json:"exit_code,omitempty"` // only for "exit"
|
||||
}
|
||||
|
||||
// ExecStream handles WS /v1/sandboxes/{id}/exec/stream.
|
||||
func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
|
||||
sb, err := h.db.GetSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Read the start message.
|
||||
var startMsg wsStartMsg
|
||||
if err := conn.ReadJSON(&startMsg); err != nil {
|
||||
sendWSError(conn, "failed to read start message: "+err.Error())
|
||||
return
|
||||
}
|
||||
if startMsg.Type != "start" || startMsg.Cmd == "" {
|
||||
sendWSError(conn, "first message must be type 'start' with a 'cmd' field")
|
||||
return
|
||||
}
|
||||
|
||||
// Open streaming exec to host agent.
|
||||
streamCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
stream, err := h.agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{
|
||||
SandboxId: sandboxID,
|
||||
Cmd: startMsg.Cmd,
|
||||
Args: startMsg.Args,
|
||||
}))
|
||||
if err != nil {
|
||||
sendWSError(conn, "failed to start exec stream: "+err.Error())
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
// Listen for stop messages from the client in a goroutine.
|
||||
go func() {
|
||||
for {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
var parsed struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if json.Unmarshal(msg, &parsed) == nil && parsed.Type == "stop" {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Forward stream events to WebSocket.
|
||||
for stream.Receive() {
|
||||
resp := stream.Msg()
|
||||
switch ev := resp.Event.(type) {
|
||||
case *pb.ExecStreamResponse_Start:
|
||||
writeWSJSON(conn, wsOutMsg{Type: "start", PID: ev.Start.Pid})
|
||||
|
||||
case *pb.ExecStreamResponse_Data:
|
||||
switch o := ev.Data.Output.(type) {
|
||||
case *pb.ExecStreamData_Stdout:
|
||||
writeWSJSON(conn, wsOutMsg{Type: "stdout", Data: string(o.Stdout)})
|
||||
case *pb.ExecStreamData_Stderr:
|
||||
writeWSJSON(conn, wsOutMsg{Type: "stderr", Data: string(o.Stderr)})
|
||||
}
|
||||
|
||||
case *pb.ExecStreamResponse_End:
|
||||
exitCode := ev.End.ExitCode
|
||||
writeWSJSON(conn, wsOutMsg{Type: "exit", ExitCode: &exitCode})
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil {
|
||||
// Only send if the connection is still alive (not a normal close).
|
||||
if streamCtx.Err() == nil {
|
||||
sendWSError(conn, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Update last active using a fresh context (the request context may be cancelled).
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func sendWSError(conn *websocket.Conn, msg string) {
|
||||
writeWSJSON(conn, wsOutMsg{Type: "error", Data: msg})
|
||||
}
|
||||
|
||||
func writeWSJSON(conn *websocket.Conn, v any) {
|
||||
if err := conn.WriteJSON(v); err != nil {
|
||||
slog.Debug("websocket write error", "error", err)
|
||||
}
|
||||
}
|
||||
194
internal/api/handlers_files_stream.go
Normal file
194
internal/api/handlers_files_stream.go
Normal file
@ -0,0 +1,194 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
type filesStreamHandler struct {
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
}
|
||||
|
||||
func newFilesStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *filesStreamHandler {
|
||||
return &filesStreamHandler{db: db, agent: agent}
|
||||
}
|
||||
|
||||
// StreamUpload handles POST /v1/sandboxes/{id}/files/stream/write.
|
||||
// Expects multipart/form-data with "path" text field and "file" file field.
|
||||
// Streams file content directly from the request body to the host agent without buffering.
|
||||
func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
|
||||
sb, err := h.db.GetSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse boundary from Content-Type without buffering the body.
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
_, params, err := mime.ParseMediaType(contentType)
|
||||
if err != nil || params["boundary"] == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "expected multipart/form-data with boundary")
|
||||
return
|
||||
}
|
||||
|
||||
// Read parts manually from the multipart stream.
|
||||
mr := multipart.NewReader(r.Body, params["boundary"])
|
||||
|
||||
var filePath string
|
||||
var filePart *multipart.Part
|
||||
|
||||
for {
|
||||
part, err := mr.NextPart()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "failed to parse multipart")
|
||||
return
|
||||
}
|
||||
switch part.FormName() {
|
||||
case "path":
|
||||
data, _ := io.ReadAll(part)
|
||||
filePath = string(data)
|
||||
case "file":
|
||||
filePart = part
|
||||
}
|
||||
if filePath != "" && filePart != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if filePath == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "path field is required")
|
||||
return
|
||||
}
|
||||
if filePart == nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "file field is required")
|
||||
return
|
||||
}
|
||||
defer filePart.Close()
|
||||
|
||||
// Open client-streaming RPC to host agent.
|
||||
stream := h.agent.WriteFileStream(ctx)
|
||||
|
||||
// Send metadata first.
|
||||
if err := stream.Send(&pb.WriteFileStreamRequest{
|
||||
Content: &pb.WriteFileStreamRequest_Meta{
|
||||
Meta: &pb.WriteFileStreamMeta{
|
||||
SandboxId: sandboxID,
|
||||
Path: filePath,
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusBadGateway, "agent_error", "failed to send file metadata")
|
||||
return
|
||||
}
|
||||
|
||||
// Stream file content in 64KB chunks directly from the multipart part.
|
||||
buf := make([]byte, 64*1024)
|
||||
for {
|
||||
n, err := filePart.Read(buf)
|
||||
if n > 0 {
|
||||
chunk := make([]byte, n)
|
||||
copy(chunk, buf[:n])
|
||||
if sendErr := stream.Send(&pb.WriteFileStreamRequest{
|
||||
Content: &pb.WriteFileStreamRequest_Chunk{Chunk: chunk},
|
||||
}); sendErr != nil {
|
||||
writeError(w, http.StatusBadGateway, "agent_error", "failed to stream file chunk")
|
||||
return
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "read_error", "failed to read uploaded file")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Close and receive response.
|
||||
if _, err := stream.CloseAndReceive(); err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// StreamDownload handles POST /v1/sandboxes/{id}/files/stream/read.
|
||||
// Accepts JSON body with path, streams file content back without buffering.
|
||||
func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
|
||||
sb, err := h.db.GetSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
return
|
||||
}
|
||||
|
||||
var req readFileRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
if req.Path == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "path is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Open server-streaming RPC to host agent.
|
||||
stream, err := h.agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{
|
||||
SandboxId: sandboxID,
|
||||
Path: req.Path,
|
||||
}))
|
||||
if err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
|
||||
flusher, canFlush := w.(http.Flusher)
|
||||
for stream.Receive() {
|
||||
chunk := stream.Msg().Chunk
|
||||
if len(chunk) > 0 {
|
||||
if _, err := w.Write(chunk); err != nil {
|
||||
return
|
||||
}
|
||||
if canFlush {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil {
|
||||
// Headers already sent, nothing we can do but log.
|
||||
// The client will see a truncated response.
|
||||
}
|
||||
}
|
||||
@ -1,8 +1,11 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@ -61,6 +64,10 @@ func requestLogger() func(http.Handler) http.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
func decodeJSON(r *http.Request, v any) error {
|
||||
return json.NewDecoder(r.Body).Decode(v)
|
||||
}
|
||||
|
||||
type statusWriter struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
@ -70,3 +77,18 @@ func (w *statusWriter) WriteHeader(status int) {
|
||||
w.status = status
|
||||
w.ResponseWriter.WriteHeader(status)
|
||||
}
|
||||
|
||||
// Hijack implements http.Hijacker, required for WebSocket upgrade.
|
||||
func (w *statusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hj, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||
return hj.Hijack()
|
||||
}
|
||||
return nil, nil, fmt.Errorf("underlying ResponseWriter does not implement http.Hijacker")
|
||||
}
|
||||
|
||||
// Flush implements http.Flusher, required for streaming responses.
|
||||
func (w *statusWriter) Flush() {
|
||||
if fl, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
@ -239,6 +239,143 @@ paths:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
/v1/sandboxes/{id}/exec/stream:
|
||||
parameters:
|
||||
- name: id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
||||
get:
|
||||
summary: Stream command execution via WebSocket
|
||||
operationId: execStream
|
||||
description: |
|
||||
Opens a WebSocket connection for streaming command execution.
|
||||
|
||||
**Client sends** (first message to start the process):
|
||||
```json
|
||||
{"type": "start", "cmd": "tail", "args": ["-f", "/var/log/syslog"]}
|
||||
```
|
||||
|
||||
**Client sends** (to stop the process):
|
||||
```json
|
||||
{"type": "stop"}
|
||||
```
|
||||
|
||||
**Server sends** (process events as they arrive):
|
||||
```json
|
||||
{"type": "start", "pid": 1234}
|
||||
{"type": "stdout", "data": "line of output\n"}
|
||||
{"type": "stderr", "data": "warning message\n"}
|
||||
{"type": "exit", "exit_code": 0}
|
||||
{"type": "error", "data": "description of error"}
|
||||
```
|
||||
|
||||
The connection closes automatically after the process exits.
|
||||
responses:
|
||||
"101":
|
||||
description: WebSocket upgrade
|
||||
"404":
|
||||
description: Sandbox not found
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"409":
|
||||
description: Sandbox not running
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
/v1/sandboxes/{id}/files/stream/write:
|
||||
parameters:
|
||||
- name: id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
||||
post:
|
||||
summary: Upload a file (streaming)
|
||||
operationId: streamUploadFile
|
||||
description: |
|
||||
Streams file content to the sandbox without buffering in memory.
|
||||
Suitable for large files. Uses the same multipart/form-data format
|
||||
as the non-streaming upload endpoint.
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
multipart/form-data:
|
||||
schema:
|
||||
type: object
|
||||
required: [path, file]
|
||||
properties:
|
||||
path:
|
||||
type: string
|
||||
description: Absolute destination path inside the sandbox
|
||||
file:
|
||||
type: string
|
||||
format: binary
|
||||
description: File content
|
||||
responses:
|
||||
"204":
|
||||
description: File uploaded
|
||||
"404":
|
||||
description: Sandbox not found
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"409":
|
||||
description: Sandbox not running
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
/v1/sandboxes/{id}/files/stream/read:
|
||||
parameters:
|
||||
- name: id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
||||
post:
|
||||
summary: Download a file (streaming)
|
||||
operationId: streamDownloadFile
|
||||
description: |
|
||||
Streams file content from the sandbox without buffering in memory.
|
||||
Suitable for large files. Returns raw bytes with chunked transfer encoding.
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ReadFileRequest"
|
||||
responses:
|
||||
"200":
|
||||
description: File content streamed in chunks
|
||||
content:
|
||||
application/octet-stream:
|
||||
schema:
|
||||
type: string
|
||||
format: binary
|
||||
"404":
|
||||
description: Sandbox or file not found
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"409":
|
||||
description: Sandbox not running
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
components:
|
||||
schemas:
|
||||
CreateSandboxRequest:
|
||||
|
||||
@ -26,7 +26,9 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *
|
||||
|
||||
sandbox := newSandboxHandler(queries, agent)
|
||||
exec := newExecHandler(queries, agent)
|
||||
execStream := newExecStreamHandler(queries, agent)
|
||||
files := newFilesHandler(queries, agent)
|
||||
filesStream := newFilesStreamHandler(queries, agent)
|
||||
|
||||
// OpenAPI spec and docs.
|
||||
r.Get("/openapi.yaml", serveOpenAPI)
|
||||
@ -41,10 +43,13 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *
|
||||
r.Get("/", sandbox.Get)
|
||||
r.Delete("/", sandbox.Destroy)
|
||||
r.Post("/exec", exec.Exec)
|
||||
r.Get("/exec/stream", execStream.ExecStream)
|
||||
r.Post("/pause", sandbox.Pause)
|
||||
r.Post("/resume", sandbox.Resume)
|
||||
r.Post("/files/write", files.Upload)
|
||||
r.Post("/files/read", files.Download)
|
||||
r.Post("/files/stream/write", filesStream.StreamUpload)
|
||||
r.Post("/files/stream/read", filesStream.StreamDownload)
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@ -42,6 +42,11 @@ func New(hostIP string) *Client {
|
||||
}
|
||||
}
|
||||
|
||||
// BaseURL returns the HTTP base URL for reaching envd.
|
||||
func (c *Client) BaseURL() string {
|
||||
return c.base
|
||||
}
|
||||
|
||||
// ExecResult holds the output of a command execution.
|
||||
type ExecResult struct {
|
||||
Stdout []byte
|
||||
@ -110,6 +115,83 @@ func (c *Client) Exec(ctx context.Context, cmd string, args ...string) (*ExecRes
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ExecStreamEvent represents a single event from a streaming exec.
|
||||
type ExecStreamEvent struct {
|
||||
Type string // "start", "stdout", "stderr", "end"
|
||||
PID uint32
|
||||
Data []byte
|
||||
ExitCode int32
|
||||
Error string
|
||||
}
|
||||
|
||||
// ExecStream runs a command inside the sandbox and returns a channel of output events.
|
||||
// The channel is closed when the process ends or the context is cancelled.
|
||||
func (c *Client) ExecStream(ctx context.Context, cmd string, args ...string) (<-chan ExecStreamEvent, error) {
|
||||
stdin := false
|
||||
req := connect.NewRequest(&envdpb.StartRequest{
|
||||
Process: &envdpb.ProcessConfig{
|
||||
Cmd: cmd,
|
||||
Args: args,
|
||||
},
|
||||
Stdin: &stdin,
|
||||
})
|
||||
|
||||
stream, err := c.process.Start(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("start process: %w", err)
|
||||
}
|
||||
|
||||
ch := make(chan ExecStreamEvent, 16)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
defer stream.Close()
|
||||
|
||||
for stream.Receive() {
|
||||
msg := stream.Msg()
|
||||
if msg.Event == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var ev ExecStreamEvent
|
||||
event := msg.Event.GetEvent()
|
||||
switch e := event.(type) {
|
||||
case *envdpb.ProcessEvent_Start:
|
||||
ev = ExecStreamEvent{Type: "start", PID: e.Start.GetPid()}
|
||||
|
||||
case *envdpb.ProcessEvent_Data:
|
||||
output := e.Data.GetOutput()
|
||||
switch o := output.(type) {
|
||||
case *envdpb.ProcessEvent_DataEvent_Stdout:
|
||||
ev = ExecStreamEvent{Type: "stdout", Data: o.Stdout}
|
||||
case *envdpb.ProcessEvent_DataEvent_Stderr:
|
||||
ev = ExecStreamEvent{Type: "stderr", Data: o.Stderr}
|
||||
}
|
||||
|
||||
case *envdpb.ProcessEvent_End:
|
||||
ev = ExecStreamEvent{Type: "end", ExitCode: e.End.GetExitCode()}
|
||||
if e.End.Error != nil {
|
||||
ev.Error = e.End.GetError()
|
||||
}
|
||||
|
||||
case *envdpb.ProcessEvent_Keepalive:
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case ch <- ev:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil && err != io.EOF {
|
||||
slog.Debug("exec stream error", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// WriteFile writes content to a file inside the sandbox via envd's REST endpoint.
|
||||
// envd expects POST /files?path=...&username=root with multipart/form-data (field name "file").
|
||||
func (c *Client) WriteFile(ctx context.Context, path string, content []byte) error {
|
||||
|
||||
@ -3,6 +3,11 @@ package hostagent
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
@ -135,6 +140,209 @@ func (s *Server) ReadFile(
|
||||
return connect.NewResponse(&pb.ReadFileResponse{Content: content}), nil
|
||||
}
|
||||
|
||||
func (s *Server) ExecStream(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.ExecStreamRequest],
|
||||
stream *connect.ServerStream[pb.ExecStreamResponse],
|
||||
) error {
|
||||
msg := req.Msg
|
||||
|
||||
// Only apply a timeout if explicitly requested; streaming execs may be long-running.
|
||||
execCtx := ctx
|
||||
if msg.TimeoutSec > 0 {
|
||||
var cancel context.CancelFunc
|
||||
execCtx, cancel = context.WithTimeout(ctx, time.Duration(msg.TimeoutSec)*time.Second)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
events, err := s.mgr.ExecStream(execCtx, msg.SandboxId, msg.Cmd, msg.Args...)
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("exec stream: %w", err))
|
||||
}
|
||||
|
||||
for ev := range events {
|
||||
var resp pb.ExecStreamResponse
|
||||
switch ev.Type {
|
||||
case "start":
|
||||
resp.Event = &pb.ExecStreamResponse_Start{
|
||||
Start: &pb.ExecStreamStart{Pid: ev.PID},
|
||||
}
|
||||
case "stdout":
|
||||
resp.Event = &pb.ExecStreamResponse_Data{
|
||||
Data: &pb.ExecStreamData{
|
||||
Output: &pb.ExecStreamData_Stdout{Stdout: ev.Data},
|
||||
},
|
||||
}
|
||||
case "stderr":
|
||||
resp.Event = &pb.ExecStreamResponse_Data{
|
||||
Data: &pb.ExecStreamData{
|
||||
Output: &pb.ExecStreamData_Stderr{Stderr: ev.Data},
|
||||
},
|
||||
}
|
||||
case "end":
|
||||
resp.Event = &pb.ExecStreamResponse_End{
|
||||
End: &pb.ExecStreamEnd{
|
||||
ExitCode: ev.ExitCode,
|
||||
Error: ev.Error,
|
||||
},
|
||||
}
|
||||
}
|
||||
if err := stream.Send(&resp); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) WriteFileStream(
|
||||
ctx context.Context,
|
||||
stream *connect.ClientStream[pb.WriteFileStreamRequest],
|
||||
) (*connect.Response[pb.WriteFileStreamResponse], error) {
|
||||
// First message must contain metadata.
|
||||
if !stream.Receive() {
|
||||
if err := stream.Err(); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, err)
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("empty stream"))
|
||||
}
|
||||
|
||||
first := stream.Msg()
|
||||
meta := first.GetMeta()
|
||||
if meta == nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("first message must contain metadata"))
|
||||
}
|
||||
|
||||
client, err := s.mgr.GetClient(meta.SandboxId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
|
||||
// Use io.Pipe to stream chunks into a multipart body for envd's REST endpoint.
|
||||
pr, pw := io.Pipe()
|
||||
mpWriter := multipart.NewWriter(pw)
|
||||
|
||||
// Write multipart data in a goroutine.
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
defer pw.Close()
|
||||
part, err := mpWriter.CreateFormFile("file", "upload")
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("create multipart: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
for stream.Receive() {
|
||||
chunk := stream.Msg().GetChunk()
|
||||
if len(chunk) == 0 {
|
||||
continue
|
||||
}
|
||||
if _, err := part.Write(chunk); err != nil {
|
||||
errCh <- fmt.Errorf("write chunk: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := stream.Err(); err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
mpWriter.Close()
|
||||
errCh <- nil
|
||||
}()
|
||||
|
||||
// Send the streaming multipart body to envd.
|
||||
base := client.BaseURL()
|
||||
u := fmt.Sprintf("%s/files?%s", base, url.Values{
|
||||
"path": {meta.Path},
|
||||
"username": {"root"},
|
||||
}.Encode())
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, u, pr)
|
||||
if err != nil {
|
||||
pw.CloseWithError(err)
|
||||
<-errCh
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create request: %w", err))
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", mpWriter.FormDataContentType())
|
||||
|
||||
resp, err := http.DefaultClient.Do(httpReq)
|
||||
if err != nil {
|
||||
pw.CloseWithError(err)
|
||||
<-errCh
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("write file stream: %w", err))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Wait for the writer goroutine.
|
||||
if writerErr := <-errCh; writerErr != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, writerErr)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("envd write: status %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
slog.Debug("streaming file write complete", "sandbox_id", meta.SandboxId, "path", meta.Path)
|
||||
return connect.NewResponse(&pb.WriteFileStreamResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Server) ReadFileStream(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.ReadFileStreamRequest],
|
||||
stream *connect.ServerStream[pb.ReadFileStreamResponse],
|
||||
) error {
|
||||
msg := req.Msg
|
||||
|
||||
client, err := s.mgr.GetClient(msg.SandboxId)
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
|
||||
base := client.BaseURL()
|
||||
u := fmt.Sprintf("%s/files?%s", base, url.Values{
|
||||
"path": {msg.Path},
|
||||
"username": {"root"},
|
||||
}.Encode())
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("create request: %w", err))
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("read file stream: %w", err))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("envd read: status %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
// Stream file content in 64KB chunks.
|
||||
buf := make([]byte, 64*1024)
|
||||
for {
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
chunk := make([]byte, n)
|
||||
copy(chunk, buf[:n])
|
||||
if sendErr := stream.Send(&pb.ReadFileStreamResponse{Chunk: chunk}); sendErr != nil {
|
||||
return sendErr
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("read body: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) ListSandboxes(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.ListSandboxesRequest],
|
||||
|
||||
@ -262,6 +262,24 @@ func (m *Manager) Exec(ctx context.Context, sandboxID string, cmd string, args .
|
||||
return sb.client.Exec(ctx, cmd, args...)
|
||||
}
|
||||
|
||||
// ExecStream runs a command inside a sandbox and returns a channel of streaming events.
|
||||
func (m *Manager) ExecStream(ctx context.Context, sandboxID string, cmd string, args ...string) (<-chan envdclient.ExecStreamEvent, error) {
|
||||
sb, err := m.get(sandboxID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if sb.Status != models.StatusRunning {
|
||||
return nil, fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
sb.LastActiveAt = time.Now()
|
||||
m.mu.Unlock()
|
||||
|
||||
return sb.client.ExecStream(ctx, cmd, args...)
|
||||
}
|
||||
|
||||
// List returns all sandboxes.
|
||||
func (m *Manager) List() []models.Sandbox {
|
||||
m.mu.RLock()
|
||||
|
||||
Reference in New Issue
Block a user