Skip to content

Commit

Permalink
Use trap to collect envs; better error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
adambabik committed Feb 2, 2024
1 parent 0e89d1f commit b3883b4
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 61 deletions.
40 changes: 21 additions & 19 deletions internal/command/command_exec_normalizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (n *argsNormalizer) Normalize(cfg *Config) (*Config, error) {

if isShellLanguage(filepath.Base(cfg.ProgramName)) {
if n.session != nil {
_, _ = buf.WriteString(fmt.Sprintf("%s > %s\n", EnvDumpCommand, filepath.Join(n.tempDir, envEndFileName)))
_, _ = buf.WriteString(fmt.Sprintf("trap \"%s > %s\" EXIT\n", EnvDumpCommand, filepath.Join(n.tempDir, envEndFileName)))

n.isEnvCollectable = true
}
Expand Down Expand Up @@ -107,57 +107,59 @@ func (n *argsNormalizer) Normalize(cfg *Config) (*Config, error) {
return result, nil
}

func (n *argsNormalizer) Cleanup() {
func (n *argsNormalizer) Cleanup() error {
if n.tempDir == "" {
return
return nil
}

n.logger.Info("cleaning up the temporary dir")

if err := os.RemoveAll(n.tempDir); err != nil {
n.logger.Info("failed to remove temporary dir", zap.Error(err))
return errors.WithMessage(err, "failed to remove the temporary dir")
}

return nil
}

func (n *argsNormalizer) CollectEnv() {
func (n *argsNormalizer) CollectEnv() error {
if n.session == nil || !n.isEnvCollectable {
return
return nil
}

n.logger.Info("collecting env")

startEnv, err := n.readEnvFromFile(envStartFileName)
if err != nil {
n.logger.Info("failed to read the start env file", zap.Error(err))
return
return err
}

endEnv, err := n.readEnvFromFile(envEndFileName)
if err != nil {
n.logger.Info("failed to read the end env file", zap.Error(err))
return
return err
}

// Below, we diff the env collected before and after the script execution.
// Then, update the session with the new or updated env and delete the deleted env.

startEnvStore := newEnvStore()
if _, err := startEnvStore.Merge(startEnv...); err != nil {
n.logger.Info("failed to create the start env store", zap.Error(err))
return
return errors.WithMessage(err, "failed to create the start env store")
}

endEnvStore := newEnvStore()
if _, err := endEnvStore.Merge(endEnv...); err != nil {
n.logger.Info("failed to create the end env store", zap.Error(err))
return
return errors.WithMessage(err, "failed to create the end env store")
}

newOrUpdated, _, deleted := diffEnvStores(startEnvStore, endEnvStore)

if err := n.session.SetEnv(newOrUpdated...); err != nil {
n.logger.Info("failed to set the new or updated env", zap.Error(err))
return
return errors.WithMessage(err, "failed to set the new or updated env")
}

n.session.DeleteEnv(deleted...)

return nil
}

func (n *argsNormalizer) createTempDir() (err error) {
Expand All @@ -182,7 +184,7 @@ func (n *argsNormalizer) writeScript(script []byte) error {
func (n *argsNormalizer) readEnvFromFile(name string) (result []string, _ error) {
f, err := os.Open(filepath.Join(n.tempDir, name))
if err != nil {
return nil, errors.WithStack(err)
return nil, errors.WithMessagef(err, "failed to open the env file %q", name)
}
defer func() { _ = f.Close() }()

Expand All @@ -194,10 +196,10 @@ func (n *argsNormalizer) readEnvFromFile(name string) (result []string, _ error)
}

if err := scanner.Err(); err != nil {
return nil, errors.WithStack(err)
return nil, errors.WithMessagef(err, "failed to scan the env file %q", name)
}

return result, errors.WithStack(scanner.Err())
return result, nil
}

func splitNull(data []byte, atEOF bool) (advance int, token []byte, err error) {
Expand Down
34 changes: 23 additions & 11 deletions internal/command/command_native.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os/exec"

"github.com/pkg/errors"
"go.uber.org/multierr"
"go.uber.org/zap"
)

Expand All @@ -19,7 +20,7 @@ type NativeCommand struct {
// cmd is populated when the command is started.
cmd *exec.Cmd

cleanFuncs []func()
cleanFuncs []func() error

logger *zap.Logger
}
Expand Down Expand Up @@ -104,18 +105,20 @@ func (c *NativeCommand) Start(ctx context.Context) (err error) {
// like "python", hence, it's commented out.
// setSysProcAttrPgid(c.cmd)

c.logger.Info("starting a local command", zap.Any("config", redactConfig(cfg)))
c.logger.Info("starting a native command", zap.Any("config", redactConfig(cfg)))

if err := c.cmd.Start(); err != nil {
return errors.WithStack(err)
}

c.logger.Info("a local command started")
c.logger.Info("a native command started")

return nil
}

func (c *NativeCommand) StopWithSignal(sig os.Signal) error {
c.logger.Info("stopping the native command with a signal", zap.Stringer("signal", sig))

if SignalToProcessGroup {
// Try to terminate the whole process group. If it fails, fall back to stdlib methods.
err := signalPgid(c.cmd.Process.Pid, sig)
Expand All @@ -133,28 +136,37 @@ func (c *NativeCommand) StopWithSignal(sig os.Signal) error {
return nil
}

func (c *NativeCommand) Wait() error {
c.logger.Info("waiting for the local command to finish")
func (c *NativeCommand) Wait() (err error) {
c.logger.Info("waiting for the native command to finish")

defer c.cleanup()
defer func() {
errC := errors.WithStack(c.cleanup())
c.logger.Info("cleaned up the native command", zap.Error(errC))
if err == nil && errC != nil {
err = errC
}
}()

var stderr []byte

err := c.cmd.Wait()
err = errors.WithStack(c.cmd.Wait())
if err != nil {
var exitErr *exec.ExitError
if errors.As(err, &exitErr) {
stderr = exitErr.Stderr
}
}

c.logger.Info("the local command finished", zap.Error(err), zap.ByteString("stderr", stderr))
c.logger.Info("the native command finished", zap.Error(err), zap.ByteString("stderr", stderr))

return errors.WithStack(err)
return
}

func (c *NativeCommand) cleanup() {
func (c *NativeCommand) cleanup() (err error) {
for _, fn := range c.cleanFuncs {
fn()
if errFn := fn(); errFn != nil {
err = multierr.Append(err, errFn)
}
}
return
}
60 changes: 35 additions & 25 deletions internal/command/command_virtual.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/creack/pty"
"github.com/pkg/errors"
"go.uber.org/multierr"
"go.uber.org/zap"
)

Expand All @@ -24,14 +25,14 @@ type VirtualCommand struct {
// stdin is Opts.Stdin wrapped in readCloser.
stdin io.ReadCloser

cleanFuncs []func()
cleanFuncs []func() error

pty *os.File
tty *os.File

wg sync.WaitGroup

mx sync.Mutex
mu sync.Mutex
err error

logger *zap.Logger
Expand Down Expand Up @@ -207,64 +208,73 @@ func (c *VirtualCommand) StopWithSignal(sig os.Signal) error {
return nil
}

func (c *VirtualCommand) Wait() error {
func (c *VirtualCommand) Wait() (err error) {
c.logger.Info("waiting for the virtual command to finish")

defer c.cleanup()
defer func() {
errC := errors.WithStack(c.cleanup())
c.logger.Info("cleaned up the virtual command", zap.Error(errC))
if err == nil && errC != nil {
err = errC
}
}()

waitErr := c.cmd.Wait()
c.logger.Info("the virtual command finished", zap.Error(waitErr))
err = errors.WithStack(c.cmd.Wait())
c.logger.Info("the virtual command finished", zap.Error(err))

if err := c.closeIO(); err != nil {
return err
errIO := c.closeIO()
c.logger.Info("closed IO of the virtual command", zap.Error(errIO))
if err == nil && errIO != nil {
err = errIO
}

c.wg.Wait()

if waitErr != nil {
return errors.WithStack(waitErr)
c.mu.Lock()
if err == nil && c.err != nil {
err = c.err
}
c.mu.Unlock()

c.mx.Lock()
err := c.err
c.mx.Unlock()

return err
return
}

func (c *VirtualCommand) setErr(err error) {
if err == nil {
return
}
c.mx.Lock()
c.mu.Lock()
if c.err == nil {
c.err = err
}
c.mx.Unlock()
c.mu.Unlock()
}

func (c *VirtualCommand) closeIO() error {
func (c *VirtualCommand) closeIO() (err error) {
if !isNil(c.stdin) {
if err := c.stdin.Close(); err != nil {
return errors.WithMessage(err, "failed to close stdin")
if errClose := c.stdin.Close(); errClose != nil {
err = multierr.Append(err, errors.WithMessage(errClose, "failed to close stdin"))
}
}

if err := c.tty.Close(); err != nil {
return errors.WithMessage(err, "failed to close tty")
if errClose := c.tty.Close(); errClose != nil {
err = multierr.Append(err, errors.WithMessage(errClose, "failed to close tty"))
}

// if err := c.pty.Close(); err != nil {
// return errors.WithMessage(err, "failed to close pty")
// }

return nil
return
}

func (c *VirtualCommand) cleanup() {
func (c *VirtualCommand) cleanup() (err error) {
for _, fn := range c.cleanFuncs {
fn()
if errFn := fn(); errFn != nil {
err = multierr.Append(err, errFn)
}
}
return
}

type Winsize pty.Winsize
Expand Down
4 changes: 3 additions & 1 deletion internal/runnerv2service/execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ func (e *execution) Wait(ctx context.Context, sender sender) (int, error) {
// Wait for both errors, or nils.
select {
case err2 := <-errc:
e.logger.Info("another error from readSendLoop; won't be returned", zap.Error(err2))
if err2 != nil {
e.logger.Info("another error from readSendLoop; won't be returned", zap.Error(err2))
}
case <-ctx.Done():
}
return exitCode, err1
Expand Down
16 changes: 11 additions & 5 deletions internal/runnerv2service/service_execute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package runnerv2service

import (
"bytes"
"context"
"io"
"net"
Expand Down Expand Up @@ -58,19 +59,24 @@ func TestRunnerServiceServerExecute(t *testing.T) {
assert.Greater(t, resp.Pid.Value, uint32(1))
assert.Nil(t, resp.ExitCode)

// Assert second response.
// Assert second and third responses.
var out bytes.Buffer

resp, err = stream.Recv()
assert.NoError(t, err)
assert.Equal(t, "test\n", string(resp.StdoutData))
assert.Nil(t, resp.ExitCode)
assert.Nil(t, resp.Pid)
_, err = out.Write(resp.StdoutData)
assert.NoError(t, err)

// Assert third response.
resp, err = stream.Recv()
assert.NoError(t, err)
assert.Equal(t, "test\n", string(resp.StderrData))
assert.Nil(t, resp.ExitCode)
assert.Nil(t, resp.Pid)
_, err = out.Write(resp.StderrData)
assert.NoError(t, err)

assert.Equal(t, "test\ntest\n", out.String())

// Assert fourth response.
resp, err = stream.Recv()
Expand Down Expand Up @@ -222,7 +228,7 @@ func TestRunnerServiceServerExecuteConfigs(t *testing.T) {
}
}

func TestRunnerServiceServerExecute_Input(t *testing.T) {
func TestRunnerServiceServerExecuteWithInput(t *testing.T) {
lis, stop := testStartRunnerServiceServer(t)
t.Cleanup(stop)
_, client := testCreateRunnerServiceClient(t, lis)
Expand Down

0 comments on commit b3883b4

Please sign in to comment.