Skip to content

Commit

Permalink
Support the WebSocket Denial Response ASGI extension (#1916)
Browse files Browse the repository at this point in the history
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
  • Loading branch information
kristjanvalur and Kludex authored Dec 17, 2023
1 parent 7d274ed commit 6568184
Show file tree
Hide file tree
Showing 5 changed files with 447 additions and 15 deletions.
349 changes: 345 additions & 4 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@
from websockets.extensions.permessage_deflate import ClientPerMessageDeflateFactory
from websockets.typing import Subprotocol

from tests.response import Response
from tests.utils import run_server
from uvicorn._types import (
ASGIReceiveCallable,
ASGIReceiveEvent,
ASGISendCallable,
Scope,
WebSocketCloseEvent,
WebSocketDisconnectEvent,
WebSocketResponseStartEvent,
)
from uvicorn.config import Config
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
Expand Down Expand Up @@ -55,6 +58,21 @@ async def asgi(self):
break


async def wsresponse(url):
"""
A simple websocket connection request and response helper
"""
url = url.replace("ws:", "http:")
headers = {
"connection": "upgrade",
"upgrade": "websocket",
"Sec-WebSocket-Key": "x3JJHMbDL1EzLkh9GBhXDw==",
"Sec-WebSocket-Version": "13",
}
async with httpx.AsyncClient() as client:
return await client.get(url, headers=headers)


@pytest.mark.anyio
async def test_invalid_upgrade(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
Expand Down Expand Up @@ -942,7 +960,10 @@ async def test_server_reject_connection(
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
unused_tcp_port: int,
):
disconnected_message: ASGIReceiveEvent = {} # type: ignore

async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
nonlocal disconnected_message
assert scope["type"] == "websocket"

# Pull up first recv message.
Expand All @@ -955,15 +976,241 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable

# This doesn't raise `TypeError`:
# See https://github.com/encode/uvicorn/issues/244
disconnected_message = await receive()

async def websocket_session(url):
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
async with websockets.client.connect(url):
pass # pragma: no cover
assert exc_info.value.status_code == 403

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")

assert disconnected_message == {"type": "websocket.disconnect", "code": 1006}


@pytest.mark.anyio
async def test_server_reject_connection_with_response(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
unused_tcp_port: int,
):
disconnected_message = {}

async def app(scope, receive, send):
nonlocal disconnected_message
assert scope["type"] == "websocket"
assert "websocket.http.response" in scope["extensions"]

# Pull up first recv message.
message = await receive()
assert message["type"] == "websocket.disconnect"
assert message["type"] == "websocket.connect"

# Reject the connection with a response
response = Response(b"goodbye", status_code=400)
await response(scope, receive, send)
disconnected_message = await receive()

async def websocket_session(url):
response = await wsresponse(url)
assert response.status_code == 400
assert response.content == b"goodbye"

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")

assert disconnected_message == {"type": "websocket.disconnect", "code": 1006}


@pytest.mark.anyio
async def test_server_reject_connection_with_multibody_response(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
unused_tcp_port: int,
):
disconnected_message: ASGIReceiveEvent = {} # type: ignore

async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
nonlocal disconnected_message
assert scope["type"] == "websocket"
assert "extensions" in scope
assert "websocket.http.response" in scope["extensions"]

# Pull up first recv message.
message = await receive()
assert message["type"] == "websocket.connect"
await send(
{
"type": "websocket.http.response.start",
"status": 400,
"headers": [
(b"Content-Length", b"20"),
(b"Content-Type", b"text/plain"),
],
}
)
await send(
{
"type": "websocket.http.response.body",
"body": b"x" * 10,
"more_body": True,
}
)
await send({"type": "websocket.http.response.body", "body": b"y" * 10})
disconnected_message = await receive()

async def websocket_session(url: str):
try:
response = await wsresponse(url)
assert response.status_code == 400
assert response.content == (b"x" * 10) + (b"y" * 10)

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")

assert disconnected_message == {"type": "websocket.disconnect", "code": 1006}


@pytest.mark.anyio
async def test_server_reject_connection_with_invalid_status(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
unused_tcp_port: int,
):
# this test checks that even if there is an error in the response, the server
# can successfully send a 500 error back to the client
async def app(scope, receive, send):
assert scope["type"] == "websocket"
assert "websocket.http.response" in scope["extensions"]

# Pull up first recv message.
message = await receive()
assert message["type"] == "websocket.connect"

message = {
"type": "websocket.http.response.start",
"status": 700, # invalid status code
"headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")],
}
await send(message)
message = {
"type": "websocket.http.response.body",
"body": b"",
}
await send(message)

