From 4aca18bb0a673c751b2c7a81498d020ecf1a1242 Mon Sep 17 00:00:00 2001 From: zhenthebuilder Date: Thu, 18 Apr 2024 16:08:38 -0700 Subject: [PATCH] Update buffer --- pyproject.toml | 2 +- replit_river/message_buffer.py | 4 ++-- replit_river/session.py | 34 ++++++++++++++++++---------------- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c7a05c1..495b0a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name="replit-river" -version="0.1.6" +version="0.1.7-beta.1" description="Replit river toolkit for Python" authors = ["Replit "] license = "LICENSE" diff --git a/replit_river/message_buffer.py b/replit_river/message_buffer.py index eab1c17..644d098 100644 --- a/replit_river/message_buffer.py +++ b/replit_river/message_buffer.py @@ -11,12 +11,12 @@ def __init__(self, max_size: int = 1000): self.buffer: list[TransportMessage] = [] self._lock = asyncio.Lock() - async def is_empty(self) -> bool: + async def empty(self) -> bool: """Check if the buffer is empty""" async with self._lock: return len(self.buffer) == 0 - async def add(self, message: TransportMessage) -> None: + async def put(self, message: TransportMessage) -> None: """Add a message to the buffer""" async with self._lock: if len(self.buffer) >= self.max_size: diff --git a/replit_river/session.py b/replit_river/session.py index de21021..7927c33 100644 --- a/replit_river/session.py +++ b/replit_river/session.py @@ -9,6 +9,7 @@ from websockets.exceptions import ConnectionClosedError from websockets.server import WebSocketServerProtocol +from replit_river.message_buffer import MessageBuffer from replit_river.messages import ( FailedSendingMessageException, parse_transport_msg, @@ -30,6 +31,11 @@ TransportMessage, ) +SEND_TRANSPORT_MESSAGE_EXCEPTIONS = ( + websockets.exceptions.ConnectionClosed, + FailedSendingMessageException, +) + class Session(object): """A transport object that handles the websocket connection with a client.""" @@ -56,8 +62,7 @@ def __init__( self._transport_options = transport_options self._seq_manager = SeqManager() self._task_manager = BackgroundTaskManager() - self._buffer: asyncio.Queue[TransportMessage] = asyncio.Queue(1000) - self._lock = asyncio.Lock() + self._buffer = MessageBuffer() self.heartbeat_misses = 0 # should disconnect after this time self._disconnect_after_this_time: Optional[float] = None @@ -132,21 +137,13 @@ async def _heartbeat( async def _send_buffered_messages( self, websocket: websockets.WebSocketCommonProtocol ) -> None: - while not self._buffer.empty(): - msg = await self._buffer.get() + while not await self._buffer.empty(): + msg = await self._buffer.peek() + if not msg: + continue try: await send_transport_message(msg, websocket) - except ( - websockets.exceptions.ConnectionClosed, - FailedSendingMessageException, - ) as e: - # Put the message back, they need to be resent - async with self._lock: - msg_not_sent = [msg] - while not self._buffer.empty(): - msg_not_sent.append(await self._buffer.get()) - for msg in msg_not_sent: - self._buffer.put_nowait(msg) + except SEND_TRANSPORT_MESSAGE_EXCEPTIONS as e: raise FailedSendingMessageException( f"Failed to resend message during reconnecting : {e}" ) @@ -173,7 +170,7 @@ async def send_message( serviceName=service_name, procedureName=procedure_name, ) - self._buffer.put_nowait(msg) + await self._buffer.put(msg) try: await send_transport_message( msg, @@ -271,6 +268,9 @@ async def _add_msg_to_stream( # close message is not sent to the stream await stream.put(msg.payload) + async def _update_buffer(self) -> None: + await self._buffer.remove_old_messages(self._seq_manager.receiver_ack) + async def handle_messages_from_ws( self, websocket: WebSocketCommonProtocol, tg: Optional[asyncio.TaskGroup] = None ) -> None: @@ -297,6 +297,8 @@ async def handle_messages_from_ws( continue except InvalidTransportMessageException: return + + await self._update_buffer() if msg.controlFlags & ACK_BIT != 0: self.cancel_disconnect_grace_period() continue