-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
535 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
package kutils | ||
|
||
import ( | ||
"context" | ||
"math" | ||
"runtime" | ||
"runtime/debug" | ||
"sync/atomic" | ||
"time" | ||
|
||
"github.com/KyberNetwork/logger" | ||
"github.com/pkg/errors" | ||
) | ||
|
||
//go:generate mockgen -source=batcher.go -destination mocks/mocks.go -package mocks | ||
|
||
var ( | ||
ErrBatcherClosed = errors.New("batcher closed") | ||
) | ||
|
||
// BatchableTask represents a batchable task | ||
type BatchableTask[R any] interface { | ||
Ctx() context.Context // The context of this task | ||
Done() <-chan struct{} // Signals if this task was already resolved | ||
IsDone() bool // Checks (non-blocking) if this task was already resolved | ||
Result() (R, error) // Blocks until this task is resolved and returns result and error | ||
Resolve(ret R, err error) // Resolves this task with return value and error | ||
} | ||
|
||
// ChanTask uses a done channel to signal resolution of return value and error | ||
type ChanTask[R any] struct { | ||
ctx context.Context | ||
done chan struct{} | ||
Ret R | ||
Err error | ||
} | ||
|
||
func NewChanTask[R any](ctx context.Context) *ChanTask[R] { | ||
if ctx == nil { | ||
ctx = context.Background() | ||
} | ||
return &ChanTask[R]{ | ||
ctx: ctx, | ||
done: make(chan struct{}), | ||
} | ||
} | ||
|
||
func (c *ChanTask[R]) Ctx() context.Context { | ||
return c.ctx | ||
} | ||
|
||
func (c *ChanTask[R]) Done() <-chan struct{} { | ||
return c.done | ||
} | ||
|
||
func (c *ChanTask[R]) IsDone() bool { | ||
select { | ||
case <-c.done: | ||
return true | ||
default: | ||
return false | ||
} | ||
} | ||
|
||
func (c *ChanTask[R]) Result() (R, error) { | ||
if c.IsDone() { | ||
return c.Ret, c.Err | ||
} | ||
select { | ||
case <-c.done: | ||
return c.Ret, c.Err | ||
case <-c.ctx.Done(): | ||
return *new(R), c.ctx.Err() | ||
} | ||
} | ||
|
||
func (c *ChanTask[R]) Resolve(ret R, err error) { | ||
select { | ||
case <-c.done: | ||
logger.Errorf("ChanTask.Resolve|called twice, ignored|c.Ret=%v,c.Err=%v|Ret=%v,Err=%v", c.Ret, c.Err, ret, err) | ||
default: | ||
c.Ret, c.Err = ret, err | ||
close(c.done) | ||
} | ||
} | ||
|
||
// Batcher batches together n BatchableTask's together and executes a logic for a batch of BatchableTask's. | ||
// It skips BatchableTask's with cancelled Ctx and resolve those tasks with the context's error. | ||
// Batch logic execution should signal each BatchableTask as done by using its Resolve method. | ||
type Batcher[T BatchableTask[R], R any] interface { | ||
// Batch submits a BatchableTask to the batcher. | ||
Batch(task T) | ||
// Close should stop Batch from being called and clean up any background resources. | ||
Close() | ||
} | ||
|
||
// BatchCfg provides batchRate and batchCnt configs for a ChanBatcher. ChanBatcher will trigger a batch processing | ||
// either if no more task is queued after batchRate, or batchCnt BatchableTask's are already queued. | ||
type BatchCfg func() (batchRate time.Duration, batchCnt int) | ||
|
||
// BatchFn is called for a batch of tasks collected and triggered by a ChanBatcher per its batchCfg. | ||
type BatchFn[T any] func([]T) | ||
|
||
// ChanBatcher implements Batcher using golang channel. | ||
type ChanBatcher[T BatchableTask[R], R any] struct { | ||
batchCfg BatchCfg | ||
batchFn BatchFn[T] | ||
taskCh chan T | ||
closed atomic.Bool | ||
} | ||
|
||
func NewChanBatcher[T BatchableTask[R], R any](batchCfg BatchCfg, batchFn BatchFn[T]) *ChanBatcher[T, R] { | ||
_, batchCnt := batchCfg() | ||
chanBatcher := &ChanBatcher[T, R]{ | ||
batchCfg: batchCfg, | ||
batchFn: batchFn, | ||
taskCh: make(chan T, 16*batchCnt), | ||
} | ||
go chanBatcher.worker() | ||
return chanBatcher | ||
} | ||
|
||
// Batch submits a BatchableTask to the channel if this chanBatcher hasn't been closed. | ||
func (b *ChanBatcher[T, R]) Batch(task T) { | ||
if !b.closed.Load() { | ||
b.taskCh <- task | ||
} else { | ||
task.Resolve(*new(R), ErrBatcherClosed) | ||
} | ||
} | ||
|
||
// Close closes this chanBatcher to prevents Batch-ing new BatchableTask's and tell the worker goroutine to finish up. | ||
func (b *ChanBatcher[_, _]) Close() { | ||
if !b.closed.Swap(true) { | ||
close(b.taskCh) | ||
} | ||
} | ||
|
||
// goBatchFn | ||
func (b *ChanBatcher[T, R]) batchFnWithRecover(tasks []T) { | ||
defer func() { | ||
p := recover() | ||
if p == nil { | ||
return | ||
} | ||
logger.Errorf("ChanBatcher.goBatchFn|recovered from panic: %v\n%s", p, string(debug.Stack())) | ||
var ret R | ||
for _, task := range tasks { | ||
if task.IsDone() { | ||
continue | ||
} | ||
if err, ok := p.(error); ok { | ||
task.Resolve(ret, errors.Wrap(err, "batchFn panicked")) | ||
} else { | ||
task.Resolve(ret, errors.Errorf("batchFn panicked: %v", p)) | ||
} | ||
} | ||
}() | ||
b.batchFn(tasks) | ||
} | ||
|
||
// worker batches up BatchableTask's in taskCh per batchCfg (per at most batchRate ns and at most batchCnt BatchableTask's) | ||
// and triggers batchFn with each batch. | ||
func (b *ChanBatcher[T, R]) worker() { | ||
defer func() { | ||
if p := recover(); p != nil { | ||
logger.Errorf("ChanBatcher.worker|recovered from panic: %v\n%s", p, string(debug.Stack())) | ||
} | ||
}() | ||
var tasks []T | ||
batchTimer := time.NewTimer(time.Duration(math.MaxInt64)) | ||
for { | ||
runtime.Gosched() // in case GOMAXPROCS is 1, we need to cooperatively yield | ||
select { | ||
case <-batchTimer.C: | ||
if len(tasks) == 0 { | ||
break | ||
} | ||
logger.Debugf("ChanBatcher.worker|timer|%d tasks", len(tasks)) | ||
go b.batchFnWithRecover(tasks) | ||
tasks = tasks[:0:0] | ||
case task, ok := <-b.taskCh: | ||
if !ok { | ||
logger.Debugf("ChanBatcher.worker|closed|%d tasks", len(tasks)) | ||
if len(tasks) > 0 { | ||
go b.batchFnWithRecover(tasks) | ||
} | ||
return | ||
} | ||
if !task.IsDone() { | ||
select { | ||
case <-task.Ctx().Done(): | ||
logger.Infof("ChanBatcher.worker|skip|task=%v", task) | ||
task.Resolve(*new(R), task.Ctx().Err()) | ||
continue | ||
default: | ||
} | ||
} | ||
duration, batchCount := b.batchCfg() | ||
if len(tasks) == 0 { | ||
logger.Debugf("ChanBatcher.worker|timer start|duration=%s", duration) | ||
if !batchTimer.Stop() { | ||
select { | ||
case <-batchTimer.C: | ||
default: | ||
} | ||
} | ||
batchTimer.Reset(duration) | ||
} | ||
tasks = append(tasks, task) | ||
if len(tasks) >= batchCount { | ||
logger.Debugf("ChanBatcher.worker|max|%d tasks", len(tasks)) | ||
go b.batchFnWithRecover(tasks) | ||
tasks = tasks[:0:0] | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
package kutils | ||
|
||
import ( | ||
"context" | ||
"runtime" | ||
"sync/atomic" | ||
"testing" | ||
"time" | ||
|
||
"github.com/KyberNetwork/logger" | ||
"github.com/pkg/errors" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestChanBatcher(t *testing.T) { | ||
ctx := context.Background() | ||
batchRate := 10 * time.Millisecond | ||
batchFn := func(_ []*ChanTask[time.Duration]) {} | ||
batcher := NewChanBatcher[*ChanTask[time.Duration], time.Duration](func() (time.Duration, int) { | ||
return batchRate, 2 | ||
}, func(tasks []*ChanTask[time.Duration]) { batchFn(tasks) }) | ||
var cnt atomic.Uint32 | ||
start := time.Now() | ||
batchFn = func(tasks []*ChanTask[time.Duration]) { | ||
cnt.Add(1) | ||
for _, task := range tasks { | ||
task.Resolve(time.Since(start), nil) | ||
} | ||
} | ||
task0 := NewChanTask[time.Duration](ctx) | ||
task1 := NewChanTask[time.Duration](ctx) | ||
task2 := NewChanTask[time.Duration](ctx) | ||
|
||
t.Run("happy", func(t *testing.T) { | ||
batcher.Batch(task0) | ||
batcher.Batch(task1) | ||
_, _ = task0.Result() | ||
assert.EqualValues(t, 1, cnt.Load()) | ||
assert.NoError(t, task0.Err) | ||
assert.Less(t, task0.Ret, batchRate) | ||
ret, err := task1.Result() | ||
assert.NoError(t, err) | ||
assert.Less(t, ret, batchRate) | ||
time.Sleep(batchRate * 11 / 10) | ||
runtime.Gosched() | ||
|
||
batcher.Batch(task2) | ||
assert.False(t, task2.IsDone()) | ||
ret, err = task2.Result() | ||
assert.True(t, task2.IsDone()) | ||
assert.EqualValues(t, 2, cnt.Load()) | ||
assert.Equal(t, task2.Err, err) | ||
assert.NoError(t, task2.Err) | ||
assert.Equal(t, task2.Ret, ret) | ||
assert.Greater(t, ret, batchRate) | ||
}) | ||
|
||
t.Run("spam", func(t *testing.T) { | ||
batcher := NewChanBatcher[*ChanTask[int], int](func() (time.Duration, int) { return 0, 0 }, | ||
func(tasks []*ChanTask[int]) { | ||
for _, task := range tasks { | ||
task.Resolve(0, nil) | ||
} | ||
}) | ||
const taskCnt = 1000 | ||
tasks := make([]*ChanTask[int], taskCnt) | ||
start := time.Now() | ||
for i := 0; i < taskCnt; i++ { | ||
tasks[i] = NewChanTask[int](ctx) | ||
batcher.Batch(tasks[i]) | ||
} | ||
// 1k: 2.561804ms; 1M: 2.62s - average overhead per task = 2.6µs | ||
logger.Warnf("done %d tasks in %v", taskCnt, time.Since(start)) | ||
for i := 0; i < taskCnt; i++ { | ||
ret, err := tasks[i].Result() | ||
assert.NoError(t, err) | ||
assert.EqualValues(t, 0, ret) | ||
} | ||
batcher.Close() | ||
}) | ||
|
||
t.Run("resolve twice", func(t *testing.T) { | ||
task0.Resolve(batchRate, nil) | ||
assert.NoError(t, task0.Err) | ||
assert.Less(t, task0.Ret, batchRate) | ||
}) | ||
|
||
t.Run("recover from panic", func(t *testing.T) { | ||
oldBatchFn := batchFn | ||
batchFn = func(tasks []*ChanTask[time.Duration]) { | ||
panic("test panic") | ||
} | ||
task0 = NewChanTask[time.Duration](ctx) | ||
task1 = NewChanTask[time.Duration](ctx) | ||
task0.Resolve(0, nil) | ||
batcher.Batch(task0) | ||
batcher.Batch(task1) | ||
<-task1.Done() | ||
assert.ErrorContains(t, task1.Err, "test panic") | ||
|
||
panicErr := errors.New("test panic error") | ||
batchFn = func(tasks []*ChanTask[time.Duration]) { | ||
panic(panicErr) | ||
} | ||
task0 = NewChanTask[time.Duration](ctx) | ||
task1 = NewChanTask[time.Duration](ctx) | ||
batcher.Batch(task0) | ||
batcher.Batch(task1) | ||
<-task1.Done() | ||
assert.ErrorIs(t, task0.Err, panicErr) | ||
assert.ErrorIs(t, task1.Err, panicErr) | ||
|
||
batchFn = oldBatchFn | ||
task2 = NewChanTask[time.Duration](nil) // nolint:staticcheck | ||
batcher.Batch(task2) | ||
batcher.Batch(task2) | ||
ret, err := task2.Result() | ||
assert.NoError(t, err) | ||
assert.Greater(t, ret, batchRate) | ||
}) | ||
|
||
t.Run("cancelled task", func(t *testing.T) { | ||
ctx, cancel := context.WithCancel(ctx) | ||
task0 = NewChanTask[time.Duration](ctx) | ||
batcher.Batch(task0) | ||
cancel() | ||
_, err := task0.Result() | ||
assert.ErrorIs(t, err, context.Canceled) | ||
}) | ||
|
||
t.Run("close", func(t *testing.T) { | ||
batcher.Batch(task2) | ||
batcher.Close() | ||
task3 := NewChanTask[time.Duration](ctx) | ||
batcher.Batch(task3) | ||
assert.ErrorIs(t, task3.Err, ErrBatcherClosed) | ||
}) | ||
|
||
t.Run("invalid task", func(t *testing.T) { | ||
NewChanBatcher[*ChanTask[int], int](func() (time.Duration, int) { return 0, 0 }, | ||
nil).Batch(&ChanTask[int]{}) | ||
}) | ||
} |
Oops, something went wrong.