Skip to content

Commit

Permalink
Dont send message on session closing, and don't error in log when retry
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenthebuilder committed Apr 29, 2024
1 parent a0872b8 commit bda0dcb
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 21 deletions.
8 changes: 4 additions & 4 deletions replit_river/client_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from replit_river.error_schema import (
ERROR_CODE_STREAM_CLOSED,
ERROR_HANDSHAKE,
ERROR_SESSION,
RiverException,
)
from replit_river.messages import (
PROTOCOL_VERSION,
FailedSendingMessageException,
WebsocketClosedException,
parse_transport_msg,
send_transport_message,
)
Expand Down Expand Up @@ -226,8 +226,8 @@ async def _send_handshake_request(
)
stream_id = self.generate_nanoid()

def websocket_closed_callback() -> None:
raise RiverException(ERROR_SESSION, "Session closed while sending")
async def websocket_closed_callback() -> None:
logging.error("websocket closed before handshake response")

try:
await send_transport_message(
Expand All @@ -246,7 +246,7 @@ def websocket_closed_callback() -> None:
websocket_closed_callback=websocket_closed_callback,
)
return handshake_request
except ConnectionClosed:
except (WebsocketClosedException, FailedSendingMessageException):
raise RiverException(ERROR_HANDSHAKE, "Hand shake failed")

async def _get_handshake_response_msg(
Expand Down
6 changes: 3 additions & 3 deletions replit_river/codegen/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def message_type(
"required": [],
}
# Non-oneof fields.
oneofs: DefaultDict[
int, List[descriptor_pb2.FieldDescriptorProto]
] = collections.defaultdict(list)
oneofs: DefaultDict[int, List[descriptor_pb2.FieldDescriptorProto]] = (
collections.defaultdict(list)
)
for field in m.field:
if field.HasField("oneof_index"):
oneofs[field.oneof_index].append(field)
Expand Down
12 changes: 6 additions & 6 deletions replit_river/codegen/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def message_decoder(
" return m",
]
# Non-oneof fields.
oneofs: DefaultDict[
int, List[descriptor_pb2.FieldDescriptorProto]
] = collections.defaultdict(list)
oneofs: DefaultDict[int, List[descriptor_pb2.FieldDescriptorProto]] = (
collections.defaultdict(list)
)
for field in m.field:
if field.HasField("oneof_index"):
oneofs[field.oneof_index].append(field)
Expand Down Expand Up @@ -157,9 +157,9 @@ def message_encoder(
" d: Dict[str, Any] = {}",
]
# Non-oneof fields.
oneofs: DefaultDict[
int, List[descriptor_pb2.FieldDescriptorProto]
] = collections.defaultdict(list)
oneofs: DefaultDict[int, List[descriptor_pb2.FieldDescriptorProto]] = (
collections.defaultdict(list)
)
for field in m.field:
if field.HasField("oneof_index"):
oneofs[field.oneof_index].append(field)
Expand Down
17 changes: 14 additions & 3 deletions replit_river/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
from replit_river.transport_options import TransportOptions


class WebsocketClosedException(Exception):
pass


class FailedSendingMessageException(Exception):
pass

Expand All @@ -43,11 +47,18 @@ async def send_transport_message(
msg.model_dump(by_alias=True, exclude_none=True), datetime=True
)
)
except websockets.exceptions.ConnectionClosed as e:
except websockets.exceptions.ConnectionClosed:
await websocket_closed_callback()
raise e
raise WebsocketClosedException()
except RuntimeError:
# RuntimeError: Unexpected ASGI message 'websocket.send',
# after sending 'websocket.close'
await websocket_closed_callback()
raise WebsocketClosedException()
except Exception as e:
raise FailedSendingMessageException(f"Exception during send message : {e}")
raise FailedSendingMessageException(
f"Exception during send message : {type(e)} {e}"
)


def formatted_bytes(message: bytes) -> str:
Expand Down
2 changes: 1 addition & 1 deletion replit_river/server_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ async def _send_handshake_response(
procedureName=request_message.procedureName,
)

async def websocket_closed_callback() -> None:
def websocket_closed_callback() -> None:
logging.error("websocket closed before handshake response")

try:
Expand Down
10 changes: 7 additions & 3 deletions replit_river/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from replit_river.message_buffer import MessageBuffer
from replit_river.messages import (
FailedSendingMessageException,
WebsocketClosedException,
parse_transport_msg,
send_transport_message,
)
Expand Down Expand Up @@ -304,7 +305,7 @@ async def _send_buffered_messages(
websocket,
prefix_bytes=self._transport_options.get_prefix_bytes(),
)
except ConnectionClosed as e:
except WebsocketClosedException as e:
logging.info(f"Connection closed while sending buffered messages : {e}")
break
except FailedSendingMessageException as e:
Expand All @@ -321,7 +322,7 @@ async def _send_transport_message(
await send_transport_message(
msg, ws, self._on_websocket_unexpected_close, prefix_bytes
)
except ConnectionClosed as e:
except WebsocketClosedException as e:
raise e
except FailedSendingMessageException as e:
raise e
Expand All @@ -336,6 +337,9 @@ async def send_message(
procedure_name: str | None = None,
) -> None:
"""Send serialized messages to the websockets."""
# if the session is not active, we should not do anything
if self._state != SessionState.ACTIVE:
return
msg = TransportMessage(
streamId=stream_id,
id=nanoid.generate(),
Expand Down Expand Up @@ -364,7 +368,7 @@ async def send_message(
ws,
prefix_bytes=self._transport_options.get_prefix_bytes(),
)
except ConnectionClosed as e:
except WebsocketClosedException as e:
logging.error(
f"Connection closed while sending message : {e}, waiting for "
"retry from buffer"
Expand Down
13 changes: 12 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import logging
from typing import Any, AsyncGenerator
from unittest.mock import MagicMock, patch

import nanoid # type: ignore
import pytest
Expand Down Expand Up @@ -127,9 +129,15 @@ def server(transport_options: TransportOptions) -> Server:
return server


@pytest.fixture
def no_logging_error() -> MagicMock:
with patch("logging.error") as mock_error:
yield mock_error


@pytest.fixture
async def client(
server: Server, transport_options: TransportOptions
server: Server, transport_options: TransportOptions, no_logging_error: MagicMock
) -> AsyncGenerator[Client, None]:
try:
async with serve(server.serve, "localhost", 8765):
Expand All @@ -145,5 +153,8 @@ async def client(
logging.debug("Start closing test client : %s", "test_client")
await client.close()
finally:
await asyncio.sleep(1)
logging.debug("Start closing test server")
await server.close()
# Server should close normally
no_logging_error.assert_not_called()

0 comments on commit bda0dcb

Please sign in to comment.