From 2cc851851cb9fe3bf6f3e4d045ecf91677154308 Mon Sep 17 00:00:00 2001 From: Camden Cheek Date: Mon, 1 Jan 2024 08:59:40 -0700 Subject: [PATCH] make pools reusable --- pool/error_pool.go | 10 +++++++--- pool/error_pool_test.go | 15 +++++++++++++++ pool/result_context_pool.go | 4 +++- pool/result_context_pool_test.go | 16 ++++++++++++++++ pool/result_error_pool.go | 4 +++- pool/result_error_pool_test.go | 16 ++++++++++++++++ pool/result_pool.go | 4 +++- pool/result_pool_test.go | 13 +++++++++++++ 8 files changed, 76 insertions(+), 6 deletions(-) diff --git a/pool/error_pool.go b/pool/error_pool.go index 2e999a5..939fb29 100644 --- a/pool/error_pool.go +++ b/pool/error_pool.go @@ -35,12 +35,16 @@ func (p *ErrorPool) Go(f func() error) { // returning any errors from tasks. func (p *ErrorPool) Wait() error { p.pool.Wait() - if len(p.errs) == 0 { + + errs := p.errs + p.errs = nil // reset errs + + if len(errs) == 0 { return nil } else if p.onlyFirstError { - return p.errs[0] + return errs[0] } else { - return multierror.Join(p.errs...) + return multierror.Join(errs...) } } diff --git a/pool/error_pool_test.go b/pool/error_pool_test.go index 5eb8b20..814f90b 100644 --- a/pool/error_pool_test.go +++ b/pool/error_pool_test.go @@ -117,4 +117,19 @@ func TestErrorPool(t *testing.T) { }) } }) + + t.Run("reuse", func(t *testing.T) { + // Test for https://github.com/sourcegraph/conc/issues/128 + p := pool.New().WithErrors() + + p.Go(func() error { return err1 }) + wait1 := p.Wait() + require.ErrorIs(t, wait1, err1) + + p.Go(func() error { return err2 }) + wait2 := p.Wait() + // On reuse, only the new error should be returned + require.ErrorIs(t, wait2, err2) + require.NotErrorIs(t, wait1, err2) + }) } diff --git a/pool/result_context_pool.go b/pool/result_context_pool.go index 8560c6a..6bc30dd 100644 --- a/pool/result_context_pool.go +++ b/pool/result_context_pool.go @@ -32,7 +32,9 @@ func (p *ResultContextPool[T]) Go(f func(context.Context) (T, error)) { // returns an error if any of the tasks errored. func (p *ResultContextPool[T]) Wait() ([]T, error) { err := p.contextPool.Wait() - return p.agg.collect(p.collectErrored), err + results := p.agg.collect(p.collectErrored) + p.agg = resultAggregator[T]{} + return results, err } // WithCollectErrored configures the pool to still collect the result of a task diff --git a/pool/result_context_pool_test.go b/pool/result_context_pool_test.go index ceae5e8..fc3b68a 100644 --- a/pool/result_context_pool_test.go +++ b/pool/result_context_pool_test.go @@ -228,4 +228,20 @@ func TestResultContextPool(t *testing.T) { }) } }) + + t.Run("reuse", func(t *testing.T) { + // Test for https://github.com/sourcegraph/conc/issues/128 + p := pool.NewWithResults[int]().WithContext(context.Background()) + + p.Go(func(context.Context) (int, error) { return 1, err1 }) + results1, errs1 := p.Wait() + require.Empty(t, results1) + require.ErrorIs(t, errs1, err1) + + p.Go(func(context.Context) (int, error) { return 2, err2 }) + results2, errs2 := p.Wait() + require.Empty(t, results2) + require.ErrorIs(t, errs2, err2) + require.NotErrorIs(t, errs2, err1) + }) } diff --git a/pool/result_error_pool.go b/pool/result_error_pool.go index 5a0bfb9..832cd9b 100644 --- a/pool/result_error_pool.go +++ b/pool/result_error_pool.go @@ -34,7 +34,9 @@ func (p *ResultErrorPool[T]) Go(f func() (T, error)) { // returning the results and any errors from tasks. func (p *ResultErrorPool[T]) Wait() ([]T, error) { err := p.errorPool.Wait() - return p.agg.collect(p.collectErrored), err + results := p.agg.collect(p.collectErrored) + p.agg = resultAggregator[T]{} // reset for reuse + return results, err } // WithCollectErrored configures the pool to still collect the result of a task diff --git a/pool/result_error_pool_test.go b/pool/result_error_pool_test.go index c9b1b08..7326639 100644 --- a/pool/result_error_pool_test.go +++ b/pool/result_error_pool_test.go @@ -130,4 +130,20 @@ func TestResultErrorPool(t *testing.T) { }) } }) + + t.Run("reuse", func(t *testing.T) { + // Test for https://github.com/sourcegraph/conc/issues/128 + p := pool.NewWithResults[int]().WithErrors() + + p.Go(func() (int, error) { return 1, err1 }) + results1, errs1 := p.Wait() + require.Empty(t, results1) + require.ErrorIs(t, errs1, err1) + + p.Go(func() (int, error) { return 2, err2 }) + results2, errs2 := p.Wait() + require.Empty(t, results2) + require.ErrorIs(t, errs2, err2) + require.NotErrorIs(t, errs2, err1) + }) } diff --git a/pool/result_pool.go b/pool/result_pool.go index 16d8b46..f73a772 100644 --- a/pool/result_pool.go +++ b/pool/result_pool.go @@ -40,7 +40,9 @@ func (p *ResultPool[T]) Go(f func() T) { // a slice of results from tasks that did not panic. func (p *ResultPool[T]) Wait() []T { p.pool.Wait() - return p.agg.collect(true) + results := p.agg.collect(true) + p.agg = resultAggregator[T]{} // reset for reuse + return results } // MaxGoroutines returns the maximum size of the pool. diff --git a/pool/result_pool_test.go b/pool/result_pool_test.go index ccd7892..69b9de4 100644 --- a/pool/result_pool_test.go +++ b/pool/result_pool_test.go @@ -113,4 +113,17 @@ func TestResultGroup(t *testing.T) { }) } }) + + t.Run("reuse", func(t *testing.T) { + // Test for https://github.com/sourcegraph/conc/issues/128 + p := pool.NewWithResults[int]() + + p.Go(func() int { return 1 }) + results1 := p.Wait() + require.Equal(t, []int{1}, results1) + + p.Go(func() int { return 2 }) + results2 := p.Wait() + require.Equal(t, []int{2}, results2) + }) }