Skip to content

Commit

Permalink
feat: add batcher to batch tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
NgoKimPhu committed Dec 20, 2023
1 parent 85a7ed5 commit b4a777d
Show file tree
Hide file tree
Showing 5 changed files with 535 additions and 16 deletions.
218 changes: 218 additions & 0 deletions batcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
package kutils

import (
"context"
"math"
"runtime"
"runtime/debug"
"sync/atomic"
"time"

"github.com/KyberNetwork/logger"
"github.com/pkg/errors"
)

//go:generate mockgen -source=batcher.go -destination mocks/mocks.go -package mocks

var (
ErrBatcherClosed = errors.New("batcher closed")
)

// BatchableTask represents a batchable task
type BatchableTask[R any] interface {
Ctx() context.Context // The context of this task
Done() <-chan struct{} // Signals if this task was already resolved
IsDone() bool // Checks (non-blocking) if this task was already resolved
Result() (R, error) // Blocks until this task is resolved and returns result and error
Resolve(ret R, err error) // Resolves this task with return value and error
}

// ChanTask uses a done channel to signal resolution of return value and error
type ChanTask[R any] struct {
ctx context.Context
done chan struct{}
Ret R
Err error
}

func NewChanTask[R any](ctx context.Context) *ChanTask[R] {
if ctx == nil {
ctx = context.Background()
}
return &ChanTask[R]{
ctx: ctx,
done: make(chan struct{}),
}
}

func (c *ChanTask[R]) Ctx() context.Context {
return c.ctx
}

func (c *ChanTask[R]) Done() <-chan struct{} {
return c.done
}

func (c *ChanTask[R]) IsDone() bool {
select {
case <-c.done:
return true
default:
return false
}
}

func (c *ChanTask[R]) Result() (R, error) {
if c.IsDone() {
return c.Ret, c.Err
}
select {
case <-c.done:
return c.Ret, c.Err
case <-c.ctx.Done():
return *new(R), c.ctx.Err()
}
}

func (c *ChanTask[R]) Resolve(ret R, err error) {
select {
case <-c.done:
logger.Errorf("ChanTask.Resolve|called twice, ignored|c.Ret=%v,c.Err=%v|Ret=%v,Err=%v", c.Ret, c.Err, ret, err)
default:
c.Ret, c.Err = ret, err
close(c.done)
}
}

// Batcher batches together n BatchableTask's together and executes a logic for a batch of BatchableTask's.
// It skips BatchableTask's with cancelled Ctx and resolve those tasks with the context's error.
// Batch logic execution should signal each BatchableTask as done by using its Resolve method.
type Batcher[T BatchableTask[R], R any] interface {
// Batch submits a BatchableTask to the batcher.
Batch(task T)
// Close should stop Batch from being called and clean up any background resources.
Close()
}

// BatchCfg provides batchRate and batchCnt configs for a ChanBatcher. ChanBatcher will trigger a batch processing
// either if no more task is queued after batchRate, or batchCnt BatchableTask's are already queued.
type BatchCfg func() (batchRate time.Duration, batchCnt int)

// BatchFn is called for a batch of tasks collected and triggered by a ChanBatcher per its batchCfg.
type BatchFn[T any] func([]T)

// ChanBatcher implements Batcher using golang channel.
type ChanBatcher[T BatchableTask[R], R any] struct {
batchCfg BatchCfg
batchFn BatchFn[T]
taskCh chan T
closed atomic.Bool
}

func NewChanBatcher[T BatchableTask[R], R any](batchCfg BatchCfg, batchFn BatchFn[T]) *ChanBatcher[T, R] {
_, batchCnt := batchCfg()
chanBatcher := &ChanBatcher[T, R]{
batchCfg: batchCfg,
batchFn: batchFn,
taskCh: make(chan T, 16*batchCnt),
}
go chanBatcher.worker()
return chanBatcher
}

