diff --git a/README.md b/README.md index 1503be1..892f115 100644 --- a/README.md +++ b/README.md @@ -5,49 +5,41 @@ [![Go Report Card](https://goreportcard.com/badge/github.com/StudioSol/async)](https://goreportcard.com/report/github.com/StudioSol/async) [![GoDoc](https://godoc.org/github.com/StudioSol/async?status.svg)](https://godoc.org/github.com/StudioSol/async) -Provides a safe way to execute `fns`'s functions asynchronously, recovering them in case of panic. It also provides an error stack aiming to facilitate fail causes discovery. +Provides a safe way to execute functions asynchronously, recovering them in case of panic. It also provides an error stack aiming to facilitate fail causes discovery, and a simple way to control execution flow without `WaitGroup`. ### Usage ```go -func InsertAsynchronously(ctx context.Context) error { - transaction := db.Transaction().Begin() - - err := async.Run(ctx, - func(_ context.Context) error { - _, err := transaction.Exec(` - INSERT INTO foo (bar) - VALUES ('Hello') - `) - - return err - }, - - func(_ context.Context) error { - _, err := transaction.Exec(` - INSERT INTO foo (bar) - VALUES ('world') - `) - - return err - }, - - func(_ context.Context) error { - _, err := transaction.Exec(` - INSERT INTO foo (bar) - VALUES ('asynchronously!') - `) - - return err - }, - ) +var ( + user User + songs []Songs + photos []Photos +) + +err := async.Run(ctx, + func(ctx context.Context) error { + user, err = user.Get(ctx, id) + return err + }, + func(ctx context.Context) error { + songs, err = song.GetByUserID(ctx, id) + return err + }, + func(ctx context.Context) error { + photos, err = photo.GetByUserID(ctx, id) + return err + }, +) + +if err != nil { + log.Error(err) +} +``` - if err != nil { - e := transaction.Rollback() - log.IfError(e) - return err - } +You can also limit the number of asynchronous tasks - return transaction.Commit() +```go +runner := async.NewRunner(tasks...).WithLimit(3) +if err := runner.Run(ctx); err != nil { + log.Error(e) } - ``` diff --git a/async_test.go b/async_test.go index 94261df..0f965f2 100644 --- a/async_test.go +++ b/async_test.go @@ -1,13 +1,12 @@ package async_test import ( + "context" "errors" "sync" "testing" "time" - "context" - "github.com/StudioSol/async" . "github.com/smartystreets/goconvey/convey" ) diff --git a/runner.go b/runner.go new file mode 100644 index 0000000..6e16975 --- /dev/null +++ b/runner.go @@ -0,0 +1,106 @@ +package async + +import ( + "context" + "sync" +) + +type Runner struct { + sync.Mutex + tasks []Task + errs []error + limit int + waitErrors bool +} + +// NewRunner creates a new task manager to control async functions. +func NewRunner(tasks ...Task) *Runner { + return &Runner{ + tasks: tasks, + limit: len(tasks), + } +} + +// WaitErrors tells the runner to wait for the response from all functions instead of cancelling them all when the first error occurs. +func (r *Runner) WaitErrors() *Runner { + r.waitErrors = true + return r +} + +// WithLimit defines a limit for concurrent tasks execution +func (r *Runner) WithLimit(limit int) *Runner { + r.limit = limit + return r +} + +// AllErrors returns all errors reported by functions +func (r *Runner) AllErrors() []error { + return r.errs +} + +// registerErr store an error to final report +func (r *Runner) registerErr(err error) { + r.Lock() + defer r.Unlock() + if err != nil { + r.errs = append(r.errs, err) + } +} + +// wrapperChannel converts a given Task to a channel of errors +func wrapperChannel(ctx context.Context, task Task) chan error { + cerr := make(chan error, 1) + go func() { + cerr <- task(ctx) + close(cerr) + }() + return cerr +} + +// Run starts the task manager and returns the first error or nil if succeed +func (r *Runner) Run(parentCtx context.Context) error { + ctx, cancel := context.WithCancel(parentCtx) + cerr := make(chan error, len(r.tasks)) + queue := make(chan struct{}, r.limit) + var wg sync.WaitGroup + wg.Add(len(r.tasks)) + for _, task := range r.tasks { + queue <- struct{}{} + go func(fn func(context.Context) error) { + defer func() { + <-queue + wg.Done() + safePanic(cerr) + }() + + select { + case <-parentCtx.Done(): + cerr <- parentCtx.Err() + r.registerErr(parentCtx.Err()) + case err := <-wrapperChannel(ctx, fn): + cerr <- err + r.registerErr(err) + } + }(task) + } + + go func() { + wg.Wait() + cancel() + close(cerr) + }() + + var firstErr error + for err := range cerr { + if err != nil && firstErr == nil { + firstErr = err + if r.waitErrors { + continue + } + cancel() + return firstErr + } + } + + return firstErr +} diff --git a/runner_test.go b/runner_test.go new file mode 100644 index 0000000..676d0e1 --- /dev/null +++ b/runner_test.go @@ -0,0 +1,93 @@ +package async + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRunner_AllErrors(t *testing.T) { + expectErr := errors.New("fail") + runner := NewRunner(func(context.Context) error { + return expectErr + }).WaitErrors() + err := runner.Run(context.Background()) + require.Equal(t, expectErr, err) + require.Len(t, runner.AllErrors(), 1) + require.Equal(t, expectErr, runner.AllErrors()[0]) +} + +func TestRunner_WaitErrors(t *testing.T) { + expectErrOne := errors.New("fail") + expectErrTwo := errors.New("fail") + runner := NewRunner(func(context.Context) error { + return expectErrOne + }, func(context.Context) error { + return expectErrTwo + }).WaitErrors() + err := runner.Run(context.Background()) + require.False(t, err != expectErrOne && err != expectErrTwo) + require.Len(t, runner.AllErrors(), 2) +} + +func TestRunner_Run(t *testing.T) { + calledFist := false + calledSecond := false + runner := NewRunner(func(context.Context) error { + calledFist = true + return nil + }, func(context.Context) error { + calledSecond = true + return nil + }) + err := runner.Run(context.Background()) + require.Nil(t, err) + require.True(t, calledFist) + require.True(t, calledSecond) +} + +func TestRunner_WithLimit(t *testing.T) { + order := 1 + runner := NewRunner(func(context.Context) error { + require.Equal(t, 1, order) + order++ + return nil + }, func(context.Context) error { + require.Equal(t, 2, order) + order++ + return nil + }).WithLimit(1) + err := runner.Run(context.Background()) + require.Nil(t, err) +} + +func TestRunner_ContextCancelled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + start := time.Now() + runner := NewRunner(func(context.Context) error { + cancel() + time.Sleep(time.Minute) + return nil + }) + err := runner.Run(ctx) + require.True(t, time.Since(start) < time.Minute) + require.Equal(t, context.Canceled, err) +} + +func TestRunner_ContextTimeout(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + start := time.Now() + runner := NewRunner(func(context.Context) error { + time.Sleep(time.Minute) + return nil + }) + err := runner.Run(ctx) + require.True(t, time.Since(start) < time.Minute) + require.Equal(t, context.DeadlineExceeded, err) +}