Skip to content

Commit

Permalink
0.1.13: Simplify websocket management (#11)
Browse files Browse the repository at this point in the history
* Update

* upgrade

* Fix client session
  • Loading branch information
zhenthebuilder authored May 2, 2024
1 parent e750eb9 commit acb0519
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 79 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name="replit-river"
version="0.1.12"
version="0.1.13"
description="Replit river toolkit for Python"
authors = ["Replit <eng@replit.com>"]
license = "LICENSE"
Expand Down
8 changes: 0 additions & 8 deletions replit_river/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ async def send_rpc(
output: Channel[Any] = Channel(1)
self._streams[stream_id] = output
await self.send_message(
ws=self._ws,
stream_id=stream_id,
control_flags=STREAM_OPEN_BIT | STREAM_CLOSED_BIT,
payload=request_serializer(request),
Expand Down Expand Up @@ -91,7 +90,6 @@ async def send_upload(
if init and init_serializer:
await self.send_message(
stream_id=stream_id,
ws=self._ws,
control_flags=STREAM_OPEN_BIT,
service_name=service_name,
procedure_name=procedure_name,
Expand All @@ -107,7 +105,6 @@ async def send_upload(
first_message = False
await self.send_message(
stream_id=stream_id,
ws=self._ws,
service_name=service_name,
procedure_name=procedure_name,
control_flags=control_flags,
Expand Down Expand Up @@ -158,7 +155,6 @@ async def send_subscription(
output: Channel[Any] = Channel(MAX_MESSAGE_BUFFER_SIZE)
self._streams[stream_id] = output
await self.send_message(
ws=self._ws,
service_name=service_name,
procedure_name=procedure_name,
stream_id=stream_id,
Expand Down Expand Up @@ -209,7 +205,6 @@ async def send_stream(
try:
if init and init_serializer:
await self.send_message(
ws=self._ws,
service_name=service_name,
procedure_name=procedure_name,
stream_id=stream_id,
Expand All @@ -221,7 +216,6 @@ async def send_stream(
request_iter = aiter(request)
first = await anext(request_iter)
await self.send_message(
ws=self._ws,
service_name=service_name,
procedure_name=procedure_name,
stream_id=stream_id,
Expand All @@ -238,7 +232,6 @@ async def _encode_stream() -> None:
if item is None:
continue
await self.send_message(
ws=self._ws,
service_name=service_name,
procedure_name=procedure_name,
stream_id=stream_id,
Expand Down Expand Up @@ -275,7 +268,6 @@ async def send_close_stream(
) -> None:
# close stream
await self.send_message(
ws=self._ws,
service_name=service_name,
procedure_name=procedure_name,
stream_id=stream_id,
Expand Down
104 changes: 34 additions & 70 deletions replit_river/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import nanoid # type: ignore
import websockets
from aiochannel import Channel, ChannelClosed
from websockets import WebSocketCommonProtocol
from websockets.exceptions import ConnectionClosed

from replit_river.message_buffer import MessageBuffer
Expand All @@ -23,6 +22,7 @@
)
from replit_river.task_manager import BackgroundTaskManager
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions
from replit_river.websocket_wrapper import WebsocketWrapper

from .rpc import (
ACK_BIT,
Expand All @@ -39,12 +39,6 @@ class SessionState(enum.Enum):
CLOSED = 2


class WsState(enum.Enum):
OPEN = 0
CLOSING = 1
CLOSED = 2


class Session(object):
"""A transport object that handles the websocket connection with a client."""

Expand Down Expand Up @@ -82,8 +76,7 @@ def __init__(

# ws state
self._ws_lock = asyncio.Lock()
self._ws_state = WsState.OPEN
self._ws = websocket
self._ws_wrapper = WebsocketWrapper(websocket)
self._heartbeat_misses = 0
self._retry_connection_callback = retry_connection_callback

Expand All @@ -109,7 +102,7 @@ async def is_session_open(self) -> bool:

async def is_websocket_open(self) -> bool:
async with self._ws_lock:
return self._ws_state == WsState.OPEN
return await self._ws_wrapper.is_open()

async def _on_websocket_unexpected_close(self) -> None:
"""Handle unexpected websocket close."""
Expand Down Expand Up @@ -141,7 +134,7 @@ async def serve(self) -> None:
try:
async with asyncio.TaskGroup() as tg:
try:
await self._handle_messages_from_ws(self._ws, tg)
await self._handle_messages_from_ws(tg)
except ConnectionClosed as e:
await self._on_websocket_unexpected_close()
logging.debug("ConnectionClosed while serving: %r", e)
Expand All @@ -163,15 +156,15 @@ async def _update_book_keeping(self, msg: TransportMessage) -> None:
self._reset_session_close_countdown()

async def _handle_messages_from_ws(
self, websocket: WebSocketCommonProtocol, tg: Optional[asyncio.TaskGroup] = None
self, tg: Optional[asyncio.TaskGroup] = None
) -> None:
logging.debug(
"%s start handling messages from ws %s",
"server" if self._is_server else "client",
websocket.id,
self._ws_wrapper.id,
)
try:
async for message in websocket:
async for message in self._ws_wrapper.ws:
try:
msg = parse_transport_msg(message, self._transport_options)

Expand Down Expand Up @@ -215,15 +208,13 @@ async def replace_with_new_websocket(
self, new_ws: websockets.WebSocketCommonProtocol
) -> None:
async with self._ws_lock:
old_ws = self._ws
self._ws_state = WsState.CLOSING
if new_ws.id != old_ws.id:
self._reset_session_close_countdown()
await self.close_websocket(old_ws, should_retry=False)
async with self._ws_lock:
self._ws = new_ws
self._ws_state = WsState.OPEN
await self._send_buffered_messages(new_ws)
old_wrapper = self._ws_wrapper
old_ws_id = old_wrapper.ws.id
if new_ws.id != old_ws_id:
self._reset_session_close_countdown()
await old_wrapper.close()
self._ws_wrapper = WebsocketWrapper(new_ws)
await self._send_buffered_messages(new_ws)
# Server will call serve itself.
if not self._is_server:
await self.start_serve_responses()
Expand Down Expand Up @@ -273,7 +264,6 @@ async def _heartbeat(
{
"ack": 0,
},
self._ws,
ACK_BIT,
)
self._heartbeat_misses += 1
Expand All @@ -287,7 +277,7 @@ async def _heartbeat(
)
await self._on_websocket_unexpected_close()
await self.close_websocket(
self._ws, should_retry=not self._is_server
self._ws_wrapper, should_retry=not self._is_server
)
continue
except FailedSendingMessageException:
Expand Down Expand Up @@ -315,12 +305,12 @@ async def _send_buffered_messages(
async def _send_transport_message(
self,
msg: TransportMessage,
ws: WebSocketCommonProtocol,
websocket: websockets.WebSocketCommonProtocol,
prefix_bytes: bytes = b"",
) -> None:
try:
await send_transport_message(
msg, ws, self._on_websocket_unexpected_close, prefix_bytes
msg, websocket, self._on_websocket_unexpected_close, prefix_bytes
)
except WebsocketClosedException as e:
raise e
Expand All @@ -331,7 +321,6 @@ async def send_message(
self,
stream_id: str,
payload: Dict | str,
ws: WebSocketCommonProtocol,
control_flags: int = 0,
service_name: str | None = None,
procedure_name: str | None = None,
Expand Down Expand Up @@ -363,11 +352,14 @@ async def send_message(
# buffer
await self.close(True)
return
await self._send_transport_message(
msg,
ws,
prefix_bytes=self._transport_options.get_prefix_bytes(),
)
async with self._ws_lock:
if await self._ws_wrapper.is_open():
# if it is not open it's fine, we already put it the buffer
await self._send_transport_message(
msg,
self._ws_wrapper.ws,
prefix_bytes=self._transport_options.get_prefix_bytes(),
)
except WebsocketClosedException as e:
logging.debug(
"Connection closed while sending message %r: %r, waiting for "
Expand All @@ -388,25 +380,13 @@ async def _send_responses_from_output_stream(
) -> None:
"""Send serialized messages to the websockets."""
try:
# TODO: This blocking is not ideal, we should close this task when websocket
# is closed, and start another one when websocket reconnects.
ws = None
async for payload in output:
ws = self._ws
while self._ws_state != WsState.OPEN:
await asyncio.sleep(
self._transport_options.close_session_check_interval_ms / 1000
)
ws = self._ws
if not is_streaming_output:
await self.send_message(stream_id, payload, ws, STREAM_CLOSED_BIT)
await self.send_message(stream_id, payload, STREAM_CLOSED_BIT)
return
await self.send_message(stream_id, payload, ws)
if ws:
logging.debug("sent an end of stream %r", stream_id)
await self.send_message(
stream_id, {"type": "CLOSE"}, ws, STREAM_CLOSED_BIT
)
await self.send_message(stream_id, payload)
logging.debug("sent an end of stream %r", stream_id)
await self.send_message(stream_id, {"type": "CLOSE"}, STREAM_CLOSED_BIT)
except FailedSendingMessageException as e:
logging.error(f"Error while sending responses, {type(e)} : {e}")
except (RuntimeError, ChannelClosed) as e:
Expand All @@ -415,28 +395,11 @@ async def _send_responses_from_output_stream(
logging.error(f"Unknown error while river sending responses back : {e}")

async def close_websocket(
self, ws: WebSocketCommonProtocol, should_retry: bool
self, ws_wrapper: WebsocketWrapper, should_retry: bool
) -> None:
"""Mark the websocket as closed, close the websocket, and retry if needed."""
async with self._ws_lock:
if self._ws.id != ws.id:
# already replaced with new ws
return
if self._ws_state != WsState.OPEN:
# Already closed
return
logging.info(
f"River session from {self._transport_id} to {self._to_id} "
f"closing websocket {ws.id}"
)
self._ws_state = WsState.CLOSING
if ws:
# TODO: should we wait here?
task = asyncio.create_task(ws.close())
task.add_done_callback(
lambda _: logging.debug("old websocket %s closed.", ws.id)
)
self._ws_state = WsState.CLOSED
await ws_wrapper.close()
if should_retry and self._retry_connection_callback:
await self._retry_connection_callback(self)

Expand Down Expand Up @@ -519,15 +482,16 @@ async def close(self, is_unexpected_close: bool) -> None:
"""Close the session and all associated streams."""
logging.info(
f"{self._transport_id} closing session "
f"to {self._to_id}, ws: {self._ws.id}, current_state : {self._ws_state}"
f"to {self._to_id}, ws: {self._ws_wrapper.id}, "
f"current_state : {self._ws_wrapper}"
)
async with self._state_lock:
if self._state != SessionState.ACTIVE:
# already closing
return
self._state = SessionState.CLOSING
self._reset_session_close_countdown()
await self.close_websocket(self._ws, should_retry=False)
await self.close_websocket(self._ws_wrapper, should_retry=False)
# Clear the session in transports
await self._close_session_callback(self)
await self._task_manager.cancel_all_tasks()
Expand Down
33 changes: 33 additions & 0 deletions replit_river/websocket_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import asyncio
import enum
import logging

from websockets import WebSocketCommonProtocol


class WsState(enum.Enum):
OPEN = 0
CLOSING = 1
CLOSED = 2


class WebsocketWrapper:
def __init__(self, ws: WebSocketCommonProtocol) -> None:
self.ws = ws
self.ws_state = WsState.OPEN
self.ws_lock = asyncio.Lock()
self.id = ws.id

async def is_open(self) -> bool:
async with self.ws_lock:
return self.ws_state == WsState.OPEN

async def close(self) -> None:
async with self.ws_lock:
if self.ws_state == WsState.OPEN:
self.ws_state = WsState.CLOSING
task = asyncio.create_task(self.ws.close())
task.add_done_callback(
lambda _: logging.debug("old websocket %s closed.", self.ws.id)
)
self.ws_state = WsState.CLOSED

0 comments on commit acb0519

Please sign in to comment.