Skip to content

Commit

Permalink
pass all tests except for error notify and MismatchedClientInstanceDo…
Browse files Browse the repository at this point in the history
…esntGetResentStaleMessagesFromServer
  • Loading branch information
zhenthebuilder committed Apr 21, 2024
1 parent b96e2d6 commit 0f9ab48
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 83 deletions.
7 changes: 3 additions & 4 deletions replit_river/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from collections.abc import AsyncIterable, AsyncIterator
import logging
from typing import Any, Callable, Optional, Union
Expand Down Expand Up @@ -38,7 +39,8 @@ async def close(self) -> None:
logging.info(f"river client {self._client_id} closed")

async def _get_or_create_session(self) -> ClientSession:
return await self._transport._get_or_create_session_with_retry()
ret = await self._transport._get_or_create_session_with_retry()
return ret

async def send_rpc(
self,
Expand All @@ -50,9 +52,6 @@ async def send_rpc(
error_deserializer: Callable[[Any], ErrorType],
) -> ResponseType:
session = await self._get_or_create_session()
logging.debug(
f"## send_rpc : {session._state}, {session._ws_state}, {session._ws.id}"
)
return await session.send_rpc(
service_name,
procedure_name,
Expand Down
58 changes: 21 additions & 37 deletions replit_river/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@


class ClientSession(Session):

async def send_rpc(
self,
service_name: str,
Expand All @@ -36,20 +35,14 @@ async def send_rpc(
stream_id = nanoid.generate()
output: Channel[Any] = Channel(1)
self._streams[stream_id] = output
try:
await self.send_message(
ws=self._ws,
stream_id=stream_id,
control_flags=STREAM_OPEN_BIT | STREAM_CLOSED_BIT,
payload=request_serializer(request),
service_name=service_name,
procedure_name=procedure_name,
)
except FailedSendingMessageException:
raise RiverException(
ERROR_CODE_STREAM_CLOSED, "Stream closed before response"
)

await self.send_message(
ws=self._ws,
stream_id=stream_id,
control_flags=STREAM_OPEN_BIT | STREAM_CLOSED_BIT,
payload=request_serializer(request),
service_name=service_name,
procedure_name=procedure_name,
)
# Handle potential errors during communication
try:
try:
Expand All @@ -58,7 +51,7 @@ async def send_rpc(
# if the stream is closed before we get a response, we will get a
# RuntimeError: RuntimeError: Event loop is closed
raise RiverException(
ERROR_CODE_STREAM_CLOSED, "Stream closed before response"
ERROR_CODE_STREAM_CLOSED, f"Stream closed before response {e}"
)
if not response.get("ok", False):
try:
Expand Down Expand Up @@ -119,10 +112,8 @@ async def send_upload(
control_flags=control_flags,
payload=request_serializer(item),
)
except FailedSendingMessageException:
raise RiverException(
ERROR_CODE_STREAM_CLOSED, "Stream closed before response"
)
except Exception as e:
raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e))
await self.send_close_stream(service_name, procedure_name, stream_id)

# Handle potential errors during communication
Expand Down Expand Up @@ -167,19 +158,14 @@ async def send_subscription(
stream_id = nanoid.generate()
output: Channel[Any] = Channel(1024)
self._streams[stream_id] = output
try:
await self.send_message(
ws=self._ws,
service_name=service_name,
procedure_name=procedure_name,
stream_id=stream_id,
control_flags=STREAM_OPEN_BIT,
payload=request_serializer(request),
)
except FailedSendingMessageException:
raise RiverException(
ERROR_CODE_STREAM_CLOSED, "Stream closed before response"
)
await self.send_message(
ws=self._ws,
service_name=service_name,
procedure_name=procedure_name,
stream_id=stream_id,
control_flags=STREAM_OPEN_BIT,
payload=request_serializer(request),
)

# Handle potential errors during communication
try:
Expand Down Expand Up @@ -246,10 +232,8 @@ async def send_stream(
payload=request_serializer(first),
)

except FailedSendingMessageException:
raise RiverException(
ERROR_CODE_STREAM_CLOSED, "Stream closed before response"
)
except Exception as e:
raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e))

# Create the encoder task
async def _encode_stream() -> None:
Expand Down
33 changes: 20 additions & 13 deletions replit_river/client_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
send_transport_message,
)
from replit_river.rpc import (
ACK_BIT,
ControlMessageHandshakeRequest,
ControlMessageHandshakeResponse,
TransportMessage,
Expand Down Expand Up @@ -58,6 +59,7 @@ def __init__(
self._rate_limiter = LeakyBucketRateLimit(
transport_options.connection_retry_options
)
self._ws: Optional[WebSocketCommonProtocol] = None

async def _get_existing_session(self) -> Optional[ClientSession]:
if not self._sessions:
Expand Down Expand Up @@ -98,9 +100,11 @@ async def _get_or_create_session(self) -> ClientSession:
existing_session = await self._get_existing_session()
if not existing_session:
logging.debug("Client no existing session, creating new one")
self._rate_limiter.consume_budget(self._client_id)
return await self._create_session()
if not existing_session.is_websocket_open():
logging.debug("Client session exists but websocket closed, reconnect one")
self._rate_limiter.consume_budget(self._client_id)
ws = await websockets.connect(self._websocket_uri)
logging.debug(f"new ws : {ws.id} {ws.state}")
self._ws = ws
Expand All @@ -115,13 +119,13 @@ async def _get_or_create_session_with_retry(self) -> ClientSession:
rate_limit = self._rate_limiter
user_id = self._client_id
for i in range(self._transport_options.connection_retry_options.max_retry):
logging.info(f"Client retry build sessions {i} times")
if i > 0:
logging.info(f"Client retry build sessions {i} times")
if rate_limit.has_budget(user_id):
rate_limit.consume_budget(user_id)
backoff_time = rate_limit.get_backoff_ms(user_id)
try:
return await self._get_or_create_session()
except RiverException as e:
except Exception as e:
logging.error(
f"Error creating session: {e}, start backoff {backoff_time} ms"
)
Expand All @@ -139,13 +143,14 @@ async def _get_or_create_session_with_retry(self) -> ClientSession:
"Failed to create session after retrying max number of times",
)

async def _on_websocket_closed(self, session: Optional[Session]) -> None:
if session and session.is_session_open():
# TODO: do the retry correctly here
logging.error("Client session websocket closed, retrying")
async def _on_websocket_closed(
self, session: Optional[Session], should_retry: bool
) -> None:
if not should_retry:
logging.error("Client websocket closed, not retrying")
return
self._ws = await websockets.connect(self._websocket_uri)
await session.replace_with_new_websocket(self._ws)
if session and session.is_session_open():
await self._get_or_create_session_with_retry()

async def _send_handshake_request(
self,
Expand Down Expand Up @@ -176,13 +181,16 @@ async def _send_handshake_request(
),
ws=websocket,
prefix_bytes=self._transport_options.get_prefix_bytes(),
websocket_closed_callback=lambda: self._on_websocket_closed(None),
websocket_closed_callback=lambda: self._on_websocket_closed(
None, False
),
)
except ConnectionClosed:
raise RiverException(ERROR_HANDSHAKE, "Hand shake failed")

async def close_session_callback(self, session: Session) -> None:
logging.info(f"Client session {session._instance_id} closed")
await self._delete_session(session)

async def _send_handshake(
self,
Expand Down Expand Up @@ -213,10 +221,8 @@ async def _send_handshake(
logging.debug(
"Connection closed during waiting for handshake response : {e}"
)
await self._on_websocket_closed(None)
await self._on_websocket_closed(None, False)
raise RiverException(ERROR_HANDSHAKE, "Hand shake failed")

logging.error(f"Got something")
try:
first_message = parse_transport_msg(data, self._transport_options)
except IgnoreTransportMessageException as e:
Expand Down Expand Up @@ -255,6 +261,7 @@ async def create_client_session(
instance_id=instance_id,
websocket=websocket,
)
logging.error("##### handshake success")
return ClientSession(
transport_id=transport_id,
to_id=to_id,
Expand Down
1 change: 1 addition & 0 deletions replit_river/rate_limiter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import random
from threading import Timer
from replit_river.transport_options import ConnectionRetryOptions
Expand Down
12 changes: 9 additions & 3 deletions replit_river/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@
)

import grpc
from aiochannel import Channel
from aiochannel import Channel, ChannelClosed
from pydantic import BaseModel, ConfigDict, Field

from replit_river.error_schema import RiverError
from replit_river.error_schema import (
ERROR_CODE_STREAM_CLOSED,
RiverError,
RiverException,
)

InitType = TypeVar("InitType")
RequestType = TypeVar("RequestType")
Expand Down Expand Up @@ -289,8 +293,10 @@ async def _convert_outputs() -> None:
await output.put(
get_response_or_error_payload(response, response_serializer)
)
except ChannelClosed:
raise RiverException(ERROR_CODE_STREAM_CLOSED, "Channel closed")
except Exception as e:
logging.exception("Uncaught exception in river server upload")
logging.error("Uncaught exception in river server upload")
await output.put(
{
"ok": False,
Expand Down
50 changes: 25 additions & 25 deletions replit_river/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
is_server: bool,
handlers: Dict[Tuple[str, str], Tuple[str, GenericRpcHandler]],
close_websocket_callback: Optional[
Callable[["Session"], Coroutine[Any, Any, None]]
Callable[["Session", bool], Coroutine[Any, Any, None]]
] = None,
) -> None:
self._transport_id = transport_id
Expand Down Expand Up @@ -176,16 +176,14 @@ async def replace_with_new_websocket(
async with self._ws_lock:
old_ws = self._ws
self._ws_state = WsState.CLOSING
logging.info("replacing with new websocket")
if new_ws.id != old_ws.id:
self.reset_session_close_countdown()
await self.close_websocket(old_ws)
logging.debug("Old websocket closed")
if new_ws.id != old_ws.id:
self.reset_session_close_countdown()
await self.close_websocket(old_ws, should_retry=False)
async with self._ws_lock:
self._ws = new_ws
self._ws_state = WsState.OPEN
await self.start_serve_messages()
await self._send_buffered_messages(new_ws)
logging.debug("Websocket replace success")
await self.start_serve_messages()
await self._send_buffered_messages(new_ws)

async def _get_current_time(self) -> float:
return asyncio.get_event_loop().time()
Expand All @@ -206,7 +204,6 @@ async def begin_close_session_countdown(self) -> None:
def reset_session_close_countdown(self) -> None:
self.heartbeat_misses = 0
self._close_session_after_time_secs = None
logging.info(f"Countdown reset for session to {self._transport_id}")

async def _check_to_close_session(self) -> None:
while True:
Expand All @@ -216,9 +213,6 @@ async def _check_to_close_session(self) -> None:
if not self._close_session_after_time_secs:
continue
current_time = await self._get_current_time()
# logging.error(
# f":### checking to close session {current_time - self._close_session_after_time_secs:.2f}"
# )
if current_time > self._close_session_after_time_secs:
logging.info(
"Grace period ended for :" f" {self._transport_id}, closing session"
Expand Down Expand Up @@ -257,7 +251,9 @@ async def _heartbeat(
):
logging.debug("closing websocket because of heartbeat misses")
await self.begin_close_session_countdown()
await self.close_websocket(self._ws)
await self.close_websocket(
self._ws, should_retry=not self._is_server
)
continue
except FailedSendingMessageException:
# this is expected during websocket closed period
Expand Down Expand Up @@ -326,9 +322,13 @@ async def send_message(
prefix_bytes=self._transport_options.get_prefix_bytes(),
)
except ConnectionClosed as e:
raise FailedSendingMessageException(e)
logging.error(
f"Connection closed while sending message : {e}, waiting for retry from buffer"
)
except FailedSendingMessageException as e:
raise e
logging.error(
f"Failed sending message : {e}, waiting for retry from buffer"
)
finally:
try:
await self._buffer.put(msg)
Expand Down Expand Up @@ -369,14 +369,14 @@ async def send_responses(
except Exception as e:
logging.error(f"Unknown error while river sending responses back : {e}")

async def close_websocket(self, ws: WebSocketCommonProtocol) -> None:
async def close_websocket(
self, ws: WebSocketCommonProtocol, should_retry: bool
) -> None:
logging.debug(f"{self._transport_id} is closing websocket {ws.id} {ws.state}")
async with self._ws_lock:
if self._ws_state != WsState.OPEN:
return
self._ws_state = WsState.CLOSING
if self._close_websocket_callback:
await self._close_websocket_callback(self)
logging.error(f"closing websocket {ws.id} state: {ws.state}")
if ws:
logging.info(
Expand All @@ -390,6 +390,8 @@ async def close_websocket(self, ws: WebSocketCommonProtocol) -> None:
lambda _: logging.debug(f"old websocket closed, {ws.id}")
)
self._ws_state = WsState.CLOSED
if self._close_websocket_callback:
await self._close_websocket_callback(self, should_retry)

async def _open_stream_and_call_handler(
self,
Expand Down Expand Up @@ -472,15 +474,13 @@ async def close(self, is_unexpected_close: bool) -> None:
return
self._state = SessionState.CLOSING
self.reset_session_close_countdown()
await self.close_websocket(self._ws)
await self.close_websocket(self._ws, should_retry=not self._is_server)
# Clear the session in transports
await self._close_session_callback(self)
await self._task_manager.cancel_all_tasks()
for previous_input in self._streams.values():
previous_input.close()
logging.error("###" * 20 + "closing streams")
for stream in self._streams.values():
stream.close()
async with self._stream_lock:
self._streams.clear()
self._state = SessionState.CLOSED
logging.info(
f"################ {self._transport_id} closed session to {self._to_id}"
)
2 changes: 1 addition & 1 deletion replit_river/transport_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class ConnectionRetryOptions(BaseModel):

class TransportOptions(BaseModel):
session_disconnect_grace_ms: float = 5_000
heartbeat_ms: float = 2000
heartbeat_ms: float = 500
heartbeats_until_dead: int = 2
use_prefix_bytes: bool = False
close_session_check_interval_ms: float = 100
Expand Down

0 comments on commit 0f9ab48

Please sign in to comment.