Skip to content

Commit 72917a1

Browse files
authored
add sync/errgroup like functionality (#28)
add sync/errgroup like functionality
2 parents bdca7bb + 5e6cd52 commit 72917a1

File tree

2 files changed

+82
-0
lines changed

2 files changed

+82
-0
lines changed

group.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package multierror
2+
3+
import "sync"
4+
5+
// Group is a collection of goroutines which return errors that need to be
6+
// coalesced.
7+
type Group struct {
8+
mutex sync.Mutex
9+
err *Error
10+
wg sync.WaitGroup
11+
}
12+
13+
// Go calls the given function in a new goroutine.
14+
//
15+
// If the function returns an error it is added to the group multierror which
16+
// is returned by Wait.
17+
func (g *Group) Go(f func() error) {
18+
g.wg.Add(1)
19+
20+
go func() {
21+
defer g.wg.Done()
22+
23+
if err := f(); err != nil {
24+
g.mutex.Lock()
25+
g.err = Append(g.err, err)
26+
g.mutex.Unlock()
27+
}
28+
}()
29+
}
30+
31+
// Wait blocks until all function calls from the Go method have returned, then
32+
// returns the multierror.
33+
func (g *Group) Wait() *Error {
34+
g.wg.Wait()
35+
g.mutex.Lock()
36+
defer g.mutex.Unlock()
37+
return g.err
38+
}

group_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package multierror
2+
3+
import (
4+
"errors"
5+
"strings"
6+
"testing"
7+
)
8+
9+
func TestGroup(t *testing.T) {
10+
err1 := errors.New("group_test: 1")
11+
err2 := errors.New("group_test: 2")
12+
13+
cases := []struct {
14+
errs []error
15+
nilResult bool
16+
}{
17+
{errs: []error{}, nilResult: true},
18+
{errs: []error{nil}, nilResult: true},
19+
{errs: []error{err1}},
20+
{errs: []error{err1, nil}},
21+
{errs: []error{err1, nil, err2}},
22+
}
23+
24+
for _, tc := range cases {
25+
var g Group
26+
27+
for _, err := range tc.errs {
28+
err := err
29+
g.Go(func() error { return err })
30+
31+
}
32+
33+
gErr := g.Wait()
34+
if gErr != nil {
35+
for i := range tc.errs {
36+
if tc.errs[i] != nil && !strings.Contains(gErr.Error(), tc.errs[i].Error()) {
37+
t.Fatalf("expected error to contain %q, actual: %v", tc.errs[i].Error(), gErr)
38+
}
39+
}
40+
} else if !tc.nilResult {
41+
t.Fatalf("Group.Wait() should not have returned nil for errs: %v", tc.errs)
42+
}
43+
}
44+
}

0 commit comments

Comments
 (0)