diff --git a/filebeat/input/filestream/internal/task/group_test.go b/filebeat/input/filestream/internal/task/group_test.go index db50ef3ccabe..6ba0ac2cf1db 100644 --- a/filebeat/input/filestream/internal/task/group_test.go +++ b/filebeat/input/filestream/internal/task/group_test.go @@ -36,15 +36,21 @@ type noopLogger struct{} func (n noopLogger) Errorf(string, ...interface{}) {} -type testLogger strings.Builder +type testLogger struct { + mu sync.Mutex + b strings.Builder +} func (tl *testLogger) Errorf(format string, args ...interface{}) { - sb := (*strings.Builder)(tl) - sb.WriteString(fmt.Sprintf(format, args...)) - sb.WriteString("\n") + tl.mu.Lock() + defer tl.mu.Unlock() + tl.b.WriteString(fmt.Sprintf(format, args...)) + tl.b.WriteString("\n") } func (tl *testLogger) String() string { - return (*strings.Builder)(tl).String() + tl.mu.Lock() + defer tl.mu.Unlock() + return tl.b.String() } func TestNewGroup(t *testing.T) { @@ -67,7 +73,6 @@ func TestNewGroup(t *testing.T) { } func TestGroup_Go(t *testing.T) { - t.Skip("Flaky tests: https://github.com/elastic/beats/issues/41218") t.Run("don't run more than limit goroutines", func(t *testing.T) { done := make(chan struct{}) defer close(done) @@ -227,14 +232,12 @@ func TestGroup_Go(t *testing.T) { t.Run("all workloads return an error", func(t *testing.T) { logger := &testLogger{} - runCunt := atomic.Uint64{} - wg := sync.WaitGroup{} + var count atomic.Uint64 wantErr := errors.New("a error") workload := func(i int) func(context.Context) error { return func(_ context.Context) error { - defer runCunt.Add(1) - defer wg.Done() + defer count.Add(1) return fmt.Errorf("[%d]: %w", i, wantErr) } } @@ -242,23 +245,24 @@ func TestGroup_Go(t *testing.T) { want := uint64(2) g := NewGroup(want, time.Second, logger, "errorPrefix") - wg.Add(1) err := g.Go(workload(1)) require.NoError(t, err) - wg.Wait() - wg.Add(1) err = g.Go(workload(2)) require.NoError(t, err) - wg.Wait() - err = g.Stop() + assert.Eventually(t, func() bool { + return count.Load() == want && logger.String() != "" + }, 100*time.Millisecond, time.Millisecond) + err = g.Stop() require.NoError(t, err) + logs := logger.String() assert.Contains(t, logs, wantErr.Error()) assert.Contains(t, logs, "[2]") assert.Contains(t, logs, "[1]") + }) t.Run("some workloads return an error", func(t *testing.T) { @@ -268,17 +272,26 @@ func TestGroup_Go(t *testing.T) { g := NewGroup(want, time.Second, logger, "") - err := g.Go(func(_ context.Context) error { return nil }) + var count atomic.Uint64 + err := g.Go(func(_ context.Context) error { + count.Add(1) + return nil + }) require.NoError(t, err) - err = g.Go(func(_ context.Context) error { return wantErr }) + err = g.Go(func(_ context.Context) error { + count.Add(1) + return wantErr + }) require.NoError(t, err) - time.Sleep(time.Millisecond) + assert.Eventually(t, func() bool { + return count.Load() == want && logger.String() != "" + }, 100*time.Millisecond, time.Millisecond, "not all workloads finished") - err = g.Stop() + assert.Contains(t, logger.String(), wantErr.Error()) + err = g.Stop() assert.NoError(t, err) - assert.Contains(t, logger.String(), wantErr.Error()) }) t.Run("workload returns no error", func(t *testing.T) {