Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenthebuilder committed Apr 21, 2024
1 parent 9b7f991 commit a3e29d8
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 63 deletions.
2 changes: 1 addition & 1 deletion replit_river/client_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def _get_or_create_session(self) -> ClientSession:
existing_session = await self._get_existing_session()
if not existing_session:
return await self._create_session()
if not await existing_session.is_websocket_open():
if not existing_session.is_websocket_open():
logging.debug("Client session exists but websocket closed, reconnect one")
self._ws = await websockets.connect(self._websocket_uri)
await existing_session.replace_with_new_websocket(self._ws)
Expand Down
2 changes: 1 addition & 1 deletion replit_river/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async def send_transport_message(
websocket_closed_callback: Callable[[], Coroutine[Any, Any, None]],
prefix_bytes: bytes = b"",
) -> None:
logging.debug("sent a message %r", msg)
logging.debug(f"sending a message {msg} to ws {ws.id} with state {ws.state}")
try:
await ws.send(
prefix_bytes
Expand Down
1 change: 0 additions & 1 deletion replit_river/seq_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def __init__(
self._ack_lock = asyncio.Lock()
self.ack = 0
self.receiver_ack = 0
self.next_send_seq = 0

async def get_seq_and_increment(self) -> int:
"""Get the current sequence number and increment it.
Expand Down
5 changes: 4 additions & 1 deletion replit_river/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ def add_rpc_handlers(
self._transport._handlers.update(rpc_handlers)

async def serve(self, websocket: WebSocketServerProtocol) -> None:
logging.debug("River server started establishing session")
logging.debug(
f"River server started establishing session with ws: {websocket.id}"
f" {websocket.state}"
)
try:
session = await self._transport.handshake_to_get_session(websocket)
except Exception as e:
Expand Down
9 changes: 8 additions & 1 deletion replit_river/server_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ async def get_or_create_session(
session_to_close: Optional[Session] = None
async with self._session_lock:
if to_id not in self._sessions:
logging.debug(
f'Creating new session with "{to_id}" using ws: {websocket.id}'
)
self._sessions[to_id] = Session(
transport_id,
to_id,
Expand Down Expand Up @@ -68,13 +71,16 @@ async def get_or_create_session(
else:
# If the instance id is the same, we reuse the session and assign
# a new websocket to it.
logging.debug(
f'Reuse old session with "{to_id}" using new ws: {websocket.id}'
)
try:
await old_session.replace_with_new_websocket(websocket)
except FailedSendingMessageException as e:
raise e
if session_to_close:
logging.info("Closing stale websocket")
await session_to_close.close()
await session_to_close.close(False)
session = self._sessions[to_id]
return session

Expand Down Expand Up @@ -155,6 +161,7 @@ async def _establish_handshake(
handshake_request = ControlMessageHandshakeRequest(
**request_message.payload
)
logging.debug('Got handshake request "%r"', handshake_request)
except (ValidationError, ValueError):
await self._send_handshake_response(
request_message,
Expand Down
155 changes: 101 additions & 54 deletions replit_river/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,14 @@

class SessionState(enum.Enum):
ACTIVE = 0
CLOSING = 1
CLOSED = 2
CLOSING = 3
CLOSED = 4


class WsState(enum.Enum):
OPEN = 0
CLOSING = 2
CLOSED = 3


class Session(object):
Expand All @@ -63,6 +69,7 @@ def __init__(
self._handlers = handlers

self._state = SessionState.ACTIVE
self._ws_state = WsState.OPEN

self._ws_lock = asyncio.Lock()
self._ws = websocket
Expand All @@ -75,18 +82,21 @@ def __init__(
self._seq_manager = SeqManager()
self._stream_lock = asyncio.Lock()
self._task_manager = BackgroundTaskManager()
self._buffer = MessageBuffer()
self._buffer = MessageBuffer(self._transport_options.buffer_size)
self.heartbeat_misses = 0
# should disconnect after this time
self._close_session_after_time_secs: Optional[float] = None
asyncio.create_task(self._task_manager.create_task(self._heartbeat()))
asyncio.create_task(self._setup_heartbeats_task())

async def _setup_heartbeats_task(self) -> None:
await self._task_manager.create_task(self._heartbeat())
await self._task_manager.create_task(self._check_to_close_session())

def is_session_open(self) -> bool:
return self._state == SessionState.ACTIVE

async def is_websocket_open(self) -> bool:
async with self._ws_lock:
return self._ws.open
def is_websocket_open(self) -> bool:
return self._ws_state == WsState.OPEN

async def serve(self) -> None:
"""Serve messages from the websocket."""
Expand Down Expand Up @@ -114,7 +124,7 @@ async def _handle_messages_from_ws(
) -> None:
logging.debug(
f'{"server" if self._is_server else "client"} start handling messages from'
" ws"
f" ws {websocket.id}, state {websocket.state}"
)
try:
async for message in websocket:
Expand Down Expand Up @@ -155,7 +165,7 @@ async def _handle_messages_from_ws(
logging.error(
f"Got invalid transport message, closing session : {e}"
)
await self.close()
await self.close(True)
return
except ConnectionClosed as e:
raise e
Expand All @@ -164,17 +174,21 @@ async def replace_with_new_websocket(
self, websocket: websockets.WebSocketCommonProtocol
) -> None:
logging.info("replacing with new websocket")
await self.close_websocket(self._ws)
self.reset_session_close_countdown()
if websocket.id != self._ws.id:
await self.close_websocket(self._ws)
logging.debug("Old websocket closed")
await self._send_buffered_messages(websocket)
async with self._ws_lock:
self._ws = websocket
self._ws_state = WsState.OPEN
logging.debug("Websocket replace success")

async def _get_current_time(self) -> float:
return asyncio.get_event_loop().time()

async def begin_close_session_countdown(self) -> None:
if self._close_session_after_time_secs:
logging.debug("begin_close_session_countdown")
if self._close_session_after_time_secs is not None:
return
logging.debug(
f"websocket closed from {self._transport_id} to {self._to_id}, "
Expand All @@ -188,41 +202,53 @@ async def begin_close_session_countdown(self) -> None:
def reset_session_close_countdown(self) -> None:
self.heartbeat_misses = 0
self._close_session_after_time_secs = None
logging.info(f"Grace period cancelled for session to {self._transport_id}")
logging.info(f"Countdown reset for session to {self._transport_id}")

async def _check_to_close_session(self) -> None:
while True:
await asyncio.sleep(
self._transport_options.close_session_check_interval_ms / 1000
)
if not self._close_session_after_time_secs:
continue
current_time = await self._get_current_time()
if current_time > self._close_session_after_time_secs:
logging.info(
"Grace period ended for :" f" {self._transport_id}, closing session"
)
await self.close(False)
return

async def _heartbeat(
self,
) -> None:
logging.debug("Start heartbeat")
while True:
await asyncio.sleep(self._transport_options.heartbeat_ms / 1000)
current_time = await self._get_current_time()
if self._close_session_after_time_secs:
if current_time > self._close_session_after_time_secs:
logging.info(
"Grace period ended for :"
f" {self._transport_id}, closing session"
)
await self.close()
return
continue
if self.heartbeat_misses >= self._transport_options.heartbeats_until_dead:
await self.close_websocket(self._ws)
await self.begin_close_session_countdown()
return
if self._state != SessionState.ACTIVE:
if (
self._state != SessionState.ACTIVE
or self._close_session_after_time_secs
):
# session is closing, no need to send heartbeat
continue
try:
await self.send_message(
str(nanoid.generate()),
self._ws,
{
"ack": 0,
},
self._ws,
ACK_BIT,
)
self.heartbeat_misses += 1
if (
self.heartbeat_misses
>= self._transport_options.heartbeats_until_dead
):
logging.debug("closing websocket because of heartbeat misses")
await self.begin_close_session_countdown()
await self.close_websocket(self._ws)
return
except FailedSendingMessageException:
# this is expected during websocket closed period
continue
Expand Down Expand Up @@ -264,8 +290,8 @@ async def _send_transport_message(
async def send_message(
self,
stream_id: str,
payload: Dict | str,
ws: WebSocketCommonProtocol,
payload: Dict,
control_flags: int = 0,
service_name: str | None = None,
procedure_name: str | None = None,
Expand Down Expand Up @@ -298,43 +324,60 @@ async def send_message(
await self._buffer.put(msg)
except Exception:
# We should close the session when there are too many messages in buffer
await self.close()
await self.close(True)
return

async def send_responses(
self,
stream_id: str,
ws: WebSocketCommonProtocol,
output: Channel[Any],
is_streaming_output: bool,
) -> None:
"""Send serialized messages to the websockets."""
logging.debug("sent response of stream %r", stream_id)
try:
ws = self._ws
async for payload in output:
while self._ws_state != WsState.OPEN:
await asyncio.sleep(
self._transport_options.close_session_check_interval_ms / 1000
)
ws = self._ws
if not is_streaming_output:
await self.send_message(stream_id, ws, payload, STREAM_CLOSED_BIT)
await self.send_message(stream_id, payload, ws, STREAM_CLOSED_BIT)
return
await self.send_message(stream_id, ws, payload)
await self.send_message(stream_id, payload, ws)
logging.debug("sent an end of stream %r", stream_id)
await self.send_message(stream_id, ws, {"type": "CLOSE"}, STREAM_CLOSED_BIT)
await self.send_message(stream_id, {"type": "CLOSE"}, ws, STREAM_CLOSED_BIT)
except FailedSendingMessageException as e:
logging.error(f"Error while sending responses back : {e}")
logging.error(
f"Error while sending responses, ws_state: {ws.state}, {type(e)} : {e}"
)
except (RuntimeError, ChannelClosed) as e:
logging.error(f"Error while sending responses back : {e}")
logging.error(
f"Error while sending responses, ws_state: {ws.state} {type(e)} : {e}"
)
except Exception as e:
logging.error(f"Unknown error while river sending responses back : {e}")

async def close_websocket(self, websocket: WebSocketCommonProtocol) -> None:
logging.info(
f"River session from {self._transport_id} to {self._to_id} "
"closing websocket"
)
async with self._ws_lock:
self._ws_state = WsState.CLOSING
self.reset_session_close_countdown()
if self._close_websocket_callback:
await self._close_websocket_callback(self)
logging.error(f"closing websocket {websocket.id} state: {websocket.state}")
if websocket:
await websocket.close()
logging.info(
f"River session from {self._transport_id} to {self._to_id} "
"closing websocket"
)
# TODO: if we wait this to be closed this takes too long
# this could hang?
task = asyncio.create_task(websocket.close())
task.add_done_callback(
lambda _: logging.debug(f"old websocket closed, {websocket.id}")
)
self._ws_state = WsState.CLOSED

async def _open_stream_and_call_handler(
self,
Expand Down Expand Up @@ -380,9 +423,7 @@ async def _open_stream_and_call_handler(
handler_func(msg.from_, input_stream, output_stream), tg
)
await self._task_manager.create_task(
self.send_responses(
msg.streamId, self._ws, output_stream, is_streaming_output
),
self.send_responses(msg.streamId, output_stream, is_streaming_output),
tg,
)
return input_stream
Expand All @@ -409,22 +450,28 @@ async def _update_msg_buffer(self) -> None:
async def start_serve_messages(self) -> None:
await self._task_manager.create_task(self.serve())

async def close(self) -> None:
async def close(self, is_unexpected_close: bool) -> None:
"""Close the session and all associated streams."""
logging.info(
f"{self._transport_id} closing session "
f"to {self._to_id} current_state : {self._state}"
f"to {self._to_id} current_state : {self._ws_state}"
)
if self._state == SessionState.CLOSING or self._state == SessionState.CLOSED:
if not self.is_session_open():
return
self._state = SessionState.CLOSING
if is_unexpected_close:
await self.send_message(
str(nanoid.generate()), "UNEXPECTED_DISCONNECT", self._ws
)
await self.close_websocket(self._ws)
# Clear the session in transports
await self._close_session_callback(self)
await self._task_manager.cancel_all_tasks()
for previous_input in self._streams.values():
previous_input.close()
async with self._stream_lock:
self._streams.clear()
await self._task_manager.cancel_all_tasks()
await self.close_websocket(self._ws)
# Clear the session in transports
await self._close_session_callback(self)
self._state = SessionState.CLOSED
logging.info(f"{self._transport_id} closed session " f"to {self._to_id}")
logging.info(
f"################ {self._transport_id} closed session to {self._to_id}"
)
6 changes: 3 additions & 3 deletions replit_river/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ async def close_all_sessions(self) -> None:
f"{len(sessions)}"
)
for session in sessions:
await session.close()
await session.close(False)
logging.info(f"Transport closed {self._transport_id}")

async def _delete_session(self, session: Session) -> None:
async with self._session_lock:
if session._transport_id in self._sessions:
del self._sessions[session._transport_id]
if session._to_id in self._sessions:
del self._sessions[session._to_id]

def generate_nanoid(self) -> str:
return str(nanoid.generate())
2 changes: 2 additions & 0 deletions replit_river/transport_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ class TransportOptions(BaseModel):
heartbeat_ms: float = 2000
heartbeats_until_dead: int = 2
use_prefix_bytes: bool = False
close_session_check_interval_ms: float = 100
connection_retry_options: ConnectionRetryOptions = ConnectionRetryOptions()
buffer_size: int = 1000

def get_prefix_bytes(self) -> bytes:
return PID2_PREFIX_BYTES if self.use_prefix_bytes else b""
Expand Down
Loading

0 comments on commit a3e29d8

Please sign in to comment.