diff --git a/replit_river/client_session.py b/replit_river/client_session.py index 64e765b..f1572ce 100644 --- a/replit_river/client_session.py +++ b/replit_river/client_session.py @@ -8,6 +8,7 @@ from replit_river.error_schema import ERROR_CODE_STREAM_CLOSED, RiverException from replit_river.session import Session +from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE from .rpc import ( STREAM_CLOSED_BIT, @@ -48,12 +49,12 @@ async def send_rpc( try: try: response = await output.get() - except (RuntimeError, ChannelClosed) as e: - # if the stream is closed before we get a response, we will get a - # RuntimeError: RuntimeError: Event loop is closed + except ChannelClosed as e: raise RiverException( ERROR_CODE_STREAM_CLOSED, f"Stream closed before response {e}" ) + except RuntimeError as e: + raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) if not response.get("ok", False): try: error = error_deserializer(response["payload"]) @@ -64,8 +65,6 @@ async def send_rpc( except RiverException as e: raise e except Exception as e: - # Log the error and return an appropriate error response - logging.exception("Error during RPC communication") raise e async def send_upload( @@ -100,7 +99,7 @@ async def send_upload( ) first_message = False # If this request is not closed and the session is killed, we should - # throws exception here + # throw exception here async for item in request: control_flags = 0 if first_message: @@ -123,12 +122,12 @@ async def send_upload( try: try: response = await output.get() - except (RuntimeError, ChannelClosed): - # if the stream is closed before we get a response, we will get a - # RuntimeError: RuntimeError: Event loop is closed + except ChannelClosed: raise RiverException( ERROR_CODE_STREAM_CLOSED, "Stream closed before response" ) + except RuntimeError as e: + raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) if not response.get("ok", False): try: error = error_deserializer(response["payload"]) @@ -140,8 +139,6 @@ async def send_upload( except RiverException as e: raise e except Exception as e: - # Log the error and return an appropriate error response - logging.exception("Error during upload communication") raise e async def send_subscription( @@ -158,7 +155,7 @@ async def send_subscription( Expects the input and output be messages that will be msgpacked. """ stream_id = nanoid.generate() - output: Channel[Any] = Channel(1024) + output: Channel[Any] = Channel(MAX_MESSAGE_BUFFER_SIZE) self._streams[stream_id] = output await self.send_message( ws=self._ws, @@ -188,8 +185,6 @@ async def send_subscription( ERROR_CODE_STREAM_CLOSED, "Stream closed before response" ) except Exception as e: - # Log the error and yield an appropriate error response - logging.exception(f"Error during subscription communication : {item}") raise e async def send_stream( @@ -209,7 +204,7 @@ async def send_stream( """ stream_id = nanoid.generate() - output: Channel[Any] = Channel(1024) + output: Channel[Any] = Channel(MAX_MESSAGE_BUFFER_SIZE) self._streams[stream_id] = output try: if init and init_serializer: @@ -273,8 +268,6 @@ async def _encode_stream() -> None: ERROR_CODE_STREAM_CLOSED, "Stream closed before response" ) except Exception as e: - # Log the error and yield an appropriate error response - logging.exception("Error during stream communication") raise e async def send_close_stream( diff --git a/replit_river/client_transport.py b/replit_river/client_transport.py index b6c55e1..eb57bee 100644 --- a/replit_river/client_transport.py +++ b/replit_river/client_transport.py @@ -226,7 +226,7 @@ async def _send_handshake_request( ) stream_id = self.generate_nanoid() - def websocket_closed_callback(): + def websocket_closed_callback() -> None: raise RiverException(ERROR_SESSION, "Session closed while sending") try: diff --git a/replit_river/codegen/schema.py b/replit_river/codegen/schema.py index c7a924c..6c40989 100644 --- a/replit_river/codegen/schema.py +++ b/replit_river/codegen/schema.py @@ -37,9 +37,9 @@ def message_type( "required": [], } # Non-oneof fields. - oneofs: DefaultDict[int, List[descriptor_pb2.FieldDescriptorProto]] = ( - collections.defaultdict(list) - ) + oneofs: DefaultDict[ + int, List[descriptor_pb2.FieldDescriptorProto] + ] = collections.defaultdict(list) for field in m.field: if field.HasField("oneof_index"): oneofs[field.oneof_index].append(field) diff --git a/replit_river/codegen/server.py b/replit_river/codegen/server.py index 8a6e775..02dd197 100644 --- a/replit_river/codegen/server.py +++ b/replit_river/codegen/server.py @@ -46,9 +46,9 @@ def message_decoder( " return m", ] # Non-oneof fields. - oneofs: DefaultDict[int, List[descriptor_pb2.FieldDescriptorProto]] = ( - collections.defaultdict(list) - ) + oneofs: DefaultDict[ + int, List[descriptor_pb2.FieldDescriptorProto] + ] = collections.defaultdict(list) for field in m.field: if field.HasField("oneof_index"): oneofs[field.oneof_index].append(field) @@ -157,9 +157,9 @@ def message_encoder( " d: Dict[str, Any] = {}", ] # Non-oneof fields. - oneofs: DefaultDict[int, List[descriptor_pb2.FieldDescriptorProto]] = ( - collections.defaultdict(list) - ) + oneofs: DefaultDict[ + int, List[descriptor_pb2.FieldDescriptorProto] + ] = collections.defaultdict(list) for field in m.field: if field.HasField("oneof_index"): oneofs[field.oneof_index].append(field) diff --git a/replit_river/message_buffer.py b/replit_river/message_buffer.py index 92174e4..dd2ce01 100644 --- a/replit_river/message_buffer.py +++ b/replit_river/message_buffer.py @@ -3,12 +3,13 @@ from typing import Optional from replit_river.rpc import TransportMessage +from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE class MessageBuffer: """A buffer to store messages and support current updates""" - def __init__(self, max_num_messages: int = 1000): + def __init__(self, max_num_messages: int = MAX_MESSAGE_BUFFER_SIZE): self.max_size = max_num_messages self.buffer: list[TransportMessage] = [] self._lock = asyncio.Lock() diff --git a/replit_river/rate_limiter.py b/replit_river/rate_limiter.py index ddaf375..bb2fa52 100644 --- a/replit_river/rate_limiter.py +++ b/replit_river/rate_limiter.py @@ -25,7 +25,7 @@ def __init__(self, options: ConnectionRetryOptions): self.budget_consumed: Dict[str, int] = {} self.tasks: Dict[str, asyncio.Task] = {} - def get_backoff_ms(self, user: str) -> int: + def get_backoff_ms(self, user: str) -> float: """Calculate the backoff time in milliseconds for a user. Args: diff --git a/replit_river/rpc.py b/replit_river/rpc.py index 4af66e4..462399a 100644 --- a/replit_river/rpc.py +++ b/replit_river/rpc.py @@ -28,6 +28,7 @@ RiverException, ) from replit_river.task_manager import BackgroundTaskManager +from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE InitType = TypeVar("InitType") RequestType = TypeVar("RequestType") @@ -280,7 +281,7 @@ async def wrapped( task_manager = BackgroundTaskManager() try: context = GrpcContext(peer) - request: Channel[RequestType] = Channel(1024) + request: Channel[RequestType] = Channel(MAX_MESSAGE_BUFFER_SIZE) async def _convert_inputs() -> None: try: @@ -341,7 +342,6 @@ def stream_method_handler( request_deserializer: Callable[[Any], RequestType], response_serializer: Callable[[ResponseType], Any], ) -> GenericRpcHandler: - async def wrapped( peer: str, input: Channel[Any], @@ -350,7 +350,7 @@ async def wrapped( task_manager = BackgroundTaskManager() try: context = GrpcContext(peer) - request: Channel[RequestType] = Channel(1024) + request: Channel[RequestType] = Channel(MAX_MESSAGE_BUFFER_SIZE) async def _convert_inputs() -> None: try: diff --git a/replit_river/server_transport.py b/replit_river/server_transport.py index ee61cf0..89c6e83 100644 --- a/replit_river/server_transport.py +++ b/replit_river/server_transport.py @@ -30,7 +30,6 @@ class ServerTransport(Transport): - async def handshake_to_get_session( self, websocket: WebSocketServerProtocol, @@ -38,9 +37,11 @@ async def handshake_to_get_session( async for message in websocket: try: msg = parse_transport_msg(message, self._transport_options) - _, handshake_request, handshake_response = ( - await self._establish_handshake(msg, websocket) - ) + ( + _, + handshake_request, + handshake_response, + ) = await self._establish_handshake(msg, websocket) except IgnoreMessageException: continue except InvalidMessageException: diff --git a/replit_river/session.py b/replit_river/session.py index be2cd44..90adb9c 100644 --- a/replit_river/session.py +++ b/replit_river/session.py @@ -21,7 +21,7 @@ SeqManager, ) from replit_river.task_manager import BackgroundTaskManager -from replit_river.transport_options import TransportOptions +from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions from .rpc import ( ACK_BIT, @@ -267,6 +267,8 @@ async def _heartbeat( try: await self.send_message( str(nanoid.generate()), + # TODO: make this a message class + # https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42 { "ack": 0, }, @@ -380,9 +382,10 @@ async def _send_responses_from_output_stream( ) -> None: """Send serialized messages to the websockets.""" try: + # TODO: This blocking is not ideal, we should close this task when websocket + # is closed, and start another one when websocket reconnects. ws = None async for payload in output: - # TODO: what if the websocket changed during this? ws = self._ws while self._ws_state != WsState.OPEN: await asyncio.sleep( @@ -459,8 +462,12 @@ async def _open_stream_and_call_handler( "stream", ) # New channel pair. - input_stream: Channel[Any] = Channel(1024 if is_streaming_input else 1) - output_stream: Channel[Any] = Channel(1024 if is_streaming_output else 1) + input_stream: Channel[Any] = Channel( + MAX_MESSAGE_BUFFER_SIZE if is_streaming_input else 1 + ) + output_stream: Channel[Any] = Channel( + MAX_MESSAGE_BUFFER_SIZE if is_streaming_output else 1 + ) try: await input_stream.put(msg.payload) except (RuntimeError, ChannelClosed) as e: diff --git a/replit_river/transport.py b/replit_river/transport.py index 3c90c25..2d0b1d7 100644 --- a/replit_river/transport.py +++ b/replit_river/transport.py @@ -41,8 +41,8 @@ async def _close_all_sessions(self) -> None: f"{len(sessions)}" ) sessions_to_close = list(sessions) - for session in sessions_to_close: - await session.close(False) + tasks = [session.close(False) for session in sessions_to_close] + await asyncio.gather(*tasks) logging.info(f"Transport closed {self._transport_id}") async def _delete_session(self, session: Session) -> None: @@ -61,7 +61,7 @@ async def _get_or_create_session_id( self, to_id: str, advertised_session_id: str, - ): + ) -> str: try: async with self._session_lock: if to_id not in self._sessions: diff --git a/replit_river/transport_options.py b/replit_river/transport_options.py index f9228c7..ed020c0 100644 --- a/replit_river/transport_options.py +++ b/replit_river/transport_options.py @@ -4,6 +4,7 @@ CROSIS_PREFIX_BYTES = b"\x00\x00" PID2_PREFIX_BYTES = b"\xff\xff" +MAX_MESSAGE_BUFFER_SIZE = 1024 class ConnectionRetryOptions(BaseModel): diff --git a/tests/conftest.py b/tests/conftest.py index 6547085..c658d6f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import logging from typing import Any, AsyncGenerator -import nanoid +import nanoid # type: ignore import pytest from websockets.server import serve diff --git a/tests/test_seq_manager.py b/tests/test_seq_manager.py index b805a7d..e4b562f 100644 --- a/tests/test_seq_manager.py +++ b/tests/test_seq_manager.py @@ -11,14 +11,14 @@ @pytest.mark.asyncio -async def test_initial_sequence_and_ack_numbers(): +async def test_initial_sequence_and_ack_numbers() -> None: manager = SeqManager() assert await manager.get_seq() == 0, "Initial sequence number should be 0" assert await manager.get_ack() == 0, "Initial acknowledgment number should be 0" @pytest.mark.asyncio -async def test_sequence_number_increment(): +async def test_sequence_number_increment() -> None: manager = SeqManager() initial_seq = await manager.get_seq_and_increment() assert initial_seq == 0, "Sequence number should start at 0" @@ -27,7 +27,7 @@ async def test_sequence_number_increment(): @pytest.mark.asyncio -async def test_message_reception(): +async def test_message_reception() -> None: manager = SeqManager() msg = transport_message(seq=0, ack=0, from_="client") await manager.check_seq_and_update( @@ -46,7 +46,7 @@ async def test_message_reception(): @pytest.mark.asyncio -async def test_acknowledgment_setting(): +async def test_acknowledgment_setting() -> None: manager = SeqManager() msg = transport_message(seq=0, ack=0, from_="client") await manager.check_seq_and_update(msg) @@ -54,7 +54,7 @@ async def test_acknowledgment_setting(): @pytest.mark.asyncio -async def test_concurrent_access_to_sequence(): +async def test_concurrent_access_to_sequence() -> None: manager = SeqManager() tasks = [manager.get_seq_and_increment() for _ in range(10)] results = await asyncio.gather(*tasks)