Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(client): correct logic for line decoding in streaming #1293

Merged
merged 1 commit into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 47 additions & 26 deletions src/openai/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Stream(Generic[_T]):

response: httpx.Response

_decoder: SSEDecoder | SSEBytesDecoder
_decoder: SSEBytesDecoder

def __init__(
self,
Expand All @@ -47,10 +47,7 @@ def __iter__(self) -> Iterator[_T]:
yield item

def _iter_events(self) -> Iterator[ServerSentEvent]:
if isinstance(self._decoder, SSEBytesDecoder):
yield from self._decoder.iter_bytes(self.response.iter_bytes())
else:
yield from self._decoder.iter(self.response.iter_lines())
yield from self._decoder.iter_bytes(self.response.iter_bytes())

def __stream__(self) -> Iterator[_T]:
cast_to = cast(Any, self._cast_to)
Expand Down Expand Up @@ -151,12 +148,8 @@ async def __aiter__(self) -> AsyncIterator[_T]:
yield item

async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
if isinstance(self._decoder, SSEBytesDecoder):
async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
yield sse
else:
async for sse in self._decoder.aiter(self.response.aiter_lines()):
yield sse
async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
yield sse

async def __stream__(self) -> AsyncIterator[_T]:
cast_to = cast(Any, self._cast_to)
Expand Down Expand Up @@ -282,21 +275,49 @@ def __init__(self) -> None:
self._last_event_id = None
self._retry = None

def iter(self, iterator: Iterator[str]) -> Iterator[ServerSentEvent]:
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
for line in iterator:
line = line.rstrip("\n")
sse = self.decode(line)
if sse is not None:
yield sse

async def aiter(self, iterator: AsyncIterator[str]) -> AsyncIterator[ServerSentEvent]:
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
async for line in iterator:
line = line.rstrip("\n")
sse = self.decode(line)
if sse is not None:
yield sse
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
"""Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
for chunk in self._iter_chunks(iterator):
# Split before decoding so splitlines() only uses \r and \n
for raw_line in chunk.splitlines():
line = raw_line.decode("utf-8")
sse = self.decode(line)
if sse:
yield sse

def _iter_chunks(self, iterator: Iterator[bytes]) -> Iterator[bytes]:
"""Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
data = b""
for chunk in iterator:
for line in chunk.splitlines(keepends=True):
data += line
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
yield data
data = b""
if data:
yield data

async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
"""Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
async for chunk in self._aiter_chunks(iterator):
# Split before decoding so splitlines() only uses \r and \n
for raw_line in chunk.splitlines():
line = raw_line.decode("utf-8")
sse = self.decode(line)
if sse:
yield sse

async def _aiter_chunks(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[bytes]:
"""Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
data = b""
async for chunk in iterator:
for line in chunk.splitlines(keepends=True):
data += line
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
yield data
data = b""
if data:
yield data

def decode(self, line: str) -> ServerSentEvent | None:
# See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501
Expand Down
Loading