Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: assign streaming_callback to OpenAIGenerator in run() method #8054

Merged
merged 11 commits into from
Jul 24, 2024
17 changes: 13 additions & 4 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,17 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenAIChatGenerator":
return default_from_dict(cls, data)

@component.output_types(replies=List[ChatMessage])
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None):
def run(
self,
messages: List[ChatMessage],
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Invoke the text generation inference based on the provided messages and generation parameters.

:param messages: A list of ChatMessage instances representing the input messages.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
:param generation_kwargs: Additional keyword arguments for text generation. These parameters will
potentially override the parameters passed in the `__init__` method.
For more details on the parameters supported by the OpenAI API, refer to the
Expand All @@ -200,13 +206,16 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
# update generation kwargs by merging with the generation kwargs passed to the run method
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}

# check if streaming_callback is passed
streaming_callback = streaming_callback or self.streaming_callback

# adapt ChatMessage(s) to the format expected by the OpenAI API
openai_formatted_messages = [message.to_openai_format() for message in messages]

chat_completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create(
model=self.model,
messages=openai_formatted_messages, # type: ignore # openai expects list of specific message types
stream=self.streaming_callback is not None,
stream=streaming_callback is not None,
**generation_kwargs,
)

Expand All @@ -221,10 +230,10 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,

# pylint: disable=not-an-iterable
for chunk in chat_completion:
if chunk.choices and self.streaming_callback:
if chunk.choices and streaming_callback:
chunk_delta: StreamingChunk = self._build_chunk(chunk)
chunks.append(chunk_delta)
self.streaming_callback(chunk_delta) # invoke callback with the chunk_delta
streaming_callback(chunk_delta) # invoke callback with the chunk_delta
completions = [self._connect_chunks(chunk, chunks)]
# if streaming is disabled, the completion is a ChatCompletion
elif isinstance(chat_completion, ChatCompletion):
Expand Down
18 changes: 14 additions & 4 deletions haystack/components/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,19 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenAIGenerator":
return default_from_dict(cls, data)

@component.output_types(replies=List[str], meta=List[Dict[str, Any]])
silvanocerza marked this conversation as resolved.
Show resolved Hide resolved
def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
def run(
self,
prompt: str,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Invoke the text generation inference based on the provided messages and generation parameters.

:param prompt:
The string prompt to use for text generation.
:param streaming_callback:
A callback function that is called when a new token is received from the stream.
:param generation_kwargs:
Additional keyword arguments for text generation. These parameters will potentially override the parameters
passed in the `__init__` method. For more details on the parameters supported by the OpenAI API, refer to
Expand All @@ -193,13 +200,16 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
# update generation kwargs by merging with the generation kwargs passed to the run method
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}

# check if streaming_callback is passed
streaming_callback = streaming_callback or self.streaming_callback

# adapt ChatMessage(s) to the format expected by the OpenAI API
openai_formatted_messages = [message.to_openai_format() for message in messages]

completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create(
model=self.model,
messages=openai_formatted_messages, # type: ignore
stream=self.streaming_callback is not None,
stream=streaming_callback is not None,
**generation_kwargs,
)

Expand All @@ -213,10 +223,10 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):

# pylint: disable=not-an-iterable
for chunk in completion:
if chunk.choices and self.streaming_callback:
if chunk.choices and streaming_callback:
chunk_delta: StreamingChunk = self._build_chunk(chunk)
chunks.append(chunk_delta)
self.streaming_callback(chunk_delta) # invoke callback with the chunk_delta
streaming_callback(chunk_delta) # invoke callback with the chunk_delta
completions = [self._connect_chunks(chunk, chunks)]
elif isinstance(completion, ChatCompletion):
completions = [self._build_message(completion, choice) for choice in completion.choices]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
The `streaming_callback` parameter can be passed to OpenAIGenerator and OpenAIChatGenerator during pipeline run. This prevents the need to recreate pipelines for streaming callbacks.
21 changes: 21 additions & 0 deletions test/components/generators/chat/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,27 @@ def streaming_callback(chunk: StreamingChunk) -> None:
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk

def test_run_with_streaming_callback_in_run_method(self, chat_messages, mock_chat_completion_chunk):
streaming_callback_called = False

def streaming_callback(chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True

component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
response = component.run(chat_messages, streaming_callback=streaming_callback)

# check we called the streaming callback
assert streaming_callback_called

# check that the component still returns the correct response
assert isinstance(response, dict)
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) == 1
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk

def test_check_abnormal_completions(self, caplog):
caplog.set_level(logging.INFO)
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
Expand Down
21 changes: 21 additions & 0 deletions test/components/generators/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,27 @@ def streaming_callback(chunk: StreamingChunk) -> None:
assert len(response["replies"]) == 1
assert "Hello" in response["replies"][0] # see mock_chat_completion_chunk

def test_run_with_streaming_callback_in_run_method(self, mock_chat_completion_chunk):
streaming_callback_called = False

def streaming_callback(chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True

# pass streaming_callback to run()
component = OpenAIGenerator(api_key=Secret.from_token("test-api-key"))
response = component.run("Come on, stream!", streaming_callback=streaming_callback)

# check we called the streaming callback
assert streaming_callback_called

# check that the component still returns the correct response
assert isinstance(response, dict)
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) == 1
assert "Hello" in response["replies"][0] # see mock_chat_completion_chunk

def test_run_with_params(self, mock_chat_completion):
component = OpenAIGenerator(
api_key=Secret.from_token("test-api-key"), generation_kwargs={"max_tokens": 10, "temperature": 0.5}
Expand Down