forked from wrenn/wrenn
v0.1.0 (#17)
This commit is contained in:
@ -15,6 +15,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
envdpb "git.omukk.dev/wrenn/wrenn/proto/envd/gen"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect"
|
||||
|
||||
@ -68,10 +69,18 @@ func (s *Server) CreateSandbox(
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create sandbox: %w", err))
|
||||
}
|
||||
|
||||
// Apply template defaults (user, env vars) if provided.
|
||||
if msg.DefaultUser != "" || len(msg.DefaultEnv) > 0 {
|
||||
if err := s.mgr.SetDefaults(ctx, sb.ID, msg.DefaultUser, msg.DefaultEnv); err != nil {
|
||||
slog.Warn("failed to set sandbox defaults", "sandbox", sb.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.CreateSandboxResponse{
|
||||
SandboxId: sb.ID,
|
||||
Status: string(sb.Status),
|
||||
HostIp: sb.HostIP.String(),
|
||||
Metadata: sb.Metadata,
|
||||
}), nil
|
||||
}
|
||||
|
||||
@ -99,14 +108,24 @@ func (s *Server) ResumeSandbox(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.ResumeSandboxRequest],
|
||||
) (*connect.Response[pb.ResumeSandboxResponse], error) {
|
||||
sb, err := s.mgr.Resume(ctx, req.Msg.SandboxId, int(req.Msg.TimeoutSec))
|
||||
msg := req.Msg
|
||||
sb, err := s.mgr.Resume(ctx, msg.SandboxId, int(msg.TimeoutSec), msg.KernelVersion)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, err)
|
||||
}
|
||||
|
||||
// Apply template defaults (user, env vars) if provided.
|
||||
if msg.DefaultUser != "" || len(msg.DefaultEnv) > 0 {
|
||||
if err := s.mgr.SetDefaults(ctx, sb.ID, msg.DefaultUser, msg.DefaultEnv); err != nil {
|
||||
slog.Warn("failed to set sandbox defaults on resume", "sandbox", sb.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.ResumeSandboxResponse{
|
||||
SandboxId: sb.ID,
|
||||
Status: string(sb.Status),
|
||||
HostIp: sb.HostIP.String(),
|
||||
Metadata: sb.Metadata,
|
||||
}), nil
|
||||
}
|
||||
|
||||
@ -252,6 +271,69 @@ func (s *Server) ReadFile(
|
||||
return connect.NewResponse(&pb.ReadFileResponse{Content: content}), nil
|
||||
}
|
||||
|
||||
func (s *Server) ListDir(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.ListDirRequest],
|
||||
) (*connect.Response[pb.ListDirResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
client, err := s.mgr.GetClient(msg.SandboxId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
|
||||
resp, err := client.ListDir(ctx, msg.Path, msg.Depth)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("list dir: %w", err))
|
||||
}
|
||||
|
||||
entries := make([]*pb.FileEntry, 0, len(resp.Entries))
|
||||
for _, e := range resp.Entries {
|
||||
entries = append(entries, entryInfoToPB(e))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.ListDirResponse{Entries: entries}), nil
|
||||
}
|
||||
|
||||
func (s *Server) MakeDir(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.MakeDirRequest],
|
||||
) (*connect.Response[pb.MakeDirResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
client, err := s.mgr.GetClient(msg.SandboxId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
|
||||
resp, err := client.MakeDir(ctx, msg.Path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("make dir: %w", err)
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.MakeDirResponse{
|
||||
Entry: entryInfoToPB(resp.Entry),
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s *Server) RemovePath(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.RemovePathRequest],
|
||||
) (*connect.Response[pb.RemovePathResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
client, err := s.mgr.GetClient(msg.SandboxId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
|
||||
if err := client.Remove(ctx, msg.Path); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("remove: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.RemovePathResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Server) ExecStream(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.ExecStreamRequest],
|
||||
@ -436,6 +518,16 @@ func (s *Server) ReadFileStream(
|
||||
// Stream file content in 64KB chunks.
|
||||
buf := make([]byte, 64*1024)
|
||||
for {
|
||||
// Bail out early if the client disconnected or the context was cancelled.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return connect.NewError(connect.CodeDeadlineExceeded, ctx.Err())
|
||||
}
|
||||
return connect.NewError(connect.CodeCanceled, ctx.Err())
|
||||
default:
|
||||
}
|
||||
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
chunk := make([]byte, n)
|
||||
@ -474,6 +566,7 @@ func (s *Server) ListSandboxes(
|
||||
CreatedAtUnix: sb.CreatedAt.Unix(),
|
||||
LastActiveAtUnix: sb.LastActiveAt.Unix(),
|
||||
TimeoutSec: int32(sb.TimeoutSec),
|
||||
Metadata: sb.Metadata,
|
||||
}
|
||||
}
|
||||
|
||||
@ -545,3 +638,269 @@ func metricPointsToPB(pts []sandbox.MetricPoint) []*pb.MetricPoint {
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *Server) PtyAttach(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.PtyAttachRequest],
|
||||
stream *connect.ServerStream[pb.PtyAttachResponse],
|
||||
) error {
|
||||
msg := req.Msg
|
||||
|
||||
events, err := s.mgr.PtyAttach(ctx, msg.SandboxId, msg.Tag, msg.Cmd, msg.Args, msg.Cols, msg.Rows, msg.Envs, msg.Cwd)
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("pty attach: %w", err))
|
||||
}
|
||||
|
||||
for ev := range events {
|
||||
var resp pb.PtyAttachResponse
|
||||
switch ev.Type {
|
||||
case "started":
|
||||
resp.Event = &pb.PtyAttachResponse_Started{
|
||||
Started: &pb.PtyStarted{Pid: ev.PID, Tag: msg.Tag},
|
||||
}
|
||||
case "output":
|
||||
resp.Event = &pb.PtyAttachResponse_Output{
|
||||
Output: &pb.PtyOutput{Data: ev.Data},
|
||||
}
|
||||
case "end":
|
||||
resp.Event = &pb.PtyAttachResponse_Exited{
|
||||
Exited: &pb.PtyExited{ExitCode: ev.ExitCode, Error: ev.Error},
|
||||
}
|
||||
default:
|
||||
continue
|
||||
}
|
||||
if err := stream.Send(&resp); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) PtySendInput(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.PtySendInputRequest],
|
||||
) (*connect.Response[pb.PtySendInputResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
if err := s.mgr.PtySendInput(ctx, msg.SandboxId, msg.Tag, msg.Data); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("pty send input: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.PtySendInputResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Server) PtyResize(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.PtyResizeRequest],
|
||||
) (*connect.Response[pb.PtyResizeResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
if err := s.mgr.PtyResize(ctx, msg.SandboxId, msg.Tag, msg.Cols, msg.Rows); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("pty resize: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.PtyResizeResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Server) PtyKill(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.PtyKillRequest],
|
||||
) (*connect.Response[pb.PtyKillResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
if err := s.mgr.PtyKill(ctx, msg.SandboxId, msg.Tag); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("pty kill: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.PtyKillResponse{}), nil
|
||||
}
|
||||
|
||||
// entryInfoToPB maps an envd EntryInfo to a hostagent FileEntry.
|
||||
func entryInfoToPB(e *envdpb.EntryInfo) *pb.FileEntry {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var fileType string
|
||||
switch e.Type {
|
||||
case envdpb.FileType_FILE_TYPE_FILE:
|
||||
fileType = "file"
|
||||
case envdpb.FileType_FILE_TYPE_DIRECTORY:
|
||||
fileType = "directory"
|
||||
case envdpb.FileType_FILE_TYPE_SYMLINK:
|
||||
fileType = "symlink"
|
||||
default:
|
||||
fileType = "unknown"
|
||||
}
|
||||
|
||||
entry := &pb.FileEntry{
|
||||
Name: e.Name,
|
||||
Path: e.Path,
|
||||
Type: fileType,
|
||||
Size: e.Size,
|
||||
Mode: e.Mode,
|
||||
Permissions: e.Permissions,
|
||||
Owner: e.Owner,
|
||||
Group: e.Group,
|
||||
}
|
||||
|
||||
if e.ModifiedTime != nil {
|
||||
entry.ModifiedAt = e.ModifiedTime.GetSeconds()
|
||||
}
|
||||
|
||||
if e.SymlinkTarget != nil {
|
||||
entry.SymlinkTarget = e.SymlinkTarget
|
||||
}
|
||||
|
||||
return entry
|
||||
}
|
||||
|
||||
// ── Background Processes ────────────────────────────────────────────
|
||||
|
||||
func (s *Server) StartBackground(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.StartBackgroundRequest],
|
||||
) (*connect.Response[pb.StartBackgroundResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
pid, err := s.mgr.StartBackground(ctx, msg.SandboxId, msg.Tag, msg.Cmd, msg.Args, msg.Envs, msg.Cwd)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("start background: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.StartBackgroundResponse{
|
||||
Pid: pid,
|
||||
Tag: msg.Tag,
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s *Server) ListProcesses(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.ListProcessesRequest],
|
||||
) (*connect.Response[pb.ListProcessesResponse], error) {
|
||||
procs, err := s.mgr.ListProcesses(ctx, req.Msg.SandboxId)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("list processes: %w", err))
|
||||
}
|
||||
|
||||
entries := make([]*pb.ProcessEntry, 0, len(procs))
|
||||
for _, p := range procs {
|
||||
entries = append(entries, &pb.ProcessEntry{
|
||||
Pid: p.PID,
|
||||
Tag: p.Tag,
|
||||
Cmd: p.Cmd,
|
||||
Args: p.Args,
|
||||
})
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.ListProcessesResponse{
|
||||
Processes: entries,
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s *Server) KillProcess(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.KillProcessRequest],
|
||||
) (*connect.Response[pb.KillProcessResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
// Resolve PID/tag selector.
|
||||
var pid uint32
|
||||
var tag string
|
||||
switch sel := msg.Selector.(type) {
|
||||
case *pb.KillProcessRequest_Pid:
|
||||
pid = sel.Pid
|
||||
case *pb.KillProcessRequest_Tag:
|
||||
tag = sel.Tag
|
||||
default:
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("pid or tag is required"))
|
||||
}
|
||||
|
||||
// Map signal string to envd enum.
|
||||
var signal envdpb.Signal
|
||||
switch msg.Signal {
|
||||
case "", "SIGKILL":
|
||||
signal = envdpb.Signal_SIGNAL_SIGKILL
|
||||
case "SIGTERM":
|
||||
signal = envdpb.Signal_SIGNAL_SIGTERM
|
||||
default:
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("unsupported signal: %s (use SIGKILL or SIGTERM)", msg.Signal))
|
||||
}
|
||||
|
||||
if err := s.mgr.KillProcess(ctx, msg.SandboxId, pid, tag, signal); err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("kill process: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.KillProcessResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Server) ConnectProcess(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.ConnectProcessRequest],
|
||||
stream *connect.ServerStream[pb.ConnectProcessResponse],
|
||||
) error {
|
||||
msg := req.Msg
|
||||
|
||||
var pid uint32
|
||||
var tag string
|
||||
switch sel := msg.Selector.(type) {
|
||||
case *pb.ConnectProcessRequest_Pid:
|
||||
pid = sel.Pid
|
||||
case *pb.ConnectProcessRequest_Tag:
|
||||
tag = sel.Tag
|
||||
default:
|
||||
return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("pid or tag is required"))
|
||||
}
|
||||
|
||||
events, err := s.mgr.ConnectProcess(ctx, msg.SandboxId, pid, tag)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("connect process: %w", err))
|
||||
}
|
||||
|
||||
for ev := range events {
|
||||
var resp pb.ConnectProcessResponse
|
||||
switch ev.Type {
|
||||
case "start":
|
||||
resp.Event = &pb.ConnectProcessResponse_Start{
|
||||
Start: &pb.ExecStreamStart{Pid: ev.PID},
|
||||
}
|
||||
case "stdout":
|
||||
resp.Event = &pb.ConnectProcessResponse_Data{
|
||||
Data: &pb.ExecStreamData{
|
||||
Output: &pb.ExecStreamData_Stdout{Stdout: ev.Data},
|
||||
},
|
||||
}
|
||||
case "stderr":
|
||||
resp.Event = &pb.ConnectProcessResponse_Data{
|
||||
Data: &pb.ExecStreamData{
|
||||
Output: &pb.ExecStreamData_Stderr{Stderr: ev.Data},
|
||||
},
|
||||
}
|
||||
case "end":
|
||||
resp.Event = &pb.ConnectProcessResponse_End{
|
||||
End: &pb.ExecStreamEnd{
|
||||
ExitCode: ev.ExitCode,
|
||||
Error: ev.Error,
|
||||
},
|
||||
}
|
||||
}
|
||||
if err := stream.Send(&resp); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user