// Batch submits a BatchableTask to the channel if this chanBatcher hasn't been closed.
func (b *ChanBatcher[T, R]) Batch(task T) {
if !b.closed.Load() {
b.taskCh <- task
} else {
task.Resolve(*new(R), ErrBatcherClosed)
}
}

// Close closes this chanBatcher to prevents Batch-ing new BatchableTask's and tell the worker goroutine to finish up.
func (b *ChanBatcher[_, _]) Close() {
if !b.closed.Swap(true) {
close(b.taskCh)
}
}

// goBatchFn
func (b *ChanBatcher[T, R]) batchFnWithRecover(tasks []T) {
defer func() {
p := recover()
if p == nil {
return
}
logger.Errorf("ChanBatcher.goBatchFn|recovered from panic: %v\n%s", p, string(debug.Stack()))
var ret R
for _, task := range tasks {
if task.IsDone() {
continue
}
if err, ok := p.(error); ok {
task.Resolve(ret, errors.Wrap(err, "batchFn panicked"))
} else {
task.Resolve(ret, errors.Errorf("batchFn panicked: %v", p))
}
}
}()
b.batchFn(tasks)
}

// worker batches up BatchableTask's in taskCh per batchCfg (per at most batchRate ns and at most batchCnt BatchableTask's)
// and triggers batchFn with each batch.
func (b *ChanBatcher[T, R]) worker() {
defer func() {
if p := recover(); p != nil {
logger.Errorf("ChanBatcher.worker|recovered from panic: %v\n%s", p, string(debug.Stack()))
}
}()
var tasks []T
batchTimer := time.NewTimer(time.Duration(math.MaxInt64))
for {
runtime.Gosched() // in case GOMAXPROCS is 1, we need to cooperatively yield
select {
case <-batchTimer.C:
if len(tasks) == 0 {
break
}
logger.Debugf("ChanBatcher.worker|timer|%d tasks", len(tasks))
go b.batchFnWithRecover(tasks)
tasks = tasks[:0:0]
case task, ok := <-b.taskCh:
if !ok {
logger.Debugf("ChanBatcher.worker|closed|%d tasks", len(tasks))
if len(tasks) > 0 {
go b.batchFnWithRecover(tasks)
}
return
}
if !task.IsDone() {
select {
case <-task.Ctx().Done():
logger.Infof("ChanBatcher.worker|skip|task=%v", task)
task.Resolve(*new(R), task.Ctx().Err())
continue
default:
}
}
duration, batchCount := b.batchCfg()
if len(tasks) == 0 {
logger.Debugf("ChanBatcher.worker|timer start|duration=%s", duration)
if !batchTimer.Stop() {
select {
case <-batchTimer.C:
default:
}
}
batchTimer.Reset(duration)
}
tasks = append(tasks, task)
if len(tasks) >= batchCount {
logger.Debugf("ChanBatcher.worker|max|%d tasks", len(tasks))
go b.batchFnWithRecover(tasks)
tasks = tasks[:0:0]
}
}
}
}
143 changes: 143 additions & 0 deletions batcher_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package kutils

import (
"context"
"runtime"
"sync/atomic"
"testing"
"time"

"github.com/KyberNetwork/logger"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
)

