Skip to content

Commit

Permalink
Include task runner (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
rodrigo-brito authored Apr 3, 2020
1 parent af9fc37 commit f699bcd
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 41 deletions.
70 changes: 31 additions & 39 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,49 +5,41 @@
[![Go Report Card](https://goreportcard.com/badge/github.com/StudioSol/async)](https://goreportcard.com/report/github.com/StudioSol/async)
[![GoDoc](https://godoc.org/github.com/StudioSol/async?status.svg)](https://godoc.org/github.com/StudioSol/async)

Provides a safe way to execute `fns`'s functions asynchronously, recovering them in case of panic. It also provides an error stack aiming to facilitate fail causes discovery.
Provides a safe way to execute functions asynchronously, recovering them in case of panic. It also provides an error stack aiming to facilitate fail causes discovery, and a simple way to control execution flow without `WaitGroup`.

### Usage
```go
func InsertAsynchronously(ctx context.Context) error {
transaction := db.Transaction().Begin()

err := async.Run(ctx,
func(_ context.Context) error {
_, err := transaction.Exec(`
INSERT INTO foo (bar)
VALUES ('Hello')
`)

return err
},

func(_ context.Context) error {
_, err := transaction.Exec(`
INSERT INTO foo (bar)
VALUES ('world')
`)

return err
},

func(_ context.Context) error {
_, err := transaction.Exec(`
INSERT INTO foo (bar)
VALUES ('asynchronously!')
`)

return err
},
)
var (
user User
songs []Songs
photos []Photos
)

err := async.Run(ctx,
func(ctx context.Context) error {
user, err = user.Get(ctx, id)
return err
},
func(ctx context.Context) error {
songs, err = song.GetByUserID(ctx, id)
return err
},
func(ctx context.Context) error {
photos, err = photo.GetByUserID(ctx, id)
return err
},
)

if err != nil {
log.Error(err)
}
```

if err != nil {
e := transaction.Rollback()
log.IfError(e)
return err
}
You can also limit the number of asynchronous tasks

return transaction.Commit()
```go
runner := async.NewRunner(tasks...).WithLimit(3)
if err := runner.Run(ctx); err != nil {
log.Error(e)
}

```
3 changes: 1 addition & 2 deletions async_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
package async_test

import (
"context"
"errors"
"sync"
"testing"
"time"

"context"

"github.com/StudioSol/async"
. "github.com/smartystreets/goconvey/convey"
)
Expand Down
106 changes: 106 additions & 0 deletions runner.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package async

import (
"context"
"sync"
)

type Runner struct {
sync.Mutex
tasks []Task
errs []error
limit int
waitErrors bool
}

// NewRunner creates a new task manager to control async functions.
func NewRunner(tasks ...Task) *Runner {
return &Runner{
tasks: tasks,
limit: len(tasks),
}
}

// WaitErrors tells the runner to wait for the response from all functions instead of cancelling them all when the first error occurs.
func (r *Runner) WaitErrors() *Runner {
r.waitErrors = true
return r
}

// WithLimit defines a limit for concurrent tasks execution
func (r *Runner) WithLimit(limit int) *Runner {
r.limit = limit
return r
}

// AllErrors returns all errors reported by functions
func (r *Runner) AllErrors() []error {
return r.errs
}

// registerErr store an error to final report
func (r *Runner) registerErr(err error) {
r.Lock()
defer r.Unlock()
if err != nil {
r.errs = append(r.errs, err)
}
}

// wrapperChannel converts a given Task to a channel of errors
func wrapperChannel(ctx context.Context, task Task) chan error {
cerr := make(chan error, 1)
go func() {
cerr <- task(ctx)
close(cerr)
}()
return cerr
}

// Run starts the task manager and returns the first error or nil if succeed
func (r *Runner) Run(parentCtx context.Context) error {
ctx, cancel := context.WithCancel(parentCtx)
cerr := make(chan error, len(r.tasks))
queue := make(chan struct{}, r.limit)
var wg sync.WaitGroup
wg.Add(len(r.tasks))
for _, task := range r.tasks {
queue <- struct{}{}
go func(fn func(context.Context) error) {
defer func() {
<-queue
wg.Done()
safePanic(cerr)
}()

select {
case <-parentCtx.Done():
cerr <- parentCtx.Err()
r.registerErr(parentCtx.Err())
case err := <-wrapperChannel(ctx, fn):
cerr <- err
r.registerErr(err)
}
}(task)
}

go func() {
wg.Wait()
cancel()
close(cerr)
}()

var firstErr error
for err := range cerr {
if err != nil && firstErr == nil {
firstErr = err
if r.waitErrors {
continue
}
cancel()
return firstErr
}
}

return firstErr
}
93 changes: 93 additions & 0 deletions runner_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package async

import (
"context"
"errors"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestRunner_AllErrors(t *testing.T) {
expectErr := errors.New("fail")
runner := NewRunner(func(context.Context) error {
return expectErr
}).WaitErrors()
err := runner.Run(context.Background())
require.Equal(t, expectErr, err)
require.Len(t, runner.AllErrors(), 1)
require.Equal(t, expectErr, runner.AllErrors()[0])
}

func TestRunner_WaitErrors(t *testing.T) {
expectErrOne := errors.New("fail")
expectErrTwo := errors.New("fail")
runner := NewRunner(func(context.Context) error {
return expectErrOne
}, func(context.Context) error {
return expectErrTwo
}).WaitErrors()
err := runner.Run(context.Background())
require.False(t, err != expectErrOne && err != expectErrTwo)
require.Len(t, runner.AllErrors(), 2)
}

func TestRunner_Run(t *testing.T) {
calledFist := false
calledSecond := false
runner := NewRunner(func(context.Context) error {
calledFist = true
return nil
}, func(context.Context) error {
calledSecond = true
return nil
})
err := runner.Run(context.Background())
require.Nil(t, err)
require.True(t, calledFist)
require.True(t, calledSecond)
}

func TestRunner_WithLimit(t *testing.T) {
order := 1
runner := NewRunner(func(context.Context) error {
require.Equal(t, 1, order)
order++
return nil
}, func(context.Context) error {
require.Equal(t, 2, order)
order++
return nil
}).WithLimit(1)
err := runner.Run(context.Background())
require.Nil(t, err)
}

func TestRunner_ContextCancelled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())

start := time.Now()
runner := NewRunner(func(context.Context) error {
cancel()
time.Sleep(time.Minute)
return nil
})
err := runner.Run(ctx)
require.True(t, time.Since(start) < time.Minute)
require.Equal(t, context.Canceled, err)
}

func TestRunner_ContextTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

start := time.Now()
runner := NewRunner(func(context.Context) error {
time.Sleep(time.Minute)
return nil
})
err := runner.Run(ctx)
require.True(t, time.Since(start) < time.Minute)
require.Equal(t, context.DeadlineExceeded, err)
}

0 comments on commit f699bcd

Please sign in to comment.