Skip to content

Commit

Permalink
Send a 400 if data is received before the websocket is accepted
Browse files Browse the repository at this point in the history
This is necessary as the data is not known to be either websocket or
http data as the server must accept or reject the connection and hence
it is a bad request.

In practice this is rare as the upgrade request is a GET request which
rarely has body data and most websocket clients wait for acceptance.
  • Loading branch information
pgjones committed May 26, 2024
1 parent d264794 commit 81bbb32
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/hypercorn/protocol/ws_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class FrameTooLargeError(Exception):

class Handshake:
def __init__(self, headers: List[Tuple[bytes, bytes]], http_version: str) -> None:
self.accepted = False
self.http_version = http_version
self.connection_tokens: Optional[List[str]] = None
self.extensions: Optional[List[str]] = None
Expand Down Expand Up @@ -129,6 +130,7 @@ def accept(

headers.append((name, value))

self.accepted = True
return status_code, headers, Connection(ConnectionType.SERVER, extensions)


Expand Down Expand Up @@ -232,6 +234,9 @@ async def handle(self, event: Event) -> None:
self.app, self.config, self.scope, self.app_send
)
await self.app_put({"type": "websocket.connect"})
elif isinstance(event, (Body, Data)) and not self.handshake.accepted:
await self._send_error_response(400)
self.closed = True
elif isinstance(event, (Body, Data)):
self.connection.receive_data(event.data)
await self._handle_events()
Expand Down
29 changes: 29 additions & 0 deletions tests/protocol/test_ws_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,35 @@ async def test_handle_request(stream: WSStream) -> None:
}


@pytest.mark.asyncio
async def test_handle_data_before_acceptance(stream: WSStream) -> None:
await stream.handle(
Request(
stream_id=1,
http_version="2",
headers=[(b"sec-websocket-version", b"13")],
raw_path=b"/?a=b",
method="GET",
)
)
await stream.handle(
Data(
stream_id=1,
data=b"X",
)
)
assert stream.send.call_args_list == [ # type: ignore
call(
Response(
stream_id=1,
headers=[(b"content-length", b"0"), (b"connection", b"close")],
status_code=400,
)
),
call(EndBody(stream_id=1)),
]


@pytest.mark.asyncio
async def test_handle_connection(stream: WSStream) -> None:
await stream.handle(
Expand Down

0 comments on commit 81bbb32

Please sign in to comment.