Skip to content

Commit

Permalink
Add reason support on WebSocketDisconnectEvent (#2324)
Browse files Browse the repository at this point in the history
* Add `reason` support WebSocketDisconnectEvent

* cutify test

---------

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
  • Loading branch information
frankie567 and Kludex authored Jun 14, 2024
1 parent 44a3071 commit 4e9f48d
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 4 deletions.
9 changes: 8 additions & 1 deletion tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,20 +616,25 @@ async def websocket_session(url: str):


async def test_client_close(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
disconnect_message: WebSocketDisconnectEvent | None = None

async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
nonlocal disconnect_message
while True:
message = await receive()
if message["type"] == "websocket.connect":
await send({"type": "websocket.accept"})
elif message["type"] == "websocket.receive":
pass
elif message["type"] == "websocket.disconnect":
disconnect_message = message
break

async def websocket_session(url: str):
async with websockets.client.connect(url) as websocket:
await websocket.ping()
await websocket.send("abc")
await websocket.close(code=1001, reason="custom reason")

config = Config(
app=app,
Expand All @@ -641,6 +646,8 @@ async def websocket_session(url: str):
async with run_server(config):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")

assert disconnect_message == {"type": "websocket.disconnect", "code": 1001, "reason": "custom reason"}


async def test_client_connection_lost(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
Expand Down Expand Up @@ -1262,7 +1269,7 @@ async def send_text(url: str):
await send_text(f"ws://127.0.0.1:{unused_tcp_port}")

assert frames == [b"abc", b"abc", b"abc"]
assert disconnect_message == {"type": "websocket.disconnect", "code": 1000}
assert disconnect_message == {"type": "websocket.disconnect", "code": 1000, "reason": ""}


async def test_default_server_headers(
Expand Down
1 change: 1 addition & 0 deletions uvicorn/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ class WebSocketResponseBodyEvent(TypedDict):
class WebSocketDisconnectEvent(TypedDict):
type: Literal["websocket.disconnect"]
code: int
reason: NotRequired[str | None]


class WebSocketCloseEvent(TypedDict):
Expand Down
2 changes: 1 addition & 1 deletion uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ async def asgi_receive(
self.closed_event.set()
if self.ws_server.closing:
return {"type": "websocket.disconnect", "code": 1012}
return {"type": "websocket.disconnect", "code": exc.code}
return {"type": "websocket.disconnect", "code": exc.code, "reason": exc.reason}

if isinstance(data, str):
return {"type": "websocket.receive", "text": data}
Expand Down
4 changes: 2 additions & 2 deletions uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def handle_bytes(self, event: events.BytesMessage) -> None:
def handle_close(self, event: events.CloseConnection) -> None:
if self.conn.state == ConnectionState.REMOTE_CLOSING:
self.transport.write(self.conn.send(event.response()))
self.queue.put_nowait({"type": "websocket.disconnect", "code": event.code})
self.queue.put_nowait({"type": "websocket.disconnect", "code": event.code, "reason": event.reason})
self.transport.close()

def handle_ping(self, event: events.Ping) -> None:
Expand Down Expand Up @@ -336,7 +336,7 @@ async def send(self, message: ASGISendEvent) -> None:
self.close_sent = True
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
output = self.conn.send(wsproto.events.CloseConnection(code=code, reason=reason))
if not self.transport.is_closing():
self.transport.write(output)
Expand Down

0 comments on commit 4e9f48d

Please sign in to comment.