func TestChanBatcher(t *testing.T) {
ctx := context.Background()
batchRate := 10 * time.Millisecond
batchFn := func(_ []*ChanTask[time.Duration]) {}
batcher := NewChanBatcher[*ChanTask[time.Duration], time.Duration](func() (time.Duration, int) {
return batchRate, 2
}, func(tasks []*ChanTask[time.Duration]) { batchFn(tasks) })
var cnt atomic.Uint32
start := time.Now()
batchFn = func(tasks []*ChanTask[time.Duration]) {
cnt.Add(1)
for _, task := range tasks {
task.Resolve(time.Since(start), nil)
}
}
task0 := NewChanTask[time.Duration](ctx)
task1 := NewChanTask[time.Duration](ctx)
task2 := NewChanTask[time.Duration](ctx)

t.Run("happy", func(t *testing.T) {
batcher.Batch(task0)
batcher.Batch(task1)
_, _ = task0.Result()
assert.EqualValues(t, 1, cnt.Load())
assert.NoError(t, task0.Err)
assert.Less(t, task0.Ret, batchRate)
ret, err := task1.Result()
assert.NoError(t, err)
assert.Less(t, ret, batchRate)
time.Sleep(batchRate * 11 / 10)
runtime.Gosched()

batcher.Batch(task2)
assert.False(t, task2.IsDone())
ret, err = task2.Result()
assert.True(t, task2.IsDone())
assert.EqualValues(t, 2, cnt.Load())
assert.Equal(t, task2.Err, err)
assert.NoError(t, task2.Err)
assert.Equal(t, task2.Ret, ret)
assert.Greater(t, ret, batchRate)
})

t.Run("spam", func(t *testing.T) {
batcher := NewChanBatcher[*ChanTask[int], int](func() (time.Duration, int) { return 0, 0 },
func(tasks []*ChanTask[int]) {
for _, task := range tasks {
task.Resolve(0, nil)
}
})
const taskCnt = 1000
tasks := make([]*ChanTask[int], taskCnt)
start := time.Now()
for i := 0; i < taskCnt; i++ {
tasks[i] = NewChanTask[int](ctx)
batcher.Batch(tasks[i])
}
// 1k: 2.561804ms; 1M: 2.62s - average overhead per task = 2.6µs
logger.Warnf("done %d tasks in %v", taskCnt, time.Since(start))
for i := 0; i < taskCnt; i++ {
ret, err := tasks[i].Result()
assert.NoError(t, err)
assert.EqualValues(t, 0, ret)
}
batcher.Close()
})

t.Run("resolve twice", func(t *testing.T) {
task0.Resolve(batchRate, nil)
assert.NoError(t, task0.Err)
assert.Less(t, task0.Ret, batchRate)
})

t.Run("recover from panic", func(t *testing.T) {
oldBatchFn := batchFn
batchFn = func(tasks []*ChanTask[time.Duration]) {
panic("test panic")
}
task0 = NewChanTask[time.Duration](ctx)
task1 = NewChanTask[time.Duration](ctx)
task0.Resolve(0, nil)
batcher.Batch(task0)
batcher.Batch(task1)
<-task1.Done()
assert.ErrorContains(t, task1.Err, "test panic")

panicErr := errors.New("test panic error")
batchFn = func(tasks []*ChanTask[time.Duration]) {
panic(panicErr)
}
task0 = NewChanTask[time.Duration](ctx)
task1 = NewChanTask[time.Duration](ctx)
batcher.Batch(task0)
batcher.Batch(task1)
<-task1.Done()
assert.ErrorIs(t, task0.Err, panicErr)
assert.ErrorIs(t, task1.Err, panicErr)

batchFn = oldBatchFn
task2 = NewChanTask[time.Duration](nil) // nolint:staticcheck
batcher.Batch(task2)
batcher.Batch(task2)
ret, err := task2.Result()
assert.NoError(t, err)
assert.Greater(t, ret, batchRate)
})

t.Run("cancelled task", func(t *testing.T) {
ctx, cancel := context.WithCancel(ctx)
task0 = NewChanTask[time.Duration](ctx)
batcher.Batch(task0)
cancel()
_, err := task0.Result()
assert.ErrorIs(t, err, context.Canceled)
})

t.Run("close", func(t *testing.T) {
batcher.Batch(task2)
batcher.Close()
task3 := NewChanTask[time.Duration](ctx)
batcher.Batch(task3)
assert.ErrorIs(t, task3.Err, ErrBatcherClosed)
})

t.Run("invalid task", func(t *testing.T) {
NewChanBatcher[*ChanTask[int], int](func() (time.Duration, int) { return 0, 0 },
nil).Batch(&ChanTask[int]{})
})
}
Loading

0 comments on commit b4a777d

Please sign in to comment.