diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000..a61740b --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,14 @@ +Why +=== + +_Describe what prompted you to make this change, link relevant resources: Linear issues, Slack discussions, etc._ + +What changed +============ + +_Describe what changed to a level of detail that someone with no context with your PR could be able to review it_ + +Test plan +========= + +_Describe what you did to test this change to a level of detail that allows your reviewer to test it_ diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..467618b --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,35 @@ +name: Python package + +on: [push] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.11" + - name: cache poetry install + uses: actions/cache@v2 + with: + path: ~/.local + key: poetry-1.6.1-0 + - uses: snok/install-poetry@v1 + with: + version: 1.6.1 + virtualenvs-create: true + virtualenvs-in-project: true + - name: cache deps + id: cache-deps + uses: actions/cache@v2 + with: + path: .venv + key: pydeps-${{ hashFiles('**/poetry.lock') }} + - name: Install dependencies + run: | + poetry install --no-interaction + - name: Test with pytest + run: | + poetry run pytest tests diff --git a/pyproject.toml b/pyproject.toml index 2b07302..7a73959 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name="replit-river" -version="0.1.5" +version="0.1.9" description="Replit river toolkit for Python" authors = ["Replit "] license = "LICENSE" diff --git a/replit_river/client.py b/replit_river/client.py index 9ebe3f0..b5c9daf 100644 --- a/replit_river/client.py +++ b/replit_river/client.py @@ -1,233 +1,39 @@ -import asyncio import logging from collections.abc import AsyncIterable, AsyncIterator -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Optional, Union -import msgpack # type: ignore -import nanoid # type: ignore -from aiochannel import Channel -from pydantic import ValidationError -from websockets import Data -from websockets.client import WebSocketClientProtocol -from websockets.exceptions import ConnectionClosed - -from replit_river.error_schema import ERROR_CODE_STREAM_CLOSED, RiverException -from replit_river.seq_manager import ( - IgnoreTransportMessageException, - InvalidTransportMessageException, - SeqManager, -) -from replit_river.transport import FailedSendingMessageException +from replit_river.client_transport import ClientTransport +from replit_river.transport_options import TransportOptions from .rpc import ( - ACK_BIT, - STREAM_CLOSED_BIT, - STREAM_OPEN_BIT, - ControlMessageHandshakeRequest, - ControlMessageHandshakeResponse, ErrorType, InitType, RequestType, ResponseType, - TransportMessage, ) -CROSIS_PREFIX_BYTES = b"\x00\x00" -PID2_PREFIX_BYTES = b"\xff\xff" - class Client: def __init__( self, - websockets: WebSocketClientProtocol, - use_prefix_bytes: bool = True, - client_id: Optional[str] = None, - server_id: Optional[str] = None, + websocket_uri: str, + client_id: str, + server_id: str, + transport_options: TransportOptions, ) -> None: - self.ws = websockets - self._tasks = set() - self._from = nanoid.generate() - self._streams: Dict[str, Channel[Dict[str, Any]]] = {} - self._seq_manager = SeqManager() - self._is_handshaked = False - self._use_prefix_bytes = use_prefix_bytes - self._instance_id = client_id or "python-client-" + self.generate_nanoid() - self._server_id = server_id or "SERVER" - - task = asyncio.create_task(self._handle_messages()) - self._tasks.add(task) - - def _handle_messages_callback(task: asyncio.Task) -> None: - self._tasks.remove(task) - if task.exception(): - logging.error( - f"Error in river.client._handle_messages: {task.exception()}" - ) - - task.add_done_callback(_handle_messages_callback) - - async def send_close_stream( - self, service_name: str, procedure_name: str, stream_id: str - ) -> None: - # close stream - await self.send_transport_message( - from_=self._from, - to=self._server_id, - serviceName=service_name, - procedureName=procedure_name, - streamId=stream_id, - controlFlags=STREAM_CLOSED_BIT, - payload={ - "type": "CLOSE", - }, + self._client_id = client_id + self._server_id = server_id + self._transport = ClientTransport( + websocket_uri=websocket_uri, + client_id=client_id, + server_id=server_id, + transport_options=transport_options, ) - def to_transport_message(self, message: Data) -> TransportMessage: - unpacked = msgpack.unpackb(message, timestamp=3) - - return TransportMessage(**unpacked) - - async def send_transport_message( - self, - from_: str, - to: str, - serviceName: Optional[str], - procedureName: Optional[str], - streamId: str, - controlFlags: int, - payload: Dict[str, Any], - is_handshake: bool = False, - ) -> None: - current_seq = 0 - if not is_handshake: - while not self._is_handshaked: - await asyncio.sleep(0.01) - if is_handshake: - current_seq = await self._seq_manager.get_seq() - else: - current_seq = await self._seq_manager.get_seq_and_increment() - current_ack = await self._seq_manager.get_ack() - message = TransportMessage( - id=nanoid.generate(), - from_=from_, - to=to, - serviceName=serviceName, - procedureName=procedureName, - streamId=streamId, - controlFlags=controlFlags, - payload=payload, - seq=current_seq, - ack=current_ack, - ) - prefix = PID2_PREFIX_BYTES if self._use_prefix_bytes else b"" - try: - await self.ws.send( - prefix - + msgpack.packb( - message.model_dump(by_alias=True, exclude_none=True), - datetime=True, - ) - ) - except ConnectionClosed: - raise FailedSendingMessageException( - "Connection closed while sending message" - ) - - def generate_nanoid(self) -> str: - return str(nanoid.generate()) - - async def _receive_pid2_message(self) -> Data: - data = await self.ws.recv() - if self._use_prefix_bytes: - while data[:2] == CROSIS_PREFIX_BYTES: - data = await self.ws.recv() - return data[2:] - return data - - async def _handle_messages(self) -> None: - handshake_request = ControlMessageHandshakeRequest( - type="HANDSHAKE_REQ", - protocolVersion="v1", - instanceId=self._instance_id, - ) - try: - await self.send_transport_message( - from_=self._from, - to=self._server_id, - serviceName=None, - procedureName=None, - streamId=self.generate_nanoid(), - controlFlags=0, - payload=handshake_request.model_dump(), - is_handshake=True, - ) - except FailedSendingMessageException: - raise RiverException( - ERROR_CODE_STREAM_CLOSED, "Stream closed before response" - ) - data = await self._receive_pid2_message() - first_message = self.to_transport_message(data) - try: - handshake_response = ControlMessageHandshakeResponse( - **first_message.payload - ) - except ValidationError: - logging.error("Failed to parse handshake response") - # TODO: close the connection here - return - if not handshake_response.status.ok: - logging.error(f"Handshake failed: {handshake_response.status.reason}") - # TODO: close the connection here - return - self._is_handshaked = True - - async for message in self.ws: - if isinstance(message, str): - # Not something we will try to handle. - logging.debug( - "ignored a message beacuse it was a text frame: %r", - message, - ) - continue - if self._use_prefix_bytes: - if message[:2] == CROSIS_PREFIX_BYTES: - logging.debug("ignored a crosis message") - continue - message = message[2:] - - try: - unpacked = msgpack.unpackb(message, timestamp=3) - msg = TransportMessage(**unpacked) - try: - await self._seq_manager.check_seq_and_update(msg) - except IgnoreTransportMessageException: - continue - except InvalidTransportMessageException: - return - if msg.controlFlags == ACK_BIT: - continue - - except ConnectionClosed: - logging.info("Connection closed") - break - - except ( - ValidationError, - ValueError, - msgpack.UnpackException, - msgpack.exceptions.ExtraData, - ): - logging.exception("failed to parse message") - return - previous_output = self._streams.get(msg.streamId, None) - if not previous_output: - logging.warning("no stream for %s", msg.streamId) - continue - await previous_output.put(msg.payload) - if msg.controlFlags & STREAM_CLOSED_BIT != 0: - logging.info("Closing stream %s", msg.streamId) - previous_output.close() - del self._streams[msg.streamId] + async def close(self) -> None: + logging.info(f"river client {self._client_id} start closing") + await self._transport.close() + logging.info(f"river client {self._client_id} closed") async def send_rpc( self, @@ -238,52 +44,15 @@ async def send_rpc( response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], ) -> ResponseType: - """Sends a single RPC request to the server. - - Expects the input and output be messages that will be msgpacked. - """ - - stream_id = nanoid.generate() - output: Channel[Any] = Channel(1) - self._streams[stream_id] = output - try: - await self.send_transport_message( - from_=self._from, - to=self._server_id, - serviceName=service_name, - procedureName=procedure_name, - streamId=stream_id, - controlFlags=STREAM_OPEN_BIT | STREAM_CLOSED_BIT, - payload=request_serializer(request), - ) - except FailedSendingMessageException: - raise RiverException( - ERROR_CODE_STREAM_CLOSED, "Stream closed before response" - ) - - # Handle potential errors during communication - try: - try: - response = await output.get() - except RuntimeError: - # if the stream is closed before we get a response, we will get a - # RuntimeError: RuntimeError: Event loop is closed - raise RiverException( - ERROR_CODE_STREAM_CLOSED, "Stream closed before response" - ) - if not response.get("ok", False): - try: - error = error_deserializer(response["payload"]) - except Exception as e: - raise RiverException("error_deserializer", str(e)) - raise RiverException(error.code, error.message) - return response_deserializer(response["payload"]) - except RiverException as e: - raise e - except Exception as e: - # Log the error and return an appropriate error response - logging.exception("Error during RPC communication") - raise e + session = await self._transport._get_or_create_session() + return await session.send_rpc( + service_name, + procedure_name, + request, + request_serializer, + response_deserializer, + error_deserializer, + ) async def send_upload( self, @@ -296,72 +65,17 @@ async def send_upload( response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], ) -> ResponseType: - """Sends an upload request to the server. - - Expects the input and output be messages that will be msgpacked. - """ - - stream_id = nanoid.generate() - output: Channel[Any] = Channel(1024) - self._streams[stream_id] = output - first_message = True - try: - if init and init_serializer: - await self.send_transport_message( - from_=self._from, - to=self._server_id, - serviceName=service_name, - procedureName=procedure_name, - streamId=stream_id, - controlFlags=STREAM_OPEN_BIT, - payload=init_serializer(init), - ) - first_message = False - - async for item in request: - control_flags = 0 - if first_message: - control_flags = STREAM_OPEN_BIT - first_message = False - await self.send_transport_message( - from_=self._from, - to=self._server_id, - serviceName=service_name, - procedureName=procedure_name, - streamId=stream_id, - controlFlags=control_flags, - payload=request_serializer(item), - ) - except FailedSendingMessageException: - raise RiverException( - ERROR_CODE_STREAM_CLOSED, "Stream closed before response" - ) - await self.send_close_stream(service_name, procedure_name, stream_id) - - # Handle potential errors during communication - try: - try: - response = await output.get() - except RuntimeError: - # if the stream is closed before we get a response, we will get a - # RuntimeError: RuntimeError: Event loop is closed - raise RiverException( - ERROR_CODE_STREAM_CLOSED, "Stream closed before response" - ) - if not response.get("ok", False): - try: - error = error_deserializer(response["payload"]) - except Exception as e: - raise RiverException("error_deserializer", str(e)) - raise RiverException(error.code, error.message) - - return response_deserializer(response["payload"]) - except RiverException as e: - raise e - except Exception as e: - # Log the error and return an appropriate error response - logging.exception("Error during upload communication") - raise e + session = await self._transport._get_or_create_session() + return await session.send_upload( + service_name, + procedure_name, + init, + request, + init_serializer, + request_serializer, + response_deserializer, + error_deserializer, + ) async def send_subscription( self, @@ -372,46 +86,15 @@ async def send_subscription( response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], ) -> AsyncIterator[Union[ResponseType, ErrorType]]: - """Sends a subscription request to the server. - - Expects the input and output be messages that will be msgpacked. - """ - stream_id = nanoid.generate() - output: Channel[Any] = Channel(1024) - self._streams[stream_id] = output - try: - await self.send_transport_message( - from_=self._from, - to=self._server_id, - serviceName=service_name, - procedureName=procedure_name, - streamId=stream_id, - controlFlags=STREAM_OPEN_BIT, - payload=request_serializer(request), - ) - except FailedSendingMessageException: - raise RiverException( - ERROR_CODE_STREAM_CLOSED, "Stream closed before response" - ) - - # Handle potential errors during communication - try: - async for item in output: - if item.get("type", None) == "CLOSE": - break - if not item.get("ok", False): - try: - yield error_deserializer(item["payload"]) - except Exception: - logging.exception( - f"Error during subscription error deserialization: {item}" - ) - continue - yield response_deserializer(item["payload"]) - except Exception as e: - # Log the error and yield an appropriate error response - logging.exception(f"Error during subscription communication : {item}") - raise e + session = await self._transport._get_or_create_session() + return session.send_subscription( + service_name, + procedure_name, + request, + request_serializer, + response_deserializer, + error_deserializer, + ) async def send_stream( self, @@ -424,79 +107,14 @@ async def send_stream( response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], ) -> AsyncIterator[Union[ResponseType, ErrorType]]: - """Sends a subscription request to the server. - - Expects the input and output be messages that will be msgpacked. - """ - - stream_id = nanoid.generate() - output: Channel[Any] = Channel(1024) - self._streams[stream_id] = output - try: - if init and init_serializer: - await self.send_transport_message( - from_=self._from, - to=self._server_id, - serviceName=service_name, - procedureName=procedure_name, - streamId=stream_id, - controlFlags=STREAM_OPEN_BIT, - payload=init_serializer(init), - ) - else: - # Get the very first message to open the stream - request_iter = aiter(request) - first = await anext(request_iter) - await self.send_transport_message( - from_=self._from, - to=self._server_id, - serviceName=service_name, - procedureName=procedure_name, - streamId=stream_id, - controlFlags=STREAM_OPEN_BIT, - payload=request_serializer(first), - ) - - except FailedSendingMessageException: - raise RiverException( - ERROR_CODE_STREAM_CLOSED, "Stream closed before response" - ) - - # Create the encoder task - async def _encode_stream() -> None: - async for item in request: - if item is None: - continue - await self.send_transport_message( - from_=self._from, - to=self._server_id, - serviceName=service_name, - procedureName=procedure_name, - streamId=stream_id, - controlFlags=0, - payload=request_serializer(item), - ) - await self.send_close_stream(service_name, procedure_name, stream_id) - - task = asyncio.create_task(_encode_stream()) - self._tasks.add(task) - task.add_done_callback(lambda _: self._tasks.remove(task)) - - # Handle potential errors during communication - try: - async for item in output: - if "type" in item and item["type"] == "CLOSE": - break - if not item.get("ok", False): - try: - yield error_deserializer(item["payload"]) - except Exception: - logging.exception( - f"Error during subscription error deserialization: {item}" - ) - continue - yield response_deserializer(item["payload"]) - except Exception as e: - # Log the error and yield an appropriate error response - logging.exception("Error during stream communication") - raise e + session = await self._transport._get_or_create_session() + return session.send_stream( + service_name, + procedure_name, + init, + request, + init_serializer, + request_serializer, + response_deserializer, + error_deserializer, + ) diff --git a/replit_river/client_session.py b/replit_river/client_session.py new file mode 100644 index 0000000..f1572ce --- /dev/null +++ b/replit_river/client_session.py @@ -0,0 +1,286 @@ +import logging +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any, Callable, Optional, Union + +import nanoid # type: ignore +from aiochannel import Channel +from aiochannel.errors import ChannelClosed + +from replit_river.error_schema import ERROR_CODE_STREAM_CLOSED, RiverException +from replit_river.session import Session +from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE + +from .rpc import ( + STREAM_CLOSED_BIT, + STREAM_OPEN_BIT, + ErrorType, + InitType, + RequestType, + ResponseType, +) + + +class ClientSession(Session): + async def send_rpc( + self, + service_name: str, + procedure_name: str, + request: RequestType, + request_serializer: Callable[[RequestType], Any], + response_deserializer: Callable[[Any], ResponseType], + error_deserializer: Callable[[Any], ErrorType], + ) -> ResponseType: + """Sends a single RPC request to the server. + + Expects the input and output be messages that will be msgpacked. + """ + stream_id = nanoid.generate() + output: Channel[Any] = Channel(1) + self._streams[stream_id] = output + await self.send_message( + ws=self._ws, + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT | STREAM_CLOSED_BIT, + payload=request_serializer(request), + service_name=service_name, + procedure_name=procedure_name, + ) + # Handle potential errors during communication + try: + try: + response = await output.get() + except ChannelClosed as e: + raise RiverException( + ERROR_CODE_STREAM_CLOSED, f"Stream closed before response {e}" + ) + except RuntimeError as e: + raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) + if not response.get("ok", False): + try: + error = error_deserializer(response["payload"]) + except Exception as e: + raise RiverException("error_deserializer", str(e)) + raise RiverException(error.code, error.message) + return response_deserializer(response["payload"]) + except RiverException as e: + raise e + except Exception as e: + raise e + + async def send_upload( + self, + service_name: str, + procedure_name: str, + init: Optional[InitType], + request: AsyncIterable[RequestType], + init_serializer: Optional[Callable[[InitType], Any]], + request_serializer: Callable[[RequestType], Any], + response_deserializer: Callable[[Any], ResponseType], + error_deserializer: Callable[[Any], ErrorType], + ) -> ResponseType: + """Sends an upload request to the server. + + Expects the input and output be messages that will be msgpacked. + """ + + stream_id = nanoid.generate() + output: Channel[Any] = Channel(1) + self._streams[stream_id] = output + first_message = True + try: + if init and init_serializer: + await self.send_message( + stream_id=stream_id, + ws=self._ws, + control_flags=STREAM_OPEN_BIT, + service_name=service_name, + procedure_name=procedure_name, + payload=init_serializer(init), + ) + first_message = False + # If this request is not closed and the session is killed, we should + # throw exception here + async for item in request: + control_flags = 0 + if first_message: + control_flags = STREAM_OPEN_BIT + first_message = False + await self.send_message( + stream_id=stream_id, + ws=self._ws, + service_name=service_name, + procedure_name=procedure_name, + control_flags=control_flags, + payload=request_serializer(item), + ) + except Exception as e: + raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) + await self.send_close_stream(service_name, procedure_name, stream_id) + + # Handle potential errors during communication + # TODO: throw a error when the transport is hard closed + try: + try: + response = await output.get() + except ChannelClosed: + raise RiverException( + ERROR_CODE_STREAM_CLOSED, "Stream closed before response" + ) + except RuntimeError as e: + raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) + if not response.get("ok", False): + try: + error = error_deserializer(response["payload"]) + except Exception as e: + raise RiverException("error_deserializer", str(e)) + raise RiverException(error.code, error.message) + + return response_deserializer(response["payload"]) + except RiverException as e: + raise e + except Exception as e: + raise e + + async def send_subscription( + self, + service_name: str, + procedure_name: str, + request: RequestType, + request_serializer: Callable[[RequestType], Any], + response_deserializer: Callable[[Any], ResponseType], + error_deserializer: Callable[[Any], ErrorType], + ) -> AsyncIterator[Union[ResponseType, ErrorType]]: + """Sends a subscription request to the server. + + Expects the input and output be messages that will be msgpacked. + """ + stream_id = nanoid.generate() + output: Channel[Any] = Channel(MAX_MESSAGE_BUFFER_SIZE) + self._streams[stream_id] = output + await self.send_message( + ws=self._ws, + service_name=service_name, + procedure_name=procedure_name, + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT, + payload=request_serializer(request), + ) + + # Handle potential errors during communication + try: + async for item in output: + if item.get("type", None) == "CLOSE": + break + if not item.get("ok", False): + try: + yield error_deserializer(item["payload"]) + except Exception: + logging.exception( + f"Error during subscription error deserialization: {item}" + ) + continue + yield response_deserializer(item["payload"]) + except (RuntimeError, ChannelClosed): + raise RiverException( + ERROR_CODE_STREAM_CLOSED, "Stream closed before response" + ) + except Exception as e: + raise e + + async def send_stream( + self, + service_name: str, + procedure_name: str, + init: Optional[InitType], + request: AsyncIterable[RequestType], + init_serializer: Optional[Callable[[InitType], Any]], + request_serializer: Callable[[RequestType], Any], + response_deserializer: Callable[[Any], ResponseType], + error_deserializer: Callable[[Any], ErrorType], + ) -> AsyncIterator[Union[ResponseType, ErrorType]]: + """Sends a subscription request to the server. + + Expects the input and output be messages that will be msgpacked. + """ + + stream_id = nanoid.generate() + output: Channel[Any] = Channel(MAX_MESSAGE_BUFFER_SIZE) + self._streams[stream_id] = output + try: + if init and init_serializer: + await self.send_message( + ws=self._ws, + service_name=service_name, + procedure_name=procedure_name, + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT, + payload=init_serializer(init), + ) + else: + # Get the very first message to open the stream + request_iter = aiter(request) + first = await anext(request_iter) + await self.send_message( + ws=self._ws, + service_name=service_name, + procedure_name=procedure_name, + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT, + payload=request_serializer(first), + ) + + except Exception as e: + raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) + + # Create the encoder task + async def _encode_stream() -> None: + async for item in request: + if item is None: + continue + await self.send_message( + ws=self._ws, + service_name=service_name, + procedure_name=procedure_name, + stream_id=stream_id, + control_flags=0, + payload=request_serializer(item), + ) + await self.send_close_stream(service_name, procedure_name, stream_id) + + await self._task_manager.create_task(_encode_stream()) + + # Handle potential errors during communication + try: + async for item in output: + if "type" in item and item["type"] == "CLOSE": + break + if not item.get("ok", False): + try: + yield error_deserializer(item["payload"]) + except Exception: + logging.exception( + f"Error during subscription error deserialization: {item}" + ) + continue + yield response_deserializer(item["payload"]) + except (RuntimeError, ChannelClosed): + raise RiverException( + ERROR_CODE_STREAM_CLOSED, "Stream closed before response" + ) + except Exception as e: + raise e + + async def send_close_stream( + self, service_name: str, procedure_name: str, stream_id: str + ) -> None: + # close stream + await self.send_message( + ws=self._ws, + service_name=service_name, + procedure_name=procedure_name, + stream_id=stream_id, + control_flags=STREAM_CLOSED_BIT, + payload={ + "type": "CLOSE", + }, + ) diff --git a/replit_river/client_transport.py b/replit_river/client_transport.py new file mode 100644 index 0000000..665e103 --- /dev/null +++ b/replit_river/client_transport.py @@ -0,0 +1,314 @@ +import asyncio +import logging +from typing import Optional, Tuple + +import websockets +from pydantic import ValidationError +from websockets import ( + WebSocketCommonProtocol, +) +from websockets.exceptions import ConnectionClosed + +from replit_river.client_session import ClientSession +from replit_river.error_schema import ( + ERROR_CODE_STREAM_CLOSED, + ERROR_HANDSHAKE, + ERROR_SESSION, + RiverException, +) +from replit_river.messages import ( + PROTOCOL_VERSION, + FailedSendingMessageException, + parse_transport_msg, + send_transport_message, +) +from replit_river.rate_limiter import LeakyBucketRateLimit +from replit_river.rpc import ( + ControlMessageHandshakeRequest, + ControlMessageHandshakeResponse, + TransportMessage, +) +from replit_river.seq_manager import ( + IgnoreMessageException, + InvalidMessageException, +) +from replit_river.session import Session +from replit_river.transport import Transport +from replit_river.transport_options import TransportOptions + + +class ClientTransport(Transport): + def __init__( + self, + websocket_uri: str, + client_id: str, + server_id: str, + transport_options: TransportOptions, + ): + super().__init__( + transport_id=client_id, + transport_options=transport_options, + is_server=False, + ) + self._websocket_uri = websocket_uri + self._client_id = client_id + self._server_id = server_id + self._rate_limiter = LeakyBucketRateLimit( + transport_options.connection_retry_options + ) + # We want to make sure there's only one session creation at a time + self._create_session_lock = asyncio.Lock() + # Only one retry should happen at a time + self._retry_ws_lock = asyncio.Lock() + + async def _on_session_closed(self, session: Session) -> None: + logging.info(f"Client session {session.advertised_session_id} closed") + await self._delete_session(session) + + async def close(self) -> None: + self._rate_limiter.close() + await self._close_all_sessions() + + async def _get_existing_session(self) -> Optional[ClientSession]: + async with self._session_lock: + if not self._sessions: + return None + if len(self._sessions) > 1: + raise RiverException( + "session_error", + "More than one session found in client, should only be one", + ) + session = list(self._sessions.values())[0] + if isinstance(session, ClientSession): + return session + else: + raise RiverException( + "session_error", f"Client session type wrong, got {type(session)}" + ) + + async def _establish_new_connection( + self, + ) -> Tuple[ + WebSocketCommonProtocol, + ControlMessageHandshakeRequest, + ControlMessageHandshakeResponse, + ]: + """Build a new websocket connection with retry logic.""" + rate_limit = self._rate_limiter + max_retry = self._transport_options.connection_retry_options.max_retry + client_id = self._client_id + for i in range(max_retry): + if i > 0: + logging.info(f"Retrying build handshake number {i} times") + if not rate_limit.has_budget(client_id): + logging.debug("No retry budget for %s.", client_id) + break + try: + ws = await websockets.connect(self._websocket_uri) + existing_session = await self._get_existing_session() + session_id = ( + self.generate_session_id() + if not existing_session + else existing_session.session_id + ) + rate_limit.consume_budget(client_id) + handshake_request, handshake_response = await self._establish_handshake( + self._transport_id, self._server_id, session_id, ws + ) + rate_limit.start_restoring_budget(client_id) + return ws, handshake_request, handshake_response + except Exception as e: + backoff_time = rate_limit.get_backoff_ms(client_id) + logging.error( + f"Error creating session: {e}, start backoff {backoff_time} ms" + ) + await asyncio.sleep(backoff_time / 1000) + raise RiverException( + ERROR_HANDSHAKE, + "Failed to create session after retrying max number of times", + ) + + async def _retry_session_connection( + self, session_to_replace_ws: Session + ) -> Session: + async with self._retry_ws_lock: + if await session_to_replace_ws.is_websocket_open(): + # other retry successfully replaced the websocket, + return session_to_replace_ws + if not await session_to_replace_ws.is_session_open(): + # If the session is already closing we don't retry connection + return session_to_replace_ws + new_ws, hs_request, hs_response = await self._establish_new_connection() + # If the server session id different, we create a new session. + if ( + hs_response.status.sessionId + != session_to_replace_ws.advertised_session_id + ): + server_session_id = hs_response.status.sessionId + if not server_session_id: + raise RiverException( + ERROR_SESSION, + "Server did not return a sessionId in successful handshake", + ) + new_session = ClientSession( + transport_id=self._transport_id, + to_id=self._server_id, + session_id=hs_request.sessionId, + advertised_session_id=server_session_id, + websocket=new_ws, + transport_options=self._transport_options, + is_server=False, + handlers={}, + close_session_callback=self._on_session_closed, + retry_connection_callback=lambda x: self._retry_session_connection( + x + ), + ) + return new_session + else: + # If the session is still active and aligns with the server session + # we replace the websocket in it. + await session_to_replace_ws.replace_with_new_websocket(new_ws) + return session_to_replace_ws + + async def _get_or_create_session(self) -> ClientSession: + async with self._create_session_lock: + existing_session = await self._get_existing_session() + if existing_session: + if await existing_session.is_websocket_open(): + return existing_session + else: + session = await self._retry_session_connection(existing_session) + # This should never happen, adding here to make mypy happy + if not isinstance(session, ClientSession): + raise RiverException( + ERROR_SESSION, + f"Session type is not ClientSession, got {type(session)}", + ) + return session + else: + new_ws, hs_request, hs_response = await self._establish_new_connection() + advertised_session_id = hs_response.status.sessionId + if not advertised_session_id: + raise RiverException( + ERROR_SESSION, + "Server did not return a sessionId in successful handshake", + ) + new_session = ClientSession( + transport_id=self._transport_id, + to_id=self._server_id, + session_id=hs_request.sessionId, + advertised_session_id=advertised_session_id, + websocket=new_ws, + transport_options=self._transport_options, + is_server=False, + handlers={}, + close_session_callback=self._on_session_closed, + retry_connection_callback=lambda x: self._retry_session_connection( + x + ), + ) + await self._set_session(new_session) + await new_session.start_serve_responses() + return new_session + + async def _send_handshake_request( + self, + transport_id: str, + to_id: str, + session_id: str, + websocket: WebSocketCommonProtocol, + ) -> ControlMessageHandshakeRequest: + handshake_request = ControlMessageHandshakeRequest( + type="HANDSHAKE_REQ", + protocolVersion=PROTOCOL_VERSION, + sessionId=session_id, + ) + stream_id = self.generate_nanoid() + + def websocket_closed_callback() -> None: + raise RiverException(ERROR_SESSION, "Session closed while sending") + + try: + await send_transport_message( + TransportMessage( + from_=transport_id, + to=to_id, + streamId=stream_id, + controlFlags=0, + id=self.generate_nanoid(), + seq=0, + ack=0, + payload=handshake_request.model_dump(), + ), + ws=websocket, + prefix_bytes=self._transport_options.get_prefix_bytes(), + websocket_closed_callback=websocket_closed_callback, + ) + return handshake_request + except ConnectionClosed: + raise RiverException(ERROR_HANDSHAKE, "Hand shake failed") + + async def _get_handshake_response_msg( + self, websocket: WebSocketCommonProtocol + ) -> TransportMessage: + while True: + try: + data = await websocket.recv() + except ConnectionClosed as e: + logging.debug( + "Connection closed during waiting for handshake response : %r", e + ) + raise RiverException(ERROR_HANDSHAKE, "Hand shake failed") + try: + return parse_transport_msg(data, self._transport_options) + except IgnoreMessageException as e: + logging.debug("Ignoring transport message : %r", e) + continue + except InvalidMessageException as e: + raise RiverException( + ERROR_HANDSHAKE, + f"Got invalid transport message, closing connection : {e}", + ) + + async def _establish_handshake( + self, + transport_id: str, + to_id: str, + session_id: str, + websocket: WebSocketCommonProtocol, + ) -> Tuple[ControlMessageHandshakeRequest, ControlMessageHandshakeResponse]: + try: + handshake_request = await self._send_handshake_request( + transport_id=transport_id, + to_id=to_id, + session_id=session_id, + websocket=websocket, + ) + except FailedSendingMessageException: + raise RiverException( + ERROR_CODE_STREAM_CLOSED, "Stream closed before response" + ) + logging.debug("river client waiting for handshake response") + try: + response_msg = await asyncio.wait_for( + self._get_handshake_response_msg(websocket), + timeout=self._transport_options.session_disconnect_grace_ms / 1000, + ) + handshake_response = ControlMessageHandshakeResponse(**response_msg.payload) + logging.debug( + "river client get handshake response : %r", handshake_response + ) + except ValidationError as e: + raise RiverException( + ERROR_HANDSHAKE, f"Failed to parse handshake response : {e}" + ) + except asyncio.TimeoutError: + raise RiverException( + ERROR_HANDSHAKE, "Handshake response timeout, closing connection" + ) + if not handshake_response.status.ok: + raise RiverException( + ERROR_HANDSHAKE, f"Handshake failed: {handshake_response.status.reason}" + ) + return handshake_request, handshake_response diff --git a/replit_river/codegen/server.py b/replit_river/codegen/server.py index 8a6e775..02dd197 100644 --- a/replit_river/codegen/server.py +++ b/replit_river/codegen/server.py @@ -46,9 +46,9 @@ def message_decoder( " return m", ] # Non-oneof fields. - oneofs: DefaultDict[int, List[descriptor_pb2.FieldDescriptorProto]] = ( - collections.defaultdict(list) - ) + oneofs: DefaultDict[ + int, List[descriptor_pb2.FieldDescriptorProto] + ] = collections.defaultdict(list) for field in m.field: if field.HasField("oneof_index"): oneofs[field.oneof_index].append(field) @@ -157,9 +157,9 @@ def message_encoder( " d: Dict[str, Any] = {}", ] # Non-oneof fields. - oneofs: DefaultDict[int, List[descriptor_pb2.FieldDescriptorProto]] = ( - collections.defaultdict(list) - ) + oneofs: DefaultDict[ + int, List[descriptor_pb2.FieldDescriptorProto] + ] = collections.defaultdict(list) for field in m.field: if field.HasField("oneof_index"): oneofs[field.oneof_index].append(field) diff --git a/replit_river/error_schema.py b/replit_river/error_schema.py index eb5dd5d..c8d87a0 100644 --- a/replit_river/error_schema.py +++ b/replit_river/error_schema.py @@ -3,6 +3,8 @@ from pydantic import BaseModel ERROR_CODE_STREAM_CLOSED = "stream_closed" +ERROR_HANDSHAKE = "handshake_failed" +ERROR_SESSION = "session_error" class RiverError(BaseModel): diff --git a/replit_river/message_buffer.py b/replit_river/message_buffer.py new file mode 100644 index 0000000..dd2ce01 --- /dev/null +++ b/replit_river/message_buffer.py @@ -0,0 +1,40 @@ +import asyncio +import logging +from typing import Optional + +from replit_river.rpc import TransportMessage +from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE + + +class MessageBuffer: + """A buffer to store messages and support current updates""" + + def __init__(self, max_num_messages: int = MAX_MESSAGE_BUFFER_SIZE): + self.max_size = max_num_messages + self.buffer: list[TransportMessage] = [] + self._lock = asyncio.Lock() + + async def empty(self) -> bool: + """Check if the buffer is empty""" + async with self._lock: + return len(self.buffer) == 0 + + async def put(self, message: TransportMessage) -> None: + """Add a message to the buffer""" + async with self._lock: + if len(self.buffer) >= self.max_size: + logging.error("Buffer is full, dropping message") + raise ValueError("Buffer is full") + self.buffer.append(message) + + async def peek(self) -> Optional[TransportMessage]: + """Peek the first message in the buffer, returns None if the buffer is empty.""" + async with self._lock: + if len(self.buffer) == 0: + return None + return self.buffer[0] + + async def remove_old_messages(self, min_seq: int) -> None: + """Remove messages in the buffer with a seq number less than min_seq.""" + async with self._lock: + self.buffer = [msg for msg in self.buffer if msg.seq >= min_seq] diff --git a/replit_river/messages.py b/replit_river/messages.py new file mode 100644 index 0000000..2c711e1 --- /dev/null +++ b/replit_river/messages.py @@ -0,0 +1,86 @@ +import logging +from typing import Any, Callable, Coroutine + +import msgpack # type: ignore +import websockets +from pydantic import ValidationError +from pydantic_core import ValidationError as PydanticCoreValidationError +from websockets import ( + WebSocketCommonProtocol, +) + +from replit_river.rpc import ( + TransportMessage, +) +from replit_river.seq_manager import ( + IgnoreMessageException, + InvalidMessageException, +) +from replit_river.transport_options import TransportOptions + + +class FailedSendingMessageException(Exception): + pass + + +PROTOCOL_VERSION = "v1.1" + +CROSIS_PREFIX_BYTES = b"\x00\x00" +PID2_PREFIX_BYTES = b"\xff\xff" + + +async def send_transport_message( + msg: TransportMessage, + ws: WebSocketCommonProtocol, + websocket_closed_callback: Callable[[], Coroutine[Any, Any, None]], + prefix_bytes: bytes = b"", +) -> None: + logging.debug("sending a message %r to ws %s", msg, ws) + try: + await ws.send( + prefix_bytes + + msgpack.packb( + msg.model_dump(by_alias=True, exclude_none=True), datetime=True + ) + ) + except websockets.exceptions.ConnectionClosed as e: + await websocket_closed_callback() + raise e + except Exception as e: + raise FailedSendingMessageException(f"Exception during send message : {e}") + + +def formatted_bytes(message: bytes) -> str: + return " ".join(f"{b:02x}" for b in message) + + +def parse_transport_msg( + message: str | bytes, transport_options: TransportOptions +) -> TransportMessage: + if isinstance(message, str): + raise IgnoreMessageException( + "ignored a message beacuse it was a text frame: %r", message + ) + if transport_options.use_prefix_bytes: + if message.startswith(CROSIS_PREFIX_BYTES): + raise IgnoreMessageException("Skip crosis message") + elif message.startswith(PID2_PREFIX_BYTES): + message = message[len(PID2_PREFIX_BYTES) :] + else: + raise InvalidMessageException( + f"Got message without prefix bytes: {formatted_bytes(message)[:5]}" + ) + try: + unpacked_message = msgpack.unpackb(message) + except (msgpack.UnpackException, msgpack.exceptions.ExtraData): + raise InvalidMessageException("received non-msgpack message") + try: + msg = TransportMessage(**unpacked_message) + except ( + ValidationError, + ValueError, + msgpack.UnpackException, + PydanticCoreValidationError, + ): + raise InvalidMessageException(f"failed to parse message: {unpacked_message}") + return msg diff --git a/replit_river/rate_limiter.py b/replit_river/rate_limiter.py new file mode 100644 index 0000000..bb2fa52 --- /dev/null +++ b/replit_river/rate_limiter.py @@ -0,0 +1,103 @@ +import asyncio +import random +from typing import Dict + +from replit_river.transport_options import ConnectionRetryOptions + + +class LeakyBucketRateLimit: + """Asynchronous leaky bucket rate limiter. + + This class implements a rate limiting strategy using a leaky bucket algorithm, + utilizing asyncio + to handle periodic budget restoration in an asynchronous context. + + Attributes: + options (ConnectionRetryOptions): Configuration options for retry behavior. + budget_consumed (Dict[str, int]): Dictionary tracking the number of retries + (or budget) consumed per user. + tasks (Dict[str, asyncio.Task]): Dictionary holding asyncio tasks for budget + restoration. + """ + + def __init__(self, options: ConnectionRetryOptions): + self.options = options + self.budget_consumed: Dict[str, int] = {} + self.tasks: Dict[str, asyncio.Task] = {} + + def get_backoff_ms(self, user: str) -> float: + """Calculate the backoff time in milliseconds for a user. + + Args: + user (str): The identifier for the user. + + Returns: + int: The backoff time in milliseconds, including a random jitter. + """ + exponent = max(0, self.get_budget_consumed(user) - 1) + jitter = random.randint(0, self.options.max_jitter_ms) + backoff_ms = min( + self.options.base_interval_ms * (2**exponent), self.options.max_backoff_ms + ) + return backoff_ms + jitter + + def get_budget_consumed(self, user: str) -> int: + """Retrieve the amount of budget consumed for the specified user. + + Args: + user (str): The identifier for the user. + + Returns: + int: The number of times the budget has been consumed. + """ + return self.budget_consumed.get(user, 0) + + def has_budget(self, user: str) -> bool: + """Check if the user has remaining budget to make a retry. + + Args: + user (str): The identifier for the user. + + Returns: + bool: True if budget is available, False otherwise. + """ + return self.get_budget_consumed(user) < self.options.attempt_budget_capacity + + def consume_budget(self, user: str) -> None: + """Increment the budget consumed for the user by 1, indicating a retry attempt. + + Args: + user (str): The identifier for the user. + """ + if user in self.tasks: + self.tasks[user].cancel() + current_budget = self.get_budget_consumed(user) + self.budget_consumed[user] = current_budget + 1 + + def start_restoring_budget(self, user: str) -> None: + """Start or reset an asynchronous task to restore budget periodically for the + user. + + Args: + user (str): The identifier for the user. + """ + self.tasks[user] = asyncio.create_task(self.restore_budget(user)) + + async def restore_budget(self, user: str) -> None: + """Asynchronously wait for the interval and then restore the budget for the + user. + + Args: + user (str): The identifier for the user. + """ + while self.budget_consumed.get(user, 0) > 0: + await asyncio.sleep(self.options.budget_restore_interval_ms / 1000.0) + if self.budget_consumed[user] == 0: + break + self.budget_consumed[user] -= 1 + + def close(self) -> None: + """Cancel all asynchronous tasks when closing the limiter.""" + for task in self.tasks.values(): + task.cancel() + self.tasks.clear() diff --git a/replit_river/rpc.py b/replit_river/rpc.py index bc9dfda..462399a 100644 --- a/replit_river/rpc.py +++ b/replit_river/rpc.py @@ -19,10 +19,16 @@ ) import grpc -from aiochannel import Channel +from aiochannel import Channel, ChannelClosed from pydantic import BaseModel, ConfigDict, Field -from replit_river.error_schema import RiverError +from replit_river.error_schema import ( + ERROR_CODE_STREAM_CLOSED, + RiverError, + RiverException, +) +from replit_river.task_manager import BackgroundTaskManager +from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE InitType = TypeVar("InitType") RequestType = TypeVar("RequestType") @@ -45,13 +51,12 @@ class ControlMessageHandshakeRequest(BaseModel): type: Literal["HANDSHAKE_REQ"] = "HANDSHAKE_REQ" protocolVersion: str - instanceId: str + sessionId: str class HandShakeStatus(BaseModel): ok: bool - # Instance id should be server level id, each server have one - instanceId: Optional[str] = None + sessionId: Optional[str] = None # Reason for failure reason: Optional[str] = None @@ -63,6 +68,7 @@ class ControlMessageHandshakeResponse(BaseModel): class TransportMessage(BaseModel): id: str + # from_ is used instead of from because from is a reserved keyword in Python from_: str = Field(..., alias="from") to: str seq: int @@ -196,7 +202,7 @@ async def wrapped( } ) except Exception as e: - logging.exception("Uncaught exception") + logging.exception("Uncaught exception during river rpc") await output.put( { "ok": False, @@ -243,7 +249,7 @@ async def wrapped( } ) except Exception as e: - logging.exception("Uncaught exception in subscription") + logging.exception("Uncaught exception in river server subscription") await output.put( { "ok": False, @@ -272,9 +278,10 @@ async def wrapped( input: Channel[Any], output: Channel[Any], ) -> None: + task_manager = BackgroundTaskManager() try: context = GrpcContext(peer) - request: Channel[RequestType] = Channel(1024) + request: Channel[RequestType] = Channel(MAX_MESSAGE_BUFFER_SIZE) async def _convert_inputs() -> None: try: @@ -289,8 +296,10 @@ async def _convert_outputs() -> None: await output.put( get_response_or_error_payload(response, response_serializer) ) + except ChannelClosed: + raise RiverException(ERROR_CODE_STREAM_CLOSED, "Channel closed") except Exception as e: - print("upload caught exception", e) + logging.error("Uncaught exception in river server upload") await output.put( { "ok": False, @@ -303,9 +312,10 @@ async def _convert_outputs() -> None: finally: output.close() - convert_inputs_task = asyncio.create_task(_convert_inputs()) - convert_outputs_task = asyncio.create_task(_convert_outputs()) + convert_inputs_task = await task_manager.create_task(_convert_inputs()) + convert_outputs_task = await task_manager.create_task(_convert_outputs()) await asyncio.wait((convert_inputs_task, convert_outputs_task)) + except Exception as e: logging.exception("Uncaught exception in upload") await output.put( @@ -318,6 +328,7 @@ async def _convert_outputs() -> None: } ) finally: + await task_manager.cancel_all_tasks() output.close() return wrapped @@ -336,9 +347,10 @@ async def wrapped( input: Channel[Any], output: Channel[Any], ) -> None: + task_manager = BackgroundTaskManager() try: context = GrpcContext(peer) - request: Channel[RequestType] = Channel(1024) + request: Channel[RequestType] = Channel(MAX_MESSAGE_BUFFER_SIZE) async def _convert_inputs() -> None: try: @@ -358,11 +370,11 @@ async def _convert_outputs() -> None: finally: output.close() - convert_inputs_task = asyncio.create_task(_convert_inputs()) - convert_outputs_task = asyncio.create_task(_convert_outputs()) + convert_inputs_task = await task_manager.create_task(_convert_inputs()) + convert_outputs_task = await task_manager.create_task(_convert_outputs()) await asyncio.wait((convert_inputs_task, convert_outputs_task)) except grpc.RpcError: - logging.exception("Uncaught exception in stream") + logging.exception("RPC exception in stream") await output.put( { "ok": False, @@ -385,6 +397,7 @@ async def _convert_outputs() -> None: } ) finally: + await task_manager.cancel_all_tasks() output.close() return wrapped diff --git a/replit_river/seq_manager.py b/replit_river/seq_manager.py index 6c9deab..7298d22 100644 --- a/replit_river/seq_manager.py +++ b/replit_river/seq_manager.py @@ -4,13 +4,13 @@ from replit_river.rpc import TransportMessage -class IgnoreTransportMessageException(Exception): +class IgnoreMessageException(Exception): """Exception to ignore a transport message, but good to continue.""" pass -class InvalidTransportMessageException(Exception): +class InvalidMessageException(Exception): """Error processing a transport message, should raise a exception.""" pass @@ -26,6 +26,7 @@ def __init__( self.seq = 0 self._ack_lock = asyncio.Lock() self.ack = 0 + self.receiver_ack = 0 async def get_seq_and_increment(self) -> int: """Get the current sequence number and increment it. @@ -53,17 +54,20 @@ async def check_seq_and_update(self, msg: TransportMessage) -> None: async with self._ack_lock: if msg.seq != self.ack: if msg.seq < self.ack: - logging.debug( + raise IgnoreMessageException( f"{msg.from_} received duplicate msg, got {msg.seq}" f" expected {self.ack}" ) - raise IgnoreTransportMessageException else: logging.error( - f"{msg.from_} received duplicate msg, got {msg.seq}" + f"Out of order message received got {msg.seq} expected " + f"{self.ack}" + ) + raise InvalidMessageException( + f"{msg.from_} received out of order, got {msg.seq}" f" expected {self.ack}" ) - raise InvalidTransportMessageException + self.receiver_ack = msg.ack await self._set_ack(msg.seq + 1) async def _set_ack(self, new_ack: int) -> int: diff --git a/replit_river/server.py b/replit_river/server.py index d6ee585..e3f0f5c 100644 --- a/replit_river/server.py +++ b/replit_river/server.py @@ -1,11 +1,12 @@ +import asyncio import logging -from typing import Dict, Mapping, Tuple +from typing import Mapping, Tuple -import nanoid # type: ignore -from websockets.exceptions import ConnectionClosedError +from websockets.exceptions import ConnectionClosed from websockets.server import WebSocketServerProtocol -from replit_river.transport import Transport, TransportManager +from replit_river.server_transport import ServerTransport +from replit_river.transport import TransportOptions from .rpc import ( GenericRpcHandler, @@ -13,28 +14,53 @@ class Server(object): - def __init__(self, server_id: str) -> None: - self._handlers: Dict[Tuple[str, str], Tuple[str, GenericRpcHandler]] = {} - self._server_id = server_id or nanoid.generate() - self._transport_manager = TransportManager() + def __init__(self, server_id: str, transport_options: TransportOptions) -> None: + self._server_id = server_id or "SERVER" + self._transport_options = transport_options + self._transport = ServerTransport( + transport_id=self._server_id, + transport_options=transport_options, + is_server=True, + ) + + async def close(self) -> None: + logging.info(f"river server {self._server_id} start closing") + await self._transport.close() + logging.info(f"river server {self._server_id} closed") def add_rpc_handlers( self, rpc_handlers: Mapping[Tuple[str, str], Tuple[str, GenericRpcHandler]], ) -> None: - self._handlers.update(rpc_handlers) + self._transport._handlers.update(rpc_handlers) async def serve(self, websocket: WebSocketServerProtocol) -> None: - logging.debug("got a client") - transport = Transport( - self._server_id, self._handlers, websocket, self._transport_manager + logging.debug( + "River server started establishing session with ws: %s", websocket.id ) try: - await transport.serve() - except ConnectionClosedError as e: - logging.debug(f"ConnectionClosedError while serving {e}") + session = await asyncio.wait_for( + self._transport.handshake_to_get_session(websocket), + self._transport_options.session_disconnect_grace_ms / 1000, + ) + except Exception as e: + logging.error( + f"Error establishing handshake, closing websocket: {e}", exc_info=True + ) + await websocket.close() + return + logging.debug("River server session established, start serving messages") + + try: + # Session serve will be closed in two cases + # 1. websocket is closed + # 2. exception thrown + # session should be kept in order to be reused by the reconnect within the + # grace period. + await session.serve() + except ConnectionClosed as e: + logging.debug("ConnectionClosed while serving %r", e) + # We don't have to close the websocket here, it is already closed. except Exception as e: logging.error(f"River transport error in server {self._server_id}: {e}") - finally: - if transport: - await transport.close() + await websocket.close() diff --git a/replit_river/server_transport.py b/replit_river/server_transport.py new file mode 100644 index 0000000..89c6e83 --- /dev/null +++ b/replit_river/server_transport.py @@ -0,0 +1,158 @@ +import logging +from typing import Tuple + +import nanoid # type: ignore # type: ignore +from pydantic import ValidationError +from websockets import ( + WebSocketCommonProtocol, + WebSocketServerProtocol, +) +from websockets.exceptions import ConnectionClosed + +from replit_river.messages import ( + PROTOCOL_VERSION, + FailedSendingMessageException, + parse_transport_msg, + send_transport_message, +) +from replit_river.rpc import ( + ControlMessageHandshakeRequest, + ControlMessageHandshakeResponse, + HandShakeStatus, + TransportMessage, +) +from replit_river.seq_manager import ( + IgnoreMessageException, + InvalidMessageException, +) +from replit_river.session import Session +from replit_river.transport import Transport + + +class ServerTransport(Transport): + async def handshake_to_get_session( + self, + websocket: WebSocketServerProtocol, + ) -> Session: + async for message in websocket: + try: + msg = parse_transport_msg(message, self._transport_options) + ( + _, + handshake_request, + handshake_response, + ) = await self._establish_handshake(msg, websocket) + except IgnoreMessageException: + continue + except InvalidMessageException: + error_msg = "Got invalid transport message, closing connection" + raise InvalidMessageException(error_msg) + except FailedSendingMessageException as e: + raise e + logging.debug("handshake success on server: %r", handshake_request) + transport_id = msg.to + to_id = msg.from_ + session_id = handshake_response.status.sessionId + if not session_id: + raise InvalidMessageException("No session id in handshake response") + advertised_session_id = handshake_request.sessionId + try: + session = await self.get_or_create_session( + transport_id, + to_id, + session_id, + advertised_session_id, + websocket, + ) + except Exception as e: + error_msg = ( + "Error building sessions from handshake request : " + f"client_id: {transport_id}, session_id: {advertised_session_id}," + f" error: {e}" + ) + raise InvalidMessageException(error_msg) + return session + raise InvalidMessageException("No handshake message received") + + async def _send_handshake_response( + self, + request_message: TransportMessage, + handshake_status: HandShakeStatus, + websocket: WebSocketCommonProtocol, + ) -> ControlMessageHandshakeResponse: + response = ControlMessageHandshakeResponse( + status=handshake_status, + ) + response_message = TransportMessage( + streamId=request_message.streamId, + id=nanoid.generate(), + from_=request_message.to, + to=request_message.from_, + seq=0, + ack=0, + controlFlags=0, + payload=response.model_dump(by_alias=True, exclude_none=True), + serviceName=request_message.serviceName, + procedureName=request_message.procedureName, + ) + + async def websocket_closed_callback() -> None: + logging.error("websocket closed before handshake response") + + try: + await send_transport_message( + response_message, websocket, websocket_closed_callback + ) + except (FailedSendingMessageException, ConnectionClosed) as e: + raise FailedSendingMessageException( + f"Failed sending handshake response: {e}" + ) + return response + + async def _establish_handshake( + self, request_message: TransportMessage, websocket: WebSocketCommonProtocol + ) -> Tuple[ + WebSocketCommonProtocol, + ControlMessageHandshakeRequest, + ControlMessageHandshakeResponse, + ]: + try: + handshake_request = ControlMessageHandshakeRequest( + **request_message.payload + ) + logging.debug('Got handshake request "%r"', handshake_request) + except (ValidationError, ValueError): + await self._send_handshake_response( + request_message, + HandShakeStatus(ok=False, reason="failed validate handshake request"), + websocket, + ) + raise InvalidMessageException("failed validate handshake request") + + if handshake_request.protocolVersion != PROTOCOL_VERSION: + await self._send_handshake_response( + request_message, + HandShakeStatus(ok=False, reason="protocol version mismatch"), + websocket, + ) + error_str = ( + "protocol version mismatch: " + + f"{handshake_request.protocolVersion} != {PROTOCOL_VERSION}" + ) + raise InvalidMessageException(error_str) + if request_message.to != self._transport_id: + await self._send_handshake_response( + request_message, + HandShakeStatus(ok=False, reason="handshake request to wrong server"), + websocket, + ) + raise InvalidMessageException("handshake request to wrong server") + my_session_id = await self._get_or_create_session_id( + request_message.from_, handshake_request.sessionId + ) + handshake_response = await self._send_handshake_response( + request_message, + HandShakeStatus(ok=True, sessionId=my_session_id), + websocket, + ) + return websocket, handshake_request, handshake_response diff --git a/replit_river/session.py b/replit_river/session.py new file mode 100644 index 0000000..90adb9c --- /dev/null +++ b/replit_river/session.py @@ -0,0 +1,534 @@ +import asyncio +import enum +import logging +from typing import Any, Callable, Coroutine, Dict, Optional, Tuple + +import nanoid # type: ignore +import websockets +from aiochannel import Channel, ChannelClosed +from websockets import WebSocketCommonProtocol +from websockets.exceptions import ConnectionClosed + +from replit_river.message_buffer import MessageBuffer +from replit_river.messages import ( + FailedSendingMessageException, + parse_transport_msg, + send_transport_message, +) +from replit_river.seq_manager import ( + IgnoreMessageException, + InvalidMessageException, + SeqManager, +) +from replit_river.task_manager import BackgroundTaskManager +from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions + +from .rpc import ( + ACK_BIT, + STREAM_CLOSED_BIT, + STREAM_OPEN_BIT, + GenericRpcHandler, + TransportMessage, +) + + +class SessionState(enum.Enum): + ACTIVE = 0 + CLOSING = 1 + CLOSED = 2 + + +class WsState(enum.Enum): + OPEN = 0 + CLOSING = 1 + CLOSED = 2 + + +class Session(object): + """A transport object that handles the websocket connection with a client.""" + + def __init__( + self, + transport_id: str, + to_id: str, + session_id: str, + advertised_session_id: str, + websocket: websockets.WebSocketCommonProtocol, + transport_options: TransportOptions, + is_server: bool, + handlers: Dict[Tuple[str, str], Tuple[str, GenericRpcHandler]], + close_session_callback: Callable[["Session"], Coroutine[Any, Any, None]], + retry_connection_callback: Optional[ + Callable[ + ["Session"], + Coroutine[Any, Any, Any], + ] + ] = None, + ) -> None: + self._transport_id = transport_id + self._to_id = to_id + self.session_id = session_id + self.advertised_session_id = advertised_session_id + self._handlers = handlers + self._is_server = is_server + self._transport_options = transport_options + + # session state, only modified during closing + self._state = SessionState.ACTIVE + self._state_lock = asyncio.Lock() + self._close_session_callback = close_session_callback + self._close_session_after_time_secs: Optional[float] = None + + # ws state + self._ws_lock = asyncio.Lock() + self._ws_state = WsState.OPEN + self._ws = websocket + self._heartbeat_misses = 0 + self._retry_connection_callback = retry_connection_callback + + # stream for tasks + self._stream_lock = asyncio.Lock() + self._streams: Dict[str, Channel[Any]] = {} + + # book keeping + self._seq_manager = SeqManager() + self._msg_lock = asyncio.Lock() + self._buffer = MessageBuffer(self._transport_options.buffer_size) + self._task_manager = BackgroundTaskManager() + + asyncio.create_task(self._setup_heartbeats_task()) + + async def _setup_heartbeats_task(self) -> None: + await self._task_manager.create_task(self._heartbeat()) + await self._task_manager.create_task(self._check_to_close_session()) + + async def is_session_open(self) -> bool: + async with self._state_lock: + return self._state == SessionState.ACTIVE + + async def is_websocket_open(self) -> bool: + async with self._ws_lock: + return self._ws_state == WsState.OPEN + + async def _on_websocket_unexpected_close(self) -> None: + """Handle unexpected websocket close.""" + logging.info( + f"Unexpected websocket close from {self._transport_id} to {self._to_id}" + ) + await self._begin_close_session_countdown() + + async def _begin_close_session_countdown(self) -> None: + """Begin the countdown to close session, this should be called when + websocket is closed. + """ + logging.debug("begin_close_session_countdown") + if self._close_session_after_time_secs is not None: + # already in grace period, no need to set again + return + logging.debug( + "websocket closed from %s to %s begin grace period", + self._transport_id, + self._to_id, + ) + grace_period_ms = self._transport_options.session_disconnect_grace_ms + self._close_session_after_time_secs = ( + await self._get_current_time() + grace_period_ms / 1000 + ) + + async def serve(self) -> None: + """Serve messages from the websocket.""" + try: + async with asyncio.TaskGroup() as tg: + try: + await self._handle_messages_from_ws(self._ws, tg) + except ConnectionClosed as e: + await self._on_websocket_unexpected_close() + logging.debug("ConnectionClosed while serving: %r", e) + except FailedSendingMessageException as e: + # Expected error if the connection is closed. + logging.debug("FailedSendingMessageException while serving: %r", e) + except Exception: + logging.exception("caught exception at message iterator") + except ExceptionGroup as eg: + _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) + if unhandled: + raise ExceptionGroup( + "Unhandled exceptions on River server", unhandled.exceptions + ) + + async def _update_book_keeping(self, msg: TransportMessage) -> None: + await self._seq_manager.check_seq_and_update(msg) + await self._remove_acked_messages_in_buffer() + self._reset_session_close_countdown() + + async def _handle_messages_from_ws( + self, websocket: WebSocketCommonProtocol, tg: Optional[asyncio.TaskGroup] = None + ) -> None: + logging.debug( + "%s start handling messages from ws %s", + "server" if self._is_server else "client", + websocket.id, + ) + try: + async for message in websocket: + try: + msg = parse_transport_msg(message, self._transport_options) + + logging.debug(f"{self._transport_id} got a message %r", msg) + + await self._update_book_keeping(msg) + if msg.controlFlags & ACK_BIT != 0: + continue + async with self._stream_lock: + stream = self._streams.get(msg.streamId, None) + if msg.controlFlags & STREAM_OPEN_BIT == 0: + if not stream: + logging.warning("no stream for %s", msg.streamId) + raise IgnoreMessageException( + "no stream for message, ignoring" + ) + await self._add_msg_to_stream(msg, stream) + else: + stream = await self._open_stream_and_call_handler( + msg, stream, tg + ) + + if msg.controlFlags & STREAM_CLOSED_BIT != 0: + if stream: + stream.close() + async with self._stream_lock: + del self._streams[msg.streamId] + except IgnoreMessageException as e: + logging.debug("Ignoring transport message : %r", e) + continue + except InvalidMessageException as e: + logging.error( + f"Got invalid transport message, closing session : {e}" + ) + await self.close(True) + return + except ConnectionClosed as e: + raise e + + async def replace_with_new_websocket( + self, new_ws: websockets.WebSocketCommonProtocol + ) -> None: + async with self._ws_lock: + old_ws = self._ws + self._ws_state = WsState.CLOSING + if new_ws.id != old_ws.id: + self._reset_session_close_countdown() + await self.close_websocket(old_ws, should_retry=False) + async with self._ws_lock: + self._ws = new_ws + self._ws_state = WsState.OPEN + await self._send_buffered_messages(new_ws) + # Server will call serve itself. + if not self._is_server: + await self.start_serve_responses() + + async def _get_current_time(self) -> float: + return asyncio.get_event_loop().time() + + def _reset_session_close_countdown(self) -> None: + self._heartbeat_misses = 0 + self._close_session_after_time_secs = None + + async def _check_to_close_session(self) -> None: + while True: + await asyncio.sleep( + self._transport_options.close_session_check_interval_ms / 1000 + ) + if not self._close_session_after_time_secs: + continue + current_time = await self._get_current_time() + if current_time > self._close_session_after_time_secs: + logging.info( + "Grace period ended for :" f" {self._transport_id}, closing session" + ) + await self.close(False) + return + + async def _heartbeat( + self, + ) -> None: + logging.debug("Start heartbeat") + while True: + await asyncio.sleep(self._transport_options.heartbeat_ms / 1000) + if self._state != SessionState.ACTIVE: + logging.debug( + "Session is closed, no need to send heartbeat, state : " + "%r close_session_after_this: %r", + {self._state}, + {self._close_session_after_time_secs}, + ) + # session is closing, no need to send heartbeat + continue + try: + await self.send_message( + str(nanoid.generate()), + # TODO: make this a message class + # https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42 + { + "ack": 0, + }, + self._ws, + ACK_BIT, + ) + self._heartbeat_misses += 1 + if ( + self._heartbeat_misses + >= self._transport_options.heartbeats_until_dead + ): + logging.debug( + "%r closing websocket because of heartbeat misses", + self.session_id, + ) + await self._on_websocket_unexpected_close() + await self.close_websocket( + self._ws, should_retry=not self._is_server + ) + continue + except FailedSendingMessageException: + # this is expected during websocket closed period + continue + + async def _send_buffered_messages( + self, websocket: websockets.WebSocketCommonProtocol + ) -> None: + buffered_messages = list(self._buffer.buffer) + for msg in buffered_messages: + try: + await self._send_transport_message( + msg, + websocket, + prefix_bytes=self._transport_options.get_prefix_bytes(), + ) + except ConnectionClosed as e: + logging.info(f"Connection closed while sending buffered messages : {e}") + break + except FailedSendingMessageException as e: + logging.error(f"Error while sending buffered messages : {e}") + break + + async def _send_transport_message( + self, + msg: TransportMessage, + ws: WebSocketCommonProtocol, + prefix_bytes: bytes = b"", + ) -> None: + try: + await send_transport_message( + msg, ws, self._on_websocket_unexpected_close, prefix_bytes + ) + except ConnectionClosed as e: + raise e + except FailedSendingMessageException as e: + raise e + + async def send_message( + self, + stream_id: str, + payload: Dict | str, + ws: WebSocketCommonProtocol, + control_flags: int = 0, + service_name: str | None = None, + procedure_name: str | None = None, + ) -> None: + """Send serialized messages to the websockets.""" + msg = TransportMessage( + streamId=stream_id, + id=nanoid.generate(), + from_=self._transport_id, + to=self._to_id, + seq=await self._seq_manager.get_seq_and_increment(), + ack=await self._seq_manager.get_ack(), + controlFlags=control_flags, + payload=payload, + serviceName=service_name, + procedureName=procedure_name, + ) + try: + # We need this lock to ensure the buffer order and message sending order + # are the same. + async with self._msg_lock: + try: + await self._buffer.put(msg) + except Exception: + # We should close the session when there are too many messages in + # buffer + await self.close(True) + return + await self._send_transport_message( + msg, + ws, + prefix_bytes=self._transport_options.get_prefix_bytes(), + ) + except ConnectionClosed as e: + logging.error( + f"Connection closed while sending message : {e}, waiting for " + "retry from buffer" + ) + except FailedSendingMessageException as e: + logging.error( + f"Failed sending message : {e}, waiting for retry from buffer" + ) + + async def _send_responses_from_output_stream( + self, + stream_id: str, + output: Channel[Any], + is_streaming_output: bool, + ) -> None: + """Send serialized messages to the websockets.""" + try: + # TODO: This blocking is not ideal, we should close this task when websocket + # is closed, and start another one when websocket reconnects. + ws = None + async for payload in output: + ws = self._ws + while self._ws_state != WsState.OPEN: + await asyncio.sleep( + self._transport_options.close_session_check_interval_ms / 1000 + ) + ws = self._ws + if not is_streaming_output: + await self.send_message(stream_id, payload, ws, STREAM_CLOSED_BIT) + return + await self.send_message(stream_id, payload, ws) + if ws: + logging.debug("sent an end of stream %r", stream_id) + await self.send_message( + stream_id, {"type": "CLOSE"}, ws, STREAM_CLOSED_BIT + ) + except FailedSendingMessageException as e: + logging.error(f"Error while sending responses, {type(e)} : {e}") + except (RuntimeError, ChannelClosed) as e: + logging.error(f"Error while sending responses, {type(e)} : {e}") + except Exception as e: + logging.error(f"Unknown error while river sending responses back : {e}") + + async def close_websocket( + self, ws: WebSocketCommonProtocol, should_retry: bool + ) -> None: + """Mark the websocket as closed, close the websocket, and retry if needed.""" + async with self._ws_lock: + if self._ws.id != ws.id: + # already replaced with new ws + return + if self._ws_state != WsState.OPEN: + # Already closed + return + logging.info( + f"River session from {self._transport_id} to {self._to_id} " + f"closing websocket {ws.id}" + ) + self._ws_state = WsState.CLOSING + if ws: + # TODO: should we wait here? + task = asyncio.create_task(ws.close()) + task.add_done_callback( + lambda _: logging.debug("old websocket %s closed.", ws.id) + ) + self._ws_state = WsState.CLOSED + if should_retry and self._retry_connection_callback: + await self._retry_connection_callback(self) + + async def _open_stream_and_call_handler( + self, + msg: TransportMessage, + stream: Optional[Channel], + tg: Optional[asyncio.TaskGroup], + ) -> Channel: + if not self._is_server: + raise InvalidMessageException("Client should not receive stream open bit") + if not msg.serviceName or not msg.procedureName: + raise IgnoreMessageException( + f"Service name or procedure name is missing in the message {msg}" + ) + key = (msg.serviceName, msg.procedureName) + handler = self._handlers.get(key, None) + if not handler: + raise IgnoreMessageException( + f"No handler for {key} handlers : " f"{self._handlers.keys()}" + ) + method_type, handler_func = handler + is_streaming_output = method_type in ( + "subscription-stream", # subscription + "stream", + ) + is_streaming_input = method_type in ( + "upload-stream", # subscription + "stream", + ) + # New channel pair. + input_stream: Channel[Any] = Channel( + MAX_MESSAGE_BUFFER_SIZE if is_streaming_input else 1 + ) + output_stream: Channel[Any] = Channel( + MAX_MESSAGE_BUFFER_SIZE if is_streaming_output else 1 + ) + try: + await input_stream.put(msg.payload) + except (RuntimeError, ChannelClosed) as e: + raise InvalidMessageException(e) + if not stream: + async with self._stream_lock: + self._streams[msg.streamId] = input_stream + # Start the handler. + await self._task_manager.create_task( + handler_func(msg.from_, input_stream, output_stream), tg + ) + await self._task_manager.create_task( + self._send_responses_from_output_stream( + msg.streamId, output_stream, is_streaming_output + ), + tg, + ) + return input_stream + + async def _add_msg_to_stream( + self, + msg: TransportMessage, + stream: Channel, + ) -> None: + if ( + msg.controlFlags & STREAM_CLOSED_BIT != 0 + and msg.payload.get("type", None) == "CLOSE" + ): + # close message is not sent to the stream + return + try: + await stream.put(msg.payload) + except (RuntimeError, ChannelClosed) as e: + raise InvalidMessageException(e) + + async def _remove_acked_messages_in_buffer(self) -> None: + await self._buffer.remove_old_messages(self._seq_manager.receiver_ack) + + async def start_serve_responses(self) -> None: + await self._task_manager.create_task(self.serve()) + + async def close(self, is_unexpected_close: bool) -> None: + """Close the session and all associated streams.""" + logging.info( + f"{self._transport_id} closing session " + f"to {self._to_id}, ws: {self._ws.id}, current_state : {self._ws_state}" + ) + async with self._state_lock: + if self._state != SessionState.ACTIVE: + # already closing + return + self._state = SessionState.CLOSING + self._reset_session_close_countdown() + await self.close_websocket(self._ws, should_retry=False) + # Clear the session in transports + await self._close_session_callback(self) + await self._task_manager.cancel_all_tasks() + # TODO: unexpected_close should close stream differently here to + # throw exception correctly. + for stream in self._streams.values(): + stream.close() + async with self._stream_lock: + self._streams.clear() + self._state = SessionState.CLOSED diff --git a/replit_river/task_manager.py b/replit_river/task_manager.py new file mode 100644 index 0000000..71b0863 --- /dev/null +++ b/replit_river/task_manager.py @@ -0,0 +1,104 @@ +import asyncio +import logging +from typing import Any, Optional, Set + +from replit_river.error_schema import ERROR_CODE_STREAM_CLOSED, RiverException + + +class BackgroundTaskManager: + """Manages background tasks and logs exceptions.""" + + def __init__(self) -> None: + self.background_tasks: Set[asyncio.Task] = set() + + async def cancel_all_tasks(self) -> None: + """Asynchronously cancels all tasks managed by this instance.""" + # Convert it to a list to avoid RuntimeError: Set changed size during iteration + for task in list(self.background_tasks): + await self.cancel_task(task, self.background_tasks) + + @staticmethod + async def cancel_task( + task_to_remove: asyncio.Task[Any], + background_tasks: Set[asyncio.Task], + ) -> None: + """Cancels a given task and ensures it is removed from the set of managed tasks. + + Args: + task_to_remove: The asyncio.Task instance to cancel. + background_tasks: Set of all tasks being tracked. + """ + task_to_remove.cancel() + try: + await task_to_remove + except asyncio.CancelledError: + # If we cancel the task manager we will get called here as well, + # if we want to handle the cancellation differently we can do it here. + logging.debug("Task was cancelled %r", task_to_remove) + except RiverException as e: + if e.code == ERROR_CODE_STREAM_CLOSED: + # Task is cancelled + pass + logging.error("Exception on cancelling task: %r", e, exc_info=True) + except Exception as e: + logging.error("Exception on cancelling task: %r", e, exc_info=True) + finally: + # Remove the task from the set regardless of the outcome + background_tasks.discard(task_to_remove) + + def _task_done_callback( + self, + task_to_remove: asyncio.Task[Any], + background_tasks: Set[asyncio.Task], + ) -> None: + """Callback to be executed when a task is done. It removes the task from the set + and logs any exceptions. + + Args: + task_to_remove: The asyncio.Task that has completed. + background_tasks: Set of all tasks being tracked. + """ + if task_to_remove in background_tasks: + background_tasks.remove(task_to_remove) + try: + exception = task_to_remove.exception() + except asyncio.CancelledError: + return + except Exception: + logging.error("Error retrieving task exception", exc_info=True) + return + if exception: + if ( + isinstance(exception, RiverException) + and exception.code == ERROR_CODE_STREAM_CLOSED + ): + # Task is cancelled + pass + else: + logging.error( + "Exception on cancelling task: %r", exception, exc_info=True + ) + + async def create_task( + self, fn: Any, tg: Optional[asyncio.TaskGroup] = None + ) -> asyncio.Task: + """Creates a task from a callable and adds it to the background tasks set. + + Args: + fn: A callable to be executed in the task. + tg: Optional asyncio.TaskGroup for managing the task lifecycle. + TODO: tg is hard to understand when passed all the way here, we should + refactor to make this easier to understand. + + Returns: + The created asyncio.Task. + """ + if tg: + task = tg.create_task(fn) + else: + task = asyncio.create_task(fn) + self.background_tasks.add(task) + task.add_done_callback( + lambda x: self._task_done_callback(x, self.background_tasks) + ) + return task diff --git a/replit_river/transport.py b/replit_river/transport.py index 1baf078..2d0b1d7 100644 --- a/replit_river/transport.py +++ b/replit_river/transport.py @@ -1,413 +1,150 @@ import asyncio import logging -from typing import Any, Dict, Optional, Set, Tuple +from typing import Dict, Optional, Tuple -import msgpack # type: ignore import nanoid # type: ignore -import websockets -from aiochannel import Channel -from pydantic import ValidationError -from pydantic_core import ValidationError as PydanticCoreValidationError -from websockets.exceptions import ConnectionClosedError -from websockets.server import WebSocketServerProtocol +from websockets import WebSocketCommonProtocol -from replit_river.seq_manager import ( - IgnoreTransportMessageException, - InvalidTransportMessageException, - SeqManager, -) - -from .rpc import ( - ACK_BIT, - STREAM_CLOSED_BIT, - STREAM_OPEN_BIT, - ControlMessageHandshakeRequest, - ControlMessageHandshakeResponse, +from replit_river.messages import FailedSendingMessageException +from replit_river.rpc import ( GenericRpcHandler, - HandShakeStatus, - TransportMessage, ) - -PROTOCOL_VERSION = "v1" -HEART_BEAT_INTERVAL_SECS = 2 - - -class FailedSendingMessageException(Exception): - pass - - -class TransportManager: - def __init__(self) -> None: - self._transports_by_id: Dict[str, "Transport"] = {} - self._lock = asyncio.Lock() - - async def add_transport(self, transport_id: str, transport: "Transport") -> None: - transport_to_close = None - async with self._lock: - if transport_id in self._transports_by_id: - if ( - self._transports_by_id[transport_id]._client_instance_id - != transport._client_instance_id - ): - transport_to_close = self._transports_by_id[transport_id] - self._transports_by_id[transport_id] = transport - if transport_to_close: - await transport_to_close.close() - - async def remove_transport(self, transport_id: str) -> None: - transport_to_stop = None - async with self._lock: - if transport_id in self._transports_by_id: - transport_to_stop = self._transports_by_id.pop(transport_id) - - if transport_to_stop: - logging.debug("Stopping transport websocket") - await transport_to_stop.close() +from replit_river.session import Session +from replit_river.transport_options import TransportOptions -class Transport(object): - """A transport object that handles the websocket connection with a client.""" - +class Transport: def __init__( self, - server_instance_id: str, - handlers: Dict[Tuple[str, str], Tuple[str, GenericRpcHandler]], - websocket: WebSocketServerProtocol, - transports_manager: TransportManager, + transport_id: str, + transport_options: TransportOptions, + is_server: bool, + handlers: Dict[Tuple[str, str], Tuple[str, GenericRpcHandler]] = {}, ) -> None: - self._server_instance_id = server_instance_id - self._client_instance_id: Optional[str] = None + self._transport_id = transport_id + self._transport_options = transport_options + self._is_server = is_server + self._sessions: Dict[str, Session] = {} self._handlers = handlers - self.websocket = websocket - self.streams: Dict[str, Channel[Any]] = {} - self.background_tasks: Set[asyncio.Task] = set() - self.is_handshake_success = False - self._transports_manager = transports_manager - self._seq_manager = SeqManager() + self._session_lock = asyncio.Lock() - async def send_message( - self, - initial_message: TransportMessage, - ws: WebSocketServerProtocol, - control_flags: int, - payload: Dict, - is_hand_shake: bool = False, - ) -> None: - """Send serialized messages to the websockets.""" - msg = TransportMessage( - streamId=initial_message.streamId, - id=nanoid.generate(), - from_=initial_message.to, - to=initial_message.from_, - seq=0 if is_hand_shake else await self._seq_manager.get_seq_and_increment(), - ack=await self._seq_manager.get_ack(), - controlFlags=control_flags, - payload=payload, - serviceName=initial_message.serviceName, - procedureName=initial_message.procedureName, - ) - logging.debug("sent a message %r", msg) - try: - await ws.send( - msgpack.packb( - msg.model_dump(by_alias=True, exclude_none=True), datetime=True - ) - ) - except websockets.exceptions.ConnectionClosedOK: - logging.warning( - "Trying to send message while connection closed " - f"for between server : {self._server_instance_id} and " - f"client : {self._client_instance_id}" - ) - raise FailedSendingMessageException() + def generate_session_id(self) -> str: + return self.generate_nanoid() - async def send_responses( - self, - initial_message: TransportMessage, - ws: WebSocketServerProtocol, - output: Channel[Any], - is_stream: bool, - ) -> None: - """Send serialized messages to the websockets.""" - logging.debug("sent response of stream %r", initial_message.streamId) - async for payload in output: - if not is_stream: - await self.send_message(initial_message, ws, STREAM_CLOSED_BIT, payload) - return - await self.send_message(initial_message, ws, 0, payload) - logging.debug("sent an end of stream %r", initial_message.streamId) - await self.send_message( - initial_message, ws, STREAM_CLOSED_BIT, {"type": "CLOSE"} - ) - - async def _process_handshake_request_message( - self, transport_message: TransportMessage, websocket: WebSocketServerProtocol - ) -> ControlMessageHandshakeRequest: - """Returns the instance id instance id.""" - try: - handshake_request = ControlMessageHandshakeRequest( - **transport_message.payload - ) - except (ValidationError, ValueError): - response_message = ControlMessageHandshakeResponse( - status=HandShakeStatus( - ok=False, reason="failed validate handshake request" - ) - ) - await self.send_message( - transport_message, - websocket, - 0, - response_message.model_dump(by_alias=True, exclude_none=True), - is_hand_shake=True, - ) - logging.exception("failed to parse handshake request") - raise InvalidTransportMessageException("failed validate handshake request") - - if handshake_request.protocolVersion != PROTOCOL_VERSION: - response_message = ControlMessageHandshakeResponse( - status=HandShakeStatus(ok=False, reason="protocol version mismatch") - ) - await self.send_message( - transport_message, - websocket, - 0, - response_message.model_dump(by_alias=True, exclude_none=True), - is_hand_shake=True, - ) - error_str = ( - "protocol version mismatch: " - + f"{handshake_request.protocolVersion} != {PROTOCOL_VERSION}" - ) - logging.error(error_str) - raise InvalidTransportMessageException(error_str) + async def close(self) -> None: + await self._close_all_sessions() - response_message = ControlMessageHandshakeResponse( - status=HandShakeStatus(ok=True, instanceId=self._server_instance_id) - ) - await self.send_message( - transport_message, - websocket, - 0, - response_message.model_dump(by_alias=True, exclude_none=True), - is_hand_shake=True, + async def _close_all_sessions(self) -> None: + sessions = self._sessions.values() + logging.info( + f"start closing transport {self._transport_id}, number sessions : " + f"{len(sessions)}" ) - return handshake_request + sessions_to_close = list(sessions) + tasks = [session.close(False) for session in sessions_to_close] + await asyncio.gather(*tasks) + logging.info(f"Transport closed {self._transport_id}") - def _formatted_bytes(self, message: bytes) -> str: - return " ".join(f"{b:02x}" for b in message) + async def _delete_session(self, session: Session) -> None: + async with self._session_lock: + if session._to_id in self._sessions: + del self._sessions[session._to_id] - def _parse_transport_msg(self, message: str | bytes) -> TransportMessage: - if isinstance(message, str): - logging.debug( - "ignored a message beacuse it was a text frame: %r", - message, - ) - raise IgnoreTransportMessageException() - try: - unpacked_message = msgpack.unpackb(message, timestamp=3) - except (msgpack.UnpackException, msgpack.exceptions.ExtraData): - logging.exception("received non-msgpack message") - raise InvalidTransportMessageException() - try: - msg = TransportMessage(**unpacked_message) - except ( - ValidationError, - ValueError, - msgpack.UnpackException, - PydanticCoreValidationError, - ): - logging.exception(f"failed to parse message:{message.decode()}") - raise InvalidTransportMessageException() - return msg + async def _set_session(self, session: Session) -> None: + async with self._session_lock: + self._sessions[session._to_id] = session - async def _establish_handshake( - self, msg: TransportMessage, websocket: WebSocketServerProtocol - ) -> None: - try: - handshake_request = await self._process_handshake_request_message( - msg, websocket - ) - self._client_instance_id = handshake_request.instanceId - except InvalidTransportMessageException: - raise - transport_id = msg.from_ - await self._transports_manager.add_transport(transport_id, self) + def generate_nanoid(self) -> str: + return str(nanoid.generate()) - async def _heartbeat( + async def _get_or_create_session_id( self, - msg: TransportMessage, - websocket: WebSocketServerProtocol, - ) -> None: - logging.debug("Start heartbeat") - while True: - await asyncio.sleep(HEART_BEAT_INTERVAL_SECS) - try: - await self.send_message( - msg, + to_id: str, + advertised_session_id: str, + ) -> str: + try: + async with self._session_lock: + if to_id not in self._sessions: + return self.generate_session_id() + else: + old_session = self._sessions[to_id] + + if old_session.advertised_session_id != advertised_session_id: + return self.generate_session_id() + else: + return old_session.session_id + except Exception as e: + logging.error(f"Error getting or creating session id {e}") + raise e + + async def get_or_create_session( + self, + transport_id: str, + to_id: str, + session_id: str, + advertised_session_id: str, + websocket: WebSocketCommonProtocol, + ) -> Session: + session_to_close: Optional[Session] = None + new_session: Optional[Session] = None + async with self._session_lock: + if to_id not in self._sessions: + logging.debug( + 'Creating new session with "%s" using ws: %s', to_id, websocket.id + ) + new_session = Session( + transport_id, + to_id, + session_id, + advertised_session_id, websocket, - ACK_BIT, - { - "ack": msg.id, - }, + self._transport_options, + self._is_server, + self._handlers, + close_session_callback=self._delete_session, ) - except ConnectionClosedError: - logging.debug("heartbeat failed") - return - - def remove_task( - self, - task_to_remove: asyncio.Task[Any], - background_tasks: Set[asyncio.Task], - ) -> None: - if task_to_remove in background_tasks: - background_tasks.remove(task_to_remove) - try: - exception = task_to_remove.exception() - except asyncio.CancelledError: - logging.debug("Task was cancelled", exc_info=False) - return - except Exception: - logging.error("Error retrieving task exception", exc_info=True) - return - if exception: - logging.error( - "Task resulted in an exception", - exc_info=exception, - ) - - def _create_task(self, fn: Any, tg: asyncio.TaskGroup) -> None: - task = tg.create_task(fn) - self.background_tasks.add(task) - task.add_done_callback(lambda x: self.remove_task(x, self.background_tasks)) - - async def handle_messages_from_ws( - self, websocket: WebSocketServerProtocol, tg: asyncio.TaskGroup - ) -> None: - async for message in websocket: - try: - msg = self._parse_transport_msg(message) - except IgnoreTransportMessageException: - continue - except InvalidTransportMessageException: - logging.error("Got invalid transport message, closing connection") - return - - logging.debug("got a message %r", msg) - - if not self.is_handshake_success: - try: - await self._establish_handshake(msg, websocket) - self.is_handshake_success = True - self._create_task(self._heartbeat(msg, websocket), tg) + else: + old_session = self._sessions[to_id] + if old_session.advertised_session_id != advertised_session_id: logging.debug( - "handshake success for client_instance_id :" - f" {self._client_instance_id}" + 'Create new session with "%s" for session id %s' + " and close old session %s", + to_id, + advertised_session_id, + old_session.advertised_session_id, ) - - continue - except InvalidTransportMessageException: - logging.error("Got invalid transport message, closing connection") - return - - try: - await self._seq_manager.check_seq_and_update(msg) - except IgnoreTransportMessageException: - continue - except InvalidTransportMessageException: - return - if msg.controlFlags & ACK_BIT != 0: - # Ignore ack messages. - continue - - stream = self.streams.get(msg.streamId, None) - if msg.controlFlags & STREAM_OPEN_BIT != 0: - if not msg.serviceName or not msg.procedureName: - logging.warning("no service or procedure name in %r", msg) - return - key = (msg.serviceName, msg.procedureName) - handler = self._handlers.get(key, None) - if not handler: - logging.exception( - "No handler for %s handlers : " f"{self._handlers.keys()}", - key, + session_to_close = old_session + new_session = Session( + transport_id, + to_id, + session_id, + advertised_session_id, + websocket, + self._transport_options, + self._is_server, + self._handlers, + close_session_callback=self._delete_session, ) - return - method_type, handler_func = handler - is_streaming_output = method_type in ( - "subscription-stream", # subscription - "stream", - ) - is_streaming_input = method_type in ( - "upload-stream", # subscription - "stream", - ) - # New channel pair. - input_stream: Channel[Any] = Channel(1024 if is_streaming_input else 1) - output_stream: Channel[Any] = Channel( - 1024 if is_streaming_output else 1 - ) - await input_stream.put(msg.payload) - if not stream: - # We'll need to save it for later. - self.streams[msg.streamId] = input_stream - # Start the handler. - self._create_task( - handler_func(msg.from_, input_stream, output_stream), tg - ) - self._create_task( - self.send_responses( - msg, websocket, output_stream, is_streaming_output - ), - tg, - ) - - else: - # messages after stream is opened - if not stream: - logging.warning("no stream for %s", msg.streamId) - continue - if not ( - msg.controlFlags & STREAM_CLOSED_BIT != 0 - and msg.payload.get("type", None) == "CLOSE" - ): - # close message is not sent to the stream - await stream.put(msg.payload) - - if msg.controlFlags & STREAM_CLOSED_BIT != 0: - if stream: - stream.close() - del self.streams[msg.streamId] - - async def serve(self) -> None: - try: - async with asyncio.TaskGroup() as tg: - try: - await self.handle_messages_from_ws(self.websocket, tg) - except ConnectionClosedError as e: - # This is fine. - logging.debug(f"ConnectionClosedError while serving: {e}") - pass - except FailedSendingMessageException as e: - # Expected error if the connection is closed. - logging.debug(f"FailedSendingMessageException while serving: {e}") - pass - except Exception: - logging.exception("caught exception at message iterator") - finally: - await self.close() - except ExceptionGroup as eg: - _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosedError)) - if unhandled: - raise ExceptionGroup( - "Unhandled exceptions on River server", unhandled.exceptions - ) - - async def close(self) -> None: - for previous_input in self.streams.values(): - previous_input.close() - self.streams.clear() - for task in self.background_tasks: - task.cancel() - if self.websocket: - await self.websocket.close() + else: + # If the instance id is the same, we reuse the session and assign + # a new websocket to it. + logging.debug( + 'Reuse old session with "%s" using new ws: %s', + to_id, + websocket.id, + ) + try: + await old_session.replace_with_new_websocket(websocket) + new_session = old_session + except FailedSendingMessageException as e: + raise e + if session_to_close: + logging.info( + f"Closing stale session {session_to_close.advertised_session_id}" + ) + await session_to_close.close(False) + logging.info( + f"Closed stale session {session_to_close.advertised_session_id}" + ) + await self._set_session(new_session) + return new_session diff --git a/replit_river/transport_options.py b/replit_river/transport_options.py new file mode 100644 index 0000000..ed020c0 --- /dev/null +++ b/replit_river/transport_options.py @@ -0,0 +1,47 @@ +import os + +from pydantic import BaseModel + +CROSIS_PREFIX_BYTES = b"\x00\x00" +PID2_PREFIX_BYTES = b"\xff\xff" +MAX_MESSAGE_BUFFER_SIZE = 1024 + + +class ConnectionRetryOptions(BaseModel): + base_interval_ms: int = 250 + max_jitter_ms: int = 200 + max_backoff_ms: float = 32_000 + attempt_budget_capacity: float = 5 + budget_restore_interval_ms: float = 200 + max_retry: int = 1_000 + + +# setup in replit web can be found at +# https://github.com/replit/repl-it-web/blob/main/pkg/pid2/src/entrypoints/protocol.ts#L13 +class TransportOptions(BaseModel): + session_disconnect_grace_ms: float = 5_000 + heartbeat_ms: float = 500 + heartbeats_until_dead: int = 2 + use_prefix_bytes: bool = False + close_session_check_interval_ms: float = 100 + connection_retry_options: ConnectionRetryOptions = ConnectionRetryOptions() + buffer_size: int = 1_000 + + def get_prefix_bytes(self) -> bytes: + return PID2_PREFIX_BYTES if self.use_prefix_bytes else b"" + + def websocket_disconnect_grace_ms(self) -> float: + return self.heartbeat_ms * self.heartbeats_until_dead + + @classmethod + def create_from_env(cls) -> "TransportOptions": + session_disconnect_grace_ms = float( + os.getenv("SESSION_DISCONNECT_GRACE_MS", 5_000) + ) + heartbeat_ms = float(os.getenv("HEARTBEAT_MS", 2000)) + heartbeats_to_dead = int(os.getenv("HEARTBEATS_UNTIL_DEAD", 2)) + return TransportOptions( + session_disconnect_grace_ms=session_disconnect_grace_ms, + heartbeat_ms=heartbeat_ms, + heartbeats_until_dead=heartbeats_to_dead, + ) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..c658d6f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,149 @@ +import logging +from typing import Any, AsyncGenerator + +import nanoid # type: ignore +import pytest +from websockets.server import serve + +from replit_river.client import Client +from replit_river.error_schema import RiverError +from replit_river.rpc import ( + GrpcContext, + TransportMessage, + rpc_method_handler, + stream_method_handler, + subscription_method_handler, + upload_method_handler, +) +from replit_river.server import Server +from replit_river.transport_options import TransportOptions + + +def transport_message( + seq: int = 0, + ack: int = 0, + streamId: str = "test_stream", + from_: str = "client", + to: str = "server", + control_flag: int = 0, + payload: Any = {}, +) -> TransportMessage: + return TransportMessage( + id=str(nanoid.generate()), + from_=from_, + to=to, + streamId=streamId, + seq=seq, + ack=ack, + payload=payload, + controlFlags=control_flag, + ) + + +def serialize_request(request: str) -> dict: + return {"data": request} + + +def deserialize_request(request: dict) -> str: + return request["data"] or "" + + +def serialize_response(response: str) -> dict: + return {"data": response} + + +def deserialize_response(response: dict) -> str: + return response["data"] or "" + + +def deserialize_error(response: dict) -> RiverError: + return RiverError.model_validate(response) + + +# RPC method handlers for testing +async def rpc_handler(request: str, context: GrpcContext) -> str: + return f"Hello, {request}!" + + +async def subscription_handler( + request: str, context: GrpcContext +) -> AsyncGenerator[str, None]: + for i in range(5): + yield f"Subscription message {i} for {request}" + + +async def upload_handler( + request: AsyncGenerator[str, None], context: GrpcContext +) -> str: + uploaded_data = [] + async for data in request: + uploaded_data.append(data) + return f"Uploaded: {', '.join(uploaded_data)}" + + +async def stream_handler( + request: AsyncGenerator[str, None], context: GrpcContext +) -> AsyncGenerator[str, None]: + async for data in request: + yield f"Stream response for {data}" + + +@pytest.fixture +def transport_options() -> TransportOptions: + return TransportOptions() + + +@pytest.fixture +def server(transport_options: TransportOptions) -> Server: + server = Server(server_id="test_server", transport_options=transport_options) + server.add_rpc_handlers( + { + ("test_service", "rpc_method"): ( + "rpc", + rpc_method_handler( + rpc_handler, deserialize_request, serialize_response + ), + ), + ("test_service", "subscription_method"): ( + "subscription", + subscription_method_handler( + subscription_handler, deserialize_request, serialize_response + ), + ), + ("test_service", "upload_method"): ( + "upload", + upload_method_handler( + upload_handler, deserialize_request, serialize_response + ), + ), + ("test_service", "stream_method"): ( + "stream", + stream_method_handler( + stream_handler, deserialize_request, serialize_response + ), + ), + } + ) + return server + + +@pytest.fixture +async def client( + server: Server, transport_options: TransportOptions +) -> AsyncGenerator[Client, None]: + try: + async with serve(server.serve, "localhost", 8765): + client = Client( + "ws://localhost:8765", + client_id="test_client", + server_id="test_server", + transport_options=transport_options, + ) + try: + yield client + finally: + logging.debug("Start closing test client : %s", "test_client") + await client.close() + finally: + logging.debug("Start closing test server") + await server.close() diff --git a/tests/test_communication.py b/tests/test_communication.py new file mode 100644 index 0000000..eeeedd2 --- /dev/null +++ b/tests/test_communication.py @@ -0,0 +1,131 @@ +import asyncio +from typing import AsyncGenerator + +import pytest + +from replit_river.client import Client +from replit_river.error_schema import RiverError +from tests.conftest import deserialize_error, deserialize_response, serialize_request + + +@pytest.mark.asyncio +async def test_rpc_method(client: Client) -> None: + response = await client.send_rpc( + "test_service", + "rpc_method", + "Alice", + serialize_request, + deserialize_response, + deserialize_error, + ) # type: ignore + assert response == "Hello, Alice!" + + +@pytest.mark.asyncio +async def test_upload_method(client: Client) -> None: + async def upload_data() -> AsyncGenerator[str, None]: + yield "Data 1" + yield "Data 2" + yield "Data 3" + + response = await client.send_upload( + "test_service", + "upload_method", + "Initial Data", + upload_data(), + serialize_request, + serialize_request, + deserialize_response, + deserialize_response, + ) # type: ignore + assert response == "Uploaded: Initial Data, Data 1, Data 2, Data 3" + + +@pytest.mark.asyncio +async def test_subscription_method(client: Client) -> None: + async for response in await client.send_subscription( + "test_service", + "subscription_method", + "Bob", + serialize_request, + deserialize_response, + deserialize_error, + ): + assert "Subscription message" in response + + +@pytest.mark.asyncio +async def test_stream_method(client: Client) -> None: + async def stream_data() -> AsyncGenerator[str, None]: + yield "Stream 1" + yield "Stream 2" + yield "Stream 3" + + responses = [] + async for response in await client.send_stream( + "test_service", + "stream_method", + "Initial Stream Data", + stream_data(), + serialize_request, + serialize_request, + deserialize_response, + deserialize_error, + ): + responses.append(response) + + assert responses == [ + "Stream response for Initial Stream Data", + "Stream response for Stream 1", + "Stream response for Stream 2", + "Stream response for Stream 3", + ] + + +@pytest.mark.asyncio +async def test_multiplexing(client: Client) -> None: + async def upload_data() -> AsyncGenerator[str, None]: + yield "Upload Data 1" + yield "Upload Data 2" + + async def stream_data() -> AsyncGenerator[str, None]: + yield "Stream Data 1" + yield "Stream Data 2" + + upload_task = asyncio.create_task( + client.send_upload( + "test_service", + "upload_method", + "Initial Upload Data", + upload_data(), + serialize_request, + serialize_request, + deserialize_response, + deserialize_error, + ) + ) + stream_task = await client.send_stream( + "test_service", + "stream_method", + "Initial Stream Data", + stream_data(), + serialize_request, + serialize_request, + deserialize_response, + deserialize_error, + ) + + upload_response: str = await upload_task + assert ( + upload_response == "Uploaded: Initial Upload Data, Upload Data 1, Upload Data 2" + ) + + stream_responses: list[str | RiverError] = [] + async for response in stream_task: + stream_responses.append(response) + + assert stream_responses == [ + "Stream response for Initial Stream Data", + "Stream response for Stream Data 1", + "Stream response for Stream Data 2", + ] diff --git a/tests/test_rate_limiter.py b/tests/test_rate_limiter.py new file mode 100644 index 0000000..9bb1b9e --- /dev/null +++ b/tests/test_rate_limiter.py @@ -0,0 +1,69 @@ +import asyncio + +import pytest + +from replit_river.rate_limiter import LeakyBucketRateLimit +from replit_river.transport_options import ConnectionRetryOptions + + +@pytest.fixture +def options() -> ConnectionRetryOptions: + return ConnectionRetryOptions( + base_interval_ms=100, + max_jitter_ms=50, + max_backoff_ms=1000, + attempt_budget_capacity=5, + budget_restore_interval_ms=200, + ) + + +@pytest.fixture +def rate_limiter(options: ConnectionRetryOptions) -> LeakyBucketRateLimit: + return LeakyBucketRateLimit(options) + + +@pytest.mark.asyncio +async def test_initial_budget(rate_limiter: LeakyBucketRateLimit) -> None: + user: str = "user1" + assert rate_limiter.has_budget(user), "User should initially have full budget" + + +@pytest.mark.asyncio +async def test_consume_budget(rate_limiter: LeakyBucketRateLimit) -> None: + user: str = "user2" + rate_limiter.consume_budget(user) + assert ( + rate_limiter.get_budget_consumed(user) == 1 + ), "Budget consumed should be incremented" + + +@pytest.mark.asyncio +async def test_restore_budget(rate_limiter: LeakyBucketRateLimit) -> None: + user: str = "user3" + rate_limiter.consume_budget(user) + rate_limiter.start_restoring_budget(user) + await asyncio.sleep(0.3) # Wait more than budget restore interval + assert ( + rate_limiter.get_budget_consumed(user) == 0 + ), "Budget should be restored after interval" + + +@pytest.mark.asyncio +async def test_concurrent_access(rate_limiter: LeakyBucketRateLimit) -> None: + user: str = "user4" + + async def consume_budget() -> None: + for _ in range(5): + rate_limiter.consume_budget(user) + await asyncio.sleep(0.01) # simulate some delay + + await asyncio.gather(consume_budget(), consume_budget()) + assert ( + rate_limiter.get_budget_consumed(user) == 10 + ), "Concurrent access should be handled correctly" + + +@pytest.mark.asyncio +def test_close(rate_limiter: LeakyBucketRateLimit) -> None: + rate_limiter.close() + assert not rate_limiter.tasks, "All tasks should be cancelled upon close" diff --git a/tests/test_seq_manager.py b/tests/test_seq_manager.py new file mode 100644 index 0000000..e4b562f --- /dev/null +++ b/tests/test_seq_manager.py @@ -0,0 +1,66 @@ +import asyncio + +import pytest + +from replit_river.seq_manager import ( + IgnoreMessageException, + InvalidMessageException, + SeqManager, +) +from tests.conftest import transport_message + + +@pytest.mark.asyncio +async def test_initial_sequence_and_ack_numbers() -> None: + manager = SeqManager() + assert await manager.get_seq() == 0, "Initial sequence number should be 0" + assert await manager.get_ack() == 0, "Initial acknowledgment number should be 0" + + +@pytest.mark.asyncio +async def test_sequence_number_increment() -> None: + manager = SeqManager() + initial_seq = await manager.get_seq_and_increment() + assert initial_seq == 0, "Sequence number should start at 0" + new_seq = await manager.get_seq() + assert new_seq == 1, "Sequence number should increment to 1" + + +@pytest.mark.asyncio +async def test_message_reception() -> None: + manager = SeqManager() + msg = transport_message(seq=0, ack=0, from_="client") + await manager.check_seq_and_update( + msg + ) # No error should be raised for the correct sequence + assert await manager.get_ack() == 1, "Acknowledgment should be set to 1" + + # Test duplicate message + with pytest.raises(IgnoreMessageException): + await manager.check_seq_and_update(msg) + + # Test out of order message + msg.seq = 2 + with pytest.raises(InvalidMessageException): + await manager.check_seq_and_update(msg) + + +@pytest.mark.asyncio +async def test_acknowledgment_setting() -> None: + manager = SeqManager() + msg = transport_message(seq=0, ack=0, from_="client") + await manager.check_seq_and_update(msg) + assert await manager.get_ack() == 1, "Acknowledgment number should be updated" + + +@pytest.mark.asyncio +async def test_concurrent_access_to_sequence() -> None: + manager = SeqManager() + tasks = [manager.get_seq_and_increment() for _ in range(10)] + results = await asyncio.gather(*tasks) + assert ( + len(set(results)) == 10 + ), "Each increment call should return a unique sequence number" + assert ( + await manager.get_seq() == 10 + ), "Final sequence number should be 10 after 10 increments" diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/unit/test_communication.py b/tests/unit/test_communication.py deleted file mode 100644 index 54c1ea6..0000000 --- a/tests/unit/test_communication.py +++ /dev/null @@ -1,276 +0,0 @@ -import asyncio -from typing import AsyncGenerator - -import pytest -import websockets -from websockets.server import serve - -from replit_river.client import Client -from replit_river.error_schema import RiverError, RiverException -from replit_river.rpc import ( - GrpcContext, - rpc_method_handler, - stream_method_handler, - subscription_method_handler, - upload_method_handler, -) -from replit_river.server import Server - - -# Helper functions for testing -def serialize_request(request: str) -> dict: - return {"data": request} - - -def deserialize_request(request: dict) -> str: - return request["data"] or "" - - -def serialize_response(response: str) -> dict: - return {"data": response} - - -def deserialize_response(response: dict) -> str: - return response["data"] or "" - - -def deserialize_error(response: dict) -> RiverError: - return RiverError.model_validate(response) - - -# RPC method handlers for testing -async def rpc_handler(request: str, context: GrpcContext) -> str: - return f"Hello, {request}!" - - -async def subscription_handler( - request: str, context: GrpcContext -) -> AsyncGenerator[str, None]: - for i in range(5): - yield f"Subscription message {i} for {request}" - - -async def upload_handler( - request: AsyncGenerator[str, None], context: GrpcContext -) -> str: - uploaded_data = [] - async for data in request: - uploaded_data.append(data) - return f"Uploaded: {', '.join(uploaded_data)}" - - -async def stream_handler( - request: AsyncGenerator[str, None], context: GrpcContext -) -> AsyncGenerator[str, None]: - async for data in request: - yield f"Stream response for {data}" - - -@pytest.fixture -def server() -> Server: - server = Server(server_id="test_server") - server.add_rpc_handlers( - { - ("test_service", "rpc_method"): ( - "rpc", - rpc_method_handler( - rpc_handler, deserialize_request, serialize_response - ), - ), - ("test_service", "subscription_method"): ( - "subscription", - subscription_method_handler( - subscription_handler, deserialize_request, serialize_response - ), - ), - ("test_service", "upload_method"): ( - "upload", - upload_method_handler( - upload_handler, deserialize_request, serialize_response - ), - ), - ("test_service", "stream_method"): ( - "stream", - stream_method_handler( - stream_handler, deserialize_request, serialize_response - ), - ), - } - ) - return server - - -@pytest.fixture -async def client(server: Server) -> AsyncGenerator[Client, None]: - async with serve(server.serve, "localhost", 8765): - async with websockets.connect("ws://localhost:8765") as websocket: - client = Client(websocket, use_prefix_bytes=False) - try: - yield client - finally: - await websocket.close() - - -@pytest.mark.asyncio -async def test_rpc_method(client: Client) -> None: - response = await client.send_rpc( - "test_service", - "rpc_method", - "Alice", - serialize_request, - deserialize_response, - deserialize_error, - ) # type: ignore - assert response == "Hello, Alice!" - - -@pytest.mark.asyncio -async def test_upload_method(client: Client) -> None: - async def upload_data() -> AsyncGenerator[str, None]: - yield "Data 1" - yield "Data 2" - yield "Data 3" - - response = await client.send_upload( - "test_service", - "upload_method", - "Initial Data", - upload_data(), - serialize_request, - serialize_request, - deserialize_response, - deserialize_response, - ) # type: ignore - assert response == "Uploaded: Initial Data, Data 1, Data 2, Data 3" - - -@pytest.mark.asyncio -async def test_subscription_method(client: Client) -> None: - async for response in client.send_subscription( - "test_service", - "subscription_method", - "Bob", - serialize_request, - deserialize_response, - deserialize_error, - ): - assert "Subscription message" in response - - -@pytest.mark.asyncio -async def test_stream_method(client: Client) -> None: - async def stream_data() -> AsyncGenerator[str, None]: - yield "Stream 1" - yield "Stream 2" - yield "Stream 3" - - responses = [] - async for response in client.send_stream( - "test_service", - "stream_method", - "Initial Stream Data", - stream_data(), - serialize_request, - serialize_request, - deserialize_response, - deserialize_error, - ): - responses.append(response) - - assert responses == [ - "Stream response for Initial Stream Data", - "Stream response for Stream 1", - "Stream response for Stream 2", - "Stream response for Stream 3", - ] - - -@pytest.mark.asyncio -async def test_multiplexing(client: Client) -> None: - async def upload_data() -> AsyncGenerator[str, None]: - yield "Upload Data 1" - yield "Upload Data 2" - - async def stream_data() -> AsyncGenerator[str, None]: - yield "Stream Data 1" - yield "Stream Data 2" - - upload_task = asyncio.create_task( - client.send_upload( - "test_service", - "upload_method", - "Initial Upload Data", - upload_data(), - serialize_request, - serialize_request, - deserialize_response, - deserialize_error, - ) - ) - stream_task = client.send_stream( - "test_service", - "stream_method", - "Initial Stream Data", - stream_data(), - serialize_request, - serialize_request, - deserialize_response, - deserialize_error, - ) - - upload_response: str = await upload_task - assert ( - upload_response == "Uploaded: Initial Upload Data, Upload Data 1, Upload Data 2" - ) - - stream_responses: list[str | RiverError] = [] - async for response in stream_task: - stream_responses.append(response) - - assert stream_responses == [ - "Stream response for Initial Stream Data", - "Stream response for Stream Data 1", - "Stream response for Stream Data 2", - ] - - -@pytest.mark.asyncio -async def test_close_old_websocket_rpc(server: Server) -> None: - async with serve(server.serve, "localhost", 8765): - async with serve(server.serve, "localhost", 8766): - async with websockets.connect("ws://localhost:8765") as websocket1: - async with websockets.connect("ws://localhost:8766") as websocket2: - websockets_list = [websocket1, websocket2] - clients: list[Client] = [] - num_clients = 2 - - async def create_clients() -> None: - for i in range(num_clients): - client = Client( - websockets_list[i], - use_prefix_bytes=False, - client_id=f"client-{i}", - ) - clients.append(client) - # We set it to be the same instance for testing - client._from = "test_user" - - await create_clients() - with pytest.raises(RiverException): - await clients[0].send_rpc( - "test_service", - "rpc_method", - clients[0]._instance_id, - serialize_request, - deserialize_response, - deserialize_error, - ) - response = await clients[1].send_rpc( - "test_service", - "rpc_method", - clients[1]._instance_id, - serialize_request, - deserialize_response, - deserialize_error, - ) - assert response == f"Hello, client-{1}!"