Skip to content

Commit

Permalink
Run commands natively in service when interactive=false
Browse files Browse the repository at this point in the history
  • Loading branch information
adambabik committed Feb 2, 2024
1 parent 5782210 commit b6a3222
Show file tree
Hide file tree
Showing 10 changed files with 183 additions and 44 deletions.
1 change: 1 addition & 0 deletions internal/api/runme/runner/v2alpha1/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ message ProgramConfig {
}

// interactive, if true, uses a pseudo-tty to execute the program.
// Otherwise, the program is executed using in-memory buffers for I/O.
bool interactive = 7;

// TODO(adamb): understand motivation for this. In theory, source
Expand Down
23 changes: 19 additions & 4 deletions internal/command/command_native.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (
"go.uber.org/zap"
)

// signalToProcessGroup is used in tests to disable sending signals to a process group.
var signalToProcessGroup = true
// SignalToProcessGroup is used in tests to disable sending signals to a process group.
var SignalToProcessGroup = true

type NativeCommand struct {
cfg *Config
Expand All @@ -33,6 +33,21 @@ func newNativeCommand(cfg *Config, opts *NativeCommandOptions) *NativeCommand {
}
}

func (c *NativeCommand) Running() bool {
return c.cmd != nil && c.cmd.ProcessState == nil
}

func (c *NativeCommand) Pid() int {
if c.cmd == nil || c.cmd.Process == nil {
return 0
}
return c.cmd.Process.Pid
}

func (c *NativeCommand) SetWinsize(rows, cols, x, y uint16) error {
return errors.New("unsupported")
}

