Skip to content

Commit

Permalink
add lock while dispatching tasks to avoid deadlocks
Browse files Browse the repository at this point in the history
  • Loading branch information
alitto committed Nov 13, 2024
1 parent 9571160 commit 34e8f75
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 65 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
test:
go test -race -v -timeout 1m -count=6 ./...
go test -race -v -timeout 1m -count=3 ./...

coverage:
go test -race -v -timeout 1m -coverprofile=coverage.out -covermode=atomic ./...
15 changes: 7 additions & 8 deletions group.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package pond
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"

Expand Down Expand Up @@ -104,13 +103,9 @@ func (g *abstractTaskGroup[T, E, O]) submit(task any) {
index := int(g.nextIndex.Add(1) - 1)

g.taskWaitGroup.Add(1)
fmt.Printf("submitting task %d\n", index)

err := g.pool.Go(func() {
defer func() {
fmt.Printf("task %d done\n", index)
g.taskWaitGroup.Done()
}()
defer g.taskWaitGroup.Done()

// Check if the context has been cancelled to prevent running tasks that are not needed
if err := g.future.Context().Err(); err != nil {
Expand Down Expand Up @@ -156,7 +151,9 @@ func (g *taskGroup) Wait() error {
_, err := g.future.Wait(int(g.nextIndex.Load()))
// This wait group could reach zero before the future is resolved if called in between tasks being submitted and the future being resolved.
// That's why we wait for the future to be resolved before waiting for the wait group.
g.taskWaitGroup.Wait()
if err == nil {
g.taskWaitGroup.Wait()
}
return err
}

Expand All @@ -179,7 +176,9 @@ func (g *resultTaskGroup[O]) Wait() ([]O, error) {

// This wait group could reach zero before the future is resolved if called in between tasks being submitted and the future being resolved.
// That's why we wait for the future to be resolved before waiting for the wait group.
g.taskWaitGroup.Wait()
if err == nil {
g.taskWaitGroup.Wait()
}

values := make([]O, len(results))

Expand Down
14 changes: 0 additions & 14 deletions group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ func TestResultTaskGroupWaitWithError(t *testing.T) {

pool := NewResultPool[int](1)

pool.EnableDebug()

group := pool.NewGroup()

sampleErr := errors.New("sample error")
Expand Down Expand Up @@ -147,33 +145,27 @@ func TestResultTaskGroupWaitWithContextCanceledAndOngoingTasks(t *testing.T) {
func TestTaskGroupWaitWithContextCanceledAndOngoingTasks(t *testing.T) {
pool := NewPool(1)

var executedCount atomic.Int32

ctx, cancel := context.WithCancel(context.Background())

group := pool.NewGroupContext(ctx)

group.Submit(func() {
cancel() // cancel the context after the first task is started
time.Sleep(10 * time.Millisecond)
executedCount.Add(1)
})

group.Submit(func() {
time.Sleep(10 * time.Millisecond)
executedCount.Add(1)
})

err := group.Wait()

assert.Equal(t, context.Canceled, err)
assert.Equal(t, int32(1), executedCount.Load())
}

func TestTaskGroupWithStoppedPool(t *testing.T) {

pool := NewPool(100)
pool.EnableDebug()

pool.StopAndWait()

Expand All @@ -185,7 +177,6 @@ func TestTaskGroupWithStoppedPool(t *testing.T) {
func TestTaskGroupWithContextCanceled(t *testing.T) {

pool := NewPool(100)
pool.EnableDebug()

group := pool.NewGroup()

Expand Down Expand Up @@ -227,7 +218,6 @@ func TestTaskGroupWithNoTasks(t *testing.T) {
func TestTaskGroupCanceledShouldSkipRemainingTasks(t *testing.T) {

pool := NewPool(1)
pool.EnableDebug()

group := pool.NewGroup()

Expand Down Expand Up @@ -256,8 +246,6 @@ func TestTaskGroupCanceledShouldSkipRemainingTasks(t *testing.T) {
func TestTaskGroupWithCustomContext(t *testing.T) {
pool := NewPool(1)

pool.EnableDebug()

ctx, cancel := context.WithCancel(context.Background())

group := pool.NewGroupContext(ctx)
Expand All @@ -284,7 +272,6 @@ func TestTaskGroupWithCustomContext(t *testing.T) {

func TestTaskGroupStop(t *testing.T) {
pool := NewPool(1)
pool.EnableDebug()

group := pool.NewGroup()

Expand All @@ -310,7 +297,6 @@ func TestTaskGroupStop(t *testing.T) {

func TestTaskGroupDone(t *testing.T) {
pool := NewPool(10)
pool.EnableDebug()

group := pool.NewGroup()

Expand Down
7 changes: 1 addition & 6 deletions internal/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ type Dispatcher[T any] struct {
waitGroup sync.WaitGroup
batchSize int
closed atomic.Bool
Debug bool
}

// NewDispatcher creates a generic dispatcher that can receive values from multiple goroutines in a thread-safe manner
Expand All @@ -28,7 +27,7 @@ func NewDispatcher[T any](ctx context.Context, dispatchFunc func([]T), batchSize
dispatcher := &Dispatcher[T]{
ctx: ctx,
buffer: linkedbuffer.NewLinkedBuffer[T](10, batchSize),
bufferHasElements: make(chan struct{}, 1), // This channel needs to have size 2 in case an element is written to the buffer while the dispatcher is processing elements
bufferHasElements: make(chan struct{}, 1),
dispatchFunc: dispatchFunc,
batchSize: batchSize,
}
Expand All @@ -39,10 +38,6 @@ func NewDispatcher[T any](ctx context.Context, dispatchFunc func([]T), batchSize
return dispatcher
}

func (d *Dispatcher[T]) EnableDebug() {
d.buffer.Debug = true
}

// Write writes values to the dispatcher
func (d *Dispatcher[T]) Write(values ...T) error {
// Check if the dispatcher has been closed
Expand Down
6 changes: 0 additions & 6 deletions internal/linkedbuffer/linkedbuffer.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package linkedbuffer

import (
"fmt"
"sync"
"sync/atomic"
)
Expand All @@ -19,7 +18,6 @@ type LinkedBuffer[T any] struct {
writeCount atomic.Uint64
readCount atomic.Uint64
mutex sync.Mutex
Debug bool
}

func NewLinkedBuffer[T any](initialCapacity, maxCapacity int) *LinkedBuffer[T] {
Expand Down Expand Up @@ -91,10 +89,6 @@ func (b *LinkedBuffer[T]) Read(values []T) int {
// Read element
n, err := readBuffer.Read(values)

if b.Debug {
fmt.Printf("read %d elements: %v\n", n, values)
}

if err == ErrEOF {
// Move to next buffer
if readBuffer.next == nil {
Expand Down
68 changes: 45 additions & 23 deletions pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (

var MAX_TASKS_CHAN_LENGTH = runtime.NumCPU() * 128

var PERSISTENT_WORKER_COUNT = int64(runtime.NumCPU())

var ErrPoolStopped = errors.New("pool stopped")

var poolStoppedFuture = func() Task {
Expand Down Expand Up @@ -79,8 +81,6 @@ type Pool interface {

// Creates a new task group with the specified context.
NewGroupContext(ctx context.Context) TaskGroup

EnableDebug()
}

// pool is an implementation of the Pool interface.
Expand All @@ -93,6 +93,7 @@ type pool struct {
workerCount atomic.Int64
workerWaitGroup sync.WaitGroup
dispatcher *dispatcher.Dispatcher[any]
dispatcherRunning sync.Mutex
successfulTaskCount atomic.Uint64
failedTaskCount atomic.Uint64
}
Expand Down Expand Up @@ -141,10 +142,6 @@ func (p *pool) updateMetrics(err error) {
}
}

func (d *pool) EnableDebug() {
d.dispatcher.EnableDebug()
}

func (p *pool) Go(task func()) error {
if err := p.dispatcher.Write(task); err != nil {
return ErrPoolStopped
Expand Down Expand Up @@ -202,15 +199,16 @@ func (p *pool) NewGroupContext(ctx context.Context) TaskGroup {
}

func (p *pool) dispatch(incomingTasks []any) {
p.dispatcherRunning.Lock()
defer p.dispatcherRunning.Unlock()

// Submit tasks
for _, task := range incomingTasks {
p.dispatchTask(task)
}
}

func (p *pool) dispatchTask(task any) {
workerCount := int(p.workerCount.Load())

// Attempt to submit task without blocking
select {
case p.tasks <- task:
Expand All @@ -220,19 +218,13 @@ func (p *pool) dispatchTask(task any) {
// 1. There are no idle workers (all spawned workers are processing a task)
// 2. There are no workers in the pool
// In either case, we should launch a new worker as long as the number of workers is less than the size of the task queue.
if workerCount < p.tasksLen {
// Launch a new worker
p.startWorker()
}
p.startWorker(p.tasksLen)
return
default:
}

// Task queue is full, launch a new worker if the number of workers is less than the maximum concurrency
if workerCount < p.maxConcurrency {
// Launch a new worker
p.startWorker()
}
p.startWorker(p.maxConcurrency)

// Block until task is submitted
select {
Expand All @@ -244,15 +236,36 @@ func (p *pool) dispatchTask(task any) {
}
}

func (p *pool) startWorker() {
func (p *pool) startWorker(limit int) {
if p.workerCount.Load() >= int64(limit) {
return
}
p.workerWaitGroup.Add(1)
p.workerCount.Add(1)
go p.worker()
workerNumber := p.workerCount.Add(1)
// Guarantee at least PERSISTENT_WORKER_COUNT workers are always running during dispatch to prevent deadlocks
canExitDuringDispatch := workerNumber > PERSISTENT_WORKER_COUNT
go p.worker(canExitDuringDispatch)
}

func (p *pool) worker() {
defer func() {
func (p *pool) workerCanExit(canExitDuringDispatch bool) bool {
if canExitDuringDispatch {
p.workerCount.Add(-1)
return true
}

// Check if the dispatcher is running
if !p.dispatcherRunning.TryLock() {
// Dispatcher is running, cannot exit yet
return false
}
p.workerCount.Add(-1)
p.dispatcherRunning.Unlock()

return true
}

func (p *pool) worker(canExitDuringDispatch bool) {
defer func() {
p.workerWaitGroup.Done()
}()

Expand All @@ -261,17 +274,20 @@ func (p *pool) worker() {
select {
case <-p.ctx.Done():
// Context cancelled, exit
p.workerCount.Add(-1)
return
default:
}

select {
case <-p.ctx.Done():
// Context cancelled, exit
p.workerCount.Add(-1)
return
case task, ok := <-p.tasks:
if !ok || task == nil {
// Channel closed or worker killed, exit
p.workerCount.Add(-1)
return
}

Expand All @@ -282,8 +298,14 @@ func (p *pool) worker() {
p.updateMetrics(err)

default:
// No tasks left, exit
return
// No tasks left

// Check if the worker can exit
if p.workerCanExit(canExitDuringDispatch) {
return
}
runtime.Gosched()
continue
}
}
}
Expand Down
7 changes: 0 additions & 7 deletions result.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ import (
type ResultPool[R any] interface {
basePool

// Enables debug mode for the pool.
EnableDebug()

// Submits a task to the pool and returns a future that can be used to wait for the task to complete and get the result.
Submit(task func() R) Result[R]

Expand All @@ -33,10 +30,6 @@ type resultPool[R any] struct {
*pool
}

func (d *resultPool[R]) EnableDebug() {
d.dispatcher.EnableDebug()
}

func (p *resultPool[R]) NewGroup() ResultTaskGroup[R] {
return newResultTaskGroup[R](p.pool, p.Context())
}
Expand Down

0 comments on commit 34e8f75

Please sign in to comment.