Skip to content

Commit 167d74b

Browse files
committed
Include task runner
1 parent af9fc37 commit 167d74b

File tree

3 files changed

+182
-2
lines changed

3 files changed

+182
-2
lines changed

async_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
package async_test
22

33
import (
4+
"context"
45
"errors"
56
"sync"
67
"testing"
78
"time"
89

9-
"context"
10-
1110
"github.com/StudioSol/async"
1211
. "github.com/smartystreets/goconvey/convey"
1312
)

runner.go

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
package async
2+
3+
import (
4+
"context"
5+
"sync"
6+
"time"
7+
)
8+
9+
type Runner struct {
10+
sync.Mutex
11+
tasks []Task
12+
errs []error
13+
limit int
14+
timeout time.Duration
15+
waitErrors bool
16+
}
17+
18+
// NewRunner creates a new task manager to control async functions
19+
func NewRunner(tasks ...Task) *Runner {
20+
return &Runner{
21+
tasks: tasks,
22+
limit: len(tasks),
23+
}
24+
}
25+
26+
// WaitErrors if, if active, will wait for the error response from all functions
27+
func (r *Runner) WaitErrors() *Runner {
28+
r.waitErrors = true
29+
return r
30+
}
31+
32+
// WithLimit defines an limit for concurrent tasks execution
33+
func (r *Runner) WithLimit(limit int) *Runner {
34+
r.limit = limit
35+
return r
36+
}
37+
38+
// AllErrors returns all errors reported by functions
39+
func (r *Runner) AllErrors() []error {
40+
return r.errs
41+
}
42+
43+
// AllErrors returns all errors reported by functions
44+
func (r *Runner) registerErr(err error) {
45+
r.Lock()
46+
defer r.Unlock()
47+
if err != nil {
48+
r.errs = append(r.errs, err)
49+
}
50+
}
51+
52+
func wrapperChannel(ctx context.Context, task Task) chan error {
53+
cerr := make(chan error, 1)
54+
go func() {
55+
cerr <- task(ctx)
56+
close(cerr)
57+
}()
58+
return cerr
59+
}
60+
61+
// Run starts the task manager and return the first error or nil if succeed
62+
func (r *Runner) Run(parentCtx context.Context) error {
63+
ctx, cancel := context.WithCancel(parentCtx)
64+
cerr := make(chan error, len(r.tasks))
65+
queue := make(chan struct{}, r.limit)
66+
var wg sync.WaitGroup
67+
wg.Add(len(r.tasks))
68+
for _, task := range r.tasks {
69+
queue <- struct{}{}
70+
go func(fn func(context.Context) error) {
71+
defer func() {
72+
<-queue
73+
wg.Done()
74+
safePanic(cerr)
75+
}()
76+
77+
select {
78+
case <-parentCtx.Done():
79+
case err := <-wrapperChannel(ctx, fn):
80+
cerr <- err
81+
r.registerErr(err)
82+
}
83+
}(task)
84+
}
85+
86+
go func() {
87+
wg.Wait()
88+
cancel()
89+
close(cerr)
90+
}()
91+
92+
var firstErr error
93+
for err := range cerr {
94+
if err != nil && firstErr == nil {
95+
firstErr = err
96+
if r.waitErrors {
97+
continue
98+
}
99+
cancel()
100+
return firstErr
101+
}
102+
}
103+
104+
return firstErr
105+
}

runner_test.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package async
2+
3+
import (
4+
"context"
5+
"errors"
6+
"testing"
7+
"time"
8+
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestRunner_AllErrors(t *testing.T) {
13+
expectErr := errors.New("fail")
14+
runner := NewRunner(func(context.Context) error {
15+
return expectErr
16+
}).WaitErrors()
17+
err := runner.Run(context.Background())
18+
require.Equal(t, expectErr, err)
19+
require.Len(t, runner.AllErrors(), 1)
20+
require.Equal(t, expectErr, runner.AllErrors()[0])
21+
}
22+
23+
func TestRunner_WaitErrors(t *testing.T) {
24+
expectErrOne := errors.New("fail")
25+
expectErrTwo := errors.New("fail")
26+
runner := NewRunner(func(context.Context) error {
27+
return expectErrOne
28+
}, func(context.Context) error {
29+
return expectErrTwo
30+
}).WaitErrors()
31+
err := runner.Run(context.Background())
32+
require.False(t, err != expectErrOne && err != expectErrTwo)
33+
require.Len(t, runner.AllErrors(), 2)
34+
}
35+
36+
func TestRunner_Run(t *testing.T) {
37+
calledFist := false
38+
calledSecond := false
39+
runner := NewRunner(func(context.Context) error {
40+
calledFist = true
41+
return nil
42+
}, func(context.Context) error {
43+
calledSecond = true
44+
return nil
45+
})
46+
err := runner.Run(context.Background())
47+
require.Nil(t, err)
48+
require.True(t, calledFist)
49+
require.True(t, calledSecond)
50+
}
51+
52+
func TestRunner_WithLimit(t *testing.T) {
53+
order := 1
54+
runner := NewRunner(func(context.Context) error {
55+
require.Equal(t, 1, order)
56+
order++
57+
return nil
58+
}, func(context.Context) error {
59+
require.Equal(t, 2, order)
60+
order++
61+
return nil
62+
}).WithLimit(1)
63+
err := runner.Run(context.Background())
64+
require.Nil(t, err)
65+
}
66+
67+
func TestRunner_ContextCancelled(t *testing.T) {
68+
ctx, cancel := context.WithCancel(context.Background())
69+
runner := NewRunner(func(context.Context) error {
70+
cancel()
71+
time.Sleep(time.Minute)
72+
return nil
73+
})
74+
err := runner.Run(ctx)
75+
require.Nil(t, err)
76+
}

0 commit comments

Comments
 (0)