481 lines
12 KiB
Go
481 lines
12 KiB
Go
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package handler
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"os/exec"
|
|
"os/user"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
|
|
"connectrpc.com/connect"
|
|
"github.com/creack/pty"
|
|
"github.com/rs/zerolog"
|
|
|
|
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
|
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
|
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
|
"git.omukk.dev/wrenn/sandbox/envd/internal/services/cgroups"
|
|
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
|
)
|
|
|
|
const (
|
|
defaultNice = 0
|
|
defaultOomScore = 100
|
|
outputBufferSize = 64
|
|
stdChunkSize = 2 << 14
|
|
ptyChunkSize = 2 << 13
|
|
)
|
|
|
|
type ProcessExit struct {
|
|
Error *string
|
|
Status string
|
|
Exited bool
|
|
Code int32
|
|
}
|
|
|
|
type Handler struct {
|
|
Config *rpc.ProcessConfig
|
|
|
|
logger *zerolog.Logger
|
|
|
|
Tag *string
|
|
cmd *exec.Cmd
|
|
tty *os.File
|
|
|
|
cancel context.CancelFunc
|
|
|
|
outCtx context.Context //nolint:containedctx // todo: refactor so this can be removed
|
|
outCancel context.CancelFunc
|
|
|
|
stdinMu sync.Mutex
|
|
stdin io.WriteCloser
|
|
|
|
DataEvent *MultiplexedChannel[rpc.ProcessEvent_Data]
|
|
EndEvent *MultiplexedChannel[rpc.ProcessEvent_End]
|
|
}
|
|
|
|
// This method must be called only after the process has been started
|
|
func (p *Handler) Pid() uint32 {
|
|
return uint32(p.cmd.Process.Pid)
|
|
}
|
|
|
|
// userCommand returns a human-readable representation of the user's original command,
|
|
// without the internal OOM/nice wrapper that is prepended to the actual exec.
|
|
func (p *Handler) userCommand() string {
|
|
return strings.Join(append([]string{p.Config.GetCmd()}, p.Config.GetArgs()...), " ")
|
|
}
|
|
|
|
// currentNice returns the nice value of the current process.
|
|
func currentNice() int {
|
|
prio, err := syscall.Getpriority(syscall.PRIO_PROCESS, 0)
|
|
if err != nil {
|
|
return 0
|
|
}
|
|
|
|
// Getpriority returns 20 - nice on Linux.
|
|
return 20 - prio
|
|
}
|
|
|
|
func New(
|
|
ctx context.Context,
|
|
user *user.User,
|
|
req *rpc.StartRequest,
|
|
logger *zerolog.Logger,
|
|
defaults *execcontext.Defaults,
|
|
cgroupManager cgroups.Manager,
|
|
cancel context.CancelFunc,
|
|
) (*Handler, error) {
|
|
// User command string for logging (without the internal wrapper details).
|
|
userCmd := strings.Join(append([]string{req.GetProcess().GetCmd()}, req.GetProcess().GetArgs()...), " ")
|
|
|
|
// Wrap the command in a shell that sets the OOM score and nice value before exec-ing the actual command.
|
|
// This eliminates the race window where grandchildren could inherit the parent's protected OOM score (-1000)
|
|
// or high CPU priority (nice -20) before the post-start calls had a chance to correct them.
|
|
// nice(1) applies a relative adjustment, so we compute the delta from the current (inherited) nice to the target.
|
|
niceDelta := defaultNice - currentNice()
|
|
oomWrapperScript := fmt.Sprintf(`echo %d > /proc/$$/oom_score_adj && exec /usr/bin/nice -n %d "${@}"`, defaultOomScore, niceDelta)
|
|
wrapperArgs := append([]string{"-c", oomWrapperScript, "--", req.GetProcess().GetCmd()}, req.GetProcess().GetArgs()...)
|
|
cmd := exec.CommandContext(ctx, "/bin/sh", wrapperArgs...)
|
|
|
|
uid, gid, err := permissions.GetUserIdUints(user)
|
|
if err != nil {
|
|
return nil, connect.NewError(connect.CodeInternal, err)
|
|
}
|
|
|
|
groups := []uint32{gid}
|
|
if gids, err := user.GroupIds(); err != nil {
|
|
logger.Warn().Err(err).Str("user", user.Username).Msg("failed to get supplementary groups")
|
|
} else {
|
|
for _, g := range gids {
|
|
if parsed, err := strconv.ParseUint(g, 10, 32); err == nil {
|
|
groups = append(groups, uint32(parsed))
|
|
}
|
|
}
|
|
}
|
|
|
|
cgroupFD, ok := cgroupManager.GetFileDescriptor(getProcType(req))
|
|
|
|
cmd.SysProcAttr = &syscall.SysProcAttr{
|
|
UseCgroupFD: ok,
|
|
CgroupFD: cgroupFD,
|
|
Credential: &syscall.Credential{
|
|
Uid: uid,
|
|
Gid: gid,
|
|
Groups: groups,
|
|
},
|
|
}
|
|
|
|
resolvedPath, err := permissions.ExpandAndResolve(req.GetProcess().GetCwd(), user, defaults.Workdir)
|
|
if err != nil {
|
|
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
|
}
|
|
|
|
// Check if the cwd resolved path exists
|
|
if _, err := os.Stat(resolvedPath); errors.Is(err, os.ErrNotExist) {
|
|
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("cwd '%s' does not exist", resolvedPath))
|
|
}
|
|
|
|
cmd.Dir = resolvedPath
|
|
|
|
var formattedVars []string
|
|
|
|
// Take only 'PATH' variable from the current environment
|
|
// The 'PATH' should ideally be set in the environment
|
|
formattedVars = append(formattedVars, "PATH="+os.Getenv("PATH"))
|
|
formattedVars = append(formattedVars, "HOME="+user.HomeDir)
|
|
formattedVars = append(formattedVars, "USER="+user.Username)
|
|
formattedVars = append(formattedVars, "LOGNAME="+user.Username)
|
|
|
|
// Add the environment variables from the global environment
|
|
if defaults.EnvVars != nil {
|
|
defaults.EnvVars.Range(func(key string, value string) bool {
|
|
formattedVars = append(formattedVars, key+"="+value)
|
|
|
|
return true
|
|
})
|
|
}
|
|
|
|
// Only the last values of the env vars are used - this allows for overwriting defaults
|
|
for key, value := range req.GetProcess().GetEnvs() {
|
|
formattedVars = append(formattedVars, key+"="+value)
|
|
}
|
|
|
|
cmd.Env = formattedVars
|
|
|
|
outMultiplex := NewMultiplexedChannel[rpc.ProcessEvent_Data](outputBufferSize)
|
|
|
|
var outWg sync.WaitGroup
|
|
|
|
// Create a context for waiting for and cancelling output pipes.
|
|
// Cancellation of the process via timeout will propagate and cancel this context too.
|
|
outCtx, outCancel := context.WithCancel(ctx)
|
|
|
|
h := &Handler{
|
|
Config: req.GetProcess(),
|
|
cmd: cmd,
|
|
Tag: req.Tag,
|
|
DataEvent: outMultiplex,
|
|
cancel: cancel,
|
|
outCtx: outCtx,
|
|
outCancel: outCancel,
|
|
EndEvent: NewMultiplexedChannel[rpc.ProcessEvent_End](0),
|
|
logger: logger,
|
|
}
|
|
|
|
if req.GetPty() != nil {
|
|
// The pty should ideally start only in the Start method, but the package does not support that and we would have to code it manually.
|
|
// The output of the pty should correctly be passed though.
|
|
tty, err := pty.StartWithSize(cmd, &pty.Winsize{
|
|
Cols: uint16(req.GetPty().GetSize().GetCols()),
|
|
Rows: uint16(req.GetPty().GetSize().GetRows()),
|
|
})
|
|
if err != nil {
|
|
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("error starting pty with command '%s' in dir '%s' with '%d' cols and '%d' rows: %w", userCmd, cmd.Dir, req.GetPty().GetSize().GetCols(), req.GetPty().GetSize().GetRows(), err))
|
|
}
|
|
|
|
outWg.Go(func() {
|
|
for {
|
|
buf := make([]byte, ptyChunkSize)
|
|
|
|
n, readErr := tty.Read(buf)
|
|
|
|
if n > 0 {
|
|
outMultiplex.Source <- rpc.ProcessEvent_Data{
|
|
Data: &rpc.ProcessEvent_DataEvent{
|
|
Output: &rpc.ProcessEvent_DataEvent_Pty{
|
|
Pty: buf[:n],
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
if errors.Is(readErr, io.EOF) {
|
|
break
|
|
}
|
|
|
|
if readErr != nil {
|
|
fmt.Fprintf(os.Stderr, "error reading from pty: %s\n", readErr)
|
|
|
|
break
|
|
}
|
|
}
|
|
})
|
|
|
|
h.tty = tty
|
|
} else {
|
|
stdout, err := cmd.StdoutPipe()
|
|
if err != nil {
|
|
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating stdout pipe for command '%s': %w", userCmd, err))
|
|
}
|
|
|
|
outWg.Go(func() {
|
|
stdoutLogs := make(chan []byte, outputBufferSize)
|
|
defer close(stdoutLogs)
|
|
|
|
stdoutLogger := logger.With().Str("event_type", "stdout").Logger()
|
|
|
|
go logs.LogBufferedDataEvents(stdoutLogs, &stdoutLogger, "data")
|
|
|
|
for {
|
|
buf := make([]byte, stdChunkSize)
|
|
|
|
n, readErr := stdout.Read(buf)
|
|
|
|
if n > 0 {
|
|
outMultiplex.Source <- rpc.ProcessEvent_Data{
|
|
Data: &rpc.ProcessEvent_DataEvent{
|
|
Output: &rpc.ProcessEvent_DataEvent_Stdout{
|
|
Stdout: buf[:n],
|
|
},
|
|
},
|
|
}
|
|
|
|
stdoutLogs <- buf[:n]
|
|
}
|
|
|
|
if errors.Is(readErr, io.EOF) {
|
|
break
|
|
}
|
|
|
|
if readErr != nil {
|
|
fmt.Fprintf(os.Stderr, "error reading from stdout: %s\n", readErr)
|
|
|
|
break
|
|
}
|
|
}
|
|
})
|
|
|
|
stderr, err := cmd.StderrPipe()
|
|
if err != nil {
|
|
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating stderr pipe for command '%s': %w", userCmd, err))
|
|
}
|
|
|
|
outWg.Go(func() {
|
|
stderrLogs := make(chan []byte, outputBufferSize)
|
|
defer close(stderrLogs)
|
|
|
|
stderrLogger := logger.With().Str("event_type", "stderr").Logger()
|
|
|
|
go logs.LogBufferedDataEvents(stderrLogs, &stderrLogger, "data")
|
|
|
|
for {
|
|
buf := make([]byte, stdChunkSize)
|
|
|
|
n, readErr := stderr.Read(buf)
|
|
|
|
if n > 0 {
|
|
outMultiplex.Source <- rpc.ProcessEvent_Data{
|
|
Data: &rpc.ProcessEvent_DataEvent{
|
|
Output: &rpc.ProcessEvent_DataEvent_Stderr{
|
|
Stderr: buf[:n],
|
|
},
|
|
},
|
|
}
|
|
|
|
stderrLogs <- buf[:n]
|
|
}
|
|
|
|
if errors.Is(readErr, io.EOF) {
|
|
break
|
|
}
|
|
|
|
if readErr != nil {
|
|
fmt.Fprintf(os.Stderr, "error reading from stderr: %s\n", readErr)
|
|
|
|
break
|
|
}
|
|
}
|
|
})
|
|
|
|
// For backwards compatibility we still set the stdin if not explicitly disabled
|
|
// If stdin is disabled, the process will use /dev/null as stdin
|
|
if req.Stdin == nil || req.GetStdin() == true {
|
|
stdin, err := cmd.StdinPipe()
|
|
if err != nil {
|
|
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating stdin pipe for command '%s': %w", userCmd, err))
|
|
}
|
|
|
|
h.stdin = stdin
|
|
}
|
|
}
|
|
|
|
go func() {
|
|
outWg.Wait()
|
|
|
|
close(outMultiplex.Source)
|
|
|
|
outCancel()
|
|
}()
|
|
|
|
return h, nil
|
|
}
|
|
|
|
func getProcType(req *rpc.StartRequest) cgroups.ProcessType {
|
|
if req != nil && req.GetPty() != nil {
|
|
return cgroups.ProcessTypePTY
|
|
}
|
|
|
|
return cgroups.ProcessTypeUser
|
|
}
|
|
|
|
func (p *Handler) SendSignal(signal syscall.Signal) error {
|
|
if p.cmd.Process == nil {
|
|
return fmt.Errorf("process not started")
|
|
}
|
|
|
|
if signal == syscall.SIGKILL || signal == syscall.SIGTERM {
|
|
p.outCancel()
|
|
}
|
|
|
|
return p.cmd.Process.Signal(signal)
|
|
}
|
|
|
|
func (p *Handler) ResizeTty(size *pty.Winsize) error {
|
|
if p.tty == nil {
|
|
return fmt.Errorf("tty not assigned to process")
|
|
}
|
|
|
|
return pty.Setsize(p.tty, size)
|
|
}
|
|
|
|
func (p *Handler) WriteStdin(data []byte) error {
|
|
if p.tty != nil {
|
|
return fmt.Errorf("tty assigned to process — input should be written to the pty, not the stdin")
|
|
}
|
|
|
|
p.stdinMu.Lock()
|
|
defer p.stdinMu.Unlock()
|
|
|
|
if p.stdin == nil {
|
|
return fmt.Errorf("stdin not enabled or closed")
|
|
}
|
|
|
|
_, err := p.stdin.Write(data)
|
|
if err != nil {
|
|
return fmt.Errorf("error writing to stdin of process '%d': %w", p.cmd.Process.Pid, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// CloseStdin closes the stdin pipe to signal EOF to the process.
|
|
// Only works for non-PTY processes.
|
|
func (p *Handler) CloseStdin() error {
|
|
if p.tty != nil {
|
|
return fmt.Errorf("cannot close stdin for PTY process — send Ctrl+D (0x04) instead")
|
|
}
|
|
|
|
p.stdinMu.Lock()
|
|
defer p.stdinMu.Unlock()
|
|
|
|
if p.stdin == nil {
|
|
return nil
|
|
}
|
|
|
|
err := p.stdin.Close()
|
|
// We still set the stdin to nil even on error as there are no errors,
|
|
// for which it is really safe to retry close across all distributions.
|
|
p.stdin = nil
|
|
|
|
return err
|
|
}
|
|
|
|
func (p *Handler) WriteTty(data []byte) error {
|
|
if p.tty == nil {
|
|
return fmt.Errorf("tty not assigned to process — input should be written to the stdin, not the tty")
|
|
}
|
|
|
|
_, err := p.tty.Write(data)
|
|
if err != nil {
|
|
return fmt.Errorf("error writing to tty of process '%d': %w", p.cmd.Process.Pid, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *Handler) Start() (uint32, error) {
|
|
// Pty is already started in the New method
|
|
if p.tty == nil {
|
|
err := p.cmd.Start()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("error starting process '%s': %w", p.userCommand(), err)
|
|
}
|
|
}
|
|
|
|
p.logger.
|
|
Info().
|
|
Str("event_type", "process_start").
|
|
Int("pid", p.cmd.Process.Pid).
|
|
Str("command", p.userCommand()).
|
|
Msg(fmt.Sprintf("Process with pid %d started", p.cmd.Process.Pid))
|
|
|
|
return uint32(p.cmd.Process.Pid), nil
|
|
}
|
|
|
|
func (p *Handler) Wait() {
|
|
// Wait for the output pipes to be closed or cancelled.
|
|
<-p.outCtx.Done()
|
|
|
|
err := p.cmd.Wait()
|
|
|
|
p.tty.Close()
|
|
|
|
var errMsg *string
|
|
|
|
if err != nil {
|
|
msg := err.Error()
|
|
errMsg = &msg
|
|
}
|
|
|
|
endEvent := &rpc.ProcessEvent_EndEvent{
|
|
Error: errMsg,
|
|
ExitCode: int32(p.cmd.ProcessState.ExitCode()),
|
|
Exited: p.cmd.ProcessState.Exited(),
|
|
Status: p.cmd.ProcessState.String(),
|
|
}
|
|
|
|
event := rpc.ProcessEvent_End{
|
|
End: endEvent,
|
|
}
|
|
|
|
p.EndEvent.Source <- event
|
|
|
|
p.logger.
|
|
Info().
|
|
Str("event_type", "process_end").
|
|
Interface("process_result", endEvent).
|
|
Msg(fmt.Sprintf("Process with pid %d ended", p.cmd.Process.Pid))
|
|
|
|
// Ensure the process cancel is called to cleanup resources.
|
|
// As it is called after end event and Wait, it should not affect command execution or returned events.
|
|
p.cancel()
|
|
}
|