diff --git a/src/openai/_base_client.py b/src/openai/_base_client.py index 73bd2411fd..dda280f6aa 100644 --- a/src/openai/_base_client.py +++ b/src/openai/_base_client.py @@ -79,7 +79,7 @@ RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER, ) -from ._streaming import Stream, AsyncStream +from ._streaming import Stream, SSEDecoder, AsyncStream, SSEBytesDecoder from ._exceptions import ( APIStatusError, APITimeoutError, @@ -431,6 +431,9 @@ def _prepare_url(self, url: str) -> URL: return merge_url + def _make_sse_decoder(self) -> SSEDecoder | SSEBytesDecoder: + return SSEDecoder() + def _build_request( self, options: FinalRequestOptions, diff --git a/src/openai/_streaming.py b/src/openai/_streaming.py index 74878fd0a0..2bc8d6a14d 100644 --- a/src/openai/_streaming.py +++ b/src/openai/_streaming.py @@ -5,7 +5,7 @@ import inspect from types import TracebackType from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast -from typing_extensions import Self, TypeGuard, override, get_origin +from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable import httpx @@ -24,6 +24,8 @@ class Stream(Generic[_T]): response: httpx.Response + _decoder: SSEDecoder | SSEBytesDecoder + def __init__( self, *, @@ -34,7 +36,7 @@ def __init__( self.response = response self._cast_to = cast_to self._client = client - self._decoder = SSEDecoder() + self._decoder = client._make_sse_decoder() self._iterator = self.__stream__() def __next__(self) -> _T: @@ -45,7 +47,10 @@ def __iter__(self) -> Iterator[_T]: yield item def _iter_events(self) -> Iterator[ServerSentEvent]: - yield from self._decoder.iter(self.response.iter_lines()) + if isinstance(self._decoder, SSEBytesDecoder): + yield from self._decoder.iter_bytes(self.response.iter_bytes()) + else: + yield from self._decoder.iter(self.response.iter_lines()) def __stream__(self) -> Iterator[_T]: cast_to = cast(Any, self._cast_to) @@ -97,6 +102,8 @@ class AsyncStream(Generic[_T]): response: httpx.Response + _decoder: SSEDecoder | SSEBytesDecoder + def __init__( self, *, @@ -107,7 +114,7 @@ def __init__( self.response = response self._cast_to = cast_to self._client = client - self._decoder = SSEDecoder() + self._decoder = client._make_sse_decoder() self._iterator = self.__stream__() async def __anext__(self) -> _T: @@ -118,8 +125,12 @@ async def __aiter__(self) -> AsyncIterator[_T]: yield item async def _iter_events(self) -> AsyncIterator[ServerSentEvent]: - async for sse in self._decoder.aiter(self.response.aiter_lines()): - yield sse + if isinstance(self._decoder, SSEBytesDecoder): + async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()): + yield sse + else: + async for sse in self._decoder.aiter(self.response.aiter_lines()): + yield sse async def __stream__(self) -> AsyncIterator[_T]: cast_to = cast(Any, self._cast_to) @@ -284,6 +295,17 @@ def decode(self, line: str) -> ServerSentEvent | None: return None +@runtime_checkable +class SSEBytesDecoder(Protocol): + def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]: + """Given an iterator that yields raw binary data, iterate over it & yield every event encountered""" + ... + + def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]: + """Given an async iterator that yields raw binary data, iterate over it & yield every event encountered""" + ... + + def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]: """TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`""" origin = get_origin(typ) or typ