Port envd from e2b with internalized shared packages and Connect RPC

- Copy envd source from e2b-dev/infra, internalize shared dependencies
  into envd/internal/shared/ (keys, filesystem, id, smap, utils)
- Switch from gRPC to Connect RPC for all envd services
- Update module paths to git.omukk.dev/wrenn/{sandbox,sandbox/envd}
- Add proto specs (process, filesystem) with buf-based code generation
- Implement full envd: process exec, filesystem ops, port forwarding,
  cgroup management, MMDS integration, and HTTP API
- Update main module dependencies (firecracker SDK, pgx, goose, etc.)
- Remove placeholder .gitkeep files replaced by real implementations
This commit is contained in:
2026-03-09 21:03:19 +06:00
parent bd78cc068c
commit a3898d68fb
99 changed files with 17185 additions and 24 deletions

View File

@ -0,0 +1,126 @@
package process
import (
"context"
"errors"
"fmt"
"connectrpc.com/connect"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
)
func (s *Service) Connect(ctx context.Context, req *connect.Request[rpc.ConnectRequest], stream *connect.ServerStream[rpc.ConnectResponse]) error {
return logs.LogServerStreamWithoutEvents(ctx, s.logger, req, stream, s.handleConnect)
}
func (s *Service) handleConnect(ctx context.Context, req *connect.Request[rpc.ConnectRequest], stream *connect.ServerStream[rpc.ConnectResponse]) error {
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
proc, err := s.getProcess(req.Msg.GetProcess())
if err != nil {
return err
}
exitChan := make(chan struct{})
data, dataCancel := proc.DataEvent.Fork()
defer dataCancel()
end, endCancel := proc.EndEvent.Fork()
defer endCancel()
streamErr := stream.Send(&rpc.ConnectResponse{
Event: &rpc.ProcessEvent{
Event: &rpc.ProcessEvent_Start{
Start: &rpc.ProcessEvent_StartEvent{
Pid: proc.Pid(),
},
},
},
})
if streamErr != nil {
return connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending start event: %w", streamErr))
}
go func() {
defer close(exitChan)
keepaliveTicker, resetKeepalive := permissions.GetKeepAliveTicker(req)
defer keepaliveTicker.Stop()
dataLoop:
for {
select {
case <-keepaliveTicker.C:
streamErr := stream.Send(&rpc.ConnectResponse{
Event: &rpc.ProcessEvent{
Event: &rpc.ProcessEvent_Keepalive{
Keepalive: &rpc.ProcessEvent_KeepAlive{},
},
},
})
if streamErr != nil {
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending keepalive: %w", streamErr)))
return
}
case <-ctx.Done():
cancel(ctx.Err())
return
case event, ok := <-data:
if !ok {
break dataLoop
}
streamErr := stream.Send(&rpc.ConnectResponse{
Event: &rpc.ProcessEvent{
Event: &event,
},
})
if streamErr != nil {
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending data event: %w", streamErr)))
return
}
resetKeepalive()
}
}
select {
case <-ctx.Done():
cancel(ctx.Err())
return
case event, ok := <-end:
if !ok {
cancel(connect.NewError(connect.CodeUnknown, errors.New("end event channel closed before sending end event")))
return
}
streamErr := stream.Send(&rpc.ConnectResponse{
Event: &rpc.ProcessEvent{
Event: &event,
},
})
if streamErr != nil {
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending end event: %w", streamErr)))
return
}
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-exitChan:
return nil
}
}

View File

@ -0,0 +1,478 @@
package handler
import (
"context"
"errors"
"fmt"
"io"
"os"
"os/exec"
"os/user"
"strconv"
"strings"
"sync"
"syscall"
"connectrpc.com/connect"
"github.com/creack/pty"
"github.com/rs/zerolog"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/cgroups"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
)
const (
defaultNice = 0
defaultOomScore = 100
outputBufferSize = 64
stdChunkSize = 2 << 14
ptyChunkSize = 2 << 13
)
type ProcessExit struct {
Error *string
Status string
Exited bool
Code int32
}
type Handler struct {
Config *rpc.ProcessConfig
logger *zerolog.Logger
Tag *string
cmd *exec.Cmd
tty *os.File
cancel context.CancelFunc
outCtx context.Context //nolint:containedctx // todo: refactor so this can be removed
outCancel context.CancelFunc
stdinMu sync.Mutex
stdin io.WriteCloser
DataEvent *MultiplexedChannel[rpc.ProcessEvent_Data]
EndEvent *MultiplexedChannel[rpc.ProcessEvent_End]
}
// This method must be called only after the process has been started
func (p *Handler) Pid() uint32 {
return uint32(p.cmd.Process.Pid)
}
// userCommand returns a human-readable representation of the user's original command,
// without the internal OOM/nice wrapper that is prepended to the actual exec.
func (p *Handler) userCommand() string {
return strings.Join(append([]string{p.Config.GetCmd()}, p.Config.GetArgs()...), " ")
}
// currentNice returns the nice value of the current process.
func currentNice() int {
prio, err := syscall.Getpriority(syscall.PRIO_PROCESS, 0)
if err != nil {
return 0
}
// Getpriority returns 20 - nice on Linux.
return 20 - prio
}
func New(
ctx context.Context,
user *user.User,
req *rpc.StartRequest,
logger *zerolog.Logger,
defaults *execcontext.Defaults,
cgroupManager cgroups.Manager,
cancel context.CancelFunc,
) (*Handler, error) {
// User command string for logging (without the internal wrapper details).
userCmd := strings.Join(append([]string{req.GetProcess().GetCmd()}, req.GetProcess().GetArgs()...), " ")
// Wrap the command in a shell that sets the OOM score and nice value before exec-ing the actual command.
// This eliminates the race window where grandchildren could inherit the parent's protected OOM score (-1000)
// or high CPU priority (nice -20) before the post-start calls had a chance to correct them.
// nice(1) applies a relative adjustment, so we compute the delta from the current (inherited) nice to the target.
niceDelta := defaultNice - currentNice()
oomWrapperScript := fmt.Sprintf(`echo %d > /proc/$$/oom_score_adj && exec /usr/bin/nice -n %d "${@}"`, defaultOomScore, niceDelta)
wrapperArgs := append([]string{"-c", oomWrapperScript, "--", req.GetProcess().GetCmd()}, req.GetProcess().GetArgs()...)
cmd := exec.CommandContext(ctx, "/bin/sh", wrapperArgs...)
uid, gid, err := permissions.GetUserIdUints(user)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, err)
}
groups := []uint32{gid}
if gids, err := user.GroupIds(); err != nil {
logger.Warn().Err(err).Str("user", user.Username).Msg("failed to get supplementary groups")
} else {
for _, g := range gids {
if parsed, err := strconv.ParseUint(g, 10, 32); err == nil {
groups = append(groups, uint32(parsed))
}
}
}
cgroupFD, ok := cgroupManager.GetFileDescriptor(getProcType(req))
cmd.SysProcAttr = &syscall.SysProcAttr{
UseCgroupFD: ok,
CgroupFD: cgroupFD,
Credential: &syscall.Credential{
Uid: uid,
Gid: gid,
Groups: groups,
},
}
resolvedPath, err := permissions.ExpandAndResolve(req.GetProcess().GetCwd(), user, defaults.Workdir)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
// Check if the cwd resolved path exists
if _, err := os.Stat(resolvedPath); errors.Is(err, os.ErrNotExist) {
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("cwd '%s' does not exist", resolvedPath))
}
cmd.Dir = resolvedPath
var formattedVars []string
// Take only 'PATH' variable from the current environment
// The 'PATH' should ideally be set in the environment
formattedVars = append(formattedVars, "PATH="+os.Getenv("PATH"))
formattedVars = append(formattedVars, "HOME="+user.HomeDir)
formattedVars = append(formattedVars, "USER="+user.Username)
formattedVars = append(formattedVars, "LOGNAME="+user.Username)
// Add the environment variables from the global environment
if defaults.EnvVars != nil {
defaults.EnvVars.Range(func(key string, value string) bool {
formattedVars = append(formattedVars, key+"="+value)
return true
})
}
// Only the last values of the env vars are used - this allows for overwriting defaults
for key, value := range req.GetProcess().GetEnvs() {
formattedVars = append(formattedVars, key+"="+value)
}
cmd.Env = formattedVars
outMultiplex := NewMultiplexedChannel[rpc.ProcessEvent_Data](outputBufferSize)
var outWg sync.WaitGroup
// Create a context for waiting for and cancelling output pipes.
// Cancellation of the process via timeout will propagate and cancel this context too.
outCtx, outCancel := context.WithCancel(ctx)
h := &Handler{
Config: req.GetProcess(),
cmd: cmd,
Tag: req.Tag,
DataEvent: outMultiplex,
cancel: cancel,
outCtx: outCtx,
outCancel: outCancel,
EndEvent: NewMultiplexedChannel[rpc.ProcessEvent_End](0),
logger: logger,
}
if req.GetPty() != nil {
// The pty should ideally start only in the Start method, but the package does not support that and we would have to code it manually.
// The output of the pty should correctly be passed though.
tty, err := pty.StartWithSize(cmd, &pty.Winsize{
Cols: uint16(req.GetPty().GetSize().GetCols()),
Rows: uint16(req.GetPty().GetSize().GetRows()),
})
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("error starting pty with command '%s' in dir '%s' with '%d' cols and '%d' rows: %w", userCmd, cmd.Dir, req.GetPty().GetSize().GetCols(), req.GetPty().GetSize().GetRows(), err))
}
outWg.Go(func() {
for {
buf := make([]byte, ptyChunkSize)
n, readErr := tty.Read(buf)
if n > 0 {
outMultiplex.Source <- rpc.ProcessEvent_Data{
Data: &rpc.ProcessEvent_DataEvent{
Output: &rpc.ProcessEvent_DataEvent_Pty{
Pty: buf[:n],
},
},
}
}
if errors.Is(readErr, io.EOF) {
break
}
if readErr != nil {
fmt.Fprintf(os.Stderr, "error reading from pty: %s\n", readErr)
break
}
}
})
h.tty = tty
} else {
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating stdout pipe for command '%s': %w", userCmd, err))
}
outWg.Go(func() {
stdoutLogs := make(chan []byte, outputBufferSize)
defer close(stdoutLogs)
stdoutLogger := logger.With().Str("event_type", "stdout").Logger()
go logs.LogBufferedDataEvents(stdoutLogs, &stdoutLogger, "data")
for {
buf := make([]byte, stdChunkSize)
n, readErr := stdout.Read(buf)
if n > 0 {
outMultiplex.Source <- rpc.ProcessEvent_Data{
Data: &rpc.ProcessEvent_DataEvent{
Output: &rpc.ProcessEvent_DataEvent_Stdout{
Stdout: buf[:n],
},
},
}
stdoutLogs <- buf[:n]
}
if errors.Is(readErr, io.EOF) {
break
}
if readErr != nil {
fmt.Fprintf(os.Stderr, "error reading from stdout: %s\n", readErr)
break
}
}
})
stderr, err := cmd.StderrPipe()
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating stderr pipe for command '%s': %w", userCmd, err))
}
outWg.Go(func() {
stderrLogs := make(chan []byte, outputBufferSize)
defer close(stderrLogs)
stderrLogger := logger.With().Str("event_type", "stderr").Logger()
go logs.LogBufferedDataEvents(stderrLogs, &stderrLogger, "data")
for {
buf := make([]byte, stdChunkSize)
n, readErr := stderr.Read(buf)
if n > 0 {
outMultiplex.Source <- rpc.ProcessEvent_Data{
Data: &rpc.ProcessEvent_DataEvent{
Output: &rpc.ProcessEvent_DataEvent_Stderr{
Stderr: buf[:n],
},
},
}
stderrLogs <- buf[:n]
}
if errors.Is(readErr, io.EOF) {
break
}
if readErr != nil {
fmt.Fprintf(os.Stderr, "error reading from stderr: %s\n", readErr)
break
}
}
})
// For backwards compatibility we still set the stdin if not explicitly disabled
// If stdin is disabled, the process will use /dev/null as stdin
if req.Stdin == nil || req.GetStdin() == true {
stdin, err := cmd.StdinPipe()
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating stdin pipe for command '%s': %w", userCmd, err))
}
h.stdin = stdin
}
}
go func() {
outWg.Wait()
close(outMultiplex.Source)
outCancel()
}()
return h, nil
}
func getProcType(req *rpc.StartRequest) cgroups.ProcessType {
if req != nil && req.GetPty() != nil {
return cgroups.ProcessTypePTY
}
return cgroups.ProcessTypeUser
}
func (p *Handler) SendSignal(signal syscall.Signal) error {
if p.cmd.Process == nil {
return fmt.Errorf("process not started")
}
if signal == syscall.SIGKILL || signal == syscall.SIGTERM {
p.outCancel()
}
return p.cmd.Process.Signal(signal)
}
func (p *Handler) ResizeTty(size *pty.Winsize) error {
if p.tty == nil {
return fmt.Errorf("tty not assigned to process")
}
return pty.Setsize(p.tty, size)
}
func (p *Handler) WriteStdin(data []byte) error {
if p.tty != nil {
return fmt.Errorf("tty assigned to process — input should be written to the pty, not the stdin")
}
p.stdinMu.Lock()
defer p.stdinMu.Unlock()
if p.stdin == nil {
return fmt.Errorf("stdin not enabled or closed")
}
_, err := p.stdin.Write(data)
if err != nil {
return fmt.Errorf("error writing to stdin of process '%d': %w", p.cmd.Process.Pid, err)
}
return nil
}
// CloseStdin closes the stdin pipe to signal EOF to the process.
// Only works for non-PTY processes.
func (p *Handler) CloseStdin() error {
if p.tty != nil {
return fmt.Errorf("cannot close stdin for PTY process — send Ctrl+D (0x04) instead")
}
p.stdinMu.Lock()
defer p.stdinMu.Unlock()
if p.stdin == nil {
return nil
}
err := p.stdin.Close()
// We still set the stdin to nil even on error as there are no errors,
// for which it is really safe to retry close across all distributions.
p.stdin = nil
return err
}
func (p *Handler) WriteTty(data []byte) error {
if p.tty == nil {
return fmt.Errorf("tty not assigned to process — input should be written to the stdin, not the tty")
}
_, err := p.tty.Write(data)
if err != nil {
return fmt.Errorf("error writing to tty of process '%d': %w", p.cmd.Process.Pid, err)
}
return nil
}
func (p *Handler) Start() (uint32, error) {
// Pty is already started in the New method
if p.tty == nil {
err := p.cmd.Start()
if err != nil {
return 0, fmt.Errorf("error starting process '%s': %w", p.userCommand(), err)
}
}
p.logger.
Info().
Str("event_type", "process_start").
Int("pid", p.cmd.Process.Pid).
Str("command", p.userCommand()).
Msg(fmt.Sprintf("Process with pid %d started", p.cmd.Process.Pid))
return uint32(p.cmd.Process.Pid), nil
}
func (p *Handler) Wait() {
// Wait for the output pipes to be closed or cancelled.
<-p.outCtx.Done()
err := p.cmd.Wait()
p.tty.Close()
var errMsg *string
if err != nil {
msg := err.Error()
errMsg = &msg
}
endEvent := &rpc.ProcessEvent_EndEvent{
Error: errMsg,
ExitCode: int32(p.cmd.ProcessState.ExitCode()),
Exited: p.cmd.ProcessState.Exited(),
Status: p.cmd.ProcessState.String(),
}
event := rpc.ProcessEvent_End{
End: endEvent,
}
p.EndEvent.Source <- event
p.logger.
Info().
Str("event_type", "process_end").
Interface("process_result", endEvent).
Msg(fmt.Sprintf("Process with pid %d ended", p.cmd.Process.Pid))
// Ensure the process cancel is called to cleanup resources.
// As it is called after end event and Wait, it should not affect command execution or returned events.
p.cancel()
}

