Skip to content

Commit

Permalink
Proactively tell the server about our expected state (#32)
Browse files Browse the repository at this point in the history
Why
===

We have a few cases where the server recreates a connection that was
meant to be a reconnect.

What changed
============

This change brings the Python implementation in line with the TypeScript
one so that the server can proactively tell clients to recreate their
sessions. See replit/river#212 for reference

Test plan
=========

Logs like
https://app.datadoghq.com/logs?query=%40replid%3Ae2698176-9a19-4058-8619-7f5793b02cde&agg_m=count&agg_m_source=base&agg_t=count&cols=host%2Cservice%2C%40river.sessionId&event=AgAAAZAssKXztskj6QAAAAAAAAAYAAAAAEFaQXNzS2NBQUFBR1hKdG9Bd1NLLVFBbwAAACQAAAAAMDE5MDJjYjItYzE2Ny00YmZiLThjMjctNzdkNjI5Y2NkY2Vj&fromUser=true&messageDisplay=inline&refresh_mode=paused&storage=hot&stream_sort=desc&viz=stream&from_ts=1718734624299&to_ts=1718737623927&live=false
should completely go away
  • Loading branch information
lhchavez authored Jun 18, 2024
1 parent 0f7245a commit 5aa11a6
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 10 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name="replit-river"
version="0.2.6"
version="0.2.7"
description="Replit river toolkit for Python"
authors = ["Replit <eng@replit.com>"]
license = "LICENSE"
Expand Down
13 changes: 13 additions & 0 deletions replit_river/client_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from replit_river.rpc import (
ControlMessageHandshakeRequest,
ControlMessageHandshakeResponse,
ExpectedSessionState,
TransportMessage,
)
from replit_river.seq_manager import (
Expand Down Expand Up @@ -114,6 +115,7 @@ async def _establish_new_connection(
session_id,
self._handshake_metadata,
ws,
old_session,
)
rate_limit.start_restoring_budget(client_id)
return ws, handshake_request, handshake_response
Expand Down Expand Up @@ -183,12 +185,14 @@ async def _send_handshake_request(
session_id: str,
handshake_metadata: Optional[Any],
websocket: WebSocketCommonProtocol,
expected_session_state: ExpectedSessionState,
) -> ControlMessageHandshakeRequest:
handshake_request = ControlMessageHandshakeRequest(
type="HANDSHAKE_REQ",
protocolVersion=PROTOCOL_VERSION,
sessionId=session_id,
metadata=handshake_metadata,
expectedSessionState=expected_session_state,
)
stream_id = self.generate_nanoid()

Expand Down Expand Up @@ -244,6 +248,7 @@ async def _establish_handshake(
session_id: str,
handshake_metadata: Optional[Any],
websocket: WebSocketCommonProtocol,
old_session: Optional[ClientSession],
) -> Tuple[ControlMessageHandshakeRequest, ControlMessageHandshakeResponse]:
try:
handshake_request = await self._send_handshake_request(
Expand All @@ -252,6 +257,14 @@ async def _establish_handshake(
session_id=session_id,
handshake_metadata=handshake_metadata,
websocket=websocket,
expected_session_state=ExpectedSessionState(
reconnect=old_session is not None,
nextExpectedSeq=(
await old_session.get_next_expected_seq()
if old_session is not None
else 0
),
),
)
except FailedSendingMessageException:
raise RiverException(
Expand Down
7 changes: 6 additions & 1 deletion replit_river/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,16 @@
# Equivalent of https://github.com/replit/river/blob/c1345f1ff6a17a841d4319fad5c153b5bda43827/transport/message.ts#L23-L33


class ExpectedSessionState(BaseModel):
reconnect: bool
nextExpectedSeq: int


class ControlMessageHandshakeRequest(BaseModel):
type: Literal["HANDSHAKE_REQ"] = "HANDSHAKE_REQ"
protocolVersion: str
sessionId: str
expectedSessionState: Optional[ExpectedSessionState] = None
metadata: Optional[Any] = None


Expand Down Expand Up @@ -178,7 +184,6 @@ def rpc_method_handler(
request_deserializer: Callable[[Any], RequestType],
response_serializer: Callable[[ResponseType], Any],
) -> GenericRpcHandler:

async def wrapped(
peer: str,
input: Channel[Any],
Expand Down
31 changes: 23 additions & 8 deletions replit_river/server_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ async def handshake_to_get_session(
try:
msg = parse_transport_msg(message, self._transport_options)
(
_,
handshake_request,
handshake_response,
) = await self._establish_handshake(msg, websocket)
Expand All @@ -58,7 +57,7 @@ async def handshake_to_get_session(
raise InvalidMessageException("No session id in handshake response")
advertised_session_id = handshake_request.sessionId
try:
session = await self.get_or_create_session(
return await self.get_or_create_session(
transport_id,
to_id,
session_id,
Expand All @@ -72,7 +71,6 @@ async def handshake_to_get_session(
f" error: {e}"
)
raise InvalidMessageException(error_msg)
return session
raise WebsocketClosedException("No handshake message received")

async def _send_handshake_response(
Expand Down Expand Up @@ -113,7 +111,6 @@ async def websocket_closed_callback() -> None:
async def _establish_handshake(
self, request_message: TransportMessage, websocket: WebSocketCommonProtocol
) -> Tuple[
WebSocketCommonProtocol,
ControlMessageHandshakeRequest,
ControlMessageHandshakeResponse,
]:
Expand Down Expand Up @@ -148,12 +145,30 @@ async def _establish_handshake(
websocket,
)
raise InvalidMessageException("handshake request to wrong server")
my_session_id = await self._get_or_create_session_id(
request_message.from_, handshake_request.sessionId
)
if handshake_request.expectedSessionState is None:
# TODO: remove once we have upgraded all clients
my_session_id = await self._get_or_create_session_id(
request_message.from_, handshake_request.sessionId
)
elif handshake_request.expectedSessionState.reconnect:
maybe_my_session_id = await self._get_existing_session_id(
request_message.from_,
handshake_request.sessionId,
handshake_request.expectedSessionState.nextExpectedSeq,
)
if maybe_my_session_id is None:
handshake_response = await self._send_handshake_response(
request_message,
HandShakeStatus(ok=False, reason="session state mismatch"),
websocket,
)
raise InvalidMessageException("session state mismatch")
my_session_id = maybe_my_session_id
else:
my_session_id = self.generate_session_id()
handshake_response = await self._send_handshake_response(
request_message,
HandShakeStatus(ok=True, sessionId=my_session_id),
websocket,
)
return websocket, handshake_request, handshake_response
return handshake_request, handshake_response
8 changes: 8 additions & 0 deletions replit_river/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,14 @@ async def _send_transport_message(
except FailedSendingMessageException as e:
raise e

async def get_next_expected_seq(self) -> int:
"""Get the next expected sequence number from the server."""
return await self._seq_manager.get_ack()

async def get_next_expected_ack(self) -> int:
"""Get the next expected ack that the client expects."""
return await self._seq_manager.get_seq()

async def send_message(
self,
stream_id: str,
Expand Down
20 changes: 20 additions & 0 deletions replit_river/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,26 @@ def _set_session(self, session: Session) -> None:
def generate_nanoid(self) -> str:
return str(nanoid.generate())

async def _get_existing_session_id(
self,
to_id: str,
advertised_session_id: str,
next_expected_seq: int,
) -> Optional[str]:
try:
async with self._session_lock:
old_session = self._sessions.get(to_id, None)
if (
old_session is None
or await old_session.get_next_expected_ack() < next_expected_seq
or old_session.advertised_session_id != advertised_session_id
):
return None
return old_session.session_id
except Exception as e:
logging.error(f"Error getting existing session id {e}")
raise e

async def _get_or_create_session_id(
self,
to_id: str,
Expand Down

0 comments on commit 5aa11a6

Please sign in to comment.