diff --git a/pkg/providers/openai/chat_stream_test.go b/pkg/providers/openai/chat_stream_test.go index 31373579..e331cb1d 100644 --- a/pkg/providers/openai/chat_stream_test.go +++ b/pkg/providers/openai/chat_stream_test.go @@ -85,6 +85,10 @@ func TestOpenAIClient_ChatStreamRequest(t *testing.T) { for { chunk, err := stream.Recv() + if err == io.EOF { + return + } + require.NoError(t, err) require.NotNil(t, chunk) } diff --git a/pkg/providers/testing/lang.go b/pkg/providers/testing/lang.go index d14c439e..982bad8e 100644 --- a/pkg/providers/testing/lang.go +++ b/pkg/providers/testing/lang.go @@ -60,7 +60,7 @@ func (m *RespStreamMock) Open() error { } func (m *RespStreamMock) Recv() (*schemas.ChatStreamChunk, error) { - if m.idx < len(m.Chunks) { + if m.idx >= len(m.Chunks) { return nil, io.EOF } @@ -124,7 +124,7 @@ func (c *ProviderMock) Chat(_ context.Context, _ *schemas.ChatRequest) (*schemas } func (c *ProviderMock) ChatStream(_ context.Context, _ *schemas.ChatRequest) (clients.ChatStream, error) { - if c.chatStreams == nil || c.idx < len(*c.chatStreams) { + if c.chatStreams == nil || c.idx >= len(*c.chatStreams) { return nil, clients.ErrProviderUnavailable } diff --git a/pkg/routers/router_test.go b/pkg/routers/router_test.go index e0b28c90..92e0ab3e 100644 --- a/pkg/routers/router_test.go +++ b/pkg/routers/router_test.go @@ -260,14 +260,28 @@ func TestLangRouter_ChatStream(t *testing.T) { langModels := []*providers.LanguageModel{ providers.NewLangModel( "first", - ptesting.NewProviderMock([]ptesting.RespMock{{Msg: "1"}, {Msg: "2"}}), + ptesting.NewStreamProviderMock([]ptesting.RespStreamMock{ + ptesting.NewRespStreamMock([]ptesting.RespMock{ + {Msg: "Bill"}, + {Msg: "Gates"}, + {Msg: "entered"}, + {Msg: "the"}, + {Msg: "bar"}, + }), + }), budget, *latConfig, 1, ), providers.NewLangModel( "second", - ptesting.NewProviderMock([]ptesting.RespMock{{Msg: "1"}}), + ptesting.NewStreamProviderMock([]ptesting.RespStreamMock{ + ptesting.NewRespStreamMock([]ptesting.RespMock{ + {Msg: "Knock"}, + {Msg: "Knock"}, + {Msg: "joke"}, + }), + }), budget, *latConfig, 1, @@ -298,13 +312,19 @@ func TestLangRouter_ChatStream(t *testing.T) { go router.ChatStream(ctx, req, respC) - select { - case chunkResult := <-respC: - require.Nil(t, chunkResult.Error()) - require.NotNil(t, chunkResult.Chunk().ModelResponse.Message.Content) - case <-time.Tick(5 * time.Second): - t.Error("Timeout while waiting for stream chat chunk") + chunks := make([]string, 0, 5) + + for range 5 { + select { //nolint:gosimple + case chunk := <-respC: + require.Nil(t, chunk.Error()) + require.NotNil(t, chunk.Chunk().ModelResponse.Message.Content) + + chunks = append(chunks, chunk.Chunk().ModelResponse.Message.Content) + } } + + require.Equal(t, []string{"Bill", "Gates", "entered", "the", "bar"}, chunks) } func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { @@ -355,18 +375,16 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { respC := make(chan *schemas.ChatStreamResult) defer close(respC) - actualErrReasons := make([]string, 0, 3) - go router.ChatStream(context.Background(), schemas.NewChatFromStr("tell me a dad joke"), respC) + errs := make([]string, 0, 3) + for range 3 { result := <-respC require.Nil(t, result.Chunk()) - actualErrReasons = append(actualErrReasons, result.Error().Reason) + errs = append(errs, result.Error().Reason) } - // TODO: We should not send the error message if the model was unavailable at the very begging, - // so we have not started any streaming yet using it - require.Equal(t, []string{"modelUnavailable", "modelUnavailable", "allModelsUnavailable"}, actualErrReasons) + require.Equal(t, []string{"modelUnavailable", "modelUnavailable", "allModelsUnavailable"}, errs) }