diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index 951b7be26b..59feae103b 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -183,11 +183,17 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenAIChatGenerator": return default_from_dict(cls, data) @component.output_types(replies=List[ChatMessage]) - def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): + def run( + self, + messages: List[ChatMessage], + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ): """ Invoke the text generation inference based on the provided messages and generation parameters. :param messages: A list of ChatMessage instances representing the input messages. + :param streaming_callback: A callback function that is called when a new token is received from the stream. :param generation_kwargs: Additional keyword arguments for text generation. These parameters will potentially override the parameters passed in the `__init__` method. For more details on the parameters supported by the OpenAI API, refer to the @@ -200,13 +206,16 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, # update generation kwargs by merging with the generation kwargs passed to the run method generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + # check if streaming_callback is passed + streaming_callback = streaming_callback or self.streaming_callback + # adapt ChatMessage(s) to the format expected by the OpenAI API openai_formatted_messages = [message.to_openai_format() for message in messages] chat_completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create( model=self.model, messages=openai_formatted_messages, # type: ignore # openai expects list of specific message types - stream=self.streaming_callback is not None, + stream=streaming_callback is not None, **generation_kwargs, ) @@ -221,10 +230,10 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, # pylint: disable=not-an-iterable for chunk in chat_completion: - if chunk.choices and self.streaming_callback: + if chunk.choices and streaming_callback: chunk_delta: StreamingChunk = self._build_chunk(chunk) chunks.append(chunk_delta) - self.streaming_callback(chunk_delta) # invoke callback with the chunk_delta + streaming_callback(chunk_delta) # invoke callback with the chunk_delta completions = [self._connect_chunks(chunk, chunks)] # if streaming is disabled, the completion is a ChatCompletion elif isinstance(chat_completion, ChatCompletion): diff --git a/haystack/components/generators/openai.py b/haystack/components/generators/openai.py index f364ce15ae..ef21fbd226 100644 --- a/haystack/components/generators/openai.py +++ b/haystack/components/generators/openai.py @@ -170,12 +170,19 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenAIGenerator": return default_from_dict(cls, data) @component.output_types(replies=List[str], meta=List[Dict[str, Any]]) - def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): + def run( + self, + prompt: str, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ): """ Invoke the text generation inference based on the provided messages and generation parameters. :param prompt: The string prompt to use for text generation. + :param streaming_callback: + A callback function that is called when a new token is received from the stream. :param generation_kwargs: Additional keyword arguments for text generation. These parameters will potentially override the parameters passed in the `__init__` method. For more details on the parameters supported by the OpenAI API, refer to @@ -193,13 +200,16 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): # update generation kwargs by merging with the generation kwargs passed to the run method generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + # check if streaming_callback is passed + streaming_callback = streaming_callback or self.streaming_callback + # adapt ChatMessage(s) to the format expected by the OpenAI API openai_formatted_messages = [message.to_openai_format() for message in messages] completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create( model=self.model, messages=openai_formatted_messages, # type: ignore - stream=self.streaming_callback is not None, + stream=streaming_callback is not None, **generation_kwargs, ) @@ -213,10 +223,10 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): # pylint: disable=not-an-iterable for chunk in completion: - if chunk.choices and self.streaming_callback: + if chunk.choices and streaming_callback: chunk_delta: StreamingChunk = self._build_chunk(chunk) chunks.append(chunk_delta) - self.streaming_callback(chunk_delta) # invoke callback with the chunk_delta + streaming_callback(chunk_delta) # invoke callback with the chunk_delta completions = [self._connect_chunks(chunk, chunks)] elif isinstance(completion, ChatCompletion): completions = [self._build_message(completion, choice) for choice in completion.choices] diff --git a/releasenotes/notes/improve-streaming-callbacks-openai-b6c0b108f2de4142.yaml b/releasenotes/notes/improve-streaming-callbacks-openai-b6c0b108f2de4142.yaml new file mode 100644 index 0000000000..3df9559f64 --- /dev/null +++ b/releasenotes/notes/improve-streaming-callbacks-openai-b6c0b108f2de4142.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + The `streaming_callback` parameter can be passed to OpenAIGenerator and OpenAIChatGenerator during pipeline run. This prevents the need to recreate pipelines for streaming callbacks. diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index e945121755..335cc253bb 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -219,6 +219,27 @@ def streaming_callback(chunk: StreamingChunk) -> None: assert [isinstance(reply, ChatMessage) for reply in response["replies"]] assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk + def test_run_with_streaming_callback_in_run_method(self, chat_messages, mock_chat_completion_chunk): + streaming_callback_called = False + + def streaming_callback(chunk: StreamingChunk) -> None: + nonlocal streaming_callback_called + streaming_callback_called = True + + component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) + response = component.run(chat_messages, streaming_callback=streaming_callback) + + # check we called the streaming callback + assert streaming_callback_called + + # check that the component still returns the correct response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk + def test_check_abnormal_completions(self, caplog): caplog.set_level(logging.INFO) component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) diff --git a/test/components/generators/test_openai.py b/test/components/generators/test_openai.py index b34a5399b4..d3a1fbb860 100644 --- a/test/components/generators/test_openai.py +++ b/test/components/generators/test_openai.py @@ -196,6 +196,27 @@ def streaming_callback(chunk: StreamingChunk) -> None: assert len(response["replies"]) == 1 assert "Hello" in response["replies"][0] # see mock_chat_completion_chunk + def test_run_with_streaming_callback_in_run_method(self, mock_chat_completion_chunk): + streaming_callback_called = False + + def streaming_callback(chunk: StreamingChunk) -> None: + nonlocal streaming_callback_called + streaming_callback_called = True + + # pass streaming_callback to run() + component = OpenAIGenerator(api_key=Secret.from_token("test-api-key")) + response = component.run("Come on, stream!", streaming_callback=streaming_callback) + + # check we called the streaming callback + assert streaming_callback_called + + # check that the component still returns the correct response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert "Hello" in response["replies"][0] # see mock_chat_completion_chunk + def test_run_with_params(self, mock_chat_completion): component = OpenAIGenerator( api_key=Secret.from_token("test-api-key"), generation_kwargs={"max_tokens": 10, "temperature": 0.5}