Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenthebuilder committed Apr 24, 2024
1 parent 1e37254 commit a886d11
Show file tree
Hide file tree
Showing 13 changed files with 52 additions and 49 deletions.
27 changes: 10 additions & 17 deletions replit_river/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from replit_river.error_schema import ERROR_CODE_STREAM_CLOSED, RiverException
from replit_river.session import Session
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE

from .rpc import (
STREAM_CLOSED_BIT,
Expand Down Expand Up @@ -48,12 +49,12 @@ async def send_rpc(
try:
try:
response = await output.get()
except (RuntimeError, ChannelClosed) as e:
# if the stream is closed before we get a response, we will get a
# RuntimeError: RuntimeError: Event loop is closed
except ChannelClosed as e:
raise RiverException(
ERROR_CODE_STREAM_CLOSED, f"Stream closed before response {e}"
)
except RuntimeError as e:
raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e))
if not response.get("ok", False):
try:
error = error_deserializer(response["payload"])
Expand All @@ -64,8 +65,6 @@ async def send_rpc(
except RiverException as e:
raise e
except Exception as e:
# Log the error and return an appropriate error response
logging.exception("Error during RPC communication")
raise e

async def send_upload(
Expand Down Expand Up @@ -100,7 +99,7 @@ async def send_upload(
)
first_message = False
# If this request is not closed and the session is killed, we should
# throws exception here
# throw exception here
async for item in request:
control_flags = 0
if first_message:
Expand All @@ -123,12 +122,12 @@ async def send_upload(
try:
try:
response = await output.get()
except (RuntimeError, ChannelClosed):
# if the stream is closed before we get a response, we will get a
# RuntimeError: RuntimeError: Event loop is closed
except ChannelClosed:
raise RiverException(
ERROR_CODE_STREAM_CLOSED, "Stream closed before response"
)
except RuntimeError as e:
raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e))
if not response.get("ok", False):
try:
error = error_deserializer(response["payload"])
Expand All @@ -140,8 +139,6 @@ async def send_upload(
except RiverException as e:
raise e
except Exception as e:
# Log the error and return an appropriate error response
logging.exception("Error during upload communication")
raise e

async def send_subscription(
Expand All @@ -158,7 +155,7 @@ async def send_subscription(
Expects the input and output be messages that will be msgpacked.
"""
stream_id = nanoid.generate()
output: Channel[Any] = Channel(1024)
output: Channel[Any] = Channel(MAX_MESSAGE_BUFFER_SIZE)
self._streams[stream_id] = output
await self.send_message(
ws=self._ws,
Expand Down Expand Up @@ -188,8 +185,6 @@ async def send_subscription(
ERROR_CODE_STREAM_CLOSED, "Stream closed before response"
)
except Exception as e:
# Log the error and yield an appropriate error response
logging.exception(f"Error during subscription communication : {item}")
raise e

async def send_stream(
Expand All @@ -209,7 +204,7 @@ async def send_stream(
"""

stream_id = nanoid.generate()
output: Channel[Any] = Channel(1024)
output: Channel[Any] = Channel(MAX_MESSAGE_BUFFER_SIZE)
self._streams[stream_id] = output
try:
if init and init_serializer:
Expand Down Expand Up @@ -273,8 +268,6 @@ async def _encode_stream() -> None:
ERROR_CODE_STREAM_CLOSED, "Stream closed before response"
)
except Exception as e:
# Log the error and yield an appropriate error response
logging.exception("Error during stream communication")
raise e

async def send_close_stream(
Expand Down
2 changes: 1 addition & 1 deletion replit_river/client_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ async def _send_handshake_request(
)
stream_id = self.generate_nanoid()

def websocket_closed_callback():
def websocket_closed_callback() -> None:
raise RiverException(ERROR_SESSION, "Session closed while sending")

try:
Expand Down
6 changes: 3 additions & 3 deletions replit_river/codegen/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def message_type(
"required": [],
}
# Non-oneof fields.
oneofs: DefaultDict[int, List[descriptor_pb2.FieldDescriptorProto]] = (
collections.defaultdict(list)
)
oneofs: DefaultDict[
int, List[descriptor_pb2.FieldDescriptorProto]
] = collections.defaultdict(list)
for field in m.field:
if field.HasField("oneof_index"):
oneofs[field.oneof_index].append(field)
Expand Down
12 changes: 6 additions & 6 deletions replit_river/codegen/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def message_decoder(
" return m",
]
# Non-oneof fields.
oneofs: DefaultDict[int, List[descriptor_pb2.FieldDescriptorProto]] = (
collections.defaultdict(list)
)
oneofs: DefaultDict[
int, List[descriptor_pb2.FieldDescriptorProto]
] = collections.defaultdict(list)
for field in m.field:
if field.HasField("oneof_index"):
oneofs[field.oneof_index].append(field)
Expand Down Expand Up @@ -157,9 +157,9 @@ def message_encoder(
" d: Dict[str, Any] = {}",
]
# Non-oneof fields.
oneofs: DefaultDict[int, List[descriptor_pb2.FieldDescriptorProto]] = (
collections.defaultdict(list)
)
oneofs: DefaultDict[
int, List[descriptor_pb2.FieldDescriptorProto]
] = collections.defaultdict(list)
for field in m.field:
if field.HasField("oneof_index"):
oneofs[field.oneof_index].append(field)
Expand Down
3 changes: 2 additions & 1 deletion replit_river/message_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from typing import Optional

from replit_river.rpc import TransportMessage
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE


class MessageBuffer:
"""A buffer to store messages and support current updates"""

def __init__(self, max_num_messages: int = 1000):
def __init__(self, max_num_messages: int = MAX_MESSAGE_BUFFER_SIZE):
self.max_size = max_num_messages
self.buffer: list[TransportMessage] = []
self._lock = asyncio.Lock()
Expand Down
2 changes: 1 addition & 1 deletion replit_river/rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, options: ConnectionRetryOptions):
self.budget_consumed: Dict[str, int] = {}
self.tasks: Dict[str, asyncio.Task] = {}

def get_backoff_ms(self, user: str) -> int:
def get_backoff_ms(self, user: str) -> float:
"""Calculate the backoff time in milliseconds for a user.
Args:
Expand Down
6 changes: 3 additions & 3 deletions replit_river/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
RiverException,
)
from replit_river.task_manager import BackgroundTaskManager
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE

InitType = TypeVar("InitType")
RequestType = TypeVar("RequestType")
Expand Down Expand Up @@ -280,7 +281,7 @@ async def wrapped(
task_manager = BackgroundTaskManager()
try:
context = GrpcContext(peer)
request: Channel[RequestType] = Channel(1024)
request: Channel[RequestType] = Channel(MAX_MESSAGE_BUFFER_SIZE)

async def _convert_inputs() -> None:
try:
Expand Down Expand Up @@ -341,7 +342,6 @@ def stream_method_handler(
request_deserializer: Callable[[Any], RequestType],
response_serializer: Callable[[ResponseType], Any],
) -> GenericRpcHandler:

async def wrapped(
peer: str,
input: Channel[Any],
Expand All @@ -350,7 +350,7 @@ async def wrapped(
task_manager = BackgroundTaskManager()
try:
context = GrpcContext(peer)
request: Channel[RequestType] = Channel(1024)
request: Channel[RequestType] = Channel(MAX_MESSAGE_BUFFER_SIZE)

async def _convert_inputs() -> None:
try:
Expand Down
9 changes: 5 additions & 4 deletions replit_river/server_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,18 @@


class ServerTransport(Transport):

async def handshake_to_get_session(
self,
websocket: WebSocketServerProtocol,
) -> Session:
async for message in websocket:
try:
msg = parse_transport_msg(message, self._transport_options)
_, handshake_request, handshake_response = (
await self._establish_handshake(msg, websocket)
)
(
_,
handshake_request,
handshake_response,
) = await self._establish_handshake(msg, websocket)
except IgnoreMessageException:
continue
except InvalidMessageException:
Expand Down
15 changes: 11 additions & 4 deletions replit_river/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
SeqManager,
)
from replit_river.task_manager import BackgroundTaskManager
from replit_river.transport_options import TransportOptions
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions

from .rpc import (
ACK_BIT,
Expand Down Expand Up @@ -267,6 +267,8 @@ async def _heartbeat(
try:
await self.send_message(
str(nanoid.generate()),
# TODO: make this a message class
# https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42
{
"ack": 0,
},
Expand Down Expand Up @@ -380,9 +382,10 @@ async def _send_responses_from_output_stream(
) -> None:
"""Send serialized messages to the websockets."""
try:
# TODO: This blocking is not ideal, we should close this task when websocket
# is closed, and start another one when websocket reconnects.
ws = None
async for payload in output:
# TODO: what if the websocket changed during this?
ws = self._ws
while self._ws_state != WsState.OPEN:
await asyncio.sleep(
Expand Down Expand Up @@ -459,8 +462,12 @@ async def _open_stream_and_call_handler(
"stream",
)
# New channel pair.
input_stream: Channel[Any] = Channel(1024 if is_streaming_input else 1)
output_stream: Channel[Any] = Channel(1024 if is_streaming_output else 1)
input_stream: Channel[Any] = Channel(
MAX_MESSAGE_BUFFER_SIZE if is_streaming_input else 1
)
output_stream: Channel[Any] = Channel(
MAX_MESSAGE_BUFFER_SIZE if is_streaming_output else 1
)
try:
await input_stream.put(msg.payload)
except (RuntimeError, ChannelClosed) as e:
Expand Down
6 changes: 3 additions & 3 deletions replit_river/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ async def _close_all_sessions(self) -> None:
f"{len(sessions)}"
)
sessions_to_close = list(sessions)
for session in sessions_to_close:
await session.close(False)
tasks = [session.close(False) for session in sessions_to_close]
await asyncio.gather(*tasks)
logging.info(f"Transport closed {self._transport_id}")

async def _delete_session(self, session: Session) -> None:
Expand All @@ -61,7 +61,7 @@ async def _get_or_create_session_id(
self,
to_id: str,
advertised_session_id: str,
):
) -> str:
try:
async with self._session_lock:
if to_id not in self._sessions:
Expand Down
1 change: 1 addition & 0 deletions replit_river/transport_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

CROSIS_PREFIX_BYTES = b"\x00\x00"
PID2_PREFIX_BYTES = b"\xff\xff"
MAX_MESSAGE_BUFFER_SIZE = 1024


class ConnectionRetryOptions(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from typing import Any, AsyncGenerator

import nanoid
import nanoid # type: ignore
import pytest
from websockets.server import serve

Expand Down
10 changes: 5 additions & 5 deletions tests/test_seq_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@


@pytest.mark.asyncio
async def test_initial_sequence_and_ack_numbers():
async def test_initial_sequence_and_ack_numbers() -> None:
manager = SeqManager()
assert await manager.get_seq() == 0, "Initial sequence number should be 0"
assert await manager.get_ack() == 0, "Initial acknowledgment number should be 0"


@pytest.mark.asyncio
async def test_sequence_number_increment():
async def test_sequence_number_increment() -> None:
manager = SeqManager()
initial_seq = await manager.get_seq_and_increment()
assert initial_seq == 0, "Sequence number should start at 0"
Expand All @@ -27,7 +27,7 @@ async def test_sequence_number_increment():


@pytest.mark.asyncio
async def test_message_reception():
async def test_message_reception() -> None:
manager = SeqManager()
msg = transport_message(seq=0, ack=0, from_="client")
await manager.check_seq_and_update(
Expand All @@ -46,15 +46,15 @@ async def test_message_reception():


@pytest.mark.asyncio
async def test_acknowledgment_setting():
async def test_acknowledgment_setting() -> None:
manager = SeqManager()
msg = transport_message(seq=0, ack=0, from_="client")
await manager.check_seq_and_update(msg)
assert await manager.get_ack() == 1, "Acknowledgment number should be updated"


@pytest.mark.asyncio
async def test_concurrent_access_to_sequence():
async def test_concurrent_access_to_sequence() -> None:
manager = SeqManager()
tasks = [manager.get_seq_and_increment() for _ in range(10)]
results = await asyncio.gather(*tasks)
Expand Down

0 comments on commit a886d11

Please sign in to comment.