Add streaming exec and file transfer endpoints
Add WebSocket-based streaming exec endpoint and streaming file upload/download endpoints to the control plane API. Includes new host agent RPC methods (ExecStream, StreamWriteFile, StreamReadFile), envd client streaming support, and OpenAPI spec updates.
This commit is contained in:
@ -3,6 +3,11 @@ package hostagent
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
@ -135,6 +140,209 @@ func (s *Server) ReadFile(
|
||||
return connect.NewResponse(&pb.ReadFileResponse{Content: content}), nil
|
||||
}
|
||||
|
||||
func (s *Server) ExecStream(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.ExecStreamRequest],
|
||||
stream *connect.ServerStream[pb.ExecStreamResponse],
|
||||
) error {
|
||||
msg := req.Msg
|
||||
|
||||
// Only apply a timeout if explicitly requested; streaming execs may be long-running.
|
||||
execCtx := ctx
|
||||
if msg.TimeoutSec > 0 {
|
||||
var cancel context.CancelFunc
|
||||
execCtx, cancel = context.WithTimeout(ctx, time.Duration(msg.TimeoutSec)*time.Second)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
events, err := s.mgr.ExecStream(execCtx, msg.SandboxId, msg.Cmd, msg.Args...)
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("exec stream: %w", err))
|
||||
}
|
||||
|
||||
for ev := range events {
|
||||
var resp pb.ExecStreamResponse
|
||||
switch ev.Type {
|
||||
case "start":
|
||||
resp.Event = &pb.ExecStreamResponse_Start{
|
||||
Start: &pb.ExecStreamStart{Pid: ev.PID},
|
||||
}
|
||||
case "stdout":
|
||||
resp.Event = &pb.ExecStreamResponse_Data{
|
||||
Data: &pb.ExecStreamData{
|
||||
Output: &pb.ExecStreamData_Stdout{Stdout: ev.Data},
|
||||
},
|
||||
}
|
||||
case "stderr":
|
||||
resp.Event = &pb.ExecStreamResponse_Data{
|
||||
Data: &pb.ExecStreamData{
|
||||
Output: &pb.ExecStreamData_Stderr{Stderr: ev.Data},
|
||||
},
|
||||
}
|
||||
case "end":
|
||||
resp.Event = &pb.ExecStreamResponse_End{
|
||||
End: &pb.ExecStreamEnd{
|
||||
ExitCode: ev.ExitCode,
|
||||
Error: ev.Error,
|
||||
},
|
||||
}
|
||||
}
|
||||
if err := stream.Send(&resp); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) WriteFileStream(
|
||||
ctx context.Context,
|
||||
stream *connect.ClientStream[pb.WriteFileStreamRequest],
|
||||
) (*connect.Response[pb.WriteFileStreamResponse], error) {
|
||||
// First message must contain metadata.
|
||||
if !stream.Receive() {
|
||||
if err := stream.Err(); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, err)
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("empty stream"))
|
||||
}
|
||||
|
||||
first := stream.Msg()
|
||||
meta := first.GetMeta()
|
||||
if meta == nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("first message must contain metadata"))
|
||||
}
|
||||
|
||||
client, err := s.mgr.GetClient(meta.SandboxId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
|
||||
// Use io.Pipe to stream chunks into a multipart body for envd's REST endpoint.
|
||||
pr, pw := io.Pipe()
|
||||
mpWriter := multipart.NewWriter(pw)
|
||||
|
||||
// Write multipart data in a goroutine.
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
defer pw.Close()
|
||||
part, err := mpWriter.CreateFormFile("file", "upload")
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("create multipart: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
for stream.Receive() {
|
||||
chunk := stream.Msg().GetChunk()
|
||||
if len(chunk) == 0 {
|
||||
continue
|
||||
}
|
||||
if _, err := part.Write(chunk); err != nil {
|
||||
errCh <- fmt.Errorf("write chunk: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := stream.Err(); err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
mpWriter.Close()
|
||||
errCh <- nil
|
||||
}()
|
||||
|
||||
// Send the streaming multipart body to envd.
|
||||
base := client.BaseURL()
|
||||
u := fmt.Sprintf("%s/files?%s", base, url.Values{
|
||||
"path": {meta.Path},
|
||||
"username": {"root"},
|
||||
}.Encode())
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, u, pr)
|
||||
if err != nil {
|
||||
pw.CloseWithError(err)
|
||||
<-errCh
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create request: %w", err))
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", mpWriter.FormDataContentType())
|
||||
|
||||
resp, err := http.DefaultClient.Do(httpReq)
|
||||
if err != nil {
|
||||
pw.CloseWithError(err)
|
||||
<-errCh
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("write file stream: %w", err))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Wait for the writer goroutine.
|
||||
if writerErr := <-errCh; writerErr != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, writerErr)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("envd write: status %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
slog.Debug("streaming file write complete", "sandbox_id", meta.SandboxId, "path", meta.Path)
|
||||
return connect.NewResponse(&pb.WriteFileStreamResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Server) ReadFileStream(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.ReadFileStreamRequest],
|
||||
stream *connect.ServerStream[pb.ReadFileStreamResponse],
|
||||
) error {
|
||||
msg := req.Msg
|
||||
|
||||
client, err := s.mgr.GetClient(msg.SandboxId)
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
|
||||
base := client.BaseURL()
|
||||
u := fmt.Sprintf("%s/files?%s", base, url.Values{
|
||||
"path": {msg.Path},
|
||||
"username": {"root"},
|
||||
}.Encode())
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("create request: %w", err))
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("read file stream: %w", err))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("envd read: status %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
// Stream file content in 64KB chunks.
|
||||
buf := make([]byte, 64*1024)
|
||||
for {
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
chunk := make([]byte, n)
|
||||
copy(chunk, buf[:n])
|
||||
if sendErr := stream.Send(&pb.ReadFileStreamResponse{Chunk: chunk}); sendErr != nil {
|
||||
return sendErr
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("read body: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) ListSandboxes(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.ListSandboxesRequest],
|
||||
|
||||
Reference in New Issue
Block a user