Skip to content

Commit

Permalink
fix(dispatcher): some tasks are misse
Browse files Browse the repository at this point in the history
  • Loading branch information
alitto committed Nov 11, 2024
1 parent f1d2a44 commit 3832cf8
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
test:
go test -race -v -timeout 1m ./...
go test -race -v -timeout 1m -count=3 ./...

coverage:
go test -race -v -timeout 1m -coverprofile=coverage.out -covermode=atomic ./...
7 changes: 6 additions & 1 deletion group.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package pond
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"

Expand Down Expand Up @@ -103,9 +104,13 @@ func (g *abstractTaskGroup[T, E, O]) submit(task any) {
index := int(g.nextIndex.Add(1) - 1)

g.taskWaitGroup.Add(1)
fmt.Printf("submitting task %d\n", index)

err := g.pool.Go(func() {
defer g.taskWaitGroup.Done()
defer func() {
fmt.Printf("task %d done\n", index)
g.taskWaitGroup.Done()
}()

// Check if the context has been cancelled to prevent running tasks that are not needed
if err := g.future.Context().Err(); err != nil {
Expand Down
14 changes: 12 additions & 2 deletions group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,11 @@ func TestResultTaskGroupWait(t *testing.T) {

func TestResultTaskGroupWaitWithError(t *testing.T) {

group := NewResultPool[int](1).
NewGroup()
pool := NewResultPool[int](1)

pool.EnableDebug()

group := pool.NewGroup()

sampleErr := errors.New("sample error")

Expand Down Expand Up @@ -175,6 +178,7 @@ func TestTaskGroupWaitWithContextCanceledAndOngoingTasks(t *testing.T) {
func TestTaskGroupWithStoppedPool(t *testing.T) {

pool := NewPool(100)
pool.EnableDebug()

pool.StopAndWait()

Expand All @@ -186,6 +190,7 @@ func TestTaskGroupWithStoppedPool(t *testing.T) {
func TestTaskGroupWithContextCanceled(t *testing.T) {

pool := NewPool(100)
pool.EnableDebug()

group := pool.NewGroup()

Expand Down Expand Up @@ -227,6 +232,7 @@ func TestTaskGroupWithNoTasks(t *testing.T) {
func TestTaskGroupCanceledShouldSkipRemainingTasks(t *testing.T) {

pool := NewPool(1)
pool.EnableDebug()

group := pool.NewGroup()

Expand Down Expand Up @@ -255,6 +261,8 @@ func TestTaskGroupCanceledShouldSkipRemainingTasks(t *testing.T) {
func TestTaskGroupWithCustomContext(t *testing.T) {
pool := NewPool(1)

pool.EnableDebug()

ctx, cancel := context.WithCancel(context.Background())

group := pool.NewGroupContext(ctx)
Expand All @@ -281,6 +289,7 @@ func TestTaskGroupWithCustomContext(t *testing.T) {

func TestTaskGroupStop(t *testing.T) {
pool := NewPool(1)
pool.EnableDebug()

group := pool.NewGroup()

Expand All @@ -306,6 +315,7 @@ func TestTaskGroupStop(t *testing.T) {

func TestTaskGroupDone(t *testing.T) {
pool := NewPool(10)
pool.EnableDebug()

group := pool.NewGroup()

Expand Down
10 changes: 7 additions & 3 deletions internal/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@ type Dispatcher[T any] struct {
waitGroup sync.WaitGroup
batchSize int
closed atomic.Bool
Debug bool
}

// NewDispatcher creates a generic dispatcher that can receive values from multiple goroutines in a thread-safe manner
// and process each element serially using the dispatchFunc
func NewDispatcher[T any](ctx context.Context, dispatchFunc func([]T), batchSize int) *Dispatcher[T] {
dispatcher := &Dispatcher[T]{
ctx: ctx,
buffer: linkedbuffer.NewLinkedBuffer[T](10, batchSize),
bufferHasElements: make(chan struct{}, 1),
buffer: linkedbuffer.NewLinkedBuffer[T](1, batchSize),
bufferHasElements: make(chan struct{}, 2), // This channel needs to have size 2 in case an element is written to the buffer while the dispatcher is processing elements
dispatchFunc: dispatchFunc,
batchSize: batchSize,
closed: atomic.Bool{},
}

dispatcher.waitGroup.Add(1)
Expand All @@ -39,6 +39,10 @@ func NewDispatcher[T any](ctx context.Context, dispatchFunc func([]T), batchSize
return dispatcher
}

func (d *Dispatcher[T]) EnableDebug() {
d.buffer.Debug = true

Check warning on line 43 in internal/dispatcher/dispatcher.go

View check run for this annotation

Codecov / codecov/patch

internal/dispatcher/dispatcher.go#L42-L43

Added lines #L42 - L43 were not covered by tests
}

// Write writes values to the dispatcher
func (d *Dispatcher[T]) Write(values ...T) error {
// Check if the dispatcher has been closed
Expand Down
15 changes: 9 additions & 6 deletions internal/linkedbuffer/linkedbuffer.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package linkedbuffer

import (
"fmt"
"sync"
"sync/atomic"
)
Expand All @@ -17,7 +18,8 @@ type LinkedBuffer[T any] struct {
maxCapacity int
writeCount atomic.Uint64
readCount atomic.Uint64
mutex sync.RWMutex
mutex sync.Mutex
Debug bool
}

func NewLinkedBuffer[T any](initialCapacity, maxCapacity int) *LinkedBuffer[T] {
Expand Down Expand Up @@ -78,28 +80,29 @@ func (b *LinkedBuffer[T]) Write(values []T) {

// Read reads values from the buffer and returns the number of elements read
func (b *LinkedBuffer[T]) Read(values []T) int {
b.mutex.Lock()
defer b.mutex.Unlock()

var readBuffer *Buffer[T]

for {
b.mutex.RLock()
readBuffer = b.readBuffer
b.mutex.RUnlock()

// Read element
n, err := readBuffer.Read(values)

if b.Debug {
fmt.Printf("read %d elements: %v\n", n, values)
}

Check warning on line 96 in internal/linkedbuffer/linkedbuffer.go

View check run for this annotation

Codecov / codecov/patch

internal/linkedbuffer/linkedbuffer.go#L95-L96

Added lines #L95 - L96 were not covered by tests

if err == ErrEOF {
// Move to next buffer
b.mutex.Lock()
if readBuffer.next == nil {
b.mutex.Unlock()
return n
}
if b.readBuffer != readBuffer.next {
b.readBuffer = readBuffer.next
}
b.mutex.Unlock()
continue
}

Expand Down
39 changes: 39 additions & 0 deletions internal/linkedbuffer/linkedbuffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,42 @@ func TestLinkedBufferLen(t *testing.T) {
buf.readCount.Add(1)
assert.Equal(t, uint64(0), buf.Len())
}

func TestLinkedBufferWithReusedBuffer(t *testing.T) {

buf := NewLinkedBuffer[int](2, 1)

values := make([]int, 1)

buf.Write([]int{1})
buf.Write([]int{2})

n := buf.Read(values)

assert.Equal(t, 1, n)
assert.Equal(t, 1, values[0])

assert.Equal(t, 1, len(values))
assert.Equal(t, 1, cap(values))

n = buf.Read(values)

assert.Equal(t, 1, n)
assert.Equal(t, 1, len(values))
assert.Equal(t, 2, values[0])

buf.Write([]int{3})
buf.Write([]int{4})

n = buf.Read(values)

assert.Equal(t, 1, n)
assert.Equal(t, 1, len(values))
assert.Equal(t, 3, values[0])

n = buf.Read(values)

assert.Equal(t, 1, n)
assert.Equal(t, 1, len(values))
assert.Equal(t, 4, values[0])
}
6 changes: 6 additions & 0 deletions pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ type Pool interface {

// Creates a new task group with the specified context.
NewGroupContext(ctx context.Context) TaskGroup

EnableDebug()
}

// pool is an implementation of the Pool interface.
Expand Down Expand Up @@ -139,6 +141,10 @@ func (p *pool) updateMetrics(err error) {
}
}

func (d *pool) EnableDebug() {
d.dispatcher.EnableDebug()
}

func (p *pool) Go(task func()) error {
if err := p.dispatcher.Write(task); err != nil {
return ErrPoolStopped
Expand Down
7 changes: 7 additions & 0 deletions result.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ import (
type ResultPool[R any] interface {
basePool

// Enables debug mode for the pool.
EnableDebug()

// Submits a task to the pool and returns a future that can be used to wait for the task to complete and get the result.
Submit(task func() R) Result[R]

Expand All @@ -30,6 +33,10 @@ type resultPool[R any] struct {
*pool
}

func (d *resultPool[R]) EnableDebug() {
d.dispatcher.EnableDebug()
}

func (p *resultPool[R]) NewGroup() ResultTaskGroup[R] {
return newResultTaskGroup[R](p.pool, p.Context())
}
Expand Down

0 comments on commit 3832cf8

Please sign in to comment.