diff --git a/CHANGES/8481.bugfix.rst b/CHANGES/8481.bugfix.rst new file mode 100644 index 00000000000..b185780174e --- /dev/null +++ b/CHANGES/8481.bugfix.rst @@ -0,0 +1,2 @@ +Fixed the incorrect rejection of ``ws://`` and ``wss://`` urls +-- by :user:` AraHaan`. diff --git a/aiohttp/client.py b/aiohttp/client.py index 4c0ad893bbb..25026920206 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -210,6 +210,8 @@ class ClientTimeout: # https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2 IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"}) HTTP_SCHEMA_SET = frozenset({"http", "https", ""}) +WS_SCHEMA_SET = frozenset({"ws", "wss"}) +ALLOWED_PROTOCOL_SCHEMA_SET = HTTP_SCHEMA_SET | WS_SCHEMA_SET _RetType = TypeVar("_RetType") _CharsetResolver = Callable[[ClientResponse, bytes], str] @@ -452,7 +454,7 @@ async def _request( except ValueError as e: raise InvalidUrlClientError(str_or_url) from e - if url.scheme not in HTTP_SCHEMA_SET: + if url.scheme not in ALLOWED_PROTOCOL_SCHEMA_SET: raise NonHttpUrlClientError(url) skip_headers = set(self._skip_auto_headers) diff --git a/tests/conftest.py b/tests/conftest.py index fb294bd2cad..6310a1dd765 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,11 @@ # type: ignore import asyncio +import base64 import os import socket import ssl import sys -from hashlib import md5, sha256 +from hashlib import md5, sha1, sha256 from pathlib import Path from tempfile import TemporaryDirectory from typing import Any, List @@ -13,6 +14,7 @@ import pytest +from aiohttp.http import WS_KEY from aiohttp.test_utils import loop_context try: @@ -218,3 +220,18 @@ def start_connection(): spec_set=True, ) as start_connection_mock: yield start_connection_mock + + +@pytest.fixture +def key_data(): + return os.urandom(16) + + +@pytest.fixture +def key(key_data: Any): + return base64.b64encode(key_data) + + +@pytest.fixture +def ws_key(key: Any): + return base64.b64encode(sha1(key + WS_KEY).digest()).decode() diff --git a/tests/test_client_session.py b/tests/test_client_session.py index e08e7d10414..0125f611b98 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -467,6 +467,60 @@ async def create_connection(req, traces, timeout): c.__del__() +@pytest.mark.parametrize("protocol", ["http", "https", "ws", "wss"]) +async def test_ws_connect_allowed_protocols( + create_session: Any, + create_mocked_conn: Any, + protocol: str, + ws_key: Any, + key_data: Any, +) -> None: + resp = mock.create_autospec(aiohttp.ClientResponse) + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + } + resp.url = URL(f"{protocol}://example.com") + resp.cookies = SimpleCookie() + resp.start = mock.AsyncMock() + + req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True) + req_factory = mock.Mock(return_value=req) + req.send = mock.AsyncMock(return_value=resp) + + session = await create_session(request_class=req_factory) + + connections = [] + original_connect = session._connector.connect + + async def connect(req, traces, timeout): + conn = await original_connect(req, traces, timeout) + connections.append(conn) + return conn + + async def create_connection(req, traces, timeout): + return create_mocked_conn() + + connector = session._connector + with mock.patch.object(connector, "connect", connect), mock.patch.object( + connector, "_create_connection", create_connection + ), mock.patch.object(connector, "_release"), mock.patch( + "aiohttp.client.os" + ) as m_os: + m_os.urandom.return_value = key_data + await session.ws_connect(f"{protocol}://example.com") + + # normally called during garbage collection. triggers an exception + # if the connection wasn't already closed + for c in connections: + c.close() + del c + + await session.close() + + async def test_cookie_jar_usage(loop: Any, aiohttp_client: Any) -> None: req_url = None diff --git a/tests/test_client_ws.py b/tests/test_client_ws.py index 06cf2a12066..3e469df208f 100644 --- a/tests/test_client_ws.py +++ b/tests/test_client_ws.py @@ -15,21 +15,6 @@ from aiohttp.test_utils import make_mocked_coro -@pytest.fixture -def key_data(): - return os.urandom(16) - - -@pytest.fixture -def key(key_data: Any): - return base64.b64encode(key_data) - - -@pytest.fixture -def ws_key(key: Any): - return base64.b64encode(hashlib.sha1(key + WS_KEY).digest()).decode() - - async def test_ws_connect(ws_key: Any, loop: Any, key_data: Any) -> None: resp = mock.Mock() resp.status = 101