Skip to content

Commit

Permalink
Raise Disconnect on send() when client disconnected (#2218)
Browse files Browse the repository at this point in the history
* Raise `Disconnect` on `send()` when client disconnected

* Remove unnucessary variable

* Remove unnucessary sleep

* Undo transport close changes

* Rename Disconnect to ClientDisconnect
  • Loading branch information
Kludex authored Jan 19, 2024
1 parent baf4ea4 commit afed732
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 52 deletions.
9 changes: 7 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,12 @@ def unused_tcp_port() -> int:
marks=pytest.mark.skipif(
not importlib.util.find_spec("wsproto"), reason="wsproto not installed."
),
id="wsproto",
),
pytest.param(
"uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol",
id="websockets",
),
"uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol",
]
)
def ws_protocol_cls(request: pytest.FixtureRequest):
Expand All @@ -269,8 +273,9 @@ def ws_protocol_cls(request: pytest.FixtureRequest):
not importlib.util.find_spec("httptools"),
reason="httptools not installed.",
),
id="httptools",
),
"uvicorn.protocols.http.h11_impl:H11Protocol",
pytest.param("uvicorn.protocols.http.h11_impl:H11Protocol", id="h11"),
]
)
def http_protocol_cls(request: pytest.FixtureRequest):
Expand Down
36 changes: 36 additions & 0 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,42 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
assert got_disconnect_event_before_shutdown is True


@pytest.mark.anyio
async def test_client_connection_lost_on_send(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
unused_tcp_port: int,
):
disconnect = asyncio.Event()
got_disconnect_event = False

async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
nonlocal got_disconnect_event
message = await receive()
if message["type"] == "websocket.connect":
await send({"type": "websocket.accept"})
try:
await disconnect.wait()
await send({"type": "websocket.send", "text": "123"})
except IOError:
got_disconnect_event = True

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config):
url = f"ws://127.0.0.1:{unused_tcp_port}"
async with websockets.client.connect(url):
await asyncio.sleep(0.1)
disconnect.set()

assert got_disconnect_event is True


@pytest.mark.anyio
async def test_connection_lost_before_handshake_complete(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
Expand Down
4 changes: 4 additions & 0 deletions uvicorn/protocols/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
from uvicorn._types import WWWScope


class ClientDisconnected(IOError):
...


def get_remote_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
socket_info = transport.get_extra_info("socket")
if socket_info is not None:
Expand Down
45 changes: 26 additions & 19 deletions uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.utils import (
ClientDisconnected,
get_local_addr,
get_path_with_query_string,
get_remote_addr,
Expand Down Expand Up @@ -252,6 +253,9 @@ async def run_asgi(self) -> None:
"""
try:
result = await self.app(self.scope, self.asgi_receive, self.asgi_send)
except ClientDisconnected:
self.closed_event.set()
self.transport.close()
except BaseException as exc:
self.closed_event.set()
msg = "Exception in ASGI application\n"
Expand Down Expand Up @@ -336,26 +340,29 @@ async def asgi_send(self, message: "ASGISendEvent") -> None:
elif not self.closed_event.is_set() and self.initial_response is None:
await self.handshake_completed_event.wait()

if message_type == "websocket.send":
message = cast("WebSocketSendEvent", message)
bytes_data = message.get("bytes")
text_data = message.get("text")
data = text_data if bytes_data is None else bytes_data
await self.send(data) # type: ignore[arg-type]

elif message_type == "websocket.close":
message = cast("WebSocketCloseEvent", message)
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
await self.close(code, reason)
self.closed_event.set()
try:
if message_type == "websocket.send":
message = cast("WebSocketSendEvent", message)
bytes_data = message.get("bytes")
text_data = message.get("text")
data = text_data if bytes_data is None else bytes_data
await self.send(data) # type: ignore[arg-type]

elif message_type == "websocket.close":
message = cast("WebSocketCloseEvent", message)
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
await self.close(code, reason)
self.closed_event.set()

else:
msg = (
"Expected ASGI message 'websocket.send' or 'websocket.close',"
" but got '%s'."
)
raise RuntimeError(msg % message_type)
else:
msg = (
"Expected ASGI message 'websocket.send' or 'websocket.close',"
" but got '%s'."
)
raise RuntimeError(msg % message_type)
except ConnectionClosed as exc:
raise ClientDisconnected from exc

elif self.initial_response is not None:
if message_type == "websocket.http.response.body":
Expand Down
68 changes: 37 additions & 31 deletions uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from wsproto import ConnectionType, events
from wsproto.connection import ConnectionState
from wsproto.extensions import Extension, PerMessageDeflate
from wsproto.utilities import RemoteProtocolError
from wsproto.utilities import LocalProtocolError, RemoteProtocolError

from uvicorn._types import (
ASGISendEvent,
Expand All @@ -25,6 +25,7 @@
from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.utils import (
ClientDisconnected,
get_local_addr,
get_path_with_query_string,
get_remote_addr,
Expand Down Expand Up @@ -236,6 +237,8 @@ def send_500_response(self) -> None:
async def run_asgi(self) -> None:
try:
result = await self.app(self.scope, self.receive, self.send)
except ClientDisconnected:
self.transport.close()
except BaseException:
self.logger.exception("Exception in ASGI application\n")
self.send_500_response()
Expand Down Expand Up @@ -325,36 +328,39 @@ async def send(self, message: ASGISendEvent) -> None:
raise RuntimeError(msg % message_type)

elif not self.close_sent and not self.response_started:
if message_type == "websocket.send":
message = typing.cast(WebSocketSendEvent, message)
bytes_data = message.get("bytes")
text_data = message.get("text")
data = text_data if bytes_data is None else bytes_data
output = self.conn.send(
wsproto.events.Message(data=data) # type: ignore[type-var]
)
if not self.transport.is_closing():
self.transport.write(output)

elif message_type == "websocket.close":
message = typing.cast(WebSocketCloseEvent, message)
self.close_sent = True
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
output = self.conn.send(
wsproto.events.CloseConnection(code=code, reason=reason)
)
if not self.transport.is_closing():
self.transport.write(output)
self.transport.close()

else:
msg = (
"Expected ASGI message 'websocket.send' or 'websocket.close',"
" but got '%s'."
)
raise RuntimeError(msg % message_type)
try:
if message_type == "websocket.send":
message = typing.cast(WebSocketSendEvent, message)
bytes_data = message.get("bytes")
text_data = message.get("text")
data = text_data if bytes_data is None else bytes_data
output = self.conn.send(wsproto.events.Message(data=data)) # type: ignore
if not self.transport.is_closing():
self.transport.write(output)

elif message_type == "websocket.close":
message = typing.cast(WebSocketCloseEvent, message)
self.close_sent = True
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
self.queue.put_nowait(
{"type": "websocket.disconnect", "code": code}
)
output = self.conn.send(
wsproto.events.CloseConnection(code=code, reason=reason)
)
if not self.transport.is_closing():
self.transport.write(output)
self.transport.close()

else:
msg = (
"Expected ASGI message 'websocket.send' or 'websocket.close',"
" but got '%s'."
)
raise RuntimeError(msg % message_type)
except LocalProtocolError as exc:
raise ClientDisconnected from exc
elif self.response_started:
if message_type == "websocket.http.response.body":
message = typing.cast("WebSocketResponseBodyEvent", message)
Expand Down

0 comments on commit afed732

Please sign in to comment.