Port envd from e2b with internalized shared packages and Connect RPC
- Copy envd source from e2b-dev/infra, internalize shared dependencies
into envd/internal/shared/ (keys, filesystem, id, smap, utils)
- Switch from gRPC to Connect RPC for all envd services
- Update module paths to git.omukk.dev/wrenn/{sandbox,sandbox/envd}
- Add proto specs (process, filesystem) with buf-based code generation
- Implement full envd: process exec, filesystem ops, port forwarding,
cgroup management, MMDS integration, and HTTP API
- Update main module dependencies (firecracker SDK, pgx, goose, etc.)
- Remove placeholder .gitkeep files replaced by real implementations
This commit is contained in:
126
envd/internal/services/process/connect.go
Normal file
126
envd/internal/services/process/connect.go
Normal file
@ -0,0 +1,126 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||
)
|
||||
|
||||
func (s *Service) Connect(ctx context.Context, req *connect.Request[rpc.ConnectRequest], stream *connect.ServerStream[rpc.ConnectResponse]) error {
|
||||
return logs.LogServerStreamWithoutEvents(ctx, s.logger, req, stream, s.handleConnect)
|
||||
}
|
||||
|
||||
func (s *Service) handleConnect(ctx context.Context, req *connect.Request[rpc.ConnectRequest], stream *connect.ServerStream[rpc.ConnectResponse]) error {
|
||||
ctx, cancel := context.WithCancelCause(ctx)
|
||||
defer cancel(nil)
|
||||
|
||||
proc, err := s.getProcess(req.Msg.GetProcess())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
exitChan := make(chan struct{})
|
||||
|
||||
data, dataCancel := proc.DataEvent.Fork()
|
||||
defer dataCancel()
|
||||
|
||||
end, endCancel := proc.EndEvent.Fork()
|
||||
defer endCancel()
|
||||
|
||||
streamErr := stream.Send(&rpc.ConnectResponse{
|
||||
Event: &rpc.ProcessEvent{
|
||||
Event: &rpc.ProcessEvent_Start{
|
||||
Start: &rpc.ProcessEvent_StartEvent{
|
||||
Pid: proc.Pid(),
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
if streamErr != nil {
|
||||
return connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending start event: %w", streamErr))
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer close(exitChan)
|
||||
|
||||
keepaliveTicker, resetKeepalive := permissions.GetKeepAliveTicker(req)
|
||||
defer keepaliveTicker.Stop()
|
||||
|
||||
dataLoop:
|
||||
for {
|
||||
select {
|
||||
case <-keepaliveTicker.C:
|
||||
streamErr := stream.Send(&rpc.ConnectResponse{
|
||||
Event: &rpc.ProcessEvent{
|
||||
Event: &rpc.ProcessEvent_Keepalive{
|
||||
Keepalive: &rpc.ProcessEvent_KeepAlive{},
|
||||
},
|
||||
},
|
||||
})
|
||||
if streamErr != nil {
|
||||
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending keepalive: %w", streamErr)))
|
||||
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
cancel(ctx.Err())
|
||||
|
||||
return
|
||||
case event, ok := <-data:
|
||||
if !ok {
|
||||
break dataLoop
|
||||
}
|
||||
|
||||
streamErr := stream.Send(&rpc.ConnectResponse{
|
||||
Event: &rpc.ProcessEvent{
|
||||
Event: &event,
|
||||
},
|
||||
})
|
||||
if streamErr != nil {
|
||||
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending data event: %w", streamErr)))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resetKeepalive()
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
cancel(ctx.Err())
|
||||
|
||||
return
|
||||
case event, ok := <-end:
|
||||
if !ok {
|
||||
cancel(connect.NewError(connect.CodeUnknown, errors.New("end event channel closed before sending end event")))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
streamErr := stream.Send(&rpc.ConnectResponse{
|
||||
Event: &rpc.ProcessEvent{
|
||||
Event: &event,
|
||||
},
|
||||
})
|
||||
if streamErr != nil {
|
||||
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending end event: %w", streamErr)))
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-exitChan:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
478
envd/internal/services/process/handler/handler.go
Normal file
478
envd/internal/services/process/handler/handler.go
Normal file
@ -0,0 +1,478 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/creack/pty"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/services/cgroups"
|
||||
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultNice = 0
|
||||
defaultOomScore = 100
|
||||
outputBufferSize = 64
|
||||
stdChunkSize = 2 << 14
|
||||
ptyChunkSize = 2 << 13
|
||||
)
|
||||
|
||||
type ProcessExit struct {
|
||||
Error *string
|
||||
Status string
|
||||
Exited bool
|
||||
Code int32
|
||||
}
|
||||
|
||||
type Handler struct {
|
||||
Config *rpc.ProcessConfig
|
||||
|
||||
logger *zerolog.Logger
|
||||
|
||||
Tag *string
|
||||
cmd *exec.Cmd
|
||||
tty *os.File
|
||||
|
||||
cancel context.CancelFunc
|
||||
|
||||
outCtx context.Context //nolint:containedctx // todo: refactor so this can be removed
|
||||
outCancel context.CancelFunc
|
||||
|
||||
stdinMu sync.Mutex
|
||||
stdin io.WriteCloser
|
||||
|
||||
DataEvent *MultiplexedChannel[rpc.ProcessEvent_Data]
|
||||
EndEvent *MultiplexedChannel[rpc.ProcessEvent_End]
|
||||
}
|
||||
|
||||
// This method must be called only after the process has been started
|
||||
func (p *Handler) Pid() uint32 {
|
||||
return uint32(p.cmd.Process.Pid)
|
||||
}
|
||||
|
||||
// userCommand returns a human-readable representation of the user's original command,
|
||||
// without the internal OOM/nice wrapper that is prepended to the actual exec.
|
||||
func (p *Handler) userCommand() string {
|
||||
return strings.Join(append([]string{p.Config.GetCmd()}, p.Config.GetArgs()...), " ")
|
||||
}
|
||||
|
||||
// currentNice returns the nice value of the current process.
|
||||
func currentNice() int {
|
||||
prio, err := syscall.Getpriority(syscall.PRIO_PROCESS, 0)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Getpriority returns 20 - nice on Linux.
|
||||
return 20 - prio
|
||||
}
|
||||
|
||||
func New(
|
||||
ctx context.Context,
|
||||
user *user.User,
|
||||
req *rpc.StartRequest,
|
||||
logger *zerolog.Logger,
|
||||
defaults *execcontext.Defaults,
|
||||
cgroupManager cgroups.Manager,
|
||||
cancel context.CancelFunc,
|
||||
) (*Handler, error) {
|
||||
// User command string for logging (without the internal wrapper details).
|
||||
userCmd := strings.Join(append([]string{req.GetProcess().GetCmd()}, req.GetProcess().GetArgs()...), " ")
|
||||
|
||||
// Wrap the command in a shell that sets the OOM score and nice value before exec-ing the actual command.
|
||||
// This eliminates the race window where grandchildren could inherit the parent's protected OOM score (-1000)
|
||||
// or high CPU priority (nice -20) before the post-start calls had a chance to correct them.
|
||||
// nice(1) applies a relative adjustment, so we compute the delta from the current (inherited) nice to the target.
|
||||
niceDelta := defaultNice - currentNice()
|
||||
oomWrapperScript := fmt.Sprintf(`echo %d > /proc/$$/oom_score_adj && exec /usr/bin/nice -n %d "${@}"`, defaultOomScore, niceDelta)
|
||||
wrapperArgs := append([]string{"-c", oomWrapperScript, "--", req.GetProcess().GetCmd()}, req.GetProcess().GetArgs()...)
|
||||
cmd := exec.CommandContext(ctx, "/bin/sh", wrapperArgs...)
|
||||
|
||||
uid, gid, err := permissions.GetUserIdUints(user)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, err)
|
||||
}
|
||||
|
||||
groups := []uint32{gid}
|
||||
if gids, err := user.GroupIds(); err != nil {
|
||||
logger.Warn().Err(err).Str("user", user.Username).Msg("failed to get supplementary groups")
|
||||
} else {
|
||||
for _, g := range gids {
|
||||
if parsed, err := strconv.ParseUint(g, 10, 32); err == nil {
|
||||
groups = append(groups, uint32(parsed))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cgroupFD, ok := cgroupManager.GetFileDescriptor(getProcType(req))
|
||||
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
UseCgroupFD: ok,
|
||||
CgroupFD: cgroupFD,
|
||||
Credential: &syscall.Credential{
|
||||
Uid: uid,
|
||||
Gid: gid,
|
||||
Groups: groups,
|
||||
},
|
||||
}
|
||||
|
||||
resolvedPath, err := permissions.ExpandAndResolve(req.GetProcess().GetCwd(), user, defaults.Workdir)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
// Check if the cwd resolved path exists
|
||||
if _, err := os.Stat(resolvedPath); errors.Is(err, os.ErrNotExist) {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("cwd '%s' does not exist", resolvedPath))
|
||||
}
|
||||
|
||||
cmd.Dir = resolvedPath
|
||||
|
||||
var formattedVars []string
|
||||
|
||||
// Take only 'PATH' variable from the current environment
|
||||
// The 'PATH' should ideally be set in the environment
|
||||
formattedVars = append(formattedVars, "PATH="+os.Getenv("PATH"))
|
||||
formattedVars = append(formattedVars, "HOME="+user.HomeDir)
|
||||
formattedVars = append(formattedVars, "USER="+user.Username)
|
||||
formattedVars = append(formattedVars, "LOGNAME="+user.Username)
|
||||
|
||||
// Add the environment variables from the global environment
|
||||
if defaults.EnvVars != nil {
|
||||
defaults.EnvVars.Range(func(key string, value string) bool {
|
||||
formattedVars = append(formattedVars, key+"="+value)
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Only the last values of the env vars are used - this allows for overwriting defaults
|
||||
for key, value := range req.GetProcess().GetEnvs() {
|
||||
formattedVars = append(formattedVars, key+"="+value)
|
||||
}
|
||||
|
||||
cmd.Env = formattedVars
|
||||
|
||||
outMultiplex := NewMultiplexedChannel[rpc.ProcessEvent_Data](outputBufferSize)
|
||||
|
||||
var outWg sync.WaitGroup
|
||||
|
||||
// Create a context for waiting for and cancelling output pipes.
|
||||
// Cancellation of the process via timeout will propagate and cancel this context too.
|
||||
outCtx, outCancel := context.WithCancel(ctx)
|
||||
|
||||
h := &Handler{
|
||||
Config: req.GetProcess(),
|
||||
cmd: cmd,
|
||||
Tag: req.Tag,
|
||||
DataEvent: outMultiplex,
|
||||
cancel: cancel,
|
||||
outCtx: outCtx,
|
||||
outCancel: outCancel,
|
||||
EndEvent: NewMultiplexedChannel[rpc.ProcessEvent_End](0),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
if req.GetPty() != nil {
|
||||
// The pty should ideally start only in the Start method, but the package does not support that and we would have to code it manually.
|
||||
// The output of the pty should correctly be passed though.
|
||||
tty, err := pty.StartWithSize(cmd, &pty.Winsize{
|
||||
Cols: uint16(req.GetPty().GetSize().GetCols()),
|
||||
Rows: uint16(req.GetPty().GetSize().GetRows()),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("error starting pty with command '%s' in dir '%s' with '%d' cols and '%d' rows: %w", userCmd, cmd.Dir, req.GetPty().GetSize().GetCols(), req.GetPty().GetSize().GetRows(), err))
|
||||
}
|
||||
|
||||
outWg.Go(func() {
|
||||
for {
|
||||
buf := make([]byte, ptyChunkSize)
|
||||
|
||||
n, readErr := tty.Read(buf)
|
||||
|
||||
if n > 0 {
|
||||
outMultiplex.Source <- rpc.ProcessEvent_Data{
|
||||
Data: &rpc.ProcessEvent_DataEvent{
|
||||
Output: &rpc.ProcessEvent_DataEvent_Pty{
|
||||
Pty: buf[:n],
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if errors.Is(readErr, io.EOF) {
|
||||
break
|
||||
}
|
||||
|
||||
if readErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "error reading from pty: %s\n", readErr)
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
h.tty = tty
|
||||
} else {
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating stdout pipe for command '%s': %w", userCmd, err))
|
||||
}
|
||||
|
||||
outWg.Go(func() {
|
||||
stdoutLogs := make(chan []byte, outputBufferSize)
|
||||
defer close(stdoutLogs)
|
||||
|
||||
stdoutLogger := logger.With().Str("event_type", "stdout").Logger()
|
||||
|
||||
go logs.LogBufferedDataEvents(stdoutLogs, &stdoutLogger, "data")
|
||||
|
||||
for {
|
||||
buf := make([]byte, stdChunkSize)
|
||||
|
||||
n, readErr := stdout.Read(buf)
|
||||
|
||||
if n > 0 {
|
||||
outMultiplex.Source <- rpc.ProcessEvent_Data{
|
||||
Data: &rpc.ProcessEvent_DataEvent{
|
||||
Output: &rpc.ProcessEvent_DataEvent_Stdout{
|
||||
Stdout: buf[:n],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
stdoutLogs <- buf[:n]
|
||||
}
|
||||
|
||||
if errors.Is(readErr, io.EOF) {
|
||||
break
|
||||
}
|
||||
|
||||
if readErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "error reading from stdout: %s\n", readErr)
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating stderr pipe for command '%s': %w", userCmd, err))
|
||||
}
|
||||
|
||||
outWg.Go(func() {
|
||||
stderrLogs := make(chan []byte, outputBufferSize)
|
||||
defer close(stderrLogs)
|
||||
|
||||
stderrLogger := logger.With().Str("event_type", "stderr").Logger()
|
||||
|
||||
go logs.LogBufferedDataEvents(stderrLogs, &stderrLogger, "data")
|
||||
|
||||
for {
|
||||
buf := make([]byte, stdChunkSize)
|
||||
|
||||
n, readErr := stderr.Read(buf)
|
||||
|
||||
if n > 0 {
|
||||
outMultiplex.Source <- rpc.ProcessEvent_Data{
|
||||
Data: &rpc.ProcessEvent_DataEvent{
|
||||
Output: &rpc.ProcessEvent_DataEvent_Stderr{
|
||||
Stderr: buf[:n],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
stderrLogs <- buf[:n]
|
||||
}
|
||||
|
||||
if errors.Is(readErr, io.EOF) {
|
||||
break
|
||||
}
|
||||
|
||||
if readErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "error reading from stderr: %s\n", readErr)
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// For backwards compatibility we still set the stdin if not explicitly disabled
|
||||
// If stdin is disabled, the process will use /dev/null as stdin
|
||||
if req.Stdin == nil || req.GetStdin() == true {
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating stdin pipe for command '%s': %w", userCmd, err))
|
||||
}
|
||||
|
||||
h.stdin = stdin
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
outWg.Wait()
|
||||
|
||||
close(outMultiplex.Source)
|
||||
|
||||
outCancel()
|
||||
}()
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func getProcType(req *rpc.StartRequest) cgroups.ProcessType {
|
||||
if req != nil && req.GetPty() != nil {
|
||||
return cgroups.ProcessTypePTY
|
||||
}
|
||||
|
||||
return cgroups.ProcessTypeUser
|
||||
}
|
||||
|
||||
func (p *Handler) SendSignal(signal syscall.Signal) error {
|
||||
if p.cmd.Process == nil {
|
||||
return fmt.Errorf("process not started")
|
||||
}
|
||||
|
||||
if signal == syscall.SIGKILL || signal == syscall.SIGTERM {
|
||||
p.outCancel()
|
||||
}
|
||||
|
||||
return p.cmd.Process.Signal(signal)
|
||||
}
|
||||
|
||||
func (p *Handler) ResizeTty(size *pty.Winsize) error {
|
||||
if p.tty == nil {
|
||||
return fmt.Errorf("tty not assigned to process")
|
||||
}
|
||||
|
||||
return pty.Setsize(p.tty, size)
|
||||
}
|
||||
|
||||
func (p *Handler) WriteStdin(data []byte) error {
|
||||
if p.tty != nil {
|
||||
return fmt.Errorf("tty assigned to process — input should be written to the pty, not the stdin")
|
||||
}
|
||||
|
||||
p.stdinMu.Lock()
|
||||
defer p.stdinMu.Unlock()
|
||||
|
||||
if p.stdin == nil {
|
||||
return fmt.Errorf("stdin not enabled or closed")
|
||||
}
|
||||
|
||||
_, err := p.stdin.Write(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing to stdin of process '%d': %w", p.cmd.Process.Pid, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseStdin closes the stdin pipe to signal EOF to the process.
|
||||
// Only works for non-PTY processes.
|
||||
func (p *Handler) CloseStdin() error {
|
||||
if p.tty != nil {
|
||||
return fmt.Errorf("cannot close stdin for PTY process — send Ctrl+D (0x04) instead")
|
||||
}
|
||||
|
||||
p.stdinMu.Lock()
|
||||
defer p.stdinMu.Unlock()
|
||||
|
||||
if p.stdin == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := p.stdin.Close()
|
||||
// We still set the stdin to nil even on error as there are no errors,
|
||||
// for which it is really safe to retry close across all distributions.
|
||||
p.stdin = nil
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *Handler) WriteTty(data []byte) error {
|
||||
if p.tty == nil {
|
||||
return fmt.Errorf("tty not assigned to process — input should be written to the stdin, not the tty")
|
||||
}
|
||||
|
||||
_, err := p.tty.Write(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing to tty of process '%d': %w", p.cmd.Process.Pid, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Handler) Start() (uint32, error) {
|
||||
// Pty is already started in the New method
|
||||
if p.tty == nil {
|
||||
err := p.cmd.Start()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error starting process '%s': %w", p.userCommand(), err)
|
||||
}
|
||||
}
|
||||
|
||||
p.logger.
|
||||
Info().
|
||||
Str("event_type", "process_start").
|
||||
Int("pid", p.cmd.Process.Pid).
|
||||
Str("command", p.userCommand()).
|
||||
Msg(fmt.Sprintf("Process with pid %d started", p.cmd.Process.Pid))
|
||||
|
||||
return uint32(p.cmd.Process.Pid), nil
|
||||
}
|
||||
|
||||
func (p *Handler) Wait() {
|
||||
// Wait for the output pipes to be closed or cancelled.
|
||||
<-p.outCtx.Done()
|
||||
|
||||
err := p.cmd.Wait()
|
||||
|
||||
p.tty.Close()
|
||||
|
||||
var errMsg *string
|
||||
|
||||
if err != nil {
|
||||
msg := err.Error()
|
||||
errMsg = &msg
|
||||
}
|
||||
|
||||
endEvent := &rpc.ProcessEvent_EndEvent{
|
||||
Error: errMsg,
|
||||
ExitCode: int32(p.cmd.ProcessState.ExitCode()),
|
||||
Exited: p.cmd.ProcessState.Exited(),
|
||||
Status: p.cmd.ProcessState.String(),
|
||||
}
|
||||
|
||||
event := rpc.ProcessEvent_End{
|
||||
End: endEvent,
|
||||
}
|
||||
|
||||
p.EndEvent.Source <- event
|
||||
|
||||
p.logger.
|
||||
Info().
|
||||
Str("event_type", "process_end").
|
||||
Interface("process_result", endEvent).
|
||||
Msg(fmt.Sprintf("Process with pid %d ended", p.cmd.Process.Pid))
|
||||
|
||||
// Ensure the process cancel is called to cleanup resources.
|
||||
// As it is called after end event and Wait, it should not affect command execution or returned events.
|
||||
p.cancel()
|
||||
}
|
||||
73
envd/internal/services/process/handler/multiplex.go
Normal file
73
envd/internal/services/process/handler/multiplex.go
Normal file
@ -0,0 +1,73 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type MultiplexedChannel[T any] struct {
|
||||
Source chan T
|
||||
channels []chan T
|
||||
mu sync.RWMutex
|
||||
exited atomic.Bool
|
||||
}
|
||||
|
||||
func NewMultiplexedChannel[T any](buffer int) *MultiplexedChannel[T] {
|
||||
c := &MultiplexedChannel[T]{
|
||||
channels: nil,
|
||||
Source: make(chan T, buffer),
|
||||
}
|
||||
|
||||
go func() {
|
||||
for v := range c.Source {
|
||||
c.mu.RLock()
|
||||
|
||||
for _, cons := range c.channels {
|
||||
cons <- v
|
||||
}
|
||||
|
||||
c.mu.RUnlock()
|
||||
}
|
||||
|
||||
c.exited.Store(true)
|
||||
|
||||
for _, cons := range c.channels {
|
||||
close(cons)
|
||||
}
|
||||
}()
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (m *MultiplexedChannel[T]) Fork() (chan T, func()) {
|
||||
if m.exited.Load() {
|
||||
ch := make(chan T)
|
||||
close(ch)
|
||||
|
||||
return ch, func() {}
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
consumer := make(chan T)
|
||||
|
||||
m.channels = append(m.channels, consumer)
|
||||
|
||||
return consumer, func() {
|
||||
m.remove(consumer)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MultiplexedChannel[T]) remove(consumer chan T) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for i, ch := range m.channels {
|
||||
if ch == consumer {
|
||||
m.channels = append(m.channels[:i], m.channels[i+1:]...)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
107
envd/internal/services/process/input.go
Normal file
107
envd/internal/services/process/input.go
Normal file
@ -0,0 +1,107 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/services/process/handler"
|
||||
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||
)
|
||||
|
||||
func handleInput(ctx context.Context, process *handler.Handler, in *rpc.ProcessInput, logger *zerolog.Logger) error {
|
||||
switch in.GetInput().(type) {
|
||||
case *rpc.ProcessInput_Pty:
|
||||
err := process.WriteTty(in.GetPty())
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("error writing to tty: %w", err))
|
||||
}
|
||||
|
||||
case *rpc.ProcessInput_Stdin:
|
||||
err := process.WriteStdin(in.GetStdin())
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("error writing to stdin: %w", err))
|
||||
}
|
||||
|
||||
logger.Debug().
|
||||
Str("event_type", "stdin").
|
||||
Interface("stdin", in.GetStdin()).
|
||||
Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).
|
||||
Msg("Streaming input to process")
|
||||
|
||||
default:
|
||||
return connect.NewError(connect.CodeUnimplemented, fmt.Errorf("invalid input type %T", in.GetInput()))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) SendInput(ctx context.Context, req *connect.Request[rpc.SendInputRequest]) (*connect.Response[rpc.SendInputResponse], error) {
|
||||
proc, err := s.getProcess(req.Msg.GetProcess())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = handleInput(ctx, proc, req.Msg.GetInput(), s.logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return connect.NewResponse(&rpc.SendInputResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Service) StreamInput(ctx context.Context, stream *connect.ClientStream[rpc.StreamInputRequest]) (*connect.Response[rpc.StreamInputResponse], error) {
|
||||
return logs.LogClientStreamWithoutEvents(ctx, s.logger, stream, s.streamInputHandler)
|
||||
}
|
||||
|
||||
func (s *Service) streamInputHandler(ctx context.Context, stream *connect.ClientStream[rpc.StreamInputRequest]) (*connect.Response[rpc.StreamInputResponse], error) {
|
||||
var proc *handler.Handler
|
||||
|
||||
for stream.Receive() {
|
||||
req := stream.Msg()
|
||||
|
||||
switch req.GetEvent().(type) {
|
||||
case *rpc.StreamInputRequest_Start:
|
||||
p, err := s.getProcess(req.GetStart().GetProcess())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
proc = p
|
||||
case *rpc.StreamInputRequest_Data:
|
||||
err := handleInput(ctx, proc, req.GetData().GetInput(), s.logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case *rpc.StreamInputRequest_Keepalive:
|
||||
default:
|
||||
return nil, connect.NewError(connect.CodeUnimplemented, fmt.Errorf("invalid event type %T", req.GetEvent()))
|
||||
}
|
||||
}
|
||||
|
||||
err := stream.Err()
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeUnknown, fmt.Errorf("error streaming input: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&rpc.StreamInputResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Service) CloseStdin(
|
||||
_ context.Context,
|
||||
req *connect.Request[rpc.CloseStdinRequest],
|
||||
) (*connect.Response[rpc.CloseStdinResponse], error) {
|
||||
handler, err := s.getProcess(req.Msg.GetProcess())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := handler.CloseStdin(); err != nil {
|
||||
return nil, connect.NewError(connect.CodeUnknown, fmt.Errorf("error closing stdin: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&rpc.CloseStdinResponse{}), nil
|
||||
}
|
||||
28
envd/internal/services/process/list.go
Normal file
28
envd/internal/services/process/list.go
Normal file
@ -0,0 +1,28 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/services/process/handler"
|
||||
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||
)
|
||||
|
||||
func (s *Service) List(context.Context, *connect.Request[rpc.ListRequest]) (*connect.Response[rpc.ListResponse], error) {
|
||||
processes := make([]*rpc.ProcessInfo, 0)
|
||||
|
||||
s.processes.Range(func(pid uint32, value *handler.Handler) bool {
|
||||
processes = append(processes, &rpc.ProcessInfo{
|
||||
Pid: pid,
|
||||
Tag: value.Tag,
|
||||
Config: value.Config,
|
||||
})
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return connect.NewResponse(&rpc.ListResponse{
|
||||
Processes: processes,
|
||||
}), nil
|
||||
}
|
||||
84
envd/internal/services/process/service.go
Normal file
84
envd/internal/services/process/service.go
Normal file
@ -0,0 +1,84 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/services/cgroups"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/services/process/handler"
|
||||
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||
spec "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process/processconnect"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
processes *utils.Map[uint32, *handler.Handler]
|
||||
logger *zerolog.Logger
|
||||
defaults *execcontext.Defaults
|
||||
cgroupManager cgroups.Manager
|
||||
}
|
||||
|
||||
func newService(l *zerolog.Logger, defaults *execcontext.Defaults, cgroupManager cgroups.Manager) *Service {
|
||||
return &Service{
|
||||
logger: l,
|
||||
processes: utils.NewMap[uint32, *handler.Handler](),
|
||||
defaults: defaults,
|
||||
cgroupManager: cgroupManager,
|
||||
}
|
||||
}
|
||||
|
||||
func Handle(server *chi.Mux, l *zerolog.Logger, defaults *execcontext.Defaults, cgroupManager cgroups.Manager) *Service {
|
||||
service := newService(l, defaults, cgroupManager)
|
||||
|
||||
interceptors := connect.WithInterceptors(logs.NewUnaryLogInterceptor(l))
|
||||
|
||||
path, h := spec.NewProcessHandler(service, interceptors)
|
||||
|
||||
server.Mount(path, h)
|
||||
|
||||
return service
|
||||
}
|
||||
|
||||
func (s *Service) getProcess(selector *rpc.ProcessSelector) (*handler.Handler, error) {
|
||||
var proc *handler.Handler
|
||||
|
||||
switch selector.GetSelector().(type) {
|
||||
case *rpc.ProcessSelector_Pid:
|
||||
p, ok := s.processes.Load(selector.GetPid())
|
||||
if !ok {
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("process with pid %d not found", selector.GetPid()))
|
||||
}
|
||||
|
||||
proc = p
|
||||
case *rpc.ProcessSelector_Tag:
|
||||
tag := selector.GetTag()
|
||||
|
||||
s.processes.Range(func(_ uint32, value *handler.Handler) bool {
|
||||
if value.Tag == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if *value.Tag == tag {
|
||||
proc = value
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
})
|
||||
|
||||
if proc == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("process with tag %s not found", tag))
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, connect.NewError(connect.CodeUnimplemented, fmt.Errorf("invalid input type %T", selector))
|
||||
}
|
||||
|
||||
return proc, nil
|
||||
}
|
||||
38
envd/internal/services/process/signal.go
Normal file
38
envd/internal/services/process/signal.go
Normal file
@ -0,0 +1,38 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"syscall"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||
)
|
||||
|
||||
func (s *Service) SendSignal(
|
||||
_ context.Context,
|
||||
req *connect.Request[rpc.SendSignalRequest],
|
||||
) (*connect.Response[rpc.SendSignalResponse], error) {
|
||||
handler, err := s.getProcess(req.Msg.GetProcess())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var signal syscall.Signal
|
||||
switch req.Msg.GetSignal() {
|
||||
case rpc.Signal_SIGNAL_SIGKILL:
|
||||
signal = syscall.SIGKILL
|
||||
case rpc.Signal_SIGNAL_SIGTERM:
|
||||
signal = syscall.SIGTERM
|
||||
default:
|
||||
return nil, connect.NewError(connect.CodeUnimplemented, fmt.Errorf("invalid signal: %s", req.Msg.GetSignal()))
|
||||
}
|
||||
|
||||
err = handler.SendSignal(signal)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error sending signal: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&rpc.SendSignalResponse{}), nil
|
||||
}
|
||||
247
envd/internal/services/process/start.go
Normal file
247
envd/internal/services/process/start.go
Normal file
@ -0,0 +1,247 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os/user"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/services/process/handler"
|
||||
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||
)
|
||||
|
||||
func (s *Service) InitializeStartProcess(ctx context.Context, user *user.User, req *rpc.StartRequest) error {
|
||||
var err error
|
||||
|
||||
ctx = logs.AddRequestIDToContext(ctx)
|
||||
|
||||
defer s.logger.
|
||||
Err(err).
|
||||
Interface("request", req).
|
||||
Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).
|
||||
Msg("Initialized startCmd")
|
||||
|
||||
handlerL := s.logger.With().Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).Logger()
|
||||
|
||||
startProcCtx, startProcCancel := context.WithCancel(ctx)
|
||||
proc, err := handler.New(startProcCtx, user, req, &handlerL, s.defaults, s.cgroupManager, startProcCancel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pid, err := proc.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.processes.Store(pid, proc)
|
||||
|
||||
go func() {
|
||||
defer s.processes.Delete(pid)
|
||||
|
||||
proc.Wait()
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) Start(ctx context.Context, req *connect.Request[rpc.StartRequest], stream *connect.ServerStream[rpc.StartResponse]) error {
|
||||
return logs.LogServerStreamWithoutEvents(ctx, s.logger, req, stream, s.handleStart)
|
||||
}
|
||||
|
||||
func (s *Service) handleStart(ctx context.Context, req *connect.Request[rpc.StartRequest], stream *connect.ServerStream[rpc.StartResponse]) error {
|
||||
ctx, cancel := context.WithCancelCause(ctx)
|
||||
defer cancel(nil)
|
||||
|
||||
handlerL := s.logger.With().Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).Logger()
|
||||
|
||||
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
timeout, err := determineTimeoutFromHeader(stream.Conn().RequestHeader())
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
// Create a new context with a timeout if provided.
|
||||
// We do not want the command to be killed if the request context is cancelled
|
||||
procCtx, cancelProc := context.Background(), func() {}
|
||||
if timeout > 0 { // zero timeout means no timeout
|
||||
procCtx, cancelProc = context.WithTimeout(procCtx, timeout)
|
||||
}
|
||||
|
||||
proc, err := handler.New( //nolint:contextcheck // TODO: fix this later
|
||||
procCtx,
|
||||
u,
|
||||
req.Msg,
|
||||
&handlerL,
|
||||
s.defaults,
|
||||
s.cgroupManager,
|
||||
cancelProc,
|
||||
)
|
||||
if err != nil {
|
||||
// Ensure the process cancel is called to cleanup resources.
|
||||
cancelProc()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
exitChan := make(chan struct{})
|
||||
|
||||
startMultiplexer := handler.NewMultiplexedChannel[rpc.ProcessEvent_Start](0)
|
||||
defer close(startMultiplexer.Source)
|
||||
|
||||
start, startCancel := startMultiplexer.Fork()
|
||||
defer startCancel()
|
||||
|
||||
data, dataCancel := proc.DataEvent.Fork()
|
||||
defer dataCancel()
|
||||
|
||||
end, endCancel := proc.EndEvent.Fork()
|
||||
defer endCancel()
|
||||
|
||||
go func() {
|
||||
defer close(exitChan)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
cancel(ctx.Err())
|
||||
|
||||
return
|
||||
case event, ok := <-start:
|
||||
if !ok {
|
||||
cancel(connect.NewError(connect.CodeUnknown, errors.New("start event channel closed before sending start event")))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
streamErr := stream.Send(&rpc.StartResponse{
|
||||
Event: &rpc.ProcessEvent{
|
||||
Event: &event,
|
||||
},
|
||||
})
|
||||
if streamErr != nil {
|
||||
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending start event: %w", streamErr)))
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
keepaliveTicker, resetKeepalive := permissions.GetKeepAliveTicker(req)
|
||||
defer keepaliveTicker.Stop()
|
||||
|
||||
dataLoop:
|
||||
for {
|
||||
select {
|
||||
case <-keepaliveTicker.C:
|
||||
streamErr := stream.Send(&rpc.StartResponse{
|
||||
Event: &rpc.ProcessEvent{
|
||||
Event: &rpc.ProcessEvent_Keepalive{
|
||||
Keepalive: &rpc.ProcessEvent_KeepAlive{},
|
||||
},
|
||||
},
|
||||
})
|
||||
if streamErr != nil {
|
||||
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending keepalive: %w", streamErr)))
|
||||
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
cancel(ctx.Err())
|
||||
|
||||
return
|
||||
case event, ok := <-data:
|
||||
if !ok {
|
||||
break dataLoop
|
||||
}
|
||||
|
||||
streamErr := stream.Send(&rpc.StartResponse{
|
||||
Event: &rpc.ProcessEvent{
|
||||
Event: &event,
|
||||
},
|
||||
})
|
||||
if streamErr != nil {
|
||||
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending data event: %w", streamErr)))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resetKeepalive()
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
cancel(ctx.Err())
|
||||
|
||||
return
|
||||
case event, ok := <-end:
|
||||
if !ok {
|
||||
cancel(connect.NewError(connect.CodeUnknown, errors.New("end event channel closed before sending end event")))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
streamErr := stream.Send(&rpc.StartResponse{
|
||||
Event: &rpc.ProcessEvent{
|
||||
Event: &event,
|
||||
},
|
||||
})
|
||||
if streamErr != nil {
|
||||
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending end event: %w", streamErr)))
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
pid, err := proc.Start()
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
s.processes.Store(pid, proc)
|
||||
|
||||
start <- rpc.ProcessEvent_Start{
|
||||
Start: &rpc.ProcessEvent_StartEvent{
|
||||
Pid: pid,
|
||||
},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer s.processes.Delete(pid)
|
||||
|
||||
proc.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-exitChan:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func determineTimeoutFromHeader(header http.Header) (time.Duration, error) {
|
||||
timeoutHeader := header.Get("Connect-Timeout-Ms")
|
||||
|
||||
if timeoutHeader == "" {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
timeout, err := strconv.Atoi(timeoutHeader)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return time.Duration(timeout) * time.Millisecond, nil
|
||||
}
|
||||
30
envd/internal/services/process/update.go
Normal file
30
envd/internal/services/process/update.go
Normal file
@ -0,0 +1,30 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/creack/pty"
|
||||
|
||||
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||
)
|
||||
|
||||
func (s *Service) Update(_ context.Context, req *connect.Request[rpc.UpdateRequest]) (*connect.Response[rpc.UpdateResponse], error) {
|
||||
proc, err := s.getProcess(req.Msg.GetProcess())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if req.Msg.GetPty() != nil {
|
||||
err := proc.ResizeTty(&pty.Winsize{
|
||||
Rows: uint16(req.Msg.GetPty().GetSize().GetRows()),
|
||||
Cols: uint16(req.Msg.GetPty().GetSize().GetCols()),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error resizing tty: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return connect.NewResponse(&rpc.UpdateResponse{}), nil
|
||||
}
|
||||
Reference in New Issue
Block a user