diff --git a/README.md b/README.md index effcf67..a6c53c4 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,46 @@ func main() { [Example wrapper function](https://go.dev/play/p/BWnRhJYarZ1) to show start and finish time of submitted function. + +In case you want panic-safe concurrent use of the worker pool that handles potential stops gracefully, you can use TrySubmit and TrySubmitWait. +These methods will return an error instead of panic if the worker pool has been stopped, allowing you to handle such situations appropriately. + +```go +package main + +import ( + "fmt" + "github.com/gammazero/workerpool" +) + +func main() { + wp := workerpool.New(2) + requests := []string{"alpha", "beta", "gamma", "delta", "epsilon"} + + for _, r := range requests { + r := r + if err := wp.TrySubmit(func() { + fmt.Println("Handling request:", r) + }); err != nil { + fmt.Printf("Failed to submit task for request %s: %v", r, err) + } + } + + wp.StopWait() + + wp = workerpool.New(2) + for _, r := range requests { + r := r + if err := wp.TrySubmitWait(func() { + fmt.Println("Handling request with wait:", r) + }); err != nil { + fmt.Printf("Failed to submit and wait for task for request %s: %v", r, err) + } + } +} + +``` + ## Usage Note There is no upper limit on the number of tasks queued, other than the limits of system resources. If the number of inbound tasks is too many to even queue for pending processing, then the solution is outside the scope of workerpool. It should be solved by distributing workload over multiple systems, and/or storing input for pending processing in intermediate storage such as a file system, distributed message queue, etc. diff --git a/workerpool.go b/workerpool.go index b7ac375..ceca004 100644 --- a/workerpool.go +++ b/workerpool.go @@ -2,6 +2,7 @@ package workerpool import ( "context" + "errors" "sync" "sync/atomic" "time" @@ -14,6 +15,10 @@ const ( idleTimeout = 2 * time.Second ) +var ( + ErrWorkerStopped = errors.New("worker stopped") +) + // New creates and starts a pool of worker goroutines. // // The maxWorkers parameter specifies the maximum number of workers that can @@ -110,6 +115,18 @@ func (p *WorkerPool) Submit(task func()) { } } +// TrySubmit tries to enqueue a function for a worker to execute. +// It will return ErrWorkerStopped if the worker pool has been stopped. +// +// Refer to Submit for more information. +func (p *WorkerPool) TrySubmit(task func()) error { + if p.Stopped() { + return ErrWorkerStopped + } + p.Submit(task) + return nil +} + // SubmitWait enqueues the given function and waits for it to be executed. func (p *WorkerPool) SubmitWait(task func()) { if task == nil { @@ -123,6 +140,18 @@ func (p *WorkerPool) SubmitWait(task func()) { <-doneChan } +// TrySubmitWait tries to enqueue the given function and waits for it to be executed. +// It will return ErrWorkerStopped if the worker pool has been stopped. +// +// Refer to SubmitWait for more information. +func (p *WorkerPool) TrySubmitWait(task func()) error { + if p.Stopped() { + return ErrWorkerStopped + } + p.SubmitWait(task) + return nil +} + // WaitingQueueSize returns the count of tasks in the waiting queue. func (p *WorkerPool) WaitingQueueSize() int { return int(atomic.LoadInt32(&p.waiting)) diff --git a/workerpool_test.go b/workerpool_test.go index 5a982e2..0a3345b 100644 --- a/workerpool_test.go +++ b/workerpool_test.go @@ -2,6 +2,7 @@ package workerpool import ( "context" + "errors" "sync" "testing" "time" @@ -598,6 +599,89 @@ func TestWorkerLeak(t *testing.T) { wp.Stop() } +func TestTrySubmit(t *testing.T) { + defer goleak.VerifyNone(t) + + wp := New(1) + + doneCh := make(chan struct{}) + defer close(doneCh) + + err := wp.TrySubmit(func() { + doneCh <- struct{}{} + }) + if err != nil { + t.Error("expected no error") + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel() + + select { + case <-doneCh: + break + case <-ctx.Done(): + t.Error("timed out waiting for function to execute") + } + + wp.Stop() + err = wp.TrySubmit(func() { + doneCh <- struct{}{} + }) + if !errors.Is(err, ErrWorkerStopped) { + t.Error("expected ErrWorkerStopped") + } + + ctx2, cancel2 := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel2() + + select { + case <-doneCh: + t.Error("function should not have executed") + case <-ctx2.Done(): + break + } +} + +func TestTrySubmitWait(t *testing.T) { + defer goleak.VerifyNone(t) + + wp := New(1) + + done := make(chan struct{}) + err := wp.TrySubmitWait(func() { + time.Sleep(100 * time.Millisecond) + close(done) + }) + select { + case <-done: + if err != nil { + t.Error("expected no error") + } + default: + t.Error("TrySubmitWait did not wait for function to execute") + } + + wp.Stop() + + done2 := make(chan struct{}) + defer close(done2) + + err = wp.TrySubmitWait(func() { + time.Sleep(100 * time.Millisecond) + close(done2) + }) + + select { + case <-done2: + t.Error("no execution expected") + default: + if !errors.Is(err, ErrWorkerStopped) { + t.Error("expected ErrWorkerStopped") + } + } +} + func anyReady(w *WorkerPool) bool { release := make(chan struct{}) wait := func() {