func (c *NativeCommand) Start(ctx context.Context) (err error) {
argsNormalizer := &argsNormalizer{
session: c.opts.Session,
Expand All @@ -54,7 +69,7 @@ func (c *NativeCommand) Start(ctx context.Context) (err error) {

if f, ok := stdin.(*os.File); ok && f != nil {
// Duplicate /dev/stdin.
newStdinFd, err := syscall.Dup(int(f.Fd()))
newStdinFd, err := dup(int(f.Fd()))
if err != nil {
return errors.Wrap(err, "failed to dup stdin")
}
Expand Down Expand Up @@ -102,7 +117,7 @@ func (c *NativeCommand) Start(ctx context.Context) (err error) {
}

func (c *NativeCommand) StopWithSignal(sig os.Signal) error {
if signalToProcessGroup {
if SignalToProcessGroup {
// Try to terminate the whole process group. If it fails, fall back to stdlib methods.
err := signalPgid(c.cmd.Process.Pid, sig)
if err == nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/command/command_native_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
func init() {
// Set to false to disable sending signals to process groups in tests.
// This can be turned on if setSysProcAttrPgid() is called in Start().
signalToProcessGroup = false
SignalToProcessGroup = false
}

func TestNativeCommand(t *testing.T) {
Expand Down
4 changes: 4 additions & 0 deletions internal/command/command_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ func disableEcho(fd uintptr) error {
return errors.Wrap(err, "failed to set tty attr")
}

func dup(fd int) (int, error) {
return syscall.Dup(fd)
}

// func setSysProcAttrPgid(cmd *exec.Cmd) {
// if cmd.SysProcAttr == nil {
// cmd.SysProcAttr = &syscall.SysProcAttr{}
Expand Down
4 changes: 2 additions & 2 deletions internal/command/command_virtual.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ func newVirtualCommand(cfg *Config, opts *VirtualCommandOptions) *VirtualCommand
}
}

func (c *VirtualCommand) IsRunning() bool {
func (c *VirtualCommand) Running() bool {
return c.cmd != nil && c.cmd.ProcessState == nil
}

func (c *VirtualCommand) PID() int {
func (c *VirtualCommand) Pid() int {
if c.cmd == nil || c.cmd.Process == nil {
return 0
}
Expand Down
5 changes: 3 additions & 2 deletions internal/command/command_virtual_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ func TestVirtualCommand(t *testing.T) {
}, nil)
require.NoError(t, err)
require.NoError(t, cmd.Start(context.Background()))
require.True(t, cmd.IsRunning())
require.Greater(t, cmd.PID(), 1)

require.True(t, cmd.Running())
require.Greater(t, cmd.Pid(), 1)
require.NoError(t, cmd.Wait())
})

Expand Down
27 changes: 27 additions & 0 deletions internal/command/command_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//go:build windows

package runner

import (
"os"
"os/exec"

"github.com/pkg/errors"
)

func setSysProcAttrCtty(cmd *exec.Cmd) {}

func setSysProcAttrPgid(cmd *exec.Cmd) {}

func dup(fd int) (int, error) {
return fd, nil
}

func disableEcho(fd uintptr) error {
return errors.New("Error: Environment not supported! " +
"Runme currently doesn't support PowerShell. " +
"Please go to https://github.com/stateful/runme/issues/173 to follow progress on this " +
"and join our Discord server at https://discord.gg/runme if you have further questions!")
}

func signalPgid(pid int, sig os.Signal) error { return errors.New("unsupported") }
103 changes: 76 additions & 27 deletions internal/runnerv2service/execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,38 +29,59 @@ const (
msgBufferSize = 2048 << 10 // 2 MiB
)

type commandIface interface {
Pid() int
Running() bool
SetWinsize(uint16, uint16, uint16, uint16) error
Start(context.Context) error
StopWithSignal(os.Signal) error
Wait() error
}

type execution struct {
ID string

Cmd *command.VirtualCommand
Cmd commandIface

stdin io.Reader
stdinWriter io.WriteCloser
stdout *rbuffer.RingBuffer
stderr *rbuffer.RingBuffer

logger *zap.Logger
}

func newExecution(id string, cfg *command.Config, logger *zap.Logger) (*execution, error) {
stdin, stdinWriter := io.Pipe()
stdout := rbuffer.NewRingBuffer(ringBufferSize)
stderr := rbuffer.NewRingBuffer(ringBufferSize)

var (
stdin io.Reader
stdinWriter io.WriteCloser
cmd commandIface
err error
)

if cfg.Interactive {
stdin, stdinWriter = io.Pipe()
cmd, err = command.NewVirtual(
cfg,
&command.VirtualCommandOptions{
Stdin: stdin,
Stdout: stdout,
Logger: logger,
},
)
} else {
cmd, err = command.NewNative(
cfg,
&command.NativeCommandOptions{
Stdin: stdin,
Stdout: stdout,
Stderr: stderr,
Logger: logger,
},
)
}

stdout := rbuffer.NewRingBuffer(ringBufferSize)

cmd, err := command.NewVirtual(
cfg,
&command.VirtualCommandOptions{
Stdin: stdin,
Stdout: stdout,
Logger: logger,
},
)
if err != nil {
return nil, err
}
Expand All @@ -72,6 +93,7 @@ func newExecution(id string, cfg *command.Config, logger *zap.Logger) (*executio
stdin: stdin,
stdinWriter: stdinWriter,
stdout: stdout,
stderr: stderr,

logger: logger,
}
Expand All @@ -84,9 +106,13 @@ func (e *execution) Start(ctx context.Context) error {
}

func (e *execution) Wait(ctx context.Context, sender sender) (int, error) {
errc := make(chan error, 1)
errc := make(chan error, 2)

go func() {
errc <- readSendLoop(e.stdout, sender)
errc <- readSendLoop(e.stdout, sender, func(b []byte) *runnerv2alpha1.ExecuteResponse { return &runnerv2alpha1.ExecuteResponse{StdoutData: b} })
}()
go func() {
errc <- readSendLoop(e.stderr, sender, func(b []byte) *runnerv2alpha1.ExecuteResponse { return &runnerv2alpha1.ExecuteResponse{StderrData: b} })
}()

waitErr := e.Cmd.Wait()
Expand All @@ -96,9 +122,17 @@ func (e *execution) Wait(ctx context.Context, sender sender) (int, error) {

// If waitErr is not nil, only log the errors but return waitErr.
if waitErr != nil {
handlerErrors := 0

readSendHandlerForWaitErr:
select {
case err := <-errc:
handlerErrors++
e.logger.Info("readSendLoop finished; ignoring any errors because there was a wait error", zap.Error(err))
// Wait for both errors, or nils.
if handlerErrors < 2 {
goto readSendHandlerForWaitErr
}
case <-ctx.Done():
e.logger.Info("context canceled while waiting for the readSendLoop finish; ignoring any errors because there was a wait error")
}
Expand All @@ -108,8 +142,14 @@ func (e *execution) Wait(ctx context.Context, sender sender) (int, error) {
// If waitErr is nil, wait for the readSendLoop to finish,
// or the context being canceled.
select {
case err := <-errc:
return exitCode, err
case err1 := <-errc:
// Wait for both errors, or nils.
select {
case err2 := <-errc:
e.logger.Info("another error from readSendLoop; won't be returned", zap.Error(err2))
case <-ctx.Done():
}
return exitCode, err1
case <-ctx.Done():
return exitCode, ctx.Err()
}
Expand All @@ -124,23 +164,32 @@ func (e *execution) SetWinsize(size *runnerv2alpha1.Winsize) error {
return e.Cmd.SetWinsize(uint16(size.Cols), uint16(size.Rows), uint16(size.X), uint16(size.Y))
}

func (e *execution) closeIO() {
var err error

if e.stdinWriter != nil {
err = e.stdinWriter.Close()
e.logger.Debug("closed stdin writer", zap.Error(err))
func (e *execution) PostInitialRequest() {
// Close stdin writer for native commands after handling the initial request.
// Native commands do not support sending data continously.
if _, ok := e.Cmd.(*command.NativeCommand); ok {
if err := e.stdinWriter.Close(); err != nil {
e.logger.Info("failed to close stdin writer", zap.Error(err))
}
}
}

func (e *execution) closeIO() {
err := e.stdinWriter.Close()
e.logger.Info("closed stdin writer", zap.Error(err))

err = e.stdout.Close()
e.logger.Debug("closed stdout writer", zap.Error(err))
e.logger.Info("closed stdout writer", zap.Error(err))

err = e.stderr.Close()
e.logger.Info("closed stderr writer", zap.Error(err))
}

type sender interface {
Send(*runnerv2alpha1.ExecuteResponse) error
}

func readSendLoop(reader io.Reader, sender sender) error {
func readSendLoop(reader io.Reader, sender sender, fn func([]byte) *runnerv2alpha1.ExecuteResponse) error {
limitedReader := io.LimitReader(reader, msgBufferSize)

for {
Expand All @@ -156,7 +205,7 @@ func readSendLoop(reader io.Reader, sender sender) error {
continue
}

err = sender.Send(&runnerv2alpha1.ExecuteResponse{StdoutData: buf[:n]})
err = sender.Send(fn(buf[:n]))
if err != nil {
return errors.WithStack(err)
}
Expand Down
10 changes: 6 additions & 4 deletions internal/runnerv2service/service_execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (r *runnerService) Execute(srv runnerv2alpha1.RunnerService_ExecuteServer)
}
if err := srv.Send(&runnerv2alpha1.ExecuteResponse{
Pid: &runnerv2alpha1.ProcessPID{
Pid: int64(exec.Cmd.PID()),
Pid: int64(exec.Cmd.Pid()),
},
}); err != nil {
return err
Expand All @@ -57,8 +57,8 @@ func (r *runnerService) Execute(srv runnerv2alpha1.RunnerService_ExecuteServer)
for {
var err error

if l := len(req.InputData); l > 0 {
logger.Info("received input data", zap.Int("len", l))
if req.InputData != nil {
logger.Info("received input data", zap.Int("len", len(req.InputData)))
_, err = exec.Write(req.InputData)
if err != nil {
logger.Info("failed to write to stdin; ignoring", zap.Error(err))
Expand Down Expand Up @@ -87,6 +87,8 @@ func (r *runnerService) Execute(srv runnerv2alpha1.RunnerService_ExecuteServer)
return
}

exec.PostInitialRequest()

req, err = srv.Recv()
logger.Info("received request", zap.Any("req", req), zap.Error(err))
switch {
Expand All @@ -99,7 +101,7 @@ func (r *runnerService) Execute(srv runnerv2alpha1.RunnerService_ExecuteServer)
}
return
case status.Convert(err).Code() == codes.Canceled || status.Convert(err).Code() == codes.DeadlineExceeded:
if !exec.Cmd.IsRunning() {
if !exec.Cmd.Running() {
logger.Info("stream canceled after the process finished; ignoring")
} else {
logger.Info("stream canceled while the process is still running; program will be stopped if non-background")
Expand Down
Loading

0 comments on commit b6a3222

Please sign in to comment.