async def websocket_session(url):
response = await wsresponse(url)
assert response.status_code == 500
assert response.content == b"Internal Server Error"

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")


@pytest.mark.anyio
async def test_server_reject_connection_with_body_nolength(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
unused_tcp_port: int,
):
# test that the server can send a response with a body but no content-length
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "websocket"
assert "extensions" in scope
assert "websocket.http.response" in scope["extensions"]

# Pull up first recv message.
message = await receive()
assert message["type"] == "websocket.connect"

await send(
{
"type": "websocket.http.response.start",
"status": 403,
"headers": [],
}
)
await send({"type": "websocket.http.response.body", "body": b"hardbody"})

async def websocket_session(url):
response = await wsresponse(url)
assert response.status_code == 403
assert response.content == b"hardbody"
if ws_protocol_cls == WSProtocol: # pragma: no cover
# wsproto automatically makes the message chunked
assert response.headers["transfer-encoding"] == "chunked"
else: # pragma: no cover
# websockets automatically adds a content-length
assert response.headers["content-length"] == "8"

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")


@pytest.mark.anyio
async def test_server_reject_connection_with_invalid_msg(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
unused_tcp_port: int,
):
async def app(scope, receive, send):
assert scope["type"] == "websocket"
assert "websocket.http.response" in scope["extensions"]

# Pull up first recv message.
message = await receive()
assert message["type"] == "websocket.connect"

message = {
"type": "websocket.http.response.start",
"status": 404,
"headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")],
}
await send(message)
# send invalid message. This will raise an exception here
await send(message)

async def websocket_session(url):
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
async with websockets.client.connect(url):
pass # pragma: no cover
except Exception:
pass
assert exc_info.value.status_code == 404

config = Config(
app=app,
Expand All @@ -976,6 +1223,100 @@ async def websocket_session(url: str):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")


@pytest.mark.anyio
async def test_server_reject_connection_with_missing_body(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
unused_tcp_port: int,
):
async def app(scope, receive, send):
assert scope["type"] == "websocket"
assert "websocket.http.response" in scope["extensions"]

# Pull up first recv message.
message = await receive()
assert message["type"] == "websocket.connect"

message = {
"type": "websocket.http.response.start",
"status": 404,
"headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")],
}
await send(message)
# no further message

async def websocket_session(url):
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
async with websockets.client.connect(url):
pass # pragma: no cover
assert exc_info.value.status_code == 404

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")


@pytest.mark.anyio
async def test_server_multiple_websocket_http_response_start_events(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
unused_tcp_port: int,
):
"""
The server should raise an exception if it sends multiple
websocket.http.response.start events.
"""
exception_message: typing.Optional[str] = None

async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
nonlocal exception_message
assert scope["type"] == "websocket"
assert "extensions" in scope
assert "websocket.http.response" in scope["extensions"]

# Pull up first recv message.
message = await receive()
assert message["type"] == "websocket.connect"

start_event: WebSocketResponseStartEvent = {
"type": "websocket.http.response.start",
"status": 404,
"headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")],
}
await send(start_event)
try:
await send(start_event)
except Exception as exc:
exception_message = str(exc)

async def websocket_session(url: str):
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
async with websockets.client.connect(url):
pass
assert exc_info.value.status_code == 404

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")

assert exception_message == (
"Expected ASGI message 'websocket.http.response.body' but got "
"'websocket.http.response.start'."
)


@pytest.mark.anyio
async def test_server_can_read_messages_in_buffer_after_close(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
Expand Down
Loading

0 comments on commit 6568184

Please sign in to comment.