forked from wrenn/wrenn
Co-authored-by: Tasnim Kabir Sadik <tksadik@omukk.dev> Reviewed-on: wrenn/wrenn#50
537 lines
15 KiB
Go
537 lines
15 KiB
Go
package envdclient
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"net/url"
|
|
"time"
|
|
|
|
"connectrpc.com/connect"
|
|
|
|
"github.com/google/uuid"
|
|
|
|
envdpb "git.omukk.dev/wrenn/wrenn/proto/envd/gen"
|
|
"git.omukk.dev/wrenn/wrenn/proto/envd/gen/genconnect"
|
|
)
|
|
|
|
// Client wraps the Connect RPC client for envd's Process and Filesystem services.
|
|
type Client struct {
|
|
hostIP string
|
|
base string
|
|
healthURL string
|
|
httpClient *http.Client
|
|
streamingClient *http.Client
|
|
|
|
process genconnect.ProcessClient
|
|
filesystem genconnect.FilesystemClient
|
|
}
|
|
|
|
// New creates a new envd client that connects to the given host IP.
|
|
func New(hostIP string) *Client {
|
|
base := baseURL(hostIP)
|
|
httpClient := newHTTPClient()
|
|
streamingClient := newStreamingHTTPClient()
|
|
|
|
return &Client{
|
|
hostIP: hostIP,
|
|
base: base,
|
|
healthURL: base + "/health",
|
|
httpClient: httpClient,
|
|
streamingClient: streamingClient,
|
|
process: genconnect.NewProcessClient(streamingClient, base),
|
|
filesystem: genconnect.NewFilesystemClient(httpClient, base),
|
|
}
|
|
}
|
|
|
|
// CloseIdleConnections closes idle connections on both the unary and streaming
|
|
// transports. Call this before taking a VM snapshot to remove stale TCP state
|
|
// from the guest.
|
|
func (c *Client) CloseIdleConnections() {
|
|
c.httpClient.CloseIdleConnections()
|
|
c.streamingClient.CloseIdleConnections()
|
|
}
|
|
|
|
// BaseURL returns the HTTP base URL for reaching envd.
|
|
func (c *Client) BaseURL() string {
|
|
return c.base
|
|
}
|
|
|
|
// HTTPClient returns the http.Client with a 2-minute request timeout.
|
|
// Suitable for short-lived envd calls (health, init, snapshot/prepare).
|
|
func (c *Client) HTTPClient() *http.Client {
|
|
return c.httpClient
|
|
}
|
|
|
|
// StreamingHTTPClient returns the http.Client without a request timeout.
|
|
// Use for streaming file transfers or any request that may run indefinitely.
|
|
func (c *Client) StreamingHTTPClient() *http.Client {
|
|
return c.streamingClient
|
|
}
|
|
|
|
// ExecResult holds the output of a command execution.
|
|
type ExecResult struct {
|
|
Stdout []byte
|
|
Stderr []byte
|
|
ExitCode int32
|
|
}
|
|
|
|
// ExecOpts holds optional parameters for Exec.
|
|
type ExecOpts struct {
|
|
Envs map[string]string
|
|
Cwd string
|
|
}
|
|
|
|
// Exec runs a command inside the sandbox and collects all stdout/stderr output.
|
|
// It blocks until the command completes.
|
|
func (c *Client) Exec(ctx context.Context, cmd string, args []string, opts *ExecOpts) (*ExecResult, error) {
|
|
stdin := false
|
|
proc := &envdpb.ProcessConfig{
|
|
Cmd: cmd,
|
|
Args: args,
|
|
}
|
|
if opts != nil {
|
|
if len(opts.Envs) > 0 {
|
|
proc.Envs = opts.Envs
|
|
}
|
|
if opts.Cwd != "" {
|
|
proc.Cwd = &opts.Cwd
|
|
}
|
|
}
|
|
req := connect.NewRequest(&envdpb.StartRequest{
|
|
Process: proc,
|
|
Stdin: &stdin,
|
|
})
|
|
|
|
stream, err := c.process.Start(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("start process: %w", err)
|
|
}
|
|
defer stream.Close()
|
|
|
|
result := &ExecResult{}
|
|
|
|
for stream.Receive() {
|
|
msg := stream.Msg()
|
|
if msg.Event == nil {
|
|
continue
|
|
}
|
|
|
|
event := msg.Event.GetEvent()
|
|
switch e := event.(type) {
|
|
case *envdpb.ProcessEvent_Start:
|
|
slog.Debug("process started", "pid", e.Start.GetPid())
|
|
|
|
case *envdpb.ProcessEvent_Data:
|
|
output := e.Data.GetOutput()
|
|
switch o := output.(type) {
|
|
case *envdpb.ProcessEvent_DataEvent_Stdout:
|
|
result.Stdout = append(result.Stdout, o.Stdout...)
|
|
case *envdpb.ProcessEvent_DataEvent_Stderr:
|
|
result.Stderr = append(result.Stderr, o.Stderr...)
|
|
}
|
|
|
|
case *envdpb.ProcessEvent_End:
|
|
result.ExitCode = e.End.GetExitCode()
|
|
if e.End.Error != nil {
|
|
slog.Debug("process ended with error",
|
|
"exit_code", e.End.GetExitCode(),
|
|
"error", e.End.GetError(),
|
|
)
|
|
}
|
|
|
|
case *envdpb.ProcessEvent_Keepalive:
|
|
// Ignore keepalives.
|
|
}
|
|
}
|
|
|
|
if err := stream.Err(); err != nil && err != io.EOF {
|
|
return result, fmt.Errorf("stream error: %w", err)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// ExecStreamEvent represents a single event from a streaming exec.
|
|
type ExecStreamEvent struct {
|
|
Type string // "start", "stdout", "stderr", "end"
|
|
PID uint32
|
|
Data []byte
|
|
ExitCode int32
|
|
Error string
|
|
}
|
|
|
|
// ExecStream runs a command inside the sandbox and returns a channel of output events.
|
|
// The channel is closed when the process ends or the context is cancelled.
|
|
func (c *Client) ExecStream(ctx context.Context, cmd string, args ...string) (<-chan ExecStreamEvent, error) {
|
|
stdin := false
|
|
req := connect.NewRequest(&envdpb.StartRequest{
|
|
Process: &envdpb.ProcessConfig{
|
|
Cmd: cmd,
|
|
Args: args,
|
|
},
|
|
Stdin: &stdin,
|
|
})
|
|
|
|
stream, err := c.process.Start(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("start process: %w", err)
|
|
}
|
|
|
|
ch := make(chan ExecStreamEvent, 256)
|
|
go func() {
|
|
defer close(ch)
|
|
defer stream.Close()
|
|
|
|
for stream.Receive() {
|
|
msg := stream.Msg()
|
|
if msg.Event == nil {
|
|
continue
|
|
}
|
|
|
|
var ev ExecStreamEvent
|
|
event := msg.Event.GetEvent()
|
|
switch e := event.(type) {
|
|
case *envdpb.ProcessEvent_Start:
|
|
ev = ExecStreamEvent{Type: "start", PID: e.Start.GetPid()}
|
|
|
|
case *envdpb.ProcessEvent_Data:
|
|
output := e.Data.GetOutput()
|
|
switch o := output.(type) {
|
|
case *envdpb.ProcessEvent_DataEvent_Stdout:
|
|
ev = ExecStreamEvent{Type: "stdout", Data: o.Stdout}
|
|
case *envdpb.ProcessEvent_DataEvent_Stderr:
|
|
ev = ExecStreamEvent{Type: "stderr", Data: o.Stderr}
|
|
}
|
|
|
|
case *envdpb.ProcessEvent_End:
|
|
ev = ExecStreamEvent{Type: "end", ExitCode: e.End.GetExitCode()}
|
|
if e.End.Error != nil {
|
|
ev.Error = e.End.GetError()
|
|
}
|
|
|
|
case *envdpb.ProcessEvent_Keepalive:
|
|
continue
|
|
}
|
|
|
|
select {
|
|
case ch <- ev:
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
|
|
if err := stream.Err(); err != nil && err != io.EOF {
|
|
slog.Debug("exec stream error", "error", err)
|
|
}
|
|
}()
|
|
|
|
return ch, nil
|
|
}
|
|
|
|
// WriteFile writes content to a file inside the sandbox via envd's REST endpoint.
|
|
// envd expects POST /files?path=...&username=root with multipart/form-data (field name "file").
|
|
func (c *Client) WriteFile(ctx context.Context, path string, content []byte) error {
|
|
var body bytes.Buffer
|
|
writer := multipart.NewWriter(&body)
|
|
|
|
part, err := writer.CreateFormFile("file", "upload")
|
|
if err != nil {
|
|
return fmt.Errorf("create multipart: %w", err)
|
|
}
|
|
if _, err := part.Write(content); err != nil {
|
|
return fmt.Errorf("write multipart: %w", err)
|
|
}
|
|
writer.Close()
|
|
|
|
u := fmt.Sprintf("%s/files?%s", c.base, url.Values{
|
|
"path": {path},
|
|
"username": {"root"},
|
|
}.Encode())
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u, &body)
|
|
if err != nil {
|
|
return fmt.Errorf("create request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
|
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("write file: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("write file %s: status %d: %s", path, resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
slog.Debug("envd write file", "path", path, "status", resp.StatusCode, "response", string(respBody))
|
|
return nil
|
|
}
|
|
|
|
// ReadFile reads a file from inside the sandbox via envd's REST endpoint.
|
|
// envd expects GET /files?path=...&username=root.
|
|
func (c *Client) ReadFile(ctx context.Context, path string) ([]byte, error) {
|
|
u := fmt.Sprintf("%s/files?%s", c.base, url.Values{
|
|
"path": {path},
|
|
"username": {"root"},
|
|
}.Encode())
|
|
|
|
slog.Debug("envd read file", "url", u, "path", path)
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create request: %w", err)
|
|
}
|
|
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read file: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
return nil, fmt.Errorf("read file %s: status %d: %s", path, resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
data, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read file body: %w", err)
|
|
}
|
|
|
|
return data, nil
|
|
}
|
|
|
|
// PrepareSnapshot calls envd's POST /snapshot/prepare endpoint, which stops
|
|
// the port scanner/forwarder and marks active connections for post-restore
|
|
// cleanup before the VMM freezes vCPUs.
|
|
//
|
|
// Best-effort: the caller should log a warning on error but not abort the pause.
|
|
func (c *Client) PrepareSnapshot(ctx context.Context) error {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base+"/snapshot/prepare", nil)
|
|
if err != nil {
|
|
return fmt.Errorf("create request: %w", err)
|
|
}
|
|
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("prepare snapshot: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusNoContent {
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
return fmt.Errorf("prepare snapshot: status %d: %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// MemoryPreloadStatus mirrors envd's /memory/preload response.
|
|
//
|
|
// State values: "idle", "running", "done", "failed", "cancelled".
|
|
type MemoryPreloadStatus struct {
|
|
State string `json:"state"`
|
|
Regions uint64 `json:"regions"`
|
|
Pages uint64 `json:"pages"`
|
|
Bytes uint64 `json:"bytes"`
|
|
ElapsedSec float64 `json:"elapsed_sec"`
|
|
Source string `json:"source"`
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
|
|
// StartMemoryPreload posts to envd's /memory/preload to spawn a guest-side
|
|
// loader that reads every physical RAM page. The request returns immediately
|
|
// after the loader is queued — the actual materialisation runs in a detached
|
|
// thread inside envd. Required after a snapshot restore with
|
|
// memory_restore_mode=ondemand so the next ch.snapshot writes a
|
|
// self-contained memory-ranges file.
|
|
//
|
|
// Use WaitMemoryPreload to block on completion or GetMemoryPreloadStatus to
|
|
// query progress.
|
|
func (c *Client) StartMemoryPreload(ctx context.Context) (MemoryPreloadStatus, error) {
|
|
return c.memoryPreloadRequest(ctx, http.MethodPost)
|
|
}
|
|
|
|
// GetMemoryPreloadStatus reads envd's /memory/preload status without
|
|
// starting a new loader.
|
|
func (c *Client) GetMemoryPreloadStatus(ctx context.Context) (MemoryPreloadStatus, error) {
|
|
return c.memoryPreloadRequest(ctx, http.MethodGet)
|
|
}
|
|
|
|
func (c *Client) memoryPreloadRequest(ctx context.Context, method string) (MemoryPreloadStatus, error) {
|
|
var status MemoryPreloadStatus
|
|
req, err := http.NewRequestWithContext(ctx, method, c.base+"/memory/preload", nil)
|
|
if err != nil {
|
|
return status, fmt.Errorf("create request: %w", err)
|
|
}
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return status, fmt.Errorf("memory preload %s: %w", method, err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return status, fmt.Errorf("memory preload %s: status %d: %s", method, resp.StatusCode, string(body))
|
|
}
|
|
if err := json.NewDecoder(resp.Body).Decode(&status); err != nil {
|
|
return status, fmt.Errorf("memory preload %s: decode: %w", method, err)
|
|
}
|
|
return status, nil
|
|
}
|
|
|
|
// WaitMemoryPreload polls envd until the loader is no longer running or ctx
|
|
// is cancelled. Returns the final status. Polling interval is fixed at 1s —
|
|
// the loader runs for many seconds to minutes, so finer polling wastes RPCs.
|
|
func (c *Client) WaitMemoryPreload(ctx context.Context) (MemoryPreloadStatus, error) {
|
|
ticker := time.NewTicker(1 * time.Second)
|
|
defer ticker.Stop()
|
|
for {
|
|
status, err := c.GetMemoryPreloadStatus(ctx)
|
|
if err != nil {
|
|
return status, err
|
|
}
|
|
if status.State != "running" {
|
|
return status, nil
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return status, ctx.Err()
|
|
case <-ticker.C:
|
|
}
|
|
}
|
|
}
|
|
|
|
// CancelMemoryPreload signals the in-guest memory preloader to stop early.
|
|
// Used during teardown so a pause/destroy doesn't have to wait for a
|
|
// multi-hundred-MiB read to finish.
|
|
func (c *Client) CancelMemoryPreload(ctx context.Context) error {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base+"/memory/preload/cancel", nil)
|
|
if err != nil {
|
|
return fmt.Errorf("create request: %w", err)
|
|
}
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("preload cancel: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusNoContent {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return fmt.Errorf("preload cancel: status %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// PostInit calls envd's POST /init endpoint to trigger post-boot or
|
|
// post-restore initialization. sandbox_id and template_id are passed
|
|
// so envd can set WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID env vars.
|
|
func (c *Client) PostInit(ctx context.Context) error {
|
|
return c.PostInitWithDefaults(ctx, "", nil, "", "")
|
|
}
|
|
|
|
// PostInitWithDefaults calls envd's POST /init endpoint with optional default
|
|
// user, environment variables, and sandbox metadata. These are applied to
|
|
// envd's defaults so all subsequent process executions use them.
|
|
//
|
|
// timestamp and lifecycle_id are always populated: envd uses them to snap
|
|
// the guest clock to the host's wall time and to detect post-resume calls
|
|
// (which trigger port-forwarder restart + NFS remount).
|
|
func (c *Client) PostInitWithDefaults(ctx context.Context, defaultUser string, envVars map[string]string, sandboxID, templateID string) error {
|
|
payload := map[string]any{
|
|
"timestamp": time.Now().UTC().Format(time.RFC3339Nano),
|
|
"lifecycle_id": uuid.NewString(),
|
|
}
|
|
if defaultUser != "" {
|
|
payload["defaultUser"] = defaultUser
|
|
}
|
|
if len(envVars) > 0 {
|
|
payload["envVars"] = envVars
|
|
}
|
|
if sandboxID != "" {
|
|
payload["sandbox_id"] = sandboxID
|
|
}
|
|
if templateID != "" {
|
|
payload["template_id"] = templateID
|
|
}
|
|
|
|
var body io.Reader
|
|
if len(payload) > 0 {
|
|
data, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal init body: %w", err)
|
|
}
|
|
body = bytes.NewReader(data)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base+"/init", body)
|
|
if err != nil {
|
|
return fmt.Errorf("create request: %w", err)
|
|
}
|
|
if body != nil {
|
|
req.Header.Set("Content-Type", "application/json")
|
|
}
|
|
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("post init: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusNoContent {
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
return fmt.Errorf("post init: status %d: %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ListDir lists directory contents inside the sandbox.
|
|
func (c *Client) ListDir(ctx context.Context, path string, depth uint32) (*envdpb.ListDirResponse, error) {
|
|
req := connect.NewRequest(&envdpb.ListDirRequest{
|
|
Path: path,
|
|
Depth: depth,
|
|
})
|
|
|
|
resp, err := c.filesystem.ListDir(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list dir: %w", err)
|
|
}
|
|
|
|
return resp.Msg, nil
|
|
}
|
|
|
|
// MakeDir creates a directory inside the sandbox.
|
|
func (c *Client) MakeDir(ctx context.Context, path string) (*envdpb.MakeDirResponse, error) {
|
|
req := connect.NewRequest(&envdpb.MakeDirRequest{
|
|
Path: path,
|
|
})
|
|
|
|
resp, err := c.filesystem.MakeDir(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("make dir: %w", err)
|
|
}
|
|
|
|
return resp.Msg, nil
|
|
}
|
|
|
|
// Remove removes a file or directory inside the sandbox.
|
|
func (c *Client) Remove(ctx context.Context, path string) error {
|
|
req := connect.NewRequest(&envdpb.RemoveRequest{
|
|
Path: path,
|
|
})
|
|
|
|
if _, err := c.filesystem.Remove(ctx, req); err != nil {
|
|
return fmt.Errorf("remove: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|