forked from wrenn/wrenn
Add interactive PTY terminal sessions for sandboxes
Wire envd's existing PTY process capabilities through the full stack:
hostagent proto (4 new RPCs: PtyAttach, PtySendInput, PtyResize, PtyKill),
envdclient, sandbox manager, and a new WebSocket endpoint at
GET /v1/sandboxes/{id}/pty with bidirectional JSON message protocol.
Sessions use tag-based identity for disconnect/reconnect support,
base64-encoded PTY data for binary safety, and a 120s inactivity timeout.
This commit is contained in:
405
internal/api/handlers_pty.go
Normal file
405
internal/api/handlers_pty.go
Normal file
@ -0,0 +1,405 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/db"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/id"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
const (
|
||||
ptyInactivityTimeout = 120 * time.Second
|
||||
ptyKeepaliveInterval = 30 * time.Second
|
||||
ptyDefaultCmd = "/bin/bash"
|
||||
ptyDefaultCols = 80
|
||||
ptyDefaultRows = 24
|
||||
)
|
||||
|
||||
type ptyHandler struct {
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
}
|
||||
|
||||
func newPtyHandler(db *db.Queries, pool *lifecycle.HostClientPool) *ptyHandler {
|
||||
return &ptyHandler{db: db, pool: pool}
|
||||
}
|
||||
|
||||
// --- WebSocket message types ---
|
||||
|
||||
// wsPtyIn is the inbound message from the client.
|
||||
type wsPtyIn struct {
|
||||
Type string `json:"type"` // "start", "connect", "input", "resize", "kill"
|
||||
Cmd string `json:"cmd,omitempty"` // for "start"
|
||||
Args []string `json:"args,omitempty"` // for "start"
|
||||
Cols uint32 `json:"cols,omitempty"` // for "start", "resize"
|
||||
Rows uint32 `json:"rows,omitempty"` // for "start", "resize"
|
||||
Envs map[string]string `json:"envs,omitempty"` // for "start"
|
||||
Cwd string `json:"cwd,omitempty"` // for "start"
|
||||
User string `json:"user,omitempty"` // for "start"
|
||||
Tag string `json:"tag,omitempty"` // for "connect"
|
||||
Data string `json:"data,omitempty"` // for "input" (base64)
|
||||
}
|
||||
|
||||
// wsPtyOut is the outbound message to the client.
|
||||
type wsPtyOut struct {
|
||||
Type string `json:"type"` // "started", "output", "exit", "error"
|
||||
Tag string `json:"tag,omitempty"` // for "started"
|
||||
PID uint32 `json:"pid,omitempty"` // for "started"
|
||||
Data string `json:"data,omitempty"` // for "output" (base64), "error"
|
||||
ExitCode *int32 `json:"exit_code,omitempty"` // for "exit"
|
||||
Fatal bool `json:"fatal,omitempty"` // for "error"
|
||||
}
|
||||
|
||||
// wsWriter wraps a websocket.Conn with a mutex for concurrent writes.
|
||||
type wsWriter struct {
|
||||
conn *websocket.Conn
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (w *wsWriter) writeJSON(v any) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if err := w.conn.WriteJSON(v); err != nil {
|
||||
slog.Debug("pty websocket write error", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// PtySession handles WS /v1/sandboxes/{id}/pty.
|
||||
func (h *ptyHandler) PtySession(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("pty websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ws := &wsWriter{conn: conn}
|
||||
|
||||
// Read the first message to determine start vs connect.
|
||||
var firstMsg wsPtyIn
|
||||
if err := conn.ReadJSON(&firstMsg); err != nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "failed to read first message: " + err.Error(), Fatal: true})
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "sandbox host is not reachable", Fatal: true})
|
||||
return
|
||||
}
|
||||
|
||||
streamCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
switch firstMsg.Type {
|
||||
case "start":
|
||||
h.handleStart(streamCtx, cancel, ws, agent, sandboxIDStr, firstMsg)
|
||||
case "connect":
|
||||
h.handleConnect(streamCtx, cancel, ws, agent, sandboxIDStr, firstMsg)
|
||||
default:
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "first message must be type 'start' or 'connect'", Fatal: true})
|
||||
}
|
||||
|
||||
// Update last active using a fresh context.
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last active after pty session", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ptyHandler) handleStart(
|
||||
ctx context.Context,
|
||||
cancel context.CancelFunc,
|
||||
ws *wsWriter,
|
||||
agent hostagentv1connect.HostAgentServiceClient,
|
||||
sandboxIDStr string,
|
||||
msg wsPtyIn,
|
||||
) {
|
||||
cmd := msg.Cmd
|
||||
if cmd == "" {
|
||||
cmd = ptyDefaultCmd
|
||||
}
|
||||
cols := msg.Cols
|
||||
if cols == 0 {
|
||||
cols = ptyDefaultCols
|
||||
}
|
||||
rows := msg.Rows
|
||||
if rows == 0 {
|
||||
rows = ptyDefaultRows
|
||||
}
|
||||
|
||||
tag := newPtyTag()
|
||||
|
||||
stream, err := agent.PtyAttach(ctx, connect.NewRequest(&pb.PtyAttachRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Tag: tag,
|
||||
Cmd: cmd,
|
||||
Args: msg.Args,
|
||||
Cols: cols,
|
||||
Rows: rows,
|
||||
Envs: msg.Envs,
|
||||
Cwd: msg.Cwd,
|
||||
User: msg.User,
|
||||
}))
|
||||
if err != nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "failed to start pty: " + err.Error(), Fatal: true})
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
// Wait for the started event and forward it.
|
||||
if !stream.Receive() {
|
||||
if err := stream.Err(); err != nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "pty stream failed: " + err.Error(), Fatal: true})
|
||||
}
|
||||
return
|
||||
}
|
||||
resp := stream.Msg()
|
||||
started, ok := resp.Event.(*pb.PtyAttachResponse_Started)
|
||||
if !ok {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "expected started event from host agent", Fatal: true})
|
||||
return
|
||||
}
|
||||
ws.writeJSON(wsPtyOut{Type: "started", Tag: started.Started.Tag, PID: started.Started.Pid})
|
||||
|
||||
runPtyLoop(ctx, cancel, ws, stream, agent, sandboxIDStr, tag)
|
||||
}
|
||||
|
||||
func (h *ptyHandler) handleConnect(
|
||||
ctx context.Context,
|
||||
cancel context.CancelFunc,
|
||||
ws *wsWriter,
|
||||
agent hostagentv1connect.HostAgentServiceClient,
|
||||
sandboxIDStr string,
|
||||
msg wsPtyIn,
|
||||
) {
|
||||
if msg.Tag == "" {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "connect requires a 'tag' field", Fatal: true})
|
||||
return
|
||||
}
|
||||
|
||||
stream, err := agent.PtyAttach(ctx, connect.NewRequest(&pb.PtyAttachRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Tag: msg.Tag,
|
||||
}))
|
||||
if err != nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "failed to connect to pty: " + err.Error(), Fatal: true})
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
runPtyLoop(ctx, cancel, ws, stream, agent, sandboxIDStr, msg.Tag)
|
||||
}
|
||||
|
||||
// runPtyLoop drives the bidirectional communication between the WebSocket
|
||||
// and the host agent PTY stream.
|
||||
func runPtyLoop(
|
||||
ctx context.Context,
|
||||
cancel context.CancelFunc,
|
||||
ws *wsWriter,
|
||||
stream *connect.ServerStreamForClient[pb.PtyAttachResponse],
|
||||
agent hostagentv1connect.HostAgentServiceClient,
|
||||
sandboxID string,
|
||||
tag string,
|
||||
) {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Inactivity timer — reset on input/resize, fires kill after timeout.
|
||||
timer := time.NewTimer(ptyInactivityTimeout)
|
||||
defer timer.Stop()
|
||||
|
||||
// Output pump: read from Connect stream, write to WebSocket.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer cancel()
|
||||
|
||||
for stream.Receive() {
|
||||
resp := stream.Msg()
|
||||
switch ev := resp.Event.(type) {
|
||||
case *pb.PtyAttachResponse_Started:
|
||||
// Already handled before the loop for "start" mode.
|
||||
// For "connect" mode this won't appear.
|
||||
ws.writeJSON(wsPtyOut{Type: "started", Tag: ev.Started.Tag, PID: ev.Started.Pid})
|
||||
|
||||
case *pb.PtyAttachResponse_Output:
|
||||
ws.writeJSON(wsPtyOut{
|
||||
Type: "output",
|
||||
Data: base64.StdEncoding.EncodeToString(ev.Output.Data),
|
||||
})
|
||||
|
||||
case *pb.PtyAttachResponse_Exited:
|
||||
exitCode := ev.Exited.ExitCode
|
||||
ws.writeJSON(wsPtyOut{Type: "exit", ExitCode: &exitCode})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil && ctx.Err() == nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: err.Error()})
|
||||
}
|
||||
}()
|
||||
|
||||
// Input pump: read from WebSocket, dispatch to host agent.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
_, raw, err := ws.conn.ReadMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var msg wsPtyIn
|
||||
if json.Unmarshal(raw, &msg) != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Use a background context for unary RPCs so they complete
|
||||
// even if the stream context is being cancelled.
|
||||
rpcCtx, rpcCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
|
||||
switch msg.Type {
|
||||
case "input":
|
||||
data, err := base64.StdEncoding.DecodeString(msg.Data)
|
||||
if err != nil {
|
||||
rpcCancel()
|
||||
continue
|
||||
}
|
||||
if _, err := agent.PtySendInput(rpcCtx, connect.NewRequest(&pb.PtySendInputRequest{
|
||||
SandboxId: sandboxID,
|
||||
Tag: tag,
|
||||
Data: data,
|
||||
})); err != nil {
|
||||
slog.Debug("pty send input error", "error", err)
|
||||
}
|
||||
resetTimer(timer, ptyInactivityTimeout)
|
||||
|
||||
case "resize":
|
||||
cols := msg.Cols
|
||||
rows := msg.Rows
|
||||
if cols > 0 && rows > 0 {
|
||||
if _, err := agent.PtyResize(rpcCtx, connect.NewRequest(&pb.PtyResizeRequest{
|
||||
SandboxId: sandboxID,
|
||||
Tag: tag,
|
||||
Cols: cols,
|
||||
Rows: rows,
|
||||
})); err != nil {
|
||||
slog.Debug("pty resize error", "error", err)
|
||||
}
|
||||
resetTimer(timer, ptyInactivityTimeout)
|
||||
}
|
||||
|
||||
case "kill":
|
||||
if _, err := agent.PtyKill(rpcCtx, connect.NewRequest(&pb.PtyKillRequest{
|
||||
SandboxId: sandboxID,
|
||||
Tag: tag,
|
||||
})); err != nil {
|
||||
slog.Debug("pty kill error", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
rpcCancel()
|
||||
}
|
||||
}()
|
||||
|
||||
// Keepalive pump: send periodic pings to prevent idle WS closure.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ticker := time.NewTicker(ptyKeepaliveInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
ws.writeJSON(wsPtyOut{Type: "ping"})
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Inactivity timeout goroutine.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case <-timer.C:
|
||||
slog.Info("pty session timed out", "sandbox_id", sandboxID, "tag", tag)
|
||||
rpcCtx, rpcCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
if _, err := agent.PtyKill(rpcCtx, connect.NewRequest(&pb.PtyKillRequest{
|
||||
SandboxId: sandboxID,
|
||||
Tag: tag,
|
||||
})); err != nil {
|
||||
slog.Debug("pty timeout kill error", "error", err)
|
||||
}
|
||||
rpcCancel()
|
||||
cancel()
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// newPtyTag returns a PTY session tag: "pty-" + 8 random hex chars.
|
||||
func newPtyTag() string {
|
||||
return "pty-" + id.NewPtyTag()
|
||||
}
|
||||
|
||||
// resetTimer safely resets a timer by stopping it and draining the channel
|
||||
// before resetting, avoiding the race documented in time.Timer.Reset.
|
||||
func resetTimer(t *time.Timer, d time.Duration) {
|
||||
if !t.Stop() {
|
||||
select {
|
||||
case <-t.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
t.Reset(d)
|
||||
}
|
||||
@ -1206,6 +1206,84 @@ paths:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
/v1/sandboxes/{id}/pty:
|
||||
parameters:
|
||||
- name: id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
||||
get:
|
||||
summary: Interactive PTY session via WebSocket
|
||||
operationId: ptySession
|
||||
tags: [sandboxes]
|
||||
security:
|
||||
- apiKeyAuth: []
|
||||
description: |
|
||||
Opens a WebSocket connection for an interactive PTY (terminal) session.
|
||||
Supports creating new sessions, sending input, resizing, killing, and
|
||||
reconnecting to existing sessions.
|
||||
|
||||
**Client sends** (first message — start a new PTY):
|
||||
```json
|
||||
{
|
||||
"type": "start",
|
||||
"cmd": "/bin/bash",
|
||||
"args": [],
|
||||
"cols": 80,
|
||||
"rows": 24,
|
||||
"envs": {"TERM": "xterm-256color"},
|
||||
"cwd": "/home/user",
|
||||
"user": "user"
|
||||
}
|
||||
```
|
||||
All fields except `type` are optional. Defaults: cmd="/bin/bash", cols=80, rows=24.
|
||||
|
||||
**Client sends** (first message — reconnect to existing PTY):
|
||||
```json
|
||||
{"type": "connect", "tag": "pty-abc123de"}
|
||||
```
|
||||
|
||||
**Client sends** (after session is established):
|
||||
```json
|
||||
{"type": "input", "data": "<base64-encoded bytes>"}
|
||||
{"type": "resize", "cols": 120, "rows": 40}
|
||||
{"type": "kill"}
|
||||
```
|
||||
|
||||
**Server sends**:
|
||||
```json
|
||||
{"type": "started", "tag": "pty-abc123de", "pid": 42}
|
||||
{"type": "output", "data": "<base64-encoded PTY bytes>"}
|
||||
{"type": "exit", "exit_code": 0}
|
||||
{"type": "error", "data": "description", "fatal": true}
|
||||
{"type": "ping"}
|
||||
```
|
||||
|
||||
PTY data (input and output) is base64-encoded because it contains raw
|
||||
terminal bytes (escape sequences, control codes) that are not valid UTF-8.
|
||||
|
||||
Sessions have a 120-second inactivity timeout (reset on input/resize).
|
||||
Sessions persist across WebSocket disconnections — the process keeps
|
||||
running in the sandbox. Use the `tag` from the "started" response to
|
||||
reconnect later.
|
||||
responses:
|
||||
"101":
|
||||
description: WebSocket upgrade
|
||||
"404":
|
||||
description: Sandbox not found
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"409":
|
||||
description: Sandbox not running
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
/v1/sandboxes/{id}/files/stream/write:
|
||||
parameters:
|
||||
- name: id
|
||||
|
||||
@ -73,6 +73,7 @@ func New(
|
||||
metricsH := newSandboxMetricsHandler(queries, pool)
|
||||
buildH := newBuildHandler(buildSvc, queries, pool)
|
||||
channelH := newChannelHandler(channelSvc, al)
|
||||
ptyH := newPtyHandler(queries, pool)
|
||||
|
||||
// OpenAPI spec and docs.
|
||||
r.Get("/openapi.yaml", serveOpenAPI)
|
||||
@ -138,6 +139,7 @@ func New(
|
||||
r.Post("/files/mkdir", fsH.MakeDir)
|
||||
r.Post("/files/remove", fsH.Remove)
|
||||
r.Get("/metrics", metricsH.GetMetrics)
|
||||
r.Get("/pty", ptyH.PtySession)
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
220
internal/envdclient/pty.go
Normal file
220
internal/envdclient/pty.go
Normal file
@ -0,0 +1,220 @@
|
||||
package envdclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
envdpb "git.omukk.dev/wrenn/wrenn/proto/envd/gen"
|
||||
)
|
||||
|
||||
// PtyEvent represents a single event from a PTY output stream.
|
||||
type PtyEvent struct {
|
||||
Type string // "started", "output", "end"
|
||||
PID uint32
|
||||
Data []byte
|
||||
ExitCode int32
|
||||
Error string
|
||||
}
|
||||
|
||||
// PtyStart starts a new PTY process in the guest and returns a channel of events.
|
||||
// The tag is the stable identifier used to reconnect via PtyConnect.
|
||||
// The channel is closed when the process ends or ctx is cancelled.
|
||||
// NOTE: The user parameter from PtyAttachRequest is not yet supported by envd's
|
||||
// ProcessConfig proto. When envd adds user support, thread it through here.
|
||||
func (c *Client) PtyStart(ctx context.Context, tag, cmd string, args []string, cols, rows uint32, envs map[string]string, cwd string) (<-chan PtyEvent, error) {
|
||||
stdin := true
|
||||
cfg := &envdpb.ProcessConfig{
|
||||
Cmd: cmd,
|
||||
Args: args,
|
||||
Envs: envs,
|
||||
}
|
||||
if cwd != "" {
|
||||
cfg.Cwd = &cwd
|
||||
}
|
||||
|
||||
req := connect.NewRequest(&envdpb.StartRequest{
|
||||
Process: cfg,
|
||||
Pty: &envdpb.PTY{
|
||||
Size: &envdpb.PTY_Size{
|
||||
Cols: cols,
|
||||
Rows: rows,
|
||||
},
|
||||
},
|
||||
Tag: &tag,
|
||||
Stdin: &stdin,
|
||||
})
|
||||
|
||||
stream, err := c.process.Start(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pty start: %w", err)
|
||||
}
|
||||
|
||||
return drainPtyStream(ctx, &startStream{s: stream}, true), nil
|
||||
}
|
||||
|
||||
// PtyConnect re-attaches to an existing PTY process by tag.
|
||||
// Returns a channel of output events starting from the current point.
|
||||
func (c *Client) PtyConnect(ctx context.Context, tag string) (<-chan PtyEvent, error) {
|
||||
req := connect.NewRequest(&envdpb.ConnectRequest{
|
||||
Process: &envdpb.ProcessSelector{
|
||||
Selector: &envdpb.ProcessSelector_Tag{Tag: tag},
|
||||
},
|
||||
})
|
||||
|
||||
stream, err := c.process.Connect(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pty connect: %w", err)
|
||||
}
|
||||
|
||||
return drainPtyStream(ctx, &connectStream{s: stream}, false), nil
|
||||
}
|
||||
|
||||
// PtySendInput sends raw bytes to the PTY process identified by tag.
|
||||
func (c *Client) PtySendInput(ctx context.Context, tag string, data []byte) error {
|
||||
req := connect.NewRequest(&envdpb.SendInputRequest{
|
||||
Process: &envdpb.ProcessSelector{
|
||||
Selector: &envdpb.ProcessSelector_Tag{Tag: tag},
|
||||
},
|
||||
Input: &envdpb.ProcessInput{
|
||||
Input: &envdpb.ProcessInput_Pty{Pty: data},
|
||||
},
|
||||
})
|
||||
|
||||
if _, err := c.process.SendInput(ctx, req); err != nil {
|
||||
return fmt.Errorf("pty send input: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PtyResize updates the terminal dimensions for the PTY process identified by tag.
|
||||
func (c *Client) PtyResize(ctx context.Context, tag string, cols, rows uint32) error {
|
||||
req := connect.NewRequest(&envdpb.UpdateRequest{
|
||||
Process: &envdpb.ProcessSelector{
|
||||
Selector: &envdpb.ProcessSelector_Tag{Tag: tag},
|
||||
},
|
||||
Pty: &envdpb.PTY{
|
||||
Size: &envdpb.PTY_Size{
|
||||
Cols: cols,
|
||||
Rows: rows,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if _, err := c.process.Update(ctx, req); err != nil {
|
||||
return fmt.Errorf("pty resize: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PtyKill sends SIGKILL to the PTY process identified by tag.
|
||||
func (c *Client) PtyKill(ctx context.Context, tag string) error {
|
||||
req := connect.NewRequest(&envdpb.SendSignalRequest{
|
||||
Process: &envdpb.ProcessSelector{
|
||||
Selector: &envdpb.ProcessSelector_Tag{Tag: tag},
|
||||
},
|
||||
Signal: envdpb.Signal_SIGNAL_SIGKILL,
|
||||
})
|
||||
|
||||
if _, err := c.process.SendSignal(ctx, req); err != nil {
|
||||
return fmt.Errorf("pty kill: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// eventStream is an interface covering both StartResponse and ConnectResponse streams.
|
||||
type eventStream interface {
|
||||
Receive() bool
|
||||
Err() error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type startStream struct {
|
||||
s *connect.ServerStreamForClient[envdpb.StartResponse]
|
||||
}
|
||||
|
||||
func (s *startStream) Receive() bool { return s.s.Receive() }
|
||||
func (s *startStream) Err() error { return s.s.Err() }
|
||||
func (s *startStream) Close() error { return s.s.Close() }
|
||||
func (s *startStream) Event() *envdpb.ProcessEvent {
|
||||
return s.s.Msg().GetEvent()
|
||||
}
|
||||
|
||||
type connectStream struct {
|
||||
s *connect.ServerStreamForClient[envdpb.ConnectResponse]
|
||||
}
|
||||
|
||||
func (s *connectStream) Receive() bool { return s.s.Receive() }
|
||||
func (s *connectStream) Err() error { return s.s.Err() }
|
||||
func (s *connectStream) Close() error { return s.s.Close() }
|
||||
func (s *connectStream) Event() *envdpb.ProcessEvent {
|
||||
return s.s.Msg().GetEvent()
|
||||
}
|
||||
|
||||
type eventProvider interface {
|
||||
eventStream
|
||||
Event() *envdpb.ProcessEvent
|
||||
}
|
||||
|
||||
// drainPtyStream reads events from either a Start or Connect stream and maps
|
||||
// them into PtyEvent values on a channel.
|
||||
func drainPtyStream(ctx context.Context, stream eventProvider, expectStart bool) <-chan PtyEvent {
|
||||
ch := make(chan PtyEvent, 16)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
defer stream.Close()
|
||||
|
||||
for stream.Receive() {
|
||||
event := stream.Event()
|
||||
if event == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var ev PtyEvent
|
||||
switch e := event.GetEvent().(type) {
|
||||
case *envdpb.ProcessEvent_Start:
|
||||
if expectStart {
|
||||
ev = PtyEvent{Type: "started", PID: e.Start.GetPid()}
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
|
||||
case *envdpb.ProcessEvent_Data:
|
||||
switch o := e.Data.GetOutput().(type) {
|
||||
case *envdpb.ProcessEvent_DataEvent_Pty:
|
||||
ev = PtyEvent{Type: "output", Data: o.Pty}
|
||||
case *envdpb.ProcessEvent_DataEvent_Stdout:
|
||||
ev = PtyEvent{Type: "output", Data: o.Stdout}
|
||||
case *envdpb.ProcessEvent_DataEvent_Stderr:
|
||||
ev = PtyEvent{Type: "output", Data: o.Stderr}
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
case *envdpb.ProcessEvent_End:
|
||||
ev = PtyEvent{Type: "end", ExitCode: e.End.GetExitCode()}
|
||||
if e.End.Error != nil {
|
||||
ev.Error = e.End.GetError()
|
||||
}
|
||||
|
||||
case *envdpb.ProcessEvent_Keepalive:
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case ch <- ev:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil && err != io.EOF {
|
||||
slog.Debug("pty stream error", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch
|
||||
}
|
||||
@ -610,6 +610,83 @@ func metricPointsToPB(pts []sandbox.MetricPoint) []*pb.MetricPoint {
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *Server) PtyAttach(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.PtyAttachRequest],
|
||||
stream *connect.ServerStream[pb.PtyAttachResponse],
|
||||
) error {
|
||||
msg := req.Msg
|
||||
|
||||
events, err := s.mgr.PtyAttach(ctx, msg.SandboxId, msg.Tag, msg.Cmd, msg.Args, msg.Cols, msg.Rows, msg.Envs, msg.Cwd)
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("pty attach: %w", err))
|
||||
}
|
||||
|
||||
for ev := range events {
|
||||
var resp pb.PtyAttachResponse
|
||||
switch ev.Type {
|
||||
case "started":
|
||||
resp.Event = &pb.PtyAttachResponse_Started{
|
||||
Started: &pb.PtyStarted{Pid: ev.PID, Tag: msg.Tag},
|
||||
}
|
||||
case "output":
|
||||
resp.Event = &pb.PtyAttachResponse_Output{
|
||||
Output: &pb.PtyOutput{Data: ev.Data},
|
||||
}
|
||||
case "end":
|
||||
resp.Event = &pb.PtyAttachResponse_Exited{
|
||||
Exited: &pb.PtyExited{ExitCode: ev.ExitCode, Error: ev.Error},
|
||||
}
|
||||
default:
|
||||
continue
|
||||
}
|
||||
if err := stream.Send(&resp); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) PtySendInput(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.PtySendInputRequest],
|
||||
) (*connect.Response[pb.PtySendInputResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
if err := s.mgr.PtySendInput(ctx, msg.SandboxId, msg.Tag, msg.Data); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("pty send input: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.PtySendInputResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Server) PtyResize(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.PtyResizeRequest],
|
||||
) (*connect.Response[pb.PtyResizeResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
if err := s.mgr.PtyResize(ctx, msg.SandboxId, msg.Tag, msg.Cols, msg.Rows); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("pty resize: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.PtyResizeResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Server) PtyKill(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.PtyKillRequest],
|
||||
) (*connect.Response[pb.PtyKillResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
if err := s.mgr.PtyKill(ctx, msg.SandboxId, msg.Tag); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("pty kill: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.PtyKillResponse{}), nil
|
||||
}
|
||||
|
||||
// entryInfoToPB maps an envd EntryInfo to a hostagent FileEntry.
|
||||
func entryInfoToPB(e *envdpb.EntryInfo) *pb.FileEntry {
|
||||
if e == nil {
|
||||
|
||||
@ -167,6 +167,11 @@ func UUIDString(id pgtype.UUID) string {
|
||||
return uuid.UUID(id.Bytes).String()
|
||||
}
|
||||
|
||||
// NewPtyTag generates a PTY session tag: 8 random hex characters.
|
||||
func NewPtyTag() string {
|
||||
return hex8()
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func hex8() string {
|
||||
|
||||
@ -1223,6 +1223,70 @@ func (m *Manager) GetClient(sandboxID string) (*envdclient.Client, error) {
|
||||
return sb.client, nil
|
||||
}
|
||||
|
||||
// PtyAttach starts a new PTY process or reconnects to an existing one.
|
||||
// If cmd is non-empty, starts a new process. If empty, reconnects using tag.
|
||||
func (m *Manager) PtyAttach(ctx context.Context, sandboxID, tag, cmd string, args []string, cols, rows uint32, envs map[string]string, cwd string) (<-chan envdclient.PtyEvent, error) {
|
||||
sb, err := m.get(sandboxID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sb.Status != models.StatusRunning {
|
||||
return nil, fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
sb.LastActiveAt = time.Now()
|
||||
m.mu.Unlock()
|
||||
|
||||
if cmd != "" {
|
||||
return sb.client.PtyStart(ctx, tag, cmd, args, cols, rows, envs, cwd)
|
||||
}
|
||||
return sb.client.PtyConnect(ctx, tag)
|
||||
}
|
||||
|
||||
// PtySendInput sends raw bytes to a PTY process in a sandbox.
|
||||
func (m *Manager) PtySendInput(ctx context.Context, sandboxID, tag string, data []byte) error {
|
||||
sb, err := m.get(sandboxID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sb.Status != models.StatusRunning {
|
||||
return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
sb.LastActiveAt = time.Now()
|
||||
m.mu.Unlock()
|
||||
|
||||
return sb.client.PtySendInput(ctx, tag, data)
|
||||
}
|
||||
|
||||
// PtyResize updates the terminal dimensions for a PTY process in a sandbox.
|
||||
func (m *Manager) PtyResize(ctx context.Context, sandboxID, tag string, cols, rows uint32) error {
|
||||
sb, err := m.get(sandboxID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sb.Status != models.StatusRunning {
|
||||
return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
|
||||
}
|
||||
|
||||
return sb.client.PtyResize(ctx, tag, cols, rows)
|
||||
}
|
||||
|
||||
// PtyKill sends SIGKILL to a PTY process in a sandbox.
|
||||
func (m *Manager) PtyKill(ctx context.Context, sandboxID, tag string) error {
|
||||
sb, err := m.get(sandboxID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sb.Status != models.StatusRunning {
|
||||
return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
|
||||
}
|
||||
|
||||
return sb.client.PtyKill(ctx, tag)
|
||||
}
|
||||
|
||||
// AcquireProxyConn atomically looks up a sandbox by ID and registers an
|
||||
// in-flight proxy connection. Returns the sandbox's host-reachable IP, the
|
||||
// connection tracker, and true on success. The caller must call
|
||||
|
||||
Reference in New Issue
Block a user