Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(pool): fix race condition with small pool sizes #83

Merged
merged 7 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
go-version: ${{ matrix.go-version }}

- name: Test
run: make test
run: make test-ci
codecov:
name: Coverage report
runs-on: ubuntu-latest
Expand Down
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
test:
go test -race -v -timeout 1m ./...
go test -race -v -timeout 15s -count=1 ./...

test-ci:
go test -race -v -timeout 1m -count=3 ./...

coverage:
go test -race -v -timeout 1m -coverprofile=coverage.out -covermode=atomic ./...
5 changes: 3 additions & 2 deletions group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ func TestResultTaskGroupWait(t *testing.T) {

func TestResultTaskGroupWaitWithError(t *testing.T) {

group := NewResultPool[int](1).
NewGroup()
pool := NewResultPool[int](1)

group := pool.NewGroup()

sampleErr := errors.New("sample error")

Expand Down
4 changes: 0 additions & 4 deletions internal/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ func NewDispatcher[T any](ctx context.Context, dispatchFunc func([]T), batchSize
bufferHasElements: make(chan struct{}, 1),
dispatchFunc: dispatchFunc,
batchSize: batchSize,
closed: atomic.Bool{},
}

dispatcher.waitGroup.Add(1)
Expand Down Expand Up @@ -118,9 +117,6 @@ func (d *Dispatcher[T]) run(ctx context.Context) {

// Submit the next batch of values
d.dispatchFunc(batch[0:batchSize])

// Reset batch
batch = batch[:0]
}

if !ok || d.closed.Load() {
Expand Down
9 changes: 3 additions & 6 deletions internal/linkedbuffer/linkedbuffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type LinkedBuffer[T any] struct {
maxCapacity int
writeCount atomic.Uint64
readCount atomic.Uint64
mutex sync.RWMutex
mutex sync.Mutex
}

func NewLinkedBuffer[T any](initialCapacity, maxCapacity int) *LinkedBuffer[T] {
Expand Down Expand Up @@ -78,28 +78,25 @@ func (b *LinkedBuffer[T]) Write(values []T) {

// Read reads values from the buffer and returns the number of elements read
func (b *LinkedBuffer[T]) Read(values []T) int {
b.mutex.Lock()
defer b.mutex.Unlock()

var readBuffer *Buffer[T]

for {
b.mutex.RLock()
readBuffer = b.readBuffer
b.mutex.RUnlock()

// Read element
n, err := readBuffer.Read(values)

if err == ErrEOF {
// Move to next buffer
b.mutex.Lock()
if readBuffer.next == nil {
b.mutex.Unlock()
return n
}
if b.readBuffer != readBuffer.next {
b.readBuffer = readBuffer.next
}
b.mutex.Unlock()
continue
}

Expand Down
39 changes: 39 additions & 0 deletions internal/linkedbuffer/linkedbuffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,42 @@ func TestLinkedBufferLen(t *testing.T) {
buf.readCount.Add(1)
assert.Equal(t, uint64(0), buf.Len())
}

func TestLinkedBufferWithReusedBuffer(t *testing.T) {

buf := NewLinkedBuffer[int](2, 1)

values := make([]int, 1)

buf.Write([]int{1})
buf.Write([]int{2})

n := buf.Read(values)

assert.Equal(t, 1, n)
assert.Equal(t, 1, values[0])

assert.Equal(t, 1, len(values))
assert.Equal(t, 1, cap(values))

n = buf.Read(values)

assert.Equal(t, 1, n)
assert.Equal(t, 1, len(values))
assert.Equal(t, 2, values[0])

buf.Write([]int{3})
buf.Write([]int{4})

n = buf.Read(values)

assert.Equal(t, 1, n)
assert.Equal(t, 1, len(values))
assert.Equal(t, 3, values[0])

n = buf.Read(values)

assert.Equal(t, 1, n)
assert.Equal(t, 1, len(values))
assert.Equal(t, 4, values[0])
}
66 changes: 49 additions & 17 deletions pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

var MAX_TASKS_CHAN_LENGTH = runtime.NumCPU() * 128

var PERSISTENT_WORKER_COUNT = int64(runtime.NumCPU())

var ErrPoolStopped = errors.New("pool stopped")

var poolStoppedFuture = func() Task {
Expand Down Expand Up @@ -91,6 +93,7 @@
workerCount atomic.Int64
workerWaitGroup sync.WaitGroup
dispatcher *dispatcher.Dispatcher[any]
dispatcherRunning sync.Mutex
successfulTaskCount atomic.Uint64
failedTaskCount atomic.Uint64
}
Expand Down Expand Up @@ -196,15 +199,16 @@
}

func (p *pool) dispatch(incomingTasks []any) {
p.dispatcherRunning.Lock()
defer p.dispatcherRunning.Unlock()

// Submit tasks
for _, task := range incomingTasks {
p.dispatchTask(task)
}
}

func (p *pool) dispatchTask(task any) {
workerCount := int(p.workerCount.Load())

// Attempt to submit task without blocking
select {
case p.tasks <- task:
Expand All @@ -214,19 +218,13 @@
// 1. There are no idle workers (all spawned workers are processing a task)
// 2. There are no workers in the pool
// In either case, we should launch a new worker as long as the number of workers is less than the size of the task queue.
if workerCount < p.tasksLen {
// Launch a new worker
p.startWorker()
}
p.startWorker(p.tasksLen)
return
default:
}

// Task queue is full, launch a new worker if the number of workers is less than the maximum concurrency
if workerCount < p.maxConcurrency {
// Launch a new worker
p.startWorker()
}
p.startWorker(p.maxConcurrency)

// Block until task is submitted
select {
Expand All @@ -238,15 +236,41 @@
}
}

func (p *pool) startWorker() {
func (p *pool) startWorker(limit int) {
if p.workerCount.Load() >= int64(limit) {
return
}
p.workerWaitGroup.Add(1)
p.workerCount.Add(1)
go p.worker()
workerNumber := p.workerCount.Add(1)
// Guarantee at least PERSISTENT_WORKER_COUNT workers are always running during dispatch to prevent deadlocks
canExitDuringDispatch := workerNumber > PERSISTENT_WORKER_COUNT
go p.worker(canExitDuringDispatch)
}

func (p *pool) worker() {
defer func() {
func (p *pool) workerCanExit(canExitDuringDispatch bool) bool {
if canExitDuringDispatch {
p.workerCount.Add(-1)
return true
}

// Check if the dispatcher is running
if !p.dispatcherRunning.TryLock() {
// Dispatcher is running, cannot exit yet
return false
}
if len(p.tasks) > 0 {
// There are tasks in the queue, cannot exit yet
p.dispatcherRunning.Unlock()
return false
}
p.workerCount.Add(-1)
p.dispatcherRunning.Unlock()

return true
}

func (p *pool) worker(canExitDuringDispatch bool) {
defer func() {
p.workerWaitGroup.Done()
}()

Expand All @@ -255,17 +279,20 @@
select {
case <-p.ctx.Done():
// Context cancelled, exit
p.workerCount.Add(-1)
return
default:
}

select {
case <-p.ctx.Done():
// Context cancelled, exit
p.workerCount.Add(-1)

Check warning on line 290 in pool.go

View check run for this annotation

Codecov / codecov/patch

pool.go#L290

Added line #L290 was not covered by tests
return
case task, ok := <-p.tasks:
if !ok || task == nil {
// Channel closed or worker killed, exit
p.workerCount.Add(-1)
return
}

Expand All @@ -276,8 +303,13 @@
p.updateMetrics(err)

default:
// No tasks left, exit
return
// No tasks left

// Check if the worker can exit
if p.workerCanExit(canExitDuringDispatch) {
return
}
continue
}
}
}
Expand Down
Loading