diff --git a/server_routers_test.go b/server_routers_test.go index 8ac2709..204e347 100644 --- a/server_routers_test.go +++ b/server_routers_test.go @@ -295,3 +295,205 @@ func TestFindProviders(t *testing.T) { require.False(t, it.Next()) }) } + +type mockIter[T any] struct { + ctx context.Context + ch chan iter.Result[T] + waitVal chan time.Time + val iter.Result[T] + done bool +} + +var _ iter.ResultIter[int] = &mockIter[int]{} + +func newMockIter[T any](ctx context.Context) *mockIter[T] { + it := &mockIter[T]{ + ctx: ctx, + ch: make(chan iter.Result[T]), + } + + return it +} + +func newMockIters[T any](ctx context.Context, count int) []*mockIter[T] { + var arr []*mockIter[T] + + for count > 0 { + arr = append(arr, newMockIter[T](ctx)) + count-- + } + + return arr +} + +func (m *mockIter[T]) Next() bool { + if m.done { + return false + } + + select { + case v, ok := <-m.ch: + if !ok { + m.done = true + } else { + m.val = v + } + case <-m.ctx.Done(): + m.done = true + } + + return !m.done +} + +func (m *mockIter[T]) Val() iter.Result[T] { + if m.waitVal != nil { + <-m.waitVal + } + + return m.val +} + +func (m *mockIter[T]) Close() error { + m.done = true + return nil +} + +func mockItersAsInterface[T any](originalSlice []*mockIter[T]) []iter.ResultIter[T] { + var newSlice []iter.ResultIter[T] + + for _, v := range originalSlice { + newSlice = append(newSlice, v) + } + + return newSlice +} + +func TestManyIter(t *testing.T) { + t.Parallel() + + t.Run("Sequence", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + its := newMockIters[int](ctx, 2) + manyIter := newManyIter(ctx, mockItersAsInterface(its)) + + go func() { + its[0].ch <- iter.Result[int]{Val: 0} + time.Sleep(time.Millisecond * 50) + + its[1].ch <- iter.Result[int]{Val: 1} + time.Sleep(time.Millisecond * 50) + + its[0].ch <- iter.Result[int]{Val: 0} + time.Sleep(time.Millisecond * 50) + + its[0].ch <- iter.Result[int]{Val: 0} + close(its[0].ch) + time.Sleep(time.Millisecond * 50) + + its[1].ch <- iter.Result[int]{Val: 1} + time.Sleep(time.Millisecond * 50) + + close(its[1].ch) + }() + + results, err := iter.ReadAllResults(manyIter) + require.NoError(t, err) + require.Equal(t, []int{0, 1, 0, 0, 1}, results) + require.False(t, manyIter.Next()) + require.NoError(t, manyIter.Close()) + }) + + t.Run("Closed Iterator", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + its := newMockIters[int](ctx, 5) + manyIter := newManyIter(ctx, mockItersAsInterface(its)) + + go func() { + close(its[0].ch) + close(its[1].ch) + close(its[2].ch) + close(its[3].ch) + + its[4].ch <- iter.Result[int]{Val: 4} + time.Sleep(time.Millisecond * 50) + + its[4].ch <- iter.Result[int]{Val: 4} + time.Sleep(time.Millisecond * 50) + + close(its[4].ch) + }() + + results, err := iter.ReadAllResults(manyIter) + require.NoError(t, err) + require.Equal(t, []int{4, 4}, results) + require.False(t, manyIter.Next()) + require.NoError(t, manyIter.Close()) + }) + + t.Run("Context Canceled", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + its := newMockIters[int](ctx, 5) + manyIter := newManyIter(ctx, mockItersAsInterface(its)) + + go func() { + its[3].ch <- iter.Result[int]{Val: 3} + time.Sleep(time.Millisecond * 50) + + its[2].ch <- iter.Result[int]{Val: 2} + time.Sleep(time.Millisecond * 50) + + cancel() + }() + + results, err := iter.ReadAllResults(manyIter) + require.NoError(t, err) + require.Equal(t, []int{3, 2}, results) + require.False(t, manyIter.Next()) + require.NoError(t, manyIter.Close()) + }) + + t.Run("Context Canceled After .Next Returns", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + its := newMockIters[int](ctx, 5) + manyIter := newManyIter(ctx, mockItersAsInterface(its)) + + go func() { + its[1].ch <- iter.Result[int]{Val: 1} + time.Sleep(time.Millisecond * 50) + + its[4].ch <- iter.Result[int]{Val: 4} + time.Sleep(time.Millisecond * 50) + + its[3].waitVal = make(chan time.Time) + its[3].ch <- iter.Result[int]{Val: 3} + time.Sleep(time.Millisecond * 50) + + cancel() + time.Sleep(time.Millisecond * 50) + + its[3].waitVal <- time.Now() + }() + + results, err := iter.ReadAllResults(manyIter) + require.NoError(t, err) + require.Equal(t, []int{1, 4}, results) + require.False(t, manyIter.Next()) + require.NoError(t, manyIter.Close()) + }) +}