Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(internal): minor core client restructuring #1199

Merged
merged 1 commit into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/openai/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
RAW_RESPONSE_HEADER,
OVERRIDE_CAST_TO_HEADER,
)
from ._streaming import Stream, AsyncStream
from ._streaming import Stream, SSEDecoder, AsyncStream, SSEBytesDecoder
from ._exceptions import (
APIStatusError,
APITimeoutError,
Expand Down Expand Up @@ -431,6 +431,9 @@ def _prepare_url(self, url: str) -> URL:

return merge_url

def _make_sse_decoder(self) -> SSEDecoder | SSEBytesDecoder:
return SSEDecoder()

def _build_request(
self,
options: FinalRequestOptions,
Expand Down
34 changes: 28 additions & 6 deletions src/openai/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import inspect
from types import TracebackType
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
from typing_extensions import Self, TypeGuard, override, get_origin
from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable

import httpx

Expand All @@ -24,6 +24,8 @@ class Stream(Generic[_T]):

response: httpx.Response

_decoder: SSEDecoder | SSEBytesDecoder

def __init__(
self,
*,
Expand All @@ -34,7 +36,7 @@ def __init__(
self.response = response
self._cast_to = cast_to
self._client = client
self._decoder = SSEDecoder()
self._decoder = client._make_sse_decoder()
self._iterator = self.__stream__()

def __next__(self) -> _T:
Expand All @@ -45,7 +47,10 @@ def __iter__(self) -> Iterator[_T]:
yield item

def _iter_events(self) -> Iterator[ServerSentEvent]:
yield from self._decoder.iter(self.response.iter_lines())
if isinstance(self._decoder, SSEBytesDecoder):
yield from self._decoder.iter_bytes(self.response.iter_bytes())
else:
yield from self._decoder.iter(self.response.iter_lines())

def __stream__(self) -> Iterator[_T]:
cast_to = cast(Any, self._cast_to)
Expand Down Expand Up @@ -97,6 +102,8 @@ class AsyncStream(Generic[_T]):

response: httpx.Response

_decoder: SSEDecoder | SSEBytesDecoder

def __init__(
self,
*,
Expand All @@ -107,7 +114,7 @@ def __init__(
self.response = response
self._cast_to = cast_to
self._client = client
self._decoder = SSEDecoder()
self._decoder = client._make_sse_decoder()
self._iterator = self.__stream__()

async def __anext__(self) -> _T:
Expand All @@ -118,8 +125,12 @@ async def __aiter__(self) -> AsyncIterator[_T]:
yield item

async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
async for sse in self._decoder.aiter(self.response.aiter_lines()):
yield sse
if isinstance(self._decoder, SSEBytesDecoder):
async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
yield sse
else:
async for sse in self._decoder.aiter(self.response.aiter_lines()):
yield sse

async def __stream__(self) -> AsyncIterator[_T]:
cast_to = cast(Any, self._cast_to)
Expand Down Expand Up @@ -284,6 +295,17 @@ def decode(self, line: str) -> ServerSentEvent | None:
return None


@runtime_checkable
class SSEBytesDecoder(Protocol):
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
"""Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
...

def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
"""Given an async iterator that yields raw binary data, iterate over it & yield every event encountered"""
...


def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]:
"""TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`"""
origin = get_origin(typ) or typ
Expand Down