Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(client): make ChatCompletionStreamState public #1898

Merged
merged 3 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/openai/lib/streaming/chat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ._completions import (
ChatCompletionStream as ChatCompletionStream,
AsyncChatCompletionStream as AsyncChatCompletionStream,
ChatCompletionStreamState as ChatCompletionStreamState,
ChatCompletionStreamManager as ChatCompletionStreamManager,
AsyncChatCompletionStreamManager as AsyncChatCompletionStreamManager,
)
33 changes: 29 additions & 4 deletions src/openai/lib/streaming/chat/_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,11 +287,31 @@ async def __aexit__(


class ChatCompletionStreamState(Generic[ResponseFormatT]):
"""Helper class for manually accumulating `ChatCompletionChunk`s into a final `ChatCompletion` object.

This is useful in cases where you can't always use the `.stream()` method, e.g.

```py
from openai.lib.streaming.chat import ChatCompletionStreamState

state = ChatCompletionStreamState()

stream = client.chat.completions.create(..., stream=True)
for chunk in response:
state.handle_chunk(chunk)

# can also access the accumulated `ChatCompletion` mid-stream
state.current_completion_snapshot

print(state.get_final_completion())
```
"""

def __init__(
self,
*,
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
input_tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven = NOT_GIVEN,
) -> None:
self.__current_completion_snapshot: ParsedChatCompletionSnapshot | None = None
self.__choice_event_states: list[ChoiceEventState] = []
Expand All @@ -301,6 +321,11 @@ def __init__(
self._rich_response_format: type | NotGiven = response_format if inspect.isclass(response_format) else NOT_GIVEN

def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]:
"""Parse the final completion object.

Note this does not provide any guarantees that the stream has actually finished, you must
only call this method when the stream is finished.
"""
return parse_chat_completion(
chat_completion=self.current_completion_snapshot,
response_format=self._rich_response_format,
Expand All @@ -312,8 +337,8 @@ def current_completion_snapshot(self) -> ParsedChatCompletionSnapshot:
assert self.__current_completion_snapshot is not None
return self.__current_completion_snapshot

def handle_chunk(self, chunk: ChatCompletionChunk) -> list[ChatCompletionStreamEvent[ResponseFormatT]]:
"""Accumulate a new chunk into the snapshot and returns a list of events to yield."""
def handle_chunk(self, chunk: ChatCompletionChunk) -> Iterable[ChatCompletionStreamEvent[ResponseFormatT]]:
"""Accumulate a new chunk into the snapshot and returns an iterable of events to yield."""
self.__current_completion_snapshot = self._accumulate_chunk(chunk)

return self._build_events(
Expand Down
94 changes: 93 additions & 1 deletion tests/lib/chat/test_completions_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@

import openai
from openai import OpenAI, AsyncOpenAI
from openai._utils import assert_signatures_in_sync
from openai._utils import consume_sync_iterator, assert_signatures_in_sync
from openai._compat import model_copy
from openai.types.chat import ChatCompletionChunk
from openai.lib.streaming.chat import (
ContentDoneEvent,
ChatCompletionStream,
ChatCompletionStreamEvent,
ChatCompletionStreamState,
ChatCompletionStreamManager,
ParsedChatCompletionSnapshot,
)
Expand Down Expand Up @@ -997,6 +999,55 @@ def test_allows_non_strict_tools_but_no_parsing(
)


@pytest.mark.respx(base_url=base_url)
def test_chat_completion_state_helper(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
state = ChatCompletionStreamState()

def streamer(client: OpenAI) -> Iterator[ChatCompletionChunk]:
stream = client.chat.completions.create(
model="gpt-4o-2024-08-06",
messages=[
{
"role": "user",
"content": "What's the weather like in SF?",
},
],
stream=True,
)
for chunk in stream:
state.handle_chunk(chunk)
yield chunk

_make_raw_stream_snapshot_request(
streamer,
content_snapshot=snapshot(external("e2aad469b71d*.bin")),
mock_client=client,
respx_mock=respx_mock,
)

assert print_obj(state.get_final_completion().choices, monkeypatch) == snapshot(
"""\
[
ParsedChoice[NoneType](
finish_reason='stop',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[NoneType](
audio=None,
content="I'm unable to provide real-time weather updates. To get the current weather in San Francisco, I
recommend checking a reliable weather website or a weather app.",
function_call=None,
parsed=None,
refusal=None,
role='assistant',
tool_calls=[]
)
)
]
"""
)


@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
def test_stream_method_in_sync(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
checking_client: OpenAI | AsyncOpenAI = client if sync else async_client
Expand Down Expand Up @@ -1075,3 +1126,44 @@ def _on_response(response: httpx.Response) -> None:
client.close()

return listener


def _make_raw_stream_snapshot_request(
func: Callable[[OpenAI], Iterator[ChatCompletionChunk]],
*,
content_snapshot: Any,
respx_mock: MockRouter,
mock_client: OpenAI,
) -> None:
live = os.environ.get("OPENAI_LIVE") == "1"
if live:

def _on_response(response: httpx.Response) -> None:
# update the content snapshot
assert outsource(response.read()) == content_snapshot

respx_mock.stop()

client = OpenAI(
http_client=httpx.Client(
event_hooks={
"response": [_on_response],
}
)
)
else:
respx_mock.post("/chat/completions").mock(
return_value=httpx.Response(
200,
content=content_snapshot._old_value._load_value(),
headers={"content-type": "text/event-stream"},
)
)

client = mock_client

stream = func(client)
consume_sync_iterator(stream)

if live:
client.close()