Skip to content

Commit

Permalink
[Serve] Improve handling the websocket server disconnect scenario (ra…
Browse files Browse the repository at this point in the history
…y-project#42130)

Server disconnect message is not guaranteed to be sent. So we need to handle it when finish the websocket connection.

---------

Signed-off-by: Sihan Wang <sihanwang41@gmail.com>
  • Loading branch information
sihanwang41 committed Jan 8, 2024
1 parent 3bb96d2 commit 91c2cea
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 9 deletions.
29 changes: 21 additions & 8 deletions python/ray/serve/_private/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,24 +1047,37 @@ async def send_request_to_replica(
)

finally:
# For websocket connection, queue receive task is done when receiving
# disconnect message from client.
receive_client_disconnect_msg = False
if not proxy_asgi_receive_task.done():
proxy_asgi_receive_task.cancel()
else:
# If the server disconnects, status_code is set above from the
# disconnect message. Otherwise the disconnect code comes from
# a client message via the receive interface.
if (
status is None
and proxy_request.request_type == "websocket"
and proxy_asgi_receive_task.exception() is None
):
receive_client_disconnect_msg = True

# If the server disconnects, status_code can be set above from the
# disconnect message.
# If client disconnects, the disconnect code comes from
# a client message via the receive interface.
if status is None and proxy_request.request_type == "websocket":
if receive_client_disconnect_msg:
# The disconnect message is sent from the client.
status = ResponseStatus(
code=str(proxy_asgi_receive_task.result()),
is_error=True,
)
else:
# The server disconnect without sending a disconnect message
# (otherwise the `status` would be set).
status = ResponseStatus(
code="1000", # [Sihan] is there a better code for this?
is_error=True,
)

del self.asgi_receive_queues[request_id]

# The status code should always be set.
assert status is not None
yield status


Expand Down
65 changes: 64 additions & 1 deletion python/ray/serve/tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,10 @@ def __init__(self, messages=None):
self.messages = messages or []

async def __call__(self):
return self.messages.pop()
while True:
if self.messages:
return self.messages.pop()
await asyncio.sleep(0.1)


class FakeHttpSend:
Expand Down Expand Up @@ -551,6 +554,66 @@ async def test_proxy_asgi_receive(self):

queue.close.assert_called_once()

@pytest.mark.asyncio
@pytest.mark.parametrize(
"disconnect",
[
"client",
"server_with_disconnect_message",
"server_without_disconnect_message",
],
)
async def test_websocket_call(self, disconnect: str):
"""Test HTTPProxy websocket __call__ calls proxy_request."""

if disconnect == "client":
receive = FakeHttpReceive(
[{"type": "websocket.disconnect", "code": "1000"}]
)
expected_messages = [
{"type": "websocket.accept"},
{"type": "websocket.send"},
]
elif disconnect == "server_with_disconnect_message":
receive = FakeHttpReceive()
expected_messages = [
{"type": "websocket.accept"},
{"type": "websocket.send"},
{"type": "websocket.disconnect", "code": "1000"},
]
else:
receive = FakeHttpReceive()
expected_messages = [
{"type": "websocket.accept"},
{"type": "websocket.send"},
]

http_proxy = self.create_http_proxy()
http_proxy.proxy_router.route = "route"
http_proxy.proxy_router.handle = FakeHTTPHandle(messages=expected_messages)
http_proxy.proxy_router.app_is_cross_language = False

scope = {
"type": "websocket",
"headers": [
(
b"x-request-id",
b"fake_request_id",
),
],
}
send = FakeHttpSend()

# Ensure before calling __call__, send.messages should be empty.
assert send.messages == []
await http_proxy(
scope=scope,
receive=receive,
send=send,
)
# Ensure after calling __call__, send.messages should be expected messages.
assert send.messages == expected_messages


class TestTimeoutKeepAliveConfig:
"""Test setting keep_alive_timeout_s in config and env."""
Expand Down

0 comments on commit 91c2cea

Please sign in to comment.