Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

resourcemanager: fix TaskController.Stop() can't make producer exit in spmcpool #41016

Merged
merged 11 commits into from
Feb 3, 2023
5 changes: 4 additions & 1 deletion resourcemanager/pooltask/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ go_library(
],
importpath = "github.com/pingcap/tidb/resourcemanager/pooltask",
visibility = ["//visibility:public"],
deps = ["@org_uber_go_atomic//:atomic"],
deps = [
"//util/channel",
"@org_uber_go_atomic//:atomic",
],
)

go_test(
Expand Down
35 changes: 23 additions & 12 deletions resourcemanager/pooltask/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package pooltask
import (
"sync"
"sync/atomic"

"github.com/pingcap/tidb/util/channel"
)

// Context is a interface that can be used to create a context.
Expand Down Expand Up @@ -127,35 +129,44 @@ type GPool[T any, U any, C any, CT any, TF Context[CT]] interface {

// TaskController is a controller that can control or watch the pool.
type TaskController[T any, U any, C any, CT any, TF Context[CT]] struct {
pool GPool[T, U, C, CT, TF]
close chan struct{}
wg *sync.WaitGroup
taskID uint64
resultCh chan U
pool GPool[T, U, C, CT, TF]
productExitCh chan struct{}
wg *sync.WaitGroup
taskID uint64
resultCh chan U
inputCh chan Task[T]
}

// NewTaskController create a controller to deal with pooltask's status.
func NewTaskController[T any, U any, C any, CT any, TF Context[CT]](p GPool[T, U, C, CT, TF], taskID uint64, closeCh chan struct{}, wg *sync.WaitGroup, resultCh chan U) TaskController[T, U, C, CT, TF] {
func NewTaskController[T any, U any, C any, CT any, TF Context[CT]](p GPool[T, U, C, CT, TF], taskID uint64, productExitCh chan struct{}, wg *sync.WaitGroup, inputCh chan Task[T], resultCh chan U) TaskController[T, U, C, CT, TF] {
return TaskController[T, U, C, CT, TF]{
pool: p,
taskID: taskID,
close: closeCh,
wg: wg,
resultCh: resultCh,
pool: p,
taskID: taskID,
productExitCh: productExitCh,
wg: wg,
resultCh: resultCh,
inputCh: inputCh,
}
}

// Wait is to wait the pool task to stop.
func (t *TaskController[T, U, C, CT, TF]) Wait() {
<-t.close
t.wg.Wait()
close(t.resultCh)
t.pool.DeleteTask(t.taskID)
}

// Stop is to send stop command to the task. But you still need to wait the task to stop.
func (t *TaskController[T, U, C, CT, TF]) Stop() {
close(t.productExitCh)
// Clear all the task in the task queue and mark all task complete.
// so that ```t.Wait``` is able to close resultCh
for range t.inputCh {
t.wg.Done()
hawkingrei marked this conversation as resolved.
Show resolved Hide resolved
}
t.pool.StopTask(t.TaskID())
zimulala marked this conversation as resolved.
Show resolved Hide resolved
// Clear the resultCh to avoid blocking the consumer put result into the channel and cannot exit.
channel.Clear(t.resultCh)
}

// TaskID is to get the task id.
Expand Down
6 changes: 4 additions & 2 deletions util/gpool/spmc/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"time"
)

const defaultTaskChanLen = 1

// Option represents the optional function.
type Option func(opts *Options)

Expand Down Expand Up @@ -103,8 +105,8 @@ func loadTaskOptions(options ...TaskOption) *TaskOptions {
if opts.ResultChanLen == 0 {
opts.ResultChanLen = uint64(opts.Concurrency)
}
if opts.ResultChanLen == 0 {
opts.ResultChanLen = uint64(opts.Concurrency)
if opts.TaskChanLen == 0 {
opts.TaskChanLen = defaultTaskChanLen
}
return opts
}
Expand Down
21 changes: 14 additions & 7 deletions util/gpool/spmc/spmcpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ func (p *Pool[T, U, C, CT, TF]) release() {
// There might be some callers waiting in retrieveWorker(), so we need to wake them up to prevent
// those callers blocking infinitely.
p.cond.Broadcast()
close(p.taskCh)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is useless and lead to data race.

}

func isClose(exitCh chan struct{}) bool {
Expand Down Expand Up @@ -260,9 +259,9 @@ func (p *Pool[T, U, C, CT, TF]) AddProduceBySlice(producer func() ([]T, error),
taskID := p.NewTaskID()
var wg sync.WaitGroup
result := make(chan U, opt.ResultChanLen)
closeCh := make(chan struct{})
productExitCh := make(chan struct{})
inputCh := make(chan pooltask.Task[T], opt.TaskChanLen)
hawkingrei marked this conversation as resolved.
Show resolved Hide resolved
tc := pooltask.NewTaskController[T, U, C, CT, TF](p, taskID, closeCh, &wg, result)
tc := pooltask.NewTaskController[T, U, C, CT, TF](p, taskID, productExitCh, &wg, inputCh, result)
p.taskManager.RegisterTask(taskID, int32(opt.Concurrency))
for i := 0; i < opt.Concurrency; i++ {
err := p.run()
Expand All @@ -274,15 +273,19 @@ func (p *Pool[T, U, C, CT, TF]) AddProduceBySlice(producer func() ([]T, error),
p.taskManager.AddSubTask(taskID, &taskBox)
p.taskCh <- &taskBox
}
wg.Add(1)
go func() {
defer func() {
if r := recover(); r != nil {
logutil.BgLogger().Error("producer panic", zap.Any("recover", r), zap.Stack("stack"))
}
close(closeCh)
close(inputCh)
wg.Done()
}()
for {
if isClose(productExitCh) {
return
}
tasks, err := producer()
if err != nil {
if errors.Is(err, gpool.ErrProducerClosed) {
Expand Down Expand Up @@ -310,10 +313,10 @@ func (p *Pool[T, U, C, CT, TF]) AddProducer(producer func() (T, error), constArg
taskID := p.NewTaskID()
var wg sync.WaitGroup
result := make(chan U, opt.ResultChanLen)
closeCh := make(chan struct{})
productExitCh := make(chan struct{})
inputCh := make(chan pooltask.Task[T], opt.TaskChanLen)
p.taskManager.RegisterTask(taskID, int32(opt.Concurrency))
tc := pooltask.NewTaskController[T, U, C, CT, TF](p, taskID, closeCh, &wg, result)
tc := pooltask.NewTaskController[T, U, C, CT, TF](p, taskID, productExitCh, &wg, inputCh, result)
for i := 0; i < opt.Concurrency; i++ {
err := p.run()
if err == gpool.ErrPoolClosed {
Expand All @@ -324,15 +327,19 @@ func (p *Pool[T, U, C, CT, TF]) AddProducer(producer func() (T, error), constArg
p.taskManager.AddSubTask(taskID, &taskBox)
p.taskCh <- &taskBox
}
wg.Add(1)
go func() {
defer func() {
if r := recover(); r != nil {
logutil.BgLogger().Error("producer panic", zap.Any("recover", r), zap.Stack("stack"))
}
close(closeCh)
close(inputCh)
wg.Done()
}()
for {
if isClose(productExitCh) {
return
}
task, err := producer()
if err != nil {
if errors.Is(err, gpool.ErrProducerClosed) {
Expand Down
62 changes: 54 additions & 8 deletions util/gpool/spmc/spmcpool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestPool(t *testing.T) {
}
}
// add new task
resultCh, control := pool.AddProducer(pfunc, myArgs, pooltask.NilContext{}, WithConcurrency(4))
resultCh, control := pool.AddProducer(pfunc, myArgs, pooltask.NilContext{}, WithConcurrency(5))

var count atomic.Uint32
var wg sync.WaitGroup
Expand Down Expand Up @@ -112,12 +112,55 @@ func TestStopPool(t *testing.T) {
require.Greater(t, result, 10)
}
}()
wg.Add(1)
go func() {
defer wg.Done()
control.Stop()
}()
// Waiting task finishing
control.Wait()
wg.Wait()
// close pool
pool.ReleaseAndWait()
}

func TestStopPoolWithSlice(t *testing.T) {
type ConstArgs struct {
a int
}
myArgs := ConstArgs{a: 10}
// init the pool
// input type, output type, constArgs type
pool, err := NewSPMCPool[int, int, ConstArgs, any, pooltask.NilContext]("TestStopPoolWithSlice", 3, rmutil.UNKNOWN)
require.NoError(t, err)
pool.SetConsumerFunc(func(task int, constArgs ConstArgs, ctx any) int {
return task + constArgs.a
})

exit := make(chan struct{})

pfunc := func() ([]int, error) {
select {
case <-exit:
return nil, gpool.ErrProducerClosed
default:
return []int{1, 2, 3}, nil
}
}
// add new task
resultCh, control := pool.AddProduceBySlice(pfunc, myArgs, pooltask.NilContext{}, WithConcurrency(4))

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
hawkingrei marked this conversation as resolved.
Show resolved Hide resolved
for result := range resultCh {
require.Greater(t, result, 10)
control.Stop()
}
}()
// Waiting task finishing
control.Stop()
close(exit)
control.Wait()
// it should pass. Stop can be used after the pool is closed. we should prevent it from panic.
control.Stop()
wg.Wait()
// close pool
pool.ReleaseAndWait()
Expand Down Expand Up @@ -191,9 +234,12 @@ func testTunePool(t *testing.T, name string) {
for n := pool.Cap(); n > 1; n-- {
downclockPool(t, pool, tid)
}

// exit test
close(exit)
wg.Add(1)
go func() {
// exit test
control.Stop()
wg.Done()
}()
control.Wait()
wg.Wait()
// close pool
Expand Down