diff --git a/group.go b/group.go index 0ba530e..5c56199 100644 --- a/group.go +++ b/group.go @@ -104,7 +104,7 @@ func (g *abstractTaskGroup[T, E, O]) submit(task any) { g.taskWaitGroup.Add(1) - err := g.pool.Go(func() { + err := g.pool.dispatcher.Write(func() error { defer g.taskWaitGroup.Done() // Check if the context has been cancelled to prevent running tasks that are not needed @@ -112,7 +112,7 @@ func (g *abstractTaskGroup[T, E, O]) submit(task any) { g.futureResolver(index, &result[O]{ Err: err, }, err) - return + return err } // Invoke the task @@ -122,6 +122,8 @@ func (g *abstractTaskGroup[T, E, O]) submit(task any) { Output: output, Err: err, }, err) + + return err }) if err != nil { diff --git a/group_test.go b/group_test.go index 3a36525..8e900d2 100644 --- a/group_test.go +++ b/group_test.go @@ -323,3 +323,58 @@ func TestTaskGroupDone(t *testing.T) { assert.Equal(t, int32(5), executedCount.Load()) } + +func TestTaskGroupMetrics(t *testing.T) { + pool := NewPool(1) + + group := pool.NewGroup() + + for i := 0; i < 9; i++ { + group.Submit(func() { + time.Sleep(1 * time.Millisecond) + }) + } + + // The last task will return an error + sampleErr := errors.New("sample error") + group.SubmitErr(func() error { + time.Sleep(1 * time.Millisecond) + return sampleErr + }) + + err := group.Wait() + + time.Sleep(10 * time.Millisecond) + + assert.Equal(t, sampleErr, err) + assert.Equal(t, uint64(10), pool.SubmittedTasks()) + assert.Equal(t, uint64(9), pool.SuccessfulTasks()) + assert.Equal(t, uint64(1), pool.FailedTasks()) +} + +func TestTaskGroupMetricsWithCancelledContext(t *testing.T) { + pool := NewPool(1) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + group := pool.NewGroupContext(ctx) + + for i := 0; i < 10; i++ { + i := i + group.Submit(func() { + time.Sleep(20 * time.Millisecond) + if i == 4 { + cancel() + } + }) + } + err := group.Wait() + + time.Sleep(10 * time.Millisecond) + + assert.Equal(t, err, context.Canceled) + assert.Equal(t, uint64(10), pool.SubmittedTasks()) + assert.Equal(t, uint64(5), pool.SuccessfulTasks()) + assert.Equal(t, uint64(5), pool.FailedTasks()) +}