Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: assign streaming_callback to OpenAIGenerator in run() method #8054

Merged
merged 11 commits into from
Jul 24, 2024
5 changes: 5 additions & 0 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
# update generation kwargs by merging with the generation kwargs passed to the run method
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}

streaming_callback = generation_kwargs.pop("streaming_callback", None)
# check if streaming_callback is passed to run()
if streaming_callback:
self.streaming_callback = streaming_callback

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. This is one possible and simple fix in current run() method.
  2. Other option is to separate the run() and invoke() as implemented for OpenAIGenerator. I proposed this for code modularity but if its unnecessary, we can just add the above checks for both generators.

In any case, we can choose the same approach for both, and I'll update the PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a great idea, this is quite dangerous.

If you call run passing a streaming_callback all subsequents calls will reuse that same callback even if not explicitly set. That's extremely confusing in my opinion.

Best choice would be something like this:

streaming_callback = streaming_callback or self.streaming_callback

This also mean adding streaming_callback as another input with a None default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point!

# adapt ChatMessage(s) to the format expected by the OpenAI API
openai_formatted_messages = [message.to_openai_format() for message in messages]

Expand Down
45 changes: 31 additions & 14 deletions haystack/components/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,30 +169,29 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenAIGenerator":
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
return default_from_dict(cls, data)

@component.output_types(replies=List[str], meta=List[Dict[str, Any]])
silvanocerza marked this conversation as resolved.
Show resolved Hide resolved
def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
def invoke(self, **kwargs):
"""
Invoke the text generation inference based on the provided messages and generation parameters.
Invokes the model with the given prompt.

:param prompt:
The string prompt to use for text generation.
:param generation_kwargs:
Additional keyword arguments for text generation. These parameters will potentially override the parameters
passed in the `__init__` method. For more details on the parameters supported by the OpenAI API, refer to
the OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat/create).
:returns:
A list of strings containing the generated responses and a list of dictionaries containing the metadata
for each response.
:param kwargs: Additional keyword arguments passed to the generator.
:returns: A list of responses.
"""
kwargs = kwargs.copy()
prompt: str = kwargs.pop("prompt", None)
streaming_callback = kwargs.pop("streaming_callback", None)

message = ChatMessage.from_user(prompt)
if self.system_prompt:
messages = [ChatMessage.from_system(self.system_prompt), message]
else:
messages = [message]

# update generation kwargs by merging with the generation kwargs passed to the run method
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
# check if streaming_callback is passed to run()
if streaming_callback:
self.streaming_callback = streaming_callback

# update generation kwargs by merging with the generation kwargs passed to the run method
generation_kwargs = {**self.generation_kwargs, **kwargs}
# adapt ChatMessage(s) to the format expected by the OpenAI API
openai_formatted_messages = [message.to_openai_format() for message in messages]

Expand Down Expand Up @@ -225,6 +224,24 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
for response in completions:
self._check_finish_reason(response)

return completions

@component.output_types(replies=List[str], meta=List[Dict[str, Any]])
def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
"""
Generate a list of responses for the given prompt.

:param prompt:
The string prompt to use for text generation.
:param generation_kwargs:
Additional keyword arguments for text generation. These parameters will potentially override the parameters
passed in the `__init__` method. For more details on the parameters supported by the OpenAI API, refer to
the OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat/create).
:returns: A dictionary with the following keys:
- `replies`: A list of generated responses.
- `meta`: A list of dictionaries containing the metadata.
"""
completions = self.invoke(prompt=prompt, **(generation_kwargs or {}))
return {
"replies": [message.content for message in completions],
"meta": [message.meta for message in completions],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
Allows passing `streaming_callback` parameter to OpenAIGenerator and OpenAIChatGenerator during pipeline run. This prevents the need to recreate pipelines for streaming callbacks.