Skip to content

Commit

Permalink
Fix frankie567#40: handle large message buffering (frankie567#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankie567 authored and T-256 committed Jun 5, 2024
1 parent dc723f8 commit b6bdb36
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 0 deletions.
40 changes: 40 additions & 0 deletions httpx_ws/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ def _background_receive(self, max_bytes: int) -> None:
Args:
max_bytes: The maximum chunk size to read at each iteration.
"""
partial_message_buffer: typing.Union[str, bytes, None] = None
try:
while not self._should_close.is_set():
data = self._wait_until_closed(self.stream.read, max_bytes)
Expand All @@ -488,6 +489,25 @@ def _background_receive(self, max_bytes: int) -> None:
continue
if isinstance(event, wsproto.events.CloseConnection):
self._should_close.set()
if isinstance(event, wsproto.events.Message):
# Unfinished message: bufferize
if not event.message_finished:
if partial_message_buffer is None:
partial_message_buffer = event.data
else:
partial_message_buffer += event.data
# Finished message but no buffer: just emit the event
elif partial_message_buffer is None:
self._events.put(event)
# Finished message with buffer: emit the full event
else:
event_type = type(event)
full_message_event = event_type(
partial_message_buffer + event.data
)
partial_message_buffer = None
self._events.put(full_message_event)
continue
self._events.put(event)
except (httpcore.ReadError, httpcore.WriteError):
self.close(CloseReason.INTERNAL_ERROR, "Stream error")
Expand Down Expand Up @@ -925,6 +945,7 @@ async def _background_receive(self, max_bytes: int) -> None:
Args:
max_bytes: The maximum chunk size to read at each iteration.
"""
partial_message_buffer: typing.Union[str, bytes, None] = None
try:
while not self._should_close.is_set():
data = await self._wait_until_closed(
Expand All @@ -941,6 +962,25 @@ async def _background_receive(self, max_bytes: int) -> None:
continue
if isinstance(event, wsproto.events.CloseConnection):
self._should_close.set()
if isinstance(event, wsproto.events.Message):
# Unfinished message: bufferize
if not event.message_finished:
if partial_message_buffer is None:
partial_message_buffer = event.data
else:
partial_message_buffer += event.data
# Finished message but no buffer: just emit the event
elif partial_message_buffer is None:
await self._events.put(event)
# Finished message with buffer: emit the full event
else:
event_type = type(event)
full_message_event = event_type(
partial_message_buffer + event.data
)
partial_message_buffer = None
await self._events.put(full_message_event)
continue
await self._events.put(event)
except (httpcore.ReadError, httpcore.WriteError):
await self.close(CloseReason.INTERNAL_ERROR, "Stream error")
Expand Down
45 changes: 45 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,51 @@ async def websocket_endpoint(websocket: WebSocket):
except WebSocketDisconnect:
pass

@pytest.mark.parametrize(
"full_message,send_method",
[
(b"A" * 1024 * 1024, "send_bytes"),
("A" * 1024 * 1024, "send_text"),
],
)
async def test_receive_oversized_message(
self,
full_message: typing.Union[str, bytes],
send_method: str,
server_factory: ServerFactoryFixture,
):
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
await asyncio.sleep(0.1) # FIXME: see #7

method = getattr(websocket, send_method)
await method(full_message)

await websocket.close()

with server_factory(websocket_endpoint) as socket:
with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client:
try:
with connect_ws(
"http://socket/ws", client, max_message_size_bytes=1024
) as ws:
event = ws.receive()
assert isinstance(event, wsproto.events.Message)
assert event.data == full_message
except WebSocketDisconnect:
pass

async with httpx.AsyncClient(
transport=httpx.AsyncHTTPTransport(uds=socket)
) as aclient:
try:
async with aconnect_ws("http://socket/ws", aclient) as aws:
event = await aws.receive()
assert isinstance(event, wsproto.events.Message)
assert event.data == full_message
except WebSocketDisconnect:
pass

async def test_receive_text(self, server_factory: ServerFactoryFixture):
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
Expand Down

0 comments on commit b6bdb36

Please sign in to comment.