diff --git a/internal/assert/assert.go b/internal/assert/assert.go index 53b76bd..e2a199d 100644 --- a/internal/assert/assert.go +++ b/internal/assert/assert.go @@ -54,3 +54,15 @@ func PanicsWithError(t *testing.T, expected string, f func()) { }() f() } + +/** + * Asserts that the function not panics. + */ +func NotPanics(t *testing.T, f func()) { + defer func() { + if r := recover(); r != nil { + t.Errorf("No panic was expected, but got %T(%v)", r, r) + } + }() + f() +} diff --git a/internal/dispatcher/dispatcher.go b/internal/dispatcher/dispatcher.go index 0cc31a4..207f63f 100644 --- a/internal/dispatcher/dispatcher.go +++ b/internal/dispatcher/dispatcher.go @@ -4,7 +4,6 @@ import ( "context" "errors" "sync" - "sync/atomic" "github.com/alitto/pond/v2/internal/linkedbuffer" ) @@ -18,7 +17,7 @@ type Dispatcher[T any] struct { dispatchFunc func([]T) waitGroup sync.WaitGroup batchSize int - closed atomic.Bool + closed chan struct{} } // NewDispatcher creates a generic dispatcher that can receive values from multiple goroutines in a thread-safe manner @@ -30,6 +29,7 @@ func NewDispatcher[T any](ctx context.Context, dispatchFunc func([]T), batchSize bufferHasElements: make(chan struct{}, 1), dispatchFunc: dispatchFunc, batchSize: batchSize, + closed: make(chan struct{}), } dispatcher.waitGroup.Add(1) @@ -41,8 +41,15 @@ func NewDispatcher[T any](ctx context.Context, dispatchFunc func([]T), batchSize // Write writes values to the dispatcher func (d *Dispatcher[T]) Write(values ...T) error { // Check if the dispatcher has been closed - if d.closed.Load() || d.ctx.Err() != nil { + select { + case <-d.ctx.Done(): + return ErrDispatcherClosed + + case <-d.closed: return ErrDispatcherClosed + + default: + } // Append elements to the buffer @@ -50,6 +57,7 @@ func (d *Dispatcher[T]) Write(values ...T) error { // Notify there are elements in the buffer select { + case <-d.closed: case d.bufferHasElements <- struct{}{}: default: } @@ -74,8 +82,7 @@ func (d *Dispatcher[T]) Len() uint64 { // Close closes the dispatcher func (d *Dispatcher[T]) Close() { - d.closed.Store(true) - close(d.bufferHasElements) + close(d.closed) } // CloseAndWait closes the dispatcher and waits for all pending elements to be processed @@ -98,6 +105,8 @@ func (d *Dispatcher[T]) run(ctx context.Context) { case <-ctx.Done(): // Context cancelled, exit return + case <-d.closed: + return default: } @@ -105,7 +114,10 @@ func (d *Dispatcher[T]) run(ctx context.Context) { case <-ctx.Done(): // Context cancelled, exit return - case _, ok := <-d.bufferHasElements: + case <-d.closed: + // Dispatcher closed, exit + return + case <-d.bufferHasElements: // Attempt to read all pending elements for { @@ -118,11 +130,6 @@ func (d *Dispatcher[T]) run(ctx context.Context) { // Submit the next batch of values d.dispatchFunc(batch[0:batchSize]) } - - if !ok || d.closed.Load() { - // Channel was closed, read all remaining elements and exit - return - } } } } diff --git a/internal/dispatcher/dispatcher_test.go b/internal/dispatcher/dispatcher_test.go index c6dbfe7..ad5c930 100644 --- a/internal/dispatcher/dispatcher_test.go +++ b/internal/dispatcher/dispatcher_test.go @@ -151,3 +151,28 @@ func TestDispatcherWriteAfterClose(t *testing.T) { assert.Equal(t, uint64(0), dispatcher.WriteCount()) assert.Equal(t, uint64(0), dispatcher.ReadCount()) } + +func TestDispatcherWriteAfterCloseConcurrent(t *testing.T) { + + ctx := context.Background() + + dispatcher := NewDispatcher(ctx, func(t []int) {}, 1024) + + wg := sync.WaitGroup{} + defer wg.Wait() + + for i := 0; i < 1024; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + assert.NotPanics(t, func() { + dispatcher.Write(1) + }) + }() + } + + assert.NotPanics(t, func() { + dispatcher.Close() + }) +}