View File

@ -0,0 +1,73 @@
package handler
import (
"sync"
"sync/atomic"
)
type MultiplexedChannel[T any] struct {
Source chan T
channels []chan T
mu sync.RWMutex
exited atomic.Bool
}
func NewMultiplexedChannel[T any](buffer int) *MultiplexedChannel[T] {
c := &MultiplexedChannel[T]{
channels: nil,
Source: make(chan T, buffer),
}
go func() {
for v := range c.Source {
c.mu.RLock()
for _, cons := range c.channels {
cons <- v
}
c.mu.RUnlock()
}
c.exited.Store(true)
for _, cons := range c.channels {
close(cons)
}
}()
return c
}
func (m *MultiplexedChannel[T]) Fork() (chan T, func()) {
if m.exited.Load() {
ch := make(chan T)
close(ch)
return ch, func() {}
}
m.mu.Lock()
defer m.mu.Unlock()
consumer := make(chan T)
m.channels = append(m.channels, consumer)
return consumer, func() {
m.remove(consumer)
}
}
func (m *MultiplexedChannel[T]) remove(consumer chan T) {
m.mu.Lock()
defer m.mu.Unlock()
for i, ch := range m.channels {
if ch == consumer {
m.channels = append(m.channels[:i], m.channels[i+1:]...)
return
}
}
}

