Skip to content

Commit

Permalink
Assign streaming_callback to OpenAIGenerator and OpenAIChatGenerator …
Browse files Browse the repository at this point in the history
…in run() method (#8054)

* Add optional parameter for streaming_callback in run() method
  • Loading branch information
Amnah199 authored Jul 24, 2024
1 parent baed478 commit b374c52
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 8 deletions.
17 changes: 13 additions & 4 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,17 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenAIChatGenerator":
return default_from_dict(cls, data)

@component.output_types(replies=List[ChatMessage])
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None):
def run(
self,
messages: List[ChatMessage],
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Invoke the text generation inference based on the provided messages and generation parameters.
:param messages: A list of ChatMessage instances representing the input messages.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
:param generation_kwargs: Additional keyword arguments for text generation. These parameters will
potentially override the parameters passed in the `__init__` method.
For more details on the parameters supported by the OpenAI API, refer to the
Expand All @@ -200,13 +206,16 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
# update generation kwargs by merging with the generation kwargs passed to the run method
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}

# check if streaming_callback is passed
streaming_callback = streaming_callback or self.streaming_callback

# adapt ChatMessage(s) to the format expected by the OpenAI API
openai_formatted_messages = [message.to_openai_format() for message in messages]

chat_completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create(
model=self.model,
messages=openai_formatted_messages, # type: ignore # openai expects list of specific message types
stream=self.streaming_callback is not None,
stream=streaming_callback is not None,
**generation_kwargs,
)

Expand All @@ -221,10 +230,10 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,

# pylint: disable=not-an-iterable
for chunk in chat_completion:
if chunk.choices and self.streaming_callback:
if chunk.choices and streaming_callback:
chunk_delta: StreamingChunk = self._build_chunk(chunk)
chunks.append(chunk_delta)
self.streaming_callback(chunk_delta) # invoke callback with the chunk_delta
streaming_callback(chunk_delta) # invoke callback with the chunk_delta
completions = [self._connect_chunks(chunk, chunks)]
# if streaming is disabled, the completion is a ChatCompletion
elif isinstance(chat_completion, ChatCompletion):
Expand Down
18 changes: 14 additions & 4 deletions haystack/components/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,19 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenAIGenerator":
return default_from_dict(cls, data)

@component.output_types(replies=List[str], meta=List[Dict[str, Any]])
def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
def run(
self,
prompt: str,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Invoke the text generation inference based on the provided messages and generation parameters.
:param prompt:
The string prompt to use for text generation.
:param streaming_callback:
A callback function that is called when a new token is received from the stream.
:param generation_kwargs:
Additional keyword arguments for text generation. These parameters will potentially override the parameters
passed in the `__init__` method. For more details on the parameters supported by the OpenAI API, refer to
Expand All @@ -193,13 +200,16 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
# update generation kwargs by merging with the generation kwargs passed to the run method
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}

# check if streaming_callback is passed
streaming_callback = streaming_callback or self.streaming_callback

# adapt ChatMessage(s) to the format expected by the OpenAI API
openai_formatted_messages = [message.to_openai_format() for message in messages]

completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create(
model=self.model,
messages=openai_formatted_messages, # type: ignore
stream=self.streaming_callback is not None,
stream=streaming_callback is not None,
**generation_kwargs,
)

Expand All @@ -213,10 +223,10 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):

# pylint: disable=not-an-iterable
for chunk in completion:
if chunk.choices and self.streaming_callback:
if chunk.choices and streaming_callback:
chunk_delta: StreamingChunk = self._build_chunk(chunk)
chunks.append(chunk_delta)
self.streaming_callback(chunk_delta) # invoke callback with the chunk_delta
streaming_callback(chunk_delta) # invoke callback with the chunk_delta
completions = [self._connect_chunks(chunk, chunks)]
elif isinstance(completion, ChatCompletion):
completions = [self._build_message(completion, choice) for choice in completion.choices]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
The `streaming_callback` parameter can be passed to OpenAIGenerator and OpenAIChatGenerator during pipeline run. This prevents the need to recreate pipelines for streaming callbacks.
21 changes: 21 additions & 0 deletions test/components/generators/chat/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,27 @@ def streaming_callback(chunk: StreamingChunk) -> None:
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk

def test_run_with_streaming_callback_in_run_method(self, chat_messages, mock_chat_completion_chunk):
streaming_callback_called = False

def streaming_callback(chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True

component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
response = component.run(chat_messages, streaming_callback=streaming_callback)

# check we called the streaming callback
assert streaming_callback_called

# check that the component still returns the correct response
assert isinstance(response, dict)
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) == 1
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk

def test_check_abnormal_completions(self, caplog):
caplog.set_level(logging.INFO)
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
Expand Down
21 changes: 21 additions & 0 deletions test/components/generators/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,27 @@ def streaming_callback(chunk: StreamingChunk) -> None:
assert len(response["replies"]) == 1
assert "Hello" in response["replies"][0] # see mock_chat_completion_chunk

def test_run_with_streaming_callback_in_run_method(self, mock_chat_completion_chunk):
streaming_callback_called = False

def streaming_callback(chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True

# pass streaming_callback to run()
component = OpenAIGenerator(api_key=Secret.from_token("test-api-key"))
response = component.run("Come on, stream!", streaming_callback=streaming_callback)

# check we called the streaming callback
assert streaming_callback_called

# check that the component still returns the correct response
assert isinstance(response, dict)
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) == 1
assert "Hello" in response["replies"][0] # see mock_chat_completion_chunk

def test_run_with_params(self, mock_chat_completion):
component = OpenAIGenerator(
api_key=Secret.from_token("test-api-key"), generation_kwargs={"max_tokens": 10, "temperature": 0.5}
Expand Down

0 comments on commit b374c52

Please sign in to comment.