diff --git a/internal/command/command_exec_normalizer.go b/internal/command/command_exec_normalizer.go index 8cbe10c12..5f8fc301a 100644 --- a/internal/command/command_exec_normalizer.go +++ b/internal/command/command_exec_normalizer.go @@ -74,7 +74,7 @@ func (n *argsNormalizer) Normalize(cfg *Config) (*Config, error) { if isShellLanguage(filepath.Base(cfg.ProgramName)) { if n.session != nil { - _, _ = buf.WriteString(fmt.Sprintf("%s > %s\n", EnvDumpCommand, filepath.Join(n.tempDir, envEndFileName))) + _, _ = buf.WriteString(fmt.Sprintf("trap \"%s > %s\" EXIT\n", EnvDumpCommand, filepath.Join(n.tempDir, envEndFileName))) n.isEnvCollectable = true } @@ -107,57 +107,59 @@ func (n *argsNormalizer) Normalize(cfg *Config) (*Config, error) { return result, nil } -func (n *argsNormalizer) Cleanup() { +func (n *argsNormalizer) Cleanup() error { if n.tempDir == "" { - return + return nil } n.logger.Info("cleaning up the temporary dir") if err := os.RemoveAll(n.tempDir); err != nil { - n.logger.Info("failed to remove temporary dir", zap.Error(err)) + return errors.WithMessage(err, "failed to remove the temporary dir") } + + return nil } -func (n *argsNormalizer) CollectEnv() { +func (n *argsNormalizer) CollectEnv() error { if n.session == nil || !n.isEnvCollectable { - return + return nil } n.logger.Info("collecting env") startEnv, err := n.readEnvFromFile(envStartFileName) if err != nil { - n.logger.Info("failed to read the start env file", zap.Error(err)) - return + return err } endEnv, err := n.readEnvFromFile(envEndFileName) if err != nil { - n.logger.Info("failed to read the end env file", zap.Error(err)) - return + return err } + // Below, we diff the env collected before and after the script execution. + // Then, update the session with the new or updated env and delete the deleted env. + startEnvStore := newEnvStore() if _, err := startEnvStore.Merge(startEnv...); err != nil { - n.logger.Info("failed to create the start env store", zap.Error(err)) - return + return errors.WithMessage(err, "failed to create the start env store") } endEnvStore := newEnvStore() if _, err := endEnvStore.Merge(endEnv...); err != nil { - n.logger.Info("failed to create the end env store", zap.Error(err)) - return + return errors.WithMessage(err, "failed to create the end env store") } newOrUpdated, _, deleted := diffEnvStores(startEnvStore, endEnvStore) if err := n.session.SetEnv(newOrUpdated...); err != nil { - n.logger.Info("failed to set the new or updated env", zap.Error(err)) - return + return errors.WithMessage(err, "failed to set the new or updated env") } n.session.DeleteEnv(deleted...) + + return nil } func (n *argsNormalizer) createTempDir() (err error) { @@ -182,7 +184,7 @@ func (n *argsNormalizer) writeScript(script []byte) error { func (n *argsNormalizer) readEnvFromFile(name string) (result []string, _ error) { f, err := os.Open(filepath.Join(n.tempDir, name)) if err != nil { - return nil, errors.WithStack(err) + return nil, errors.WithMessagef(err, "failed to open the env file %q", name) } defer func() { _ = f.Close() }() @@ -194,10 +196,10 @@ func (n *argsNormalizer) readEnvFromFile(name string) (result []string, _ error) } if err := scanner.Err(); err != nil { - return nil, errors.WithStack(err) + return nil, errors.WithMessagef(err, "failed to scan the env file %q", name) } - return result, errors.WithStack(scanner.Err()) + return result, nil } func splitNull(data []byte, atEOF bool) (advance int, token []byte, err error) { diff --git a/internal/command/command_native.go b/internal/command/command_native.go index 086237fa2..1d6cb0abd 100644 --- a/internal/command/command_native.go +++ b/internal/command/command_native.go @@ -6,6 +6,7 @@ import ( "os/exec" "github.com/pkg/errors" + "go.uber.org/multierr" "go.uber.org/zap" ) @@ -19,7 +20,7 @@ type NativeCommand struct { // cmd is populated when the command is started. cmd *exec.Cmd - cleanFuncs []func() + cleanFuncs []func() error logger *zap.Logger } @@ -104,18 +105,20 @@ func (c *NativeCommand) Start(ctx context.Context) (err error) { // like "python", hence, it's commented out. // setSysProcAttrPgid(c.cmd) - c.logger.Info("starting a local command", zap.Any("config", redactConfig(cfg))) + c.logger.Info("starting a native command", zap.Any("config", redactConfig(cfg))) if err := c.cmd.Start(); err != nil { return errors.WithStack(err) } - c.logger.Info("a local command started") + c.logger.Info("a native command started") return nil } func (c *NativeCommand) StopWithSignal(sig os.Signal) error { + c.logger.Info("stopping the native command with a signal", zap.Stringer("signal", sig)) + if SignalToProcessGroup { // Try to terminate the whole process group. If it fails, fall back to stdlib methods. err := signalPgid(c.cmd.Process.Pid, sig) @@ -133,14 +136,20 @@ func (c *NativeCommand) StopWithSignal(sig os.Signal) error { return nil } -func (c *NativeCommand) Wait() error { - c.logger.Info("waiting for the local command to finish") +func (c *NativeCommand) Wait() (err error) { + c.logger.Info("waiting for the native command to finish") - defer c.cleanup() + defer func() { + errC := errors.WithStack(c.cleanup()) + c.logger.Info("cleaned up the native command", zap.Error(errC)) + if err == nil && errC != nil { + err = errC + } + }() var stderr []byte - err := c.cmd.Wait() + err = errors.WithStack(c.cmd.Wait()) if err != nil { var exitErr *exec.ExitError if errors.As(err, &exitErr) { @@ -148,13 +157,16 @@ func (c *NativeCommand) Wait() error { } } - c.logger.Info("the local command finished", zap.Error(err), zap.ByteString("stderr", stderr)) + c.logger.Info("the native command finished", zap.Error(err), zap.ByteString("stderr", stderr)) - return errors.WithStack(err) + return } -func (c *NativeCommand) cleanup() { +func (c *NativeCommand) cleanup() (err error) { for _, fn := range c.cleanFuncs { - fn() + if errFn := fn(); errFn != nil { + err = multierr.Append(err, errFn) + } } + return } diff --git a/internal/command/command_virtual.go b/internal/command/command_virtual.go index cdd61d096..dc55d6f86 100644 --- a/internal/command/command_virtual.go +++ b/internal/command/command_virtual.go @@ -11,6 +11,7 @@ import ( "github.com/creack/pty" "github.com/pkg/errors" + "go.uber.org/multierr" "go.uber.org/zap" ) @@ -24,14 +25,14 @@ type VirtualCommand struct { // stdin is Opts.Stdin wrapped in readCloser. stdin io.ReadCloser - cleanFuncs []func() + cleanFuncs []func() error pty *os.File tty *os.File wg sync.WaitGroup - mx sync.Mutex + mu sync.Mutex err error logger *zap.Logger @@ -207,64 +208,73 @@ func (c *VirtualCommand) StopWithSignal(sig os.Signal) error { return nil } -func (c *VirtualCommand) Wait() error { +func (c *VirtualCommand) Wait() (err error) { c.logger.Info("waiting for the virtual command to finish") - defer c.cleanup() + defer func() { + errC := errors.WithStack(c.cleanup()) + c.logger.Info("cleaned up the virtual command", zap.Error(errC)) + if err == nil && errC != nil { + err = errC + } + }() - waitErr := c.cmd.Wait() - c.logger.Info("the virtual command finished", zap.Error(waitErr)) + err = errors.WithStack(c.cmd.Wait()) + c.logger.Info("the virtual command finished", zap.Error(err)) - if err := c.closeIO(); err != nil { - return err + errIO := c.closeIO() + c.logger.Info("closed IO of the virtual command", zap.Error(errIO)) + if err == nil && errIO != nil { + err = errIO } c.wg.Wait() - if waitErr != nil { - return errors.WithStack(waitErr) + c.mu.Lock() + if err == nil && c.err != nil { + err = c.err } + c.mu.Unlock() - c.mx.Lock() - err := c.err - c.mx.Unlock() - - return err + return } func (c *VirtualCommand) setErr(err error) { if err == nil { return } - c.mx.Lock() + c.mu.Lock() if c.err == nil { c.err = err } - c.mx.Unlock() + c.mu.Unlock() } -func (c *VirtualCommand) closeIO() error { +func (c *VirtualCommand) closeIO() (err error) { if !isNil(c.stdin) { - if err := c.stdin.Close(); err != nil { - return errors.WithMessage(err, "failed to close stdin") + if errClose := c.stdin.Close(); errClose != nil { + err = multierr.Append(err, errors.WithMessage(errClose, "failed to close stdin")) } } - if err := c.tty.Close(); err != nil { - return errors.WithMessage(err, "failed to close tty") + if errClose := c.tty.Close(); errClose != nil { + err = multierr.Append(err, errors.WithMessage(errClose, "failed to close tty")) } // if err := c.pty.Close(); err != nil { // return errors.WithMessage(err, "failed to close pty") // } - return nil + return } -func (c *VirtualCommand) cleanup() { +func (c *VirtualCommand) cleanup() (err error) { for _, fn := range c.cleanFuncs { - fn() + if errFn := fn(); errFn != nil { + err = multierr.Append(err, errFn) + } } + return } type Winsize pty.Winsize diff --git a/internal/runnerv2service/execution.go b/internal/runnerv2service/execution.go index 8a3bbf1aa..da4178e2a 100644 --- a/internal/runnerv2service/execution.go +++ b/internal/runnerv2service/execution.go @@ -166,7 +166,9 @@ func (e *execution) Wait(ctx context.Context, sender sender) (int, error) { // Wait for both errors, or nils. select { case err2 := <-errc: - e.logger.Info("another error from readSendLoop; won't be returned", zap.Error(err2)) + if err2 != nil { + e.logger.Info("another error from readSendLoop; won't be returned", zap.Error(err2)) + } case <-ctx.Done(): } return exitCode, err1 diff --git a/internal/runnerv2service/service_execute_test.go b/internal/runnerv2service/service_execute_test.go index 5e2737e28..19293fc58 100644 --- a/internal/runnerv2service/service_execute_test.go +++ b/internal/runnerv2service/service_execute_test.go @@ -3,6 +3,7 @@ package runnerv2service import ( + "bytes" "context" "io" "net" @@ -58,19 +59,24 @@ func TestRunnerServiceServerExecute(t *testing.T) { assert.Greater(t, resp.Pid.Value, uint32(1)) assert.Nil(t, resp.ExitCode) - // Assert second response. + // Assert second and third responses. + var out bytes.Buffer + resp, err = stream.Recv() assert.NoError(t, err) - assert.Equal(t, "test\n", string(resp.StdoutData)) assert.Nil(t, resp.ExitCode) assert.Nil(t, resp.Pid) + _, err = out.Write(resp.StdoutData) + assert.NoError(t, err) - // Assert third response. resp, err = stream.Recv() assert.NoError(t, err) - assert.Equal(t, "test\n", string(resp.StderrData)) assert.Nil(t, resp.ExitCode) assert.Nil(t, resp.Pid) + _, err = out.Write(resp.StderrData) + assert.NoError(t, err) + + assert.Equal(t, "test\ntest\n", out.String()) // Assert fourth response. resp, err = stream.Recv() @@ -222,7 +228,7 @@ func TestRunnerServiceServerExecuteConfigs(t *testing.T) { } } -func TestRunnerServiceServerExecute_Input(t *testing.T) { +func TestRunnerServiceServerExecuteWithInput(t *testing.T) { lis, stop := testStartRunnerServiceServer(t) t.Cleanup(stop) _, client := testCreateRunnerServiceClient(t, lis)