Skip to content

Commit

Permalink
fix data race on p.stderr (#209)
Browse files Browse the repository at this point in the history
Co-authored-by: John Arundel <john@bitfieldconsulting.com>
  • Loading branch information
mahadzaryab1 and bitfield authored Sep 2, 2024
1 parent 0daf4b2 commit 0edd895
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
with:
go-version: ${{ matrix.go-version }}
- uses: actions/checkout@v3
- run: go test ./...
- run: go test -race ./...

gocritic:
runs-on: ubuntu-latest
Expand Down
42 changes: 29 additions & 13 deletions script.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@ import (
// Pipe represents a pipe object with an associated [ReadAutoCloser].
type Pipe struct {
// Reader is the underlying reader.
Reader ReadAutoCloser
stdout, stderr io.Writer
httpClient *http.Client
Reader ReadAutoCloser
stdout io.Writer
httpClient *http.Client

// because pipe stages are concurrent, protect 'err'
mu *sync.Mutex
err error
// because pipe stages are concurrent, protect 'err' and 'stderr'
mu *sync.Mutex
err error
stderr io.Writer
}

// Args creates a pipe containing the program's command-line arguments from
Expand Down Expand Up @@ -414,8 +415,9 @@ func (p *Pipe) Exec(cmdLine string) *Pipe {
cmd.Stdin = r
cmd.Stdout = w
cmd.Stderr = w
if p.stderr != nil {
cmd.Stderr = p.stderr
pipeStderr := p.stdErr()
if pipeStderr != nil {
cmd.Stderr = pipeStderr
}
err = cmd.Start()
if err != nil {
Expand Down Expand Up @@ -454,8 +456,9 @@ func (p *Pipe) ExecForEach(cmdLine string) *Pipe {
cmd := exec.Command(args[0], args[1:]...)
cmd.Stdout = w
cmd.Stderr = w
if p.stderr != nil {
cmd.Stderr = p.stderr
pipeStderr := p.stdErr()
if pipeStderr != nil {
cmd.Stderr = pipeStderr
}
err = cmd.Start()
if err != nil {
Expand Down Expand Up @@ -839,6 +842,18 @@ func (p *Pipe) Slice() ([]string, error) {
return result, p.Error()
}

// stdErr returns the pipe's configured standard error writer for commands run
// via [Pipe.Exec] and [Pipe.ExecForEach]. The default is nil, which means that
// error output will go to the pipe.
func (p *Pipe) stdErr() io.Writer {
if p.mu == nil { // uninitialised pipe
return nil
}
p.mu.Lock()
defer p.mu.Unlock()
return p.stderr
}

// Stdout copies the pipe's contents to its configured standard output (using
// [Pipe.WithStdout]), or to [os.Stdout] otherwise, and returns the number of
// bytes successfully written, together with any error.
Expand Down Expand Up @@ -913,10 +928,11 @@ func (p *Pipe) WithReader(r io.Reader) *Pipe {
return p
}

// WithStderr redirects the standard error output for commands run via
// [Pipe.Exec] or [Pipe.ExecForEach] to the writer w, instead of going to the
// pipe as it normally would.
// WithStderr sets the standard error output for [Pipe.Exec] or
// [Pipe.ExecForEach] commands to w, instead of the pipe.
func (p *Pipe) WithStderr(w io.Writer) *Pipe {
p.mu.Lock()
defer p.mu.Unlock()
p.stderr = w
return p
}
Expand Down
8 changes: 8 additions & 0 deletions script_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1971,6 +1971,14 @@ func TestEncodeBase64_CorrectlyEncodesInputBytes(t *testing.T) {
}
}

func TestWithStdErr_IsConcurrencySafeAfterExec(t *testing.T) {
t.Parallel()
err := script.Exec("echo").WithStderr(nil).Wait()
if err != nil {
t.Fatal(err)
}
}

func ExampleArgs() {
script.Args().Stdout()
// prints command-line arguments
Expand Down

0 comments on commit 0edd895

Please sign in to comment.