diff --git a/replit_river/client_transport.py b/replit_river/client_transport.py index 665e103..a478c7b 100644 --- a/replit_river/client_transport.py +++ b/replit_river/client_transport.py @@ -13,12 +13,12 @@ from replit_river.error_schema import ( ERROR_CODE_STREAM_CLOSED, ERROR_HANDSHAKE, - ERROR_SESSION, RiverException, ) from replit_river.messages import ( PROTOCOL_VERSION, FailedSendingMessageException, + WebsocketClosedException, parse_transport_msg, send_transport_message, ) @@ -226,8 +226,8 @@ async def _send_handshake_request( ) stream_id = self.generate_nanoid() - def websocket_closed_callback() -> None: - raise RiverException(ERROR_SESSION, "Session closed while sending") + async def websocket_closed_callback() -> None: + logging.error("websocket closed before handshake response") try: await send_transport_message( @@ -246,7 +246,7 @@ def websocket_closed_callback() -> None: websocket_closed_callback=websocket_closed_callback, ) return handshake_request - except ConnectionClosed: + except (WebsocketClosedException, FailedSendingMessageException): raise RiverException(ERROR_HANDSHAKE, "Hand shake failed") async def _get_handshake_response_msg( diff --git a/replit_river/codegen/schema.py b/replit_river/codegen/schema.py index 6c40989..c7a924c 100644 --- a/replit_river/codegen/schema.py +++ b/replit_river/codegen/schema.py @@ -37,9 +37,9 @@ def message_type( "required": [], } # Non-oneof fields. - oneofs: DefaultDict[ - int, List[descriptor_pb2.FieldDescriptorProto] - ] = collections.defaultdict(list) + oneofs: DefaultDict[int, List[descriptor_pb2.FieldDescriptorProto]] = ( + collections.defaultdict(list) + ) for field in m.field: if field.HasField("oneof_index"): oneofs[field.oneof_index].append(field) diff --git a/replit_river/codegen/server.py b/replit_river/codegen/server.py index 02dd197..8a6e775 100644 --- a/replit_river/codegen/server.py +++ b/replit_river/codegen/server.py @@ -46,9 +46,9 @@ def message_decoder( " return m", ] # Non-oneof fields. - oneofs: DefaultDict[ - int, List[descriptor_pb2.FieldDescriptorProto] - ] = collections.defaultdict(list) + oneofs: DefaultDict[int, List[descriptor_pb2.FieldDescriptorProto]] = ( + collections.defaultdict(list) + ) for field in m.field: if field.HasField("oneof_index"): oneofs[field.oneof_index].append(field) @@ -157,9 +157,9 @@ def message_encoder( " d: Dict[str, Any] = {}", ] # Non-oneof fields. - oneofs: DefaultDict[ - int, List[descriptor_pb2.FieldDescriptorProto] - ] = collections.defaultdict(list) + oneofs: DefaultDict[int, List[descriptor_pb2.FieldDescriptorProto]] = ( + collections.defaultdict(list) + ) for field in m.field: if field.HasField("oneof_index"): oneofs[field.oneof_index].append(field) diff --git a/replit_river/messages.py b/replit_river/messages.py index 2c711e1..9976fc5 100644 --- a/replit_river/messages.py +++ b/replit_river/messages.py @@ -19,6 +19,10 @@ from replit_river.transport_options import TransportOptions +class WebsocketClosedException(Exception): + pass + + class FailedSendingMessageException(Exception): pass @@ -43,11 +47,18 @@ async def send_transport_message( msg.model_dump(by_alias=True, exclude_none=True), datetime=True ) ) - except websockets.exceptions.ConnectionClosed as e: + except websockets.exceptions.ConnectionClosed: await websocket_closed_callback() - raise e + raise WebsocketClosedException() + except RuntimeError: + # RuntimeError: Unexpected ASGI message 'websocket.send', + # after sending 'websocket.close' + await websocket_closed_callback() + raise WebsocketClosedException() except Exception as e: - raise FailedSendingMessageException(f"Exception during send message : {e}") + raise FailedSendingMessageException( + f"Exception during send message : {type(e)} {e}" + ) def formatted_bytes(message: bytes) -> str: diff --git a/replit_river/server_transport.py b/replit_river/server_transport.py index 89c6e83..e02bf35 100644 --- a/replit_river/server_transport.py +++ b/replit_river/server_transport.py @@ -96,7 +96,7 @@ async def _send_handshake_response( procedureName=request_message.procedureName, ) - async def websocket_closed_callback() -> None: + def websocket_closed_callback() -> None: logging.error("websocket closed before handshake response") try: diff --git a/replit_river/session.py b/replit_river/session.py index 90adb9c..5ddc09b 100644 --- a/replit_river/session.py +++ b/replit_river/session.py @@ -12,6 +12,7 @@ from replit_river.message_buffer import MessageBuffer from replit_river.messages import ( FailedSendingMessageException, + WebsocketClosedException, parse_transport_msg, send_transport_message, ) @@ -304,7 +305,7 @@ async def _send_buffered_messages( websocket, prefix_bytes=self._transport_options.get_prefix_bytes(), ) - except ConnectionClosed as e: + except WebsocketClosedException as e: logging.info(f"Connection closed while sending buffered messages : {e}") break except FailedSendingMessageException as e: @@ -321,7 +322,7 @@ async def _send_transport_message( await send_transport_message( msg, ws, self._on_websocket_unexpected_close, prefix_bytes ) - except ConnectionClosed as e: + except WebsocketClosedException as e: raise e except FailedSendingMessageException as e: raise e @@ -336,6 +337,9 @@ async def send_message( procedure_name: str | None = None, ) -> None: """Send serialized messages to the websockets.""" + # if the session is not active, we should not do anything + if self._state != SessionState.ACTIVE: + return msg = TransportMessage( streamId=stream_id, id=nanoid.generate(), @@ -364,7 +368,7 @@ async def send_message( ws, prefix_bytes=self._transport_options.get_prefix_bytes(), ) - except ConnectionClosed as e: + except WebsocketClosedException as e: logging.error( f"Connection closed while sending message : {e}, waiting for " "retry from buffer" diff --git a/tests/conftest.py b/tests/conftest.py index c658d6f..4db1248 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,7 @@ +import asyncio import logging from typing import Any, AsyncGenerator +from unittest.mock import MagicMock, patch import nanoid # type: ignore import pytest @@ -127,9 +129,15 @@ def server(transport_options: TransportOptions) -> Server: return server +@pytest.fixture +def no_logging_error() -> MagicMock: + with patch("logging.error") as mock_error: + yield mock_error + + @pytest.fixture async def client( - server: Server, transport_options: TransportOptions + server: Server, transport_options: TransportOptions, no_logging_error: MagicMock ) -> AsyncGenerator[Client, None]: try: async with serve(server.serve, "localhost", 8765): @@ -145,5 +153,8 @@ async def client( logging.debug("Start closing test client : %s", "test_client") await client.close() finally: + await asyncio.sleep(1) logging.debug("Start closing test server") await server.close() + # Server should close normally + no_logging_error.assert_not_called()