View File

@ -0,0 +1,107 @@
package process
import (
"context"
"fmt"
"connectrpc.com/connect"
"github.com/rs/zerolog"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/process/handler"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
)
func handleInput(ctx context.Context, process *handler.Handler, in *rpc.ProcessInput, logger *zerolog.Logger) error {
switch in.GetInput().(type) {
case *rpc.ProcessInput_Pty:
err := process.WriteTty(in.GetPty())
if err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("error writing to tty: %w", err))
}
case *rpc.ProcessInput_Stdin:
err := process.WriteStdin(in.GetStdin())
if err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("error writing to stdin: %w", err))
}
logger.Debug().
Str("event_type", "stdin").
Interface("stdin", in.GetStdin()).
Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).
Msg("Streaming input to process")
default:
return connect.NewError(connect.CodeUnimplemented, fmt.Errorf("invalid input type %T", in.GetInput()))
}
return nil
}
func (s *Service) SendInput(ctx context.Context, req *connect.Request[rpc.SendInputRequest]) (*connect.Response[rpc.SendInputResponse], error) {
proc, err := s.getProcess(req.Msg.GetProcess())
if err != nil {
return nil, err
}
err = handleInput(ctx, proc, req.Msg.GetInput(), s.logger)
if err != nil {
return nil, err
}
return connect.NewResponse(&rpc.SendInputResponse{}), nil
}
func (s *Service) StreamInput(ctx context.Context, stream *connect.ClientStream[rpc.StreamInputRequest]) (*connect.Response[rpc.StreamInputResponse], error) {
return logs.LogClientStreamWithoutEvents(ctx, s.logger, stream, s.streamInputHandler)
}
func (s *Service) streamInputHandler(ctx context.Context, stream *connect.ClientStream[rpc.StreamInputRequest]) (*connect.Response[rpc.StreamInputResponse], error) {
var proc *handler.Handler
for stream.Receive() {
req := stream.Msg()
switch req.GetEvent().(type) {
case *rpc.StreamInputRequest_Start:
p, err := s.getProcess(req.GetStart().GetProcess())
if err != nil {
return nil, err
}
proc = p
case *rpc.StreamInputRequest_Data:
err := handleInput(ctx, proc, req.GetData().GetInput(), s.logger)
if err != nil {
return nil, err
}
case *rpc.StreamInputRequest_Keepalive:
default:
return nil, connect.NewError(connect.CodeUnimplemented, fmt.Errorf("invalid event type %T", req.GetEvent()))
}
}
err := stream.Err()
if err != nil {
return nil, connect.NewError(connect.CodeUnknown, fmt.Errorf("error streaming input: %w", err))
}
return connect.NewResponse(&rpc.StreamInputResponse{}), nil
}
func (s *Service) CloseStdin(
_ context.Context,
req *connect.Request[rpc.CloseStdinRequest],
) (*connect.Response[rpc.CloseStdinResponse], error) {
handler, err := s.getProcess(req.Msg.GetProcess())
if err != nil {
return nil, err
}
if err := handler.CloseStdin(); err != nil {
return nil, connect.NewError(connect.CodeUnknown, fmt.Errorf("error closing stdin: %w", err))
}
return connect.NewResponse(&rpc.CloseStdinResponse{}), nil
}

View File

@ -0,0 +1,28 @@
package process
import (
"context"
"connectrpc.com/connect"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/process/handler"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
)
func (s *Service) List(context.Context, *connect.Request[rpc.ListRequest]) (*connect.Response[rpc.ListResponse], error) {
processes := make([]*rpc.ProcessInfo, 0)
s.processes.Range(func(pid uint32, value *handler.Handler) bool {
processes = append(processes, &rpc.ProcessInfo{
Pid: pid,
Tag: value.Tag,
Config: value.Config,
})
return true
})
return connect.NewResponse(&rpc.ListResponse{
Processes: processes,
}), nil
}

View File

@ -0,0 +1,84 @@
package process
import (
"fmt"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"github.com/rs/zerolog"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/cgroups"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/process/handler"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
spec "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process/processconnect"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
type Service struct {
processes *utils.Map[uint32, *handler.Handler]
logger *zerolog.Logger
defaults *execcontext.Defaults
cgroupManager cgroups.Manager
}
func newService(l *zerolog.Logger, defaults *execcontext.Defaults, cgroupManager cgroups.Manager) *Service {
return &Service{
logger: l,
processes: utils.NewMap[uint32, *handler.Handler](),
defaults: defaults,
cgroupManager: cgroupManager,
}
}
func Handle(server *chi.Mux, l *zerolog.Logger, defaults *execcontext.Defaults, cgroupManager cgroups.Manager) *Service {
service := newService(l, defaults, cgroupManager)
interceptors := connect.WithInterceptors(logs.NewUnaryLogInterceptor(l))
path, h := spec.NewProcessHandler(service, interceptors)
server.Mount(path, h)
return service
}
func (s *Service) getProcess(selector *rpc.ProcessSelector) (*handler.Handler, error) {
var proc *handler.Handler
switch selector.GetSelector().(type) {
case *rpc.ProcessSelector_Pid:
p, ok := s.processes.Load(selector.GetPid())
if !ok {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("process with pid %d not found", selector.GetPid()))
}
proc = p
case *rpc.ProcessSelector_Tag:
tag := selector.GetTag()
s.processes.Range(func(_ uint32, value *handler.Handler) bool {
if value.Tag == nil {
return true
}
if *value.Tag == tag {
proc = value
return true
}
return false
})
if proc == nil {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("process with tag %s not found", tag))
}
default:
return nil, connect.NewError(connect.CodeUnimplemented, fmt.Errorf("invalid input type %T", selector))
}
return proc, nil
}

View File

@ -0,0 +1,38 @@
package process
import (
"context"
"fmt"
"syscall"
"connectrpc.com/connect"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
)
func (s *Service) SendSignal(
_ context.Context,
req *connect.Request[rpc.SendSignalRequest],
) (*connect.Response[rpc.SendSignalResponse], error) {
handler, err := s.getProcess(req.Msg.GetProcess())
if err != nil {
return nil, err
}
var signal syscall.Signal
switch req.Msg.GetSignal() {
case rpc.Signal_SIGNAL_SIGKILL:
signal = syscall.SIGKILL
case rpc.Signal_SIGNAL_SIGTERM:
signal = syscall.SIGTERM
default:
return nil, connect.NewError(connect.CodeUnimplemented, fmt.Errorf("invalid signal: %s", req.Msg.GetSignal()))
}
err = handler.SendSignal(signal)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error sending signal: %w", err))
}
return connect.NewResponse(&rpc.SendSignalResponse{}), nil
}

View File

@ -0,0 +1,247 @@
package process
import (
"context"
"errors"
"fmt"
"net/http"
"os/user"
"strconv"
"time"
"connectrpc.com/connect"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/process/handler"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
)
func (s *Service) InitializeStartProcess(ctx context.Context, user *user.User, req *rpc.StartRequest) error {
var err error
ctx = logs.AddRequestIDToContext(ctx)
defer s.logger.
Err(err).
Interface("request", req).
Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).
Msg("Initialized startCmd")
handlerL := s.logger.With().Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).Logger()
startProcCtx, startProcCancel := context.WithCancel(ctx)
proc, err := handler.New(startProcCtx, user, req, &handlerL, s.defaults, s.cgroupManager, startProcCancel)
if err != nil {
return err
}
pid, err := proc.Start()
if err != nil {
return err
}
s.processes.Store(pid, proc)
go func() {
defer s.processes.Delete(pid)
proc.Wait()
}()
return nil
}
func (s *Service) Start(ctx context.Context, req *connect.Request[rpc.StartRequest], stream *connect.ServerStream[rpc.StartResponse]) error {
return logs.LogServerStreamWithoutEvents(ctx, s.logger, req, stream, s.handleStart)
}
func (s *Service) handleStart(ctx context.Context, req *connect.Request[rpc.StartRequest], stream *connect.ServerStream[rpc.StartResponse]) error {
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
handlerL := s.logger.With().Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).Logger()
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
if err != nil {
return err
}
timeout, err := determineTimeoutFromHeader(stream.Conn().RequestHeader())
if err != nil {
return connect.NewError(connect.CodeInvalidArgument, err)
}
// Create a new context with a timeout if provided.
// We do not want the command to be killed if the request context is cancelled
procCtx, cancelProc := context.Background(), func() {}
if timeout > 0 { // zero timeout means no timeout
procCtx, cancelProc = context.WithTimeout(procCtx, timeout)
}
proc, err := handler.New( //nolint:contextcheck // TODO: fix this later
procCtx,
u,
req.Msg,
&handlerL,
s.defaults,
s.cgroupManager,
cancelProc,
)
if err != nil {
// Ensure the process cancel is called to cleanup resources.
cancelProc()
return err
}
exitChan := make(chan struct{})
startMultiplexer := handler.NewMultiplexedChannel[rpc.ProcessEvent_Start](0)
defer close(startMultiplexer.Source)
start, startCancel := startMultiplexer.Fork()
defer startCancel()
data, dataCancel := proc.DataEvent.Fork()
defer dataCancel()
end, endCancel := proc.EndEvent.Fork()
defer endCancel()
go func() {
defer close(exitChan)
select {
case <-ctx.Done():
cancel(ctx.Err())
return
case event, ok := <-start:
if !ok {
cancel(connect.NewError(connect.CodeUnknown, errors.New("start event channel closed before sending start event")))
return
}
streamErr := stream.Send(&rpc.StartResponse{
Event: &rpc.ProcessEvent{
Event: &event,
},
})
if streamErr != nil {
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending start event: %w", streamErr)))
return
}
}
keepaliveTicker, resetKeepalive := permissions.GetKeepAliveTicker(req)
defer keepaliveTicker.Stop()
dataLoop:
for {
select {
case <-keepaliveTicker.C:
streamErr := stream.Send(&rpc.StartResponse{
Event: &rpc.ProcessEvent{
Event: &rpc.ProcessEvent_Keepalive{
Keepalive: &rpc.ProcessEvent_KeepAlive{},
},
},
})
if streamErr != nil {
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending keepalive: %w", streamErr)))
return
}
case <-ctx.Done():
cancel(ctx.Err())
return
case event, ok := <-data:
if !ok {
break dataLoop
}
streamErr := stream.Send(&rpc.StartResponse{
Event: &rpc.ProcessEvent{
Event: &event,
},
})
if streamErr != nil {
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending data event: %w", streamErr)))
return
}
resetKeepalive()
}
}
select {
case <-ctx.Done():
cancel(ctx.Err())
return
case event, ok := <-end:
if !ok {
cancel(connect.NewError(connect.CodeUnknown, errors.New("end event channel closed before sending end event")))
return
}
streamErr := stream.Send(&rpc.StartResponse{
Event: &rpc.ProcessEvent{
Event: &event,
},
})
if streamErr != nil {
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending end event: %w", streamErr)))
return
}
}
}()
pid, err := proc.Start()
if err != nil {
return connect.NewError(connect.CodeInvalidArgument, err)
}
s.processes.Store(pid, proc)
start <- rpc.ProcessEvent_Start{
Start: &rpc.ProcessEvent_StartEvent{
Pid: pid,
},
}
go func() {
defer s.processes.Delete(pid)
proc.Wait()
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-exitChan:
return nil
}
}
func determineTimeoutFromHeader(header http.Header) (time.Duration, error) {
timeoutHeader := header.Get("Connect-Timeout-Ms")
if timeoutHeader == "" {
return 0, nil
}
timeout, err := strconv.Atoi(timeoutHeader)
if err != nil {
return 0, err
}
return time.Duration(timeout) * time.Millisecond, nil
}

View File

@ -0,0 +1,30 @@
package process
import (
"context"
"fmt"
"connectrpc.com/connect"
"github.com/creack/pty"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
)
func (s *Service) Update(_ context.Context, req *connect.Request[rpc.UpdateRequest]) (*connect.Response[rpc.UpdateResponse], error) {
proc, err := s.getProcess(req.Msg.GetProcess())
if err != nil {
return nil, err
}
if req.Msg.GetPty() != nil {
err := proc.ResizeTty(&pty.Winsize{
Rows: uint16(req.Msg.GetPty().GetSize().GetRows()),
Cols: uint16(req.Msg.GetPty().GetSize().GetCols()),
})
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error resizing tty: %w", err))
}
}
return connect.NewResponse(&rpc.UpdateResponse{}), nil
}