Files
sandbox/internal/hostagent/server.go
pptx704 477d4f8cf6 Add auto-pause TTL and ping endpoint for sandbox inactivity management
Replace the existing auto-destroy TTL behavior with auto-pause: when a
sandbox exceeds its timeout_sec of inactivity, the TTL reaper now pauses
it (snapshot + teardown) instead of destroying it, preserving the ability
to resume later.

Key changes:
- TTL reaper calls Pause instead of Destroy, with fallback to Destroy if
  pause fails (e.g. Firecracker process already gone)
- New PingSandbox RPC resets the in-memory LastActiveAt timer
- New POST /v1/sandboxes/{id}/ping REST endpoint resets both agent memory
  and DB last_active_at
- ListSandboxes RPC now includes auto_paused_sandbox_ids so the reconciler
  can distinguish auto-paused sandboxes from crashed ones in a single call
- Reconciler polls every 5s (was 30s) and marks auto-paused as "paused"
  vs orphaned as "stopped"
- Resume RPC accepts timeout_sec from DB so TTL survives pause/resume cycles
- Reaper checks every 2s (was 10s) and uses a detached context to avoid
  incomplete pauses on app shutdown
- Default timeout_sec changed from 300 to 0 (no auto-pause unless requested)
2026-03-15 05:15:18 +06:00

415 lines
11 KiB
Go

package hostagent
import (
"context"
"fmt"
"io"
"log/slog"
"mime/multipart"
"net/http"
"net/url"
"strings"
"time"
"connectrpc.com/connect"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
"git.omukk.dev/wrenn/sandbox/internal/sandbox"
)
// Server implements the HostAgentService Connect RPC handler.
type Server struct {
hostagentv1connect.UnimplementedHostAgentServiceHandler
mgr *sandbox.Manager
}
// NewServer creates a new host agent RPC server.
func NewServer(mgr *sandbox.Manager) *Server {
return &Server{mgr: mgr}
}
func (s *Server) CreateSandbox(
ctx context.Context,
req *connect.Request[pb.CreateSandboxRequest],
) (*connect.Response[pb.CreateSandboxResponse], error) {
msg := req.Msg
sb, err := s.mgr.Create(ctx, msg.SandboxId, msg.Template, int(msg.Vcpus), int(msg.MemoryMb), int(msg.TimeoutSec))
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create sandbox: %w", err))
}
return connect.NewResponse(&pb.CreateSandboxResponse{
SandboxId: sb.ID,
Status: string(sb.Status),
HostIp: sb.HostIP.String(),
}), nil
}
func (s *Server) DestroySandbox(
ctx context.Context,
req *connect.Request[pb.DestroySandboxRequest],
) (*connect.Response[pb.DestroySandboxResponse], error) {
if err := s.mgr.Destroy(ctx, req.Msg.SandboxId); err != nil {
return nil, connect.NewError(connect.CodeNotFound, err)
}
return connect.NewResponse(&pb.DestroySandboxResponse{}), nil
}
func (s *Server) PauseSandbox(
ctx context.Context,
req *connect.Request[pb.PauseSandboxRequest],
) (*connect.Response[pb.PauseSandboxResponse], error) {
if err := s.mgr.Pause(ctx, req.Msg.SandboxId); err != nil {
return nil, connect.NewError(connect.CodeInternal, err)
}
return connect.NewResponse(&pb.PauseSandboxResponse{}), nil
}
func (s *Server) ResumeSandbox(
ctx context.Context,
req *connect.Request[pb.ResumeSandboxRequest],
) (*connect.Response[pb.ResumeSandboxResponse], error) {
sb, err := s.mgr.Resume(ctx, req.Msg.SandboxId, int(req.Msg.TimeoutSec))
if err != nil {
return nil, connect.NewError(connect.CodeInternal, err)
}
return connect.NewResponse(&pb.ResumeSandboxResponse{
SandboxId: sb.ID,
Status: string(sb.Status),
HostIp: sb.HostIP.String(),
}), nil
}
func (s *Server) CreateSnapshot(
ctx context.Context,
req *connect.Request[pb.CreateSnapshotRequest],
) (*connect.Response[pb.CreateSnapshotResponse], error) {
sizeBytes, err := s.mgr.CreateSnapshot(ctx, req.Msg.SandboxId, req.Msg.Name)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create snapshot: %w", err))
}
return connect.NewResponse(&pb.CreateSnapshotResponse{
Name: req.Msg.Name,
SizeBytes: sizeBytes,
}), nil
}
func (s *Server) DeleteSnapshot(
ctx context.Context,
req *connect.Request[pb.DeleteSnapshotRequest],
) (*connect.Response[pb.DeleteSnapshotResponse], error) {
if err := s.mgr.DeleteSnapshot(req.Msg.Name); err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("delete snapshot: %w", err))
}
return connect.NewResponse(&pb.DeleteSnapshotResponse{}), nil
}
func (s *Server) PingSandbox(
ctx context.Context,
req *connect.Request[pb.PingSandboxRequest],
) (*connect.Response[pb.PingSandboxResponse], error) {
if err := s.mgr.Ping(req.Msg.SandboxId); err != nil {
if strings.Contains(err.Error(), "not found") {
return nil, connect.NewError(connect.CodeNotFound, err)
}
return nil, connect.NewError(connect.CodeFailedPrecondition, err)
}
return connect.NewResponse(&pb.PingSandboxResponse{}), nil
}
func (s *Server) Exec(
ctx context.Context,
req *connect.Request[pb.ExecRequest],
) (*connect.Response[pb.ExecResponse], error) {
msg := req.Msg
timeout := 30 * time.Second
if msg.TimeoutSec > 0 {
timeout = time.Duration(msg.TimeoutSec) * time.Second
}
execCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
result, err := s.mgr.Exec(execCtx, msg.SandboxId, msg.Cmd, msg.Args...)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("exec: %w", err))
}
return connect.NewResponse(&pb.ExecResponse{
Stdout: result.Stdout,
Stderr: result.Stderr,
ExitCode: result.ExitCode,
}), nil
}
func (s *Server) WriteFile(
ctx context.Context,
req *connect.Request[pb.WriteFileRequest],
) (*connect.Response[pb.WriteFileResponse], error) {
msg := req.Msg
client, err := s.mgr.GetClient(msg.SandboxId)
if err != nil {
return nil, connect.NewError(connect.CodeNotFound, err)
}
if err := client.WriteFile(ctx, msg.Path, msg.Content); err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("write file: %w", err))
}
return connect.NewResponse(&pb.WriteFileResponse{}), nil
}
func (s *Server) ReadFile(
ctx context.Context,
req *connect.Request[pb.ReadFileRequest],
) (*connect.Response[pb.ReadFileResponse], error) {
msg := req.Msg
client, err := s.mgr.GetClient(msg.SandboxId)
if err != nil {
return nil, connect.NewError(connect.CodeNotFound, err)
}
content, err := client.ReadFile(ctx, msg.Path)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("read file: %w", err))
}
return connect.NewResponse(&pb.ReadFileResponse{Content: content}), nil
}
func (s *Server) ExecStream(
ctx context.Context,
req *connect.Request[pb.ExecStreamRequest],
stream *connect.ServerStream[pb.ExecStreamResponse],
) error {
msg := req.Msg
// Only apply a timeout if explicitly requested; streaming execs may be long-running.
execCtx := ctx
if msg.TimeoutSec > 0 {
var cancel context.CancelFunc
execCtx, cancel = context.WithTimeout(ctx, time.Duration(msg.TimeoutSec)*time.Second)
defer cancel()
}
events, err := s.mgr.ExecStream(execCtx, msg.SandboxId, msg.Cmd, msg.Args...)
if err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("exec stream: %w", err))
}
for ev := range events {
var resp pb.ExecStreamResponse
switch ev.Type {
case "start":
resp.Event = &pb.ExecStreamResponse_Start{
Start: &pb.ExecStreamStart{Pid: ev.PID},
}
case "stdout":
resp.Event = &pb.ExecStreamResponse_Data{
Data: &pb.ExecStreamData{
Output: &pb.ExecStreamData_Stdout{Stdout: ev.Data},
},
}
case "stderr":
resp.Event = &pb.ExecStreamResponse_Data{
Data: &pb.ExecStreamData{
Output: &pb.ExecStreamData_Stderr{Stderr: ev.Data},
},
}
case "end":
resp.Event = &pb.ExecStreamResponse_End{
End: &pb.ExecStreamEnd{
ExitCode: ev.ExitCode,
Error: ev.Error,
},
}
}
if err := stream.Send(&resp); err != nil {
return err
}
}
return nil
}
func (s *Server) WriteFileStream(
ctx context.Context,
stream *connect.ClientStream[pb.WriteFileStreamRequest],
) (*connect.Response[pb.WriteFileStreamResponse], error) {
// First message must contain metadata.
if !stream.Receive() {
if err := stream.Err(); err != nil {
return nil, connect.NewError(connect.CodeInternal, err)
}
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("empty stream"))
}
first := stream.Msg()
meta := first.GetMeta()
if meta == nil {
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("first message must contain metadata"))
}
client, err := s.mgr.GetClient(meta.SandboxId)
if err != nil {
return nil, connect.NewError(connect.CodeNotFound, err)
}
// Use io.Pipe to stream chunks into a multipart body for envd's REST endpoint.
pr, pw := io.Pipe()
mpWriter := multipart.NewWriter(pw)
// Write multipart data in a goroutine.
errCh := make(chan error, 1)
go func() {
defer pw.Close()
part, err := mpWriter.CreateFormFile("file", "upload")
if err != nil {
errCh <- fmt.Errorf("create multipart: %w", err)
return
}
for stream.Receive() {
chunk := stream.Msg().GetChunk()
if len(chunk) == 0 {
continue
}
if _, err := part.Write(chunk); err != nil {
errCh <- fmt.Errorf("write chunk: %w", err)
return
}
}
if err := stream.Err(); err != nil {
errCh <- err
return
}
mpWriter.Close()
errCh <- nil
}()
// Send the streaming multipart body to envd.
base := client.BaseURL()
u := fmt.Sprintf("%s/files?%s", base, url.Values{
"path": {meta.Path},
"username": {"root"},
}.Encode())
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, u, pr)
if err != nil {
pw.CloseWithError(err)
<-errCh
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create request: %w", err))
}
httpReq.Header.Set("Content-Type", mpWriter.FormDataContentType())
resp, err := http.DefaultClient.Do(httpReq)
if err != nil {
pw.CloseWithError(err)
<-errCh
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("write file stream: %w", err))
}
defer resp.Body.Close()
// Wait for the writer goroutine.
if writerErr := <-errCh; writerErr != nil {
return nil, connect.NewError(connect.CodeInternal, writerErr)
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
body, _ := io.ReadAll(resp.Body)
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("envd write: status %d: %s", resp.StatusCode, string(body)))
}
slog.Debug("streaming file write complete", "sandbox_id", meta.SandboxId, "path", meta.Path)
return connect.NewResponse(&pb.WriteFileStreamResponse{}), nil
}
func (s *Server) ReadFileStream(
ctx context.Context,
req *connect.Request[pb.ReadFileStreamRequest],
stream *connect.ServerStream[pb.ReadFileStreamResponse],
) error {
msg := req.Msg
client, err := s.mgr.GetClient(msg.SandboxId)
if err != nil {
return connect.NewError(connect.CodeNotFound, err)
}
base := client.BaseURL()
u := fmt.Sprintf("%s/files?%s", base, url.Values{
"path": {msg.Path},
"username": {"root"},
}.Encode())
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("create request: %w", err))
}
resp, err := http.DefaultClient.Do(httpReq)
if err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("read file stream: %w", err))
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return connect.NewError(connect.CodeInternal, fmt.Errorf("envd read: status %d: %s", resp.StatusCode, string(body)))
}
// Stream file content in 64KB chunks.
buf := make([]byte, 64*1024)
for {
n, err := resp.Body.Read(buf)
if n > 0 {
chunk := make([]byte, n)
copy(chunk, buf[:n])
if sendErr := stream.Send(&pb.ReadFileStreamResponse{Chunk: chunk}); sendErr != nil {
return sendErr
}
}
if err == io.EOF {
break
}
if err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("read body: %w", err))
}
}
return nil
}
func (s *Server) ListSandboxes(
ctx context.Context,
req *connect.Request[pb.ListSandboxesRequest],
) (*connect.Response[pb.ListSandboxesResponse], error) {
sandboxes := s.mgr.List()
infos := make([]*pb.SandboxInfo, len(sandboxes))
for i, sb := range sandboxes {
infos[i] = &pb.SandboxInfo{
SandboxId: sb.ID,
Status: string(sb.Status),
Template: sb.Template,
Vcpus: int32(sb.VCPUs),
MemoryMb: int32(sb.MemoryMB),
HostIp: sb.HostIP.String(),
CreatedAtUnix: sb.CreatedAt.Unix(),
LastActiveAtUnix: sb.LastActiveAt.Unix(),
TimeoutSec: int32(sb.TimeoutSec),
}
}
return connect.NewResponse(&pb.ListSandboxesResponse{
Sandboxes: infos,
AutoPausedSandboxIds: s.mgr.DrainAutoPausedIDs(),
}), nil
}