Skip to content

Commit

Permalink
A few smol fixes
Browse files Browse the repository at this point in the history
The Python implementation of River probably has a few bugs, since we see
a large number of invariant violation errors.

This change:

* Stops a few resource leaks: the heartbeat tasks could get into a state
  where they never terminated, so they would consume a task forever.
* Stops a couple of TOCTOU races where an `await` point could have
  gotten in the way of a test-and-replace operation (which thanks to the
  GIL should be "atomic", except that pesky `await` point got in the way).
* Added a comment clarifying the transitions between the session states.
* Deindented a smol block for clarity.
  • Loading branch information
lhchavez committed Jun 17, 2024
1 parent f6a73d8 commit 3d7f3a6
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 20 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name="replit-river"
version="0.2.5"
version="0.2.6"
description="Replit river toolkit for Python"
authors = ["Replit <eng@replit.com>"]
license = "LICENSE"
Expand Down
20 changes: 8 additions & 12 deletions replit_river/client_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,19 +166,15 @@ async def _get_or_create_session(self) -> ClientSession:
is_ws_open = await existing_session.is_websocket_open()
if is_ws_open:
return existing_session
new_ws, _, hs_response = await self._establish_new_connection(
existing_session
)
if hs_response.status.sessionId == existing_session.advertised_session_id:
await existing_session.replace_with_new_websocket(new_ws)
return existing_session
else:
new_ws, _, hs_response = await self._establish_new_connection(
existing_session
)
if (
hs_response.status.sessionId
== existing_session.advertised_session_id
):
await existing_session.replace_with_new_websocket(new_ws)
return existing_session
else:
await existing_session.close(is_unexpected_close=False)
return await self._create_new_session()
await existing_session.close(is_unexpected_close=False)
return await self._create_new_session()

async def _send_handshake_request(
self,
Expand Down
29 changes: 22 additions & 7 deletions replit_river/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@


class SessionState(enum.Enum):
"""The state a session can be in.
Can only transition from ACTIVE to CLOSING to CLOSED.
"""

ACTIVE = 0
CLOSING = 1
CLOSED = 2
Expand Down Expand Up @@ -108,6 +113,13 @@ async def _begin_close_session_countdown(self) -> None:
"""Begin the countdown to close session, this should be called when
websocket is closed.
"""
# calculate the value now before establishing it so that there are no
# await points between the check and the assignment to avoid a TOCTOU
# race.
grace_period_ms = self._transport_options.session_disconnect_grace_ms
close_session_after_time_secs = (
await self._get_current_time() + grace_period_ms / 1000
)
if self._close_session_after_time_secs is not None:
# already in grace period, no need to set again
return
Expand All @@ -116,10 +128,7 @@ async def _begin_close_session_countdown(self) -> None:
self._transport_id,
self._to_id,
)
grace_period_ms = self._transport_options.session_disconnect_grace_ms
self._close_session_after_time_secs = (
await self._get_current_time() + grace_period_ms / 1000
)
self._close_session_after_time_secs = close_session_after_time_secs
await self.close_websocket(self._ws_wrapper, should_retry=not self._is_server)

async def serve(self) -> None:
Expand Down Expand Up @@ -228,9 +237,15 @@ async def _check_to_close_session(self) -> None:
await asyncio.sleep(
self._transport_options.close_session_check_interval_ms / 1000
)
if self._state != SessionState.ACTIVE:
# already closing
return
# calculate the value now before comparing it so that there are no
# await points between the check and the comparison to avoid a TOCTOU
# race.
current_time = await self._get_current_time()
if not self._close_session_after_time_secs:
continue
current_time = await self._get_current_time()
if current_time > self._close_session_after_time_secs:
logging.debug(
"Grace period ended for %s, closing session", self._transport_id
Expand All @@ -251,8 +266,8 @@ async def _heartbeat(
{self._state},
{self._close_session_after_time_secs},
)
# session is closing, no need to send heartbeat
continue
# session is closing / closed, no need to send heartbeat anymore
return
try:
await self.send_message(
str(nanoid.generate()),
Expand Down

0 comments on commit 3d7f3a6

Please sign in to comment.