Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenthebuilder committed May 3, 2024
1 parent c609ca0 commit 75813f0
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 51 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name="replit-river"
version="0.1.16.dev3"
version="0.1.16.dev4"
description="Replit river toolkit for Python"
authors = ["Replit <eng@replit.com>"]
license = "LICENSE"
Expand Down
1 change: 0 additions & 1 deletion replit_river/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ async def send_rpc(
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
) -> ResponseType:
logging.error(f"#send_rpc : {request}")
session = await self._transport._get_or_create_session()
return await session.send_rpc(
service_name,
Expand Down
39 changes: 5 additions & 34 deletions replit_river/client_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
self._rate_limiter = LeakyBucketRateLimit(
transport_options.connection_retry_options
)
# We want to make sure there's only one session creation at a time
self._create_session_lock = asyncio.Lock()

async def close(self) -> None:
Expand Down Expand Up @@ -94,25 +95,16 @@ async def _establish_new_connection(
for i in range(max_retry):
if i > 0:
logging.info(f"Retrying build handshake number {i} times")
logging.info(
f"old_session: {old_session}, old_session.is_session_open(): {await old_session.is_session_open() if old_session else None}"
)
if not rate_limit.has_budget(client_id):
logging.debug("No retry budget for %s.", client_id)
break
try:
logging.error(
f"##### _establish_new_connection: old session : {old_session}"
)
ws = await websockets.connect(self._websocket_uri)
session_id = (
self.generate_session_id()
if not old_session
else old_session.session_id
)
logging.error(
f"##### _establish_new_connection: existing session : {old_session}"
)
rate_limit.consume_budget(client_id)
handshake_request, handshake_response = await self._establish_handshake(
self._transport_id, self._server_id, session_id, ws
Expand Down Expand Up @@ -158,49 +150,28 @@ async def _create_new_session(
return new_session

async def _get_or_create_session(self) -> ClientSession:
logging.error(f"####### start get or create session")
async with self._create_session_lock:
existing_session = await self._get_existing_session()
if not existing_session:
logging.error(f"##### _get_or_create_session No existing session")
return await self._create_new_session()
is_session_open = await existing_session.is_session_open()
if not is_session_open:
logging.error(
f"##### _get_or_create_session session open, creating new session"
)
await existing_session.close(
is_unexpected_close=False, acquire_transport_lock=True
)
return await self._create_new_session()
is_ws_open = await existing_session.is_websocket_open()
if is_ws_open:
logging.error(f"##### _get_or_create_session Reuse existing session")
return existing_session
else:
try:
new_ws, _, hs_response = await self._establish_new_connection(
existing_session
)
except RiverException as e:
logging.error(
f"##### _get_or_create_session failed to establish new connection : {e}"
)
return existing_session
new_ws, _, hs_response = await self._establish_new_connection(
existing_session
)
if (
hs_response.status.sessionId
== existing_session.advertised_session_id
):
logging.error(
f"##### _get_or_create_session session open, replacing websocket"
)
await existing_session.replace_with_new_websocket(new_ws)
return existing_session
else:
logging.error(f"##### session open, not same session id, reuse")
await existing_session.close(
is_unexpected_close=False, acquire_transport_lock=True
)
await existing_session.close(is_unexpected_close=False)
return await self._create_new_session()

async def _send_handshake_request(
Expand Down
15 changes: 0 additions & 15 deletions replit_river/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ async def _begin_close_session_countdown(self) -> None:

async def serve(self) -> None:
"""Serve messages from the websocket."""
logging.error("####### serve started")
try:
async with asyncio.TaskGroup() as tg:
try:
Expand All @@ -143,16 +142,11 @@ async def serve(self) -> None:
raise ExceptionGroup(
"Unhandled exceptions on River server", unhandled.exceptions
)
logging.error("####### serve finished")

async def _update_book_keeping(self, msg: TransportMessage) -> None:
logging.error("####### _update_book_keeping started")
await self._seq_manager.check_seq_and_update(msg)
await self._remove_acked_messages_in_buffer()
self._reset_session_close_countdown()
logging.error(
f"####### _update_book_keeping end: {self._heartbeat_misses} {self._close_session_after_time_secs}"
)

async def _handle_messages_from_ws(
self, tg: Optional[asyncio.TaskGroup] = None
Expand Down Expand Up @@ -213,17 +207,14 @@ async def _handle_messages_from_ws(
async def replace_with_new_websocket(
self, new_ws: websockets.WebSocketCommonProtocol
) -> None:
logging.debug(f"#### replace_with_new_websocket 1 : {new_ws.id}")
async with self._ws_lock:
logging.debug(f"#### replace_with_new_websocket 2 : {new_ws.id}")
old_wrapper = self._ws_wrapper
old_ws_id = old_wrapper.ws.id
if new_ws.id != old_ws_id:
self._reset_session_close_countdown()
await old_wrapper.close()
self._ws_wrapper = WebsocketWrapper(new_ws)
await self._send_buffered_messages(new_ws)
logging.debug(f"#### replace_with_new_websocket 3 : {new_ws.id}")
# Server will call serve itself.
if not self._is_server:
await self.start_serve_responses()
Expand All @@ -232,7 +223,6 @@ async def _get_current_time(self) -> float:
return asyncio.get_event_loop().time()

def _reset_session_close_countdown(self) -> None:
logging.debug("#### reset_session_close_countdown")
self._heartbeat_misses = 0
self._close_session_after_time_secs = None

Expand All @@ -241,13 +231,9 @@ async def _check_to_close_session(self) -> None:
await asyncio.sleep(
self._transport_options.close_session_check_interval_ms / 1000
)
logging.error("#### _check_to_close_session")
if not self._close_session_after_time_secs:
continue
current_time = await self._get_current_time()
logging.error(
f"#### _check_to_close_session : current_time: {current_time} self._close_session_after_time_secs: {self._close_session_after_time_secs}, {current_time > self._close_session_after_time_secs}"
)
if current_time > self._close_session_after_time_secs:
logging.debug(
"Grace period ended for %s, closing session", self._transport_id
Expand Down Expand Up @@ -425,7 +411,6 @@ async def close_websocket(
return
await ws_wrapper.close()
if should_retry and self._retry_connection_callback:
logging.error("### running retry_connection_callback")
await self._task_manager.create_task(self._retry_connection_callback(self))

async def _open_stream_and_call_handler(
Expand Down

0 comments on commit 75813f0

Please sign in to comment.