diff --git a/resourcemanager/pooltask/BUILD.bazel b/resourcemanager/pooltask/BUILD.bazel index c4113b69dd141..6171e1fa3598d 100644 --- a/resourcemanager/pooltask/BUILD.bazel +++ b/resourcemanager/pooltask/BUILD.bazel @@ -10,7 +10,10 @@ go_library( ], importpath = "github.com/pingcap/tidb/resourcemanager/pooltask", visibility = ["//visibility:public"], - deps = ["@org_uber_go_atomic//:atomic"], + deps = [ + "//util/channel", + "@org_uber_go_atomic//:atomic", + ], ) go_test( diff --git a/resourcemanager/pooltask/task.go b/resourcemanager/pooltask/task.go index e166e24f76b4c..88134e684065b 100644 --- a/resourcemanager/pooltask/task.go +++ b/resourcemanager/pooltask/task.go @@ -17,6 +17,8 @@ package pooltask import ( "sync" "sync/atomic" + + "github.com/pingcap/tidb/util/channel" ) // Context is a interface that can be used to create a context. @@ -127,27 +129,28 @@ type GPool[T any, U any, C any, CT any, TF Context[CT]] interface { // TaskController is a controller that can control or watch the pool. type TaskController[T any, U any, C any, CT any, TF Context[CT]] struct { - pool GPool[T, U, C, CT, TF] - close chan struct{} - wg *sync.WaitGroup - taskID uint64 - resultCh chan U + pool GPool[T, U, C, CT, TF] + productExitCh chan struct{} + wg *sync.WaitGroup + taskID uint64 + resultCh chan U + inputCh chan Task[T] } // NewTaskController create a controller to deal with pooltask's status. -func NewTaskController[T any, U any, C any, CT any, TF Context[CT]](p GPool[T, U, C, CT, TF], taskID uint64, closeCh chan struct{}, wg *sync.WaitGroup, resultCh chan U) TaskController[T, U, C, CT, TF] { +func NewTaskController[T any, U any, C any, CT any, TF Context[CT]](p GPool[T, U, C, CT, TF], taskID uint64, productExitCh chan struct{}, wg *sync.WaitGroup, inputCh chan Task[T], resultCh chan U) TaskController[T, U, C, CT, TF] { return TaskController[T, U, C, CT, TF]{ - pool: p, - taskID: taskID, - close: closeCh, - wg: wg, - resultCh: resultCh, + pool: p, + taskID: taskID, + productExitCh: productExitCh, + wg: wg, + resultCh: resultCh, + inputCh: inputCh, } } // Wait is to wait the pool task to stop. func (t *TaskController[T, U, C, CT, TF]) Wait() { - <-t.close t.wg.Wait() close(t.resultCh) t.pool.DeleteTask(t.taskID) @@ -155,7 +158,15 @@ func (t *TaskController[T, U, C, CT, TF]) Wait() { // Stop is to send stop command to the task. But you still need to wait the task to stop. func (t *TaskController[T, U, C, CT, TF]) Stop() { + close(t.productExitCh) + // Clear all the task in the task queue and mark all task complete. + // so that ```t.Wait``` is able to close resultCh + for range t.inputCh { + t.wg.Done() + } t.pool.StopTask(t.TaskID()) + // Clear the resultCh to avoid blocking the consumer put result into the channel and cannot exit. + channel.Clear(t.resultCh) } // TaskID is to get the task id. diff --git a/util/gpool/spmc/option.go b/util/gpool/spmc/option.go index e317ce157b93d..af456e3c79772 100644 --- a/util/gpool/spmc/option.go +++ b/util/gpool/spmc/option.go @@ -18,6 +18,8 @@ import ( "time" ) +const defaultTaskChanLen = 1 + // Option represents the optional function. type Option func(opts *Options) @@ -103,8 +105,8 @@ func loadTaskOptions(options ...TaskOption) *TaskOptions { if opts.ResultChanLen == 0 { opts.ResultChanLen = uint64(opts.Concurrency) } - if opts.ResultChanLen == 0 { - opts.ResultChanLen = uint64(opts.Concurrency) + if opts.TaskChanLen == 0 { + opts.TaskChanLen = defaultTaskChanLen } return opts } diff --git a/util/gpool/spmc/spmcpool.go b/util/gpool/spmc/spmcpool.go index 6f65ca98aba01..eaa10c3b9a53a 100644 --- a/util/gpool/spmc/spmcpool.go +++ b/util/gpool/spmc/spmcpool.go @@ -219,7 +219,6 @@ func (p *Pool[T, U, C, CT, TF]) release() { // There might be some callers waiting in retrieveWorker(), so we need to wake them up to prevent // those callers blocking infinitely. p.cond.Broadcast() - close(p.taskCh) } func isClose(exitCh chan struct{}) bool { @@ -260,9 +259,9 @@ func (p *Pool[T, U, C, CT, TF]) AddProduceBySlice(producer func() ([]T, error), taskID := p.NewTaskID() var wg sync.WaitGroup result := make(chan U, opt.ResultChanLen) - closeCh := make(chan struct{}) + productExitCh := make(chan struct{}) inputCh := make(chan pooltask.Task[T], opt.TaskChanLen) - tc := pooltask.NewTaskController[T, U, C, CT, TF](p, taskID, closeCh, &wg, result) + tc := pooltask.NewTaskController[T, U, C, CT, TF](p, taskID, productExitCh, &wg, inputCh, result) p.taskManager.RegisterTask(taskID, int32(opt.Concurrency)) for i := 0; i < opt.Concurrency; i++ { err := p.run() @@ -274,15 +273,19 @@ func (p *Pool[T, U, C, CT, TF]) AddProduceBySlice(producer func() ([]T, error), p.taskManager.AddSubTask(taskID, &taskBox) p.taskCh <- &taskBox } + wg.Add(1) go func() { defer func() { if r := recover(); r != nil { logutil.BgLogger().Error("producer panic", zap.Any("recover", r), zap.Stack("stack")) } - close(closeCh) close(inputCh) + wg.Done() }() for { + if isClose(productExitCh) { + return + } tasks, err := producer() if err != nil { if errors.Is(err, gpool.ErrProducerClosed) { @@ -310,10 +313,10 @@ func (p *Pool[T, U, C, CT, TF]) AddProducer(producer func() (T, error), constArg taskID := p.NewTaskID() var wg sync.WaitGroup result := make(chan U, opt.ResultChanLen) - closeCh := make(chan struct{}) + productExitCh := make(chan struct{}) inputCh := make(chan pooltask.Task[T], opt.TaskChanLen) p.taskManager.RegisterTask(taskID, int32(opt.Concurrency)) - tc := pooltask.NewTaskController[T, U, C, CT, TF](p, taskID, closeCh, &wg, result) + tc := pooltask.NewTaskController[T, U, C, CT, TF](p, taskID, productExitCh, &wg, inputCh, result) for i := 0; i < opt.Concurrency; i++ { err := p.run() if err == gpool.ErrPoolClosed { @@ -324,15 +327,19 @@ func (p *Pool[T, U, C, CT, TF]) AddProducer(producer func() (T, error), constArg p.taskManager.AddSubTask(taskID, &taskBox) p.taskCh <- &taskBox } + wg.Add(1) go func() { defer func() { if r := recover(); r != nil { logutil.BgLogger().Error("producer panic", zap.Any("recover", r), zap.Stack("stack")) } - close(closeCh) close(inputCh) + wg.Done() }() for { + if isClose(productExitCh) { + return + } task, err := producer() if err != nil { if errors.Is(err, gpool.ErrProducerClosed) { diff --git a/util/gpool/spmc/spmcpool_test.go b/util/gpool/spmc/spmcpool_test.go index 83d02d2d47ac2..25fb62aaeb0ca 100644 --- a/util/gpool/spmc/spmcpool_test.go +++ b/util/gpool/spmc/spmcpool_test.go @@ -53,7 +53,7 @@ func TestPool(t *testing.T) { } } // add new task - resultCh, control := pool.AddProducer(pfunc, myArgs, pooltask.NilContext{}, WithConcurrency(4)) + resultCh, control := pool.AddProducer(pfunc, myArgs, pooltask.NilContext{}, WithConcurrency(5)) var count atomic.Uint32 var wg sync.WaitGroup @@ -112,12 +112,55 @@ func TestStopPool(t *testing.T) { require.Greater(t, result, 10) } }() + wg.Add(1) + go func() { + defer wg.Done() + control.Stop() + }() + // Waiting task finishing + control.Wait() + wg.Wait() + // close pool + pool.ReleaseAndWait() +} + +func TestStopPoolWithSlice(t *testing.T) { + type ConstArgs struct { + a int + } + myArgs := ConstArgs{a: 10} + // init the pool + // input type, output type, constArgs type + pool, err := NewSPMCPool[int, int, ConstArgs, any, pooltask.NilContext]("TestStopPoolWithSlice", 3, rmutil.UNKNOWN) + require.NoError(t, err) + pool.SetConsumerFunc(func(task int, constArgs ConstArgs, ctx any) int { + return task + constArgs.a + }) + + exit := make(chan struct{}) + + pfunc := func() ([]int, error) { + select { + case <-exit: + return nil, gpool.ErrProducerClosed + default: + return []int{1, 2, 3}, nil + } + } + // add new task + resultCh, control := pool.AddProduceBySlice(pfunc, myArgs, pooltask.NilContext{}, WithConcurrency(4)) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for result := range resultCh { + require.Greater(t, result, 10) + control.Stop() + } + }() // Waiting task finishing - control.Stop() - close(exit) control.Wait() - // it should pass. Stop can be used after the pool is closed. we should prevent it from panic. - control.Stop() wg.Wait() // close pool pool.ReleaseAndWait() @@ -191,9 +234,12 @@ func testTunePool(t *testing.T, name string) { for n := pool.Cap(); n > 1; n-- { downclockPool(t, pool, tid) } - - // exit test - close(exit) + wg.Add(1) + go func() { + // exit test + control.Stop() + wg.Done() + }() control.Wait() wg.Wait() // close pool