From b15a1b5b4c67f65970d3ead3c227bcb80f4010c7 Mon Sep 17 00:00:00 2001 From: Alejandro Durante Date: Sun, 10 Nov 2024 19:32:38 -0300 Subject: [PATCH] fix(linkedbuffer): fix bufferHasElements channel length --- Makefile | 2 +- group.go | 7 ++++++- group_test.go | 7 +++++++ internal/dispatcher/dispatcher.go | 8 ++++++-- internal/linkedbuffer/linkedbuffer.go | 15 +++++++++------ pool.go | 6 ++++++ 6 files changed, 35 insertions(+), 10 deletions(-) diff --git a/Makefile b/Makefile index 6ce0725..180eb66 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ test: - go test -race -v -timeout 1m ./... + go test -race -v -timeout 1m -count=3 ./... coverage: go test -race -v -timeout 1m -coverprofile=coverage.out -covermode=atomic ./... \ No newline at end of file diff --git a/group.go b/group.go index 0ba530e..19ee5ef 100644 --- a/group.go +++ b/group.go @@ -3,6 +3,7 @@ package pond import ( "context" "errors" + "fmt" "sync" "sync/atomic" @@ -103,9 +104,13 @@ func (g *abstractTaskGroup[T, E, O]) submit(task any) { index := int(g.nextIndex.Add(1) - 1) g.taskWaitGroup.Add(1) + fmt.Printf("submitting task %d\n", index) err := g.pool.Go(func() { - defer g.taskWaitGroup.Done() + defer func() { + fmt.Printf("task %d done\n", index) + g.taskWaitGroup.Done() + }() // Check if the context has been cancelled to prevent running tasks that are not needed if err := g.future.Context().Err(); err != nil { diff --git a/group_test.go b/group_test.go index e2dadae..49df2d6 100644 --- a/group_test.go +++ b/group_test.go @@ -175,6 +175,7 @@ func TestTaskGroupWaitWithContextCanceledAndOngoingTasks(t *testing.T) { func TestTaskGroupWithStoppedPool(t *testing.T) { pool := NewPool(100) + pool.EnableDebug() pool.StopAndWait() @@ -186,6 +187,7 @@ func TestTaskGroupWithStoppedPool(t *testing.T) { func TestTaskGroupWithContextCanceled(t *testing.T) { pool := NewPool(100) + pool.EnableDebug() group := pool.NewGroup() @@ -227,6 +229,7 @@ func TestTaskGroupWithNoTasks(t *testing.T) { func TestTaskGroupCanceledShouldSkipRemainingTasks(t *testing.T) { pool := NewPool(1) + pool.EnableDebug() group := pool.NewGroup() @@ -255,6 +258,8 @@ func TestTaskGroupCanceledShouldSkipRemainingTasks(t *testing.T) { func TestTaskGroupWithCustomContext(t *testing.T) { pool := NewPool(1) + pool.EnableDebug() + ctx, cancel := context.WithCancel(context.Background()) group := pool.NewGroupContext(ctx) @@ -281,6 +286,7 @@ func TestTaskGroupWithCustomContext(t *testing.T) { func TestTaskGroupStop(t *testing.T) { pool := NewPool(1) + pool.EnableDebug() group := pool.NewGroup() @@ -306,6 +312,7 @@ func TestTaskGroupStop(t *testing.T) { func TestTaskGroupDone(t *testing.T) { pool := NewPool(10) + pool.EnableDebug() group := pool.NewGroup() diff --git a/internal/dispatcher/dispatcher.go b/internal/dispatcher/dispatcher.go index 470485a..bd894f8 100644 --- a/internal/dispatcher/dispatcher.go +++ b/internal/dispatcher/dispatcher.go @@ -19,6 +19,7 @@ type Dispatcher[T any] struct { waitGroup sync.WaitGroup batchSize int closed atomic.Bool + Debug bool } // NewDispatcher creates a generic dispatcher that can receive values from multiple goroutines in a thread-safe manner @@ -27,10 +28,9 @@ func NewDispatcher[T any](ctx context.Context, dispatchFunc func([]T), batchSize dispatcher := &Dispatcher[T]{ ctx: ctx, buffer: linkedbuffer.NewLinkedBuffer[T](10, batchSize), - bufferHasElements: make(chan struct{}, 1), + bufferHasElements: make(chan struct{}, 2), // This channel needs to have size 2 in case an element is written to the buffer while the dispatcher is processing elements dispatchFunc: dispatchFunc, batchSize: batchSize, - closed: atomic.Bool{}, } dispatcher.waitGroup.Add(1) @@ -39,6 +39,10 @@ func NewDispatcher[T any](ctx context.Context, dispatchFunc func([]T), batchSize return dispatcher } +func (d *Dispatcher[T]) EnableDebug() { + d.buffer.Debug = true +} + // Write writes values to the dispatcher func (d *Dispatcher[T]) Write(values ...T) error { // Check if the dispatcher has been closed diff --git a/internal/linkedbuffer/linkedbuffer.go b/internal/linkedbuffer/linkedbuffer.go index e4da61d..872bdfd 100644 --- a/internal/linkedbuffer/linkedbuffer.go +++ b/internal/linkedbuffer/linkedbuffer.go @@ -1,6 +1,7 @@ package linkedbuffer import ( + "fmt" "sync" "sync/atomic" ) @@ -17,7 +18,8 @@ type LinkedBuffer[T any] struct { maxCapacity int writeCount atomic.Uint64 readCount atomic.Uint64 - mutex sync.RWMutex + mutex sync.Mutex + Debug bool } func NewLinkedBuffer[T any](initialCapacity, maxCapacity int) *LinkedBuffer[T] { @@ -78,28 +80,29 @@ func (b *LinkedBuffer[T]) Write(values []T) { // Read reads values from the buffer and returns the number of elements read func (b *LinkedBuffer[T]) Read(values []T) int { + b.mutex.Lock() + defer b.mutex.Unlock() var readBuffer *Buffer[T] for { - b.mutex.RLock() readBuffer = b.readBuffer - b.mutex.RUnlock() // Read element n, err := readBuffer.Read(values) + if b.Debug { + fmt.Printf("read %d elements: %v\n", n, values) + } + if err == ErrEOF { // Move to next buffer - b.mutex.Lock() if readBuffer.next == nil { - b.mutex.Unlock() return n } if b.readBuffer != readBuffer.next { b.readBuffer = readBuffer.next } - b.mutex.Unlock() continue } diff --git a/pool.go b/pool.go index 3d1a38d..f6d077f 100644 --- a/pool.go +++ b/pool.go @@ -79,6 +79,8 @@ type Pool interface { // Creates a new task group with the specified context. NewGroupContext(ctx context.Context) TaskGroup + + EnableDebug() } // pool is an implementation of the Pool interface. @@ -139,6 +141,10 @@ func (p *pool) updateMetrics(err error) { } } +func (d *pool) EnableDebug() { + d.dispatcher.EnableDebug() +} + func (p *pool) Go(task func()) error { if err := p.dispatcher.Write(task); err != nil { return ErrPoolStopped