From 91c2cea308bb7bbf6cbf7e35c81720eef55d57e0 Mon Sep 17 00:00:00 2001 From: Sihan Wang Date: Sun, 7 Jan 2024 14:20:45 -0800 Subject: [PATCH] [Serve] Improve handling the websocket server disconnect scenario (#42130) Server disconnect message is not guaranteed to be sent. So we need to handle it when finish the websocket connection. --------- Signed-off-by: Sihan Wang --- python/ray/serve/_private/proxy.py | 29 +++++++++---- python/ray/serve/tests/test_proxy.py | 65 +++++++++++++++++++++++++++- 2 files changed, 85 insertions(+), 9 deletions(-) diff --git a/python/ray/serve/_private/proxy.py b/python/ray/serve/_private/proxy.py index d2853c8703e0..f310812ff36d 100644 --- a/python/ray/serve/_private/proxy.py +++ b/python/ray/serve/_private/proxy.py @@ -1047,24 +1047,37 @@ async def send_request_to_replica( ) finally: + # For websocket connection, queue receive task is done when receiving + # disconnect message from client. + receive_client_disconnect_msg = False if not proxy_asgi_receive_task.done(): proxy_asgi_receive_task.cancel() else: - # If the server disconnects, status_code is set above from the - # disconnect message. Otherwise the disconnect code comes from - # a client message via the receive interface. - if ( - status is None - and proxy_request.request_type == "websocket" - and proxy_asgi_receive_task.exception() is None - ): + receive_client_disconnect_msg = True + + # If the server disconnects, status_code can be set above from the + # disconnect message. + # If client disconnects, the disconnect code comes from + # a client message via the receive interface. + if status is None and proxy_request.request_type == "websocket": + if receive_client_disconnect_msg: + # The disconnect message is sent from the client. status = ResponseStatus( code=str(proxy_asgi_receive_task.result()), is_error=True, ) + else: + # The server disconnect without sending a disconnect message + # (otherwise the `status` would be set). + status = ResponseStatus( + code="1000", # [Sihan] is there a better code for this? + is_error=True, + ) del self.asgi_receive_queues[request_id] + # The status code should always be set. + assert status is not None yield status diff --git a/python/ray/serve/tests/test_proxy.py b/python/ray/serve/tests/test_proxy.py index 567f1edf9df1..ba3c22a48759 100644 --- a/python/ray/serve/tests/test_proxy.py +++ b/python/ray/serve/tests/test_proxy.py @@ -174,7 +174,10 @@ def __init__(self, messages=None): self.messages = messages or [] async def __call__(self): - return self.messages.pop() + while True: + if self.messages: + return self.messages.pop() + await asyncio.sleep(0.1) class FakeHttpSend: @@ -551,6 +554,66 @@ async def test_proxy_asgi_receive(self): queue.close.assert_called_once() + @pytest.mark.asyncio + @pytest.mark.parametrize( + "disconnect", + [ + "client", + "server_with_disconnect_message", + "server_without_disconnect_message", + ], + ) + async def test_websocket_call(self, disconnect: str): + """Test HTTPProxy websocket __call__ calls proxy_request.""" + + if disconnect == "client": + receive = FakeHttpReceive( + [{"type": "websocket.disconnect", "code": "1000"}] + ) + expected_messages = [ + {"type": "websocket.accept"}, + {"type": "websocket.send"}, + ] + elif disconnect == "server_with_disconnect_message": + receive = FakeHttpReceive() + expected_messages = [ + {"type": "websocket.accept"}, + {"type": "websocket.send"}, + {"type": "websocket.disconnect", "code": "1000"}, + ] + else: + receive = FakeHttpReceive() + expected_messages = [ + {"type": "websocket.accept"}, + {"type": "websocket.send"}, + ] + + http_proxy = self.create_http_proxy() + http_proxy.proxy_router.route = "route" + http_proxy.proxy_router.handle = FakeHTTPHandle(messages=expected_messages) + http_proxy.proxy_router.app_is_cross_language = False + + scope = { + "type": "websocket", + "headers": [ + ( + b"x-request-id", + b"fake_request_id", + ), + ], + } + send = FakeHttpSend() + + # Ensure before calling __call__, send.messages should be empty. + assert send.messages == [] + await http_proxy( + scope=scope, + receive=receive, + send=send, + ) + # Ensure after calling __call__, send.messages should be expected messages. + assert send.messages == expected_messages + class TestTimeoutKeepAliveConfig: """Test setting keep_alive_timeout_s in config and env."""