Skip to content

Commit

Permalink
Fixed write to closed channel in dispatcher
Browse files Browse the repository at this point in the history
  • Loading branch information
korotin committed Nov 26, 2024
1 parent fc408b5 commit 0e1af5c
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 11 deletions.
12 changes: 12 additions & 0 deletions internal/assert/assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,15 @@ func PanicsWithError(t *testing.T, expected string, f func()) {
}()
f()
}

/**
* Asserts that the function not panics.
*/
func NotPanics(t *testing.T, f func()) {
defer func() {
if r := recover(); r != nil {
t.Errorf("No panic was expected, but got %T(%v)", r, r)
}
}()
f()
}
29 changes: 18 additions & 11 deletions internal/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"errors"
"sync"
"sync/atomic"

"github.com/alitto/pond/v2/internal/linkedbuffer"
)
Expand All @@ -18,7 +17,7 @@ type Dispatcher[T any] struct {
dispatchFunc func([]T)
waitGroup sync.WaitGroup
batchSize int
closed atomic.Bool
closed chan struct{}
}

// NewDispatcher creates a generic dispatcher that can receive values from multiple goroutines in a thread-safe manner
Expand All @@ -30,6 +29,7 @@ func NewDispatcher[T any](ctx context.Context, dispatchFunc func([]T), batchSize
bufferHasElements: make(chan struct{}, 1),
dispatchFunc: dispatchFunc,
batchSize: batchSize,
closed: make(chan struct{}),
}

dispatcher.waitGroup.Add(1)
Expand All @@ -41,15 +41,23 @@ func NewDispatcher[T any](ctx context.Context, dispatchFunc func([]T), batchSize
// Write writes values to the dispatcher
func (d *Dispatcher[T]) Write(values ...T) error {
// Check if the dispatcher has been closed
if d.closed.Load() || d.ctx.Err() != nil {
select {
case <-d.ctx.Done():
return ErrDispatcherClosed

case <-d.closed:
return ErrDispatcherClosed

default:

}

// Append elements to the buffer
d.buffer.Write(values)

// Notify there are elements in the buffer
select {
case <-d.closed:
case d.bufferHasElements <- struct{}{}:
default:
}
Expand All @@ -74,8 +82,7 @@ func (d *Dispatcher[T]) Len() uint64 {

// Close closes the dispatcher
func (d *Dispatcher[T]) Close() {
d.closed.Store(true)
close(d.bufferHasElements)
close(d.closed)
}

// CloseAndWait closes the dispatcher and waits for all pending elements to be processed
Expand All @@ -98,14 +105,19 @@ func (d *Dispatcher[T]) run(ctx context.Context) {
case <-ctx.Done():
// Context cancelled, exit
return
case <-d.closed:
return
default:
}

select {
case <-ctx.Done():
// Context cancelled, exit
return
case _, ok := <-d.bufferHasElements:
case <-d.closed:
// Dispatcher closed, exit
return
case <-d.bufferHasElements:

// Attempt to read all pending elements
for {
Expand All @@ -118,11 +130,6 @@ func (d *Dispatcher[T]) run(ctx context.Context) {
// Submit the next batch of values
d.dispatchFunc(batch[0:batchSize])
}

if !ok || d.closed.Load() {
// Channel was closed, read all remaining elements and exit
return
}
}
}
}
25 changes: 25 additions & 0 deletions internal/dispatcher/dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,28 @@ func TestDispatcherWriteAfterClose(t *testing.T) {
assert.Equal(t, uint64(0), dispatcher.WriteCount())
assert.Equal(t, uint64(0), dispatcher.ReadCount())
}

func TestDispatcherWriteAfterCloseConcurrent(t *testing.T) {

ctx := context.Background()

dispatcher := NewDispatcher(ctx, func(t []int) {}, 1024)

wg := sync.WaitGroup{}
defer wg.Wait()

for i := 0; i < 1024; i++ {
wg.Add(1)
go func() {
defer wg.Done()

assert.NotPanics(t, func() {
dispatcher.Write(1)
})
}()
}

assert.NotPanics(t, func() {
dispatcher.Close()
})
}

0 comments on commit 0e1af5c

Please sign in to comment.