Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix connecting to npipe://, tcp://, and unix:// urls #8632

Merged
merged 35 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
53d2b3e
Fix connecting to unix:// urls
bdraco Aug 7, 2024
7c01a11
changelog
bdraco Aug 7, 2024
fe4833e
fix "npipe" as well
bdraco Aug 7, 2024
b6cd175
Update CHANGES/8632.bugfix.rst
bdraco Aug 7, 2024
c8ae1f2
Make allowed protocols per connector, allow all in base connector
bdraco Aug 7, 2024
9de9578
Merge remote-tracking branch 'upstream/fix_unix_tcp' into fix_unix_tcp
bdraco Aug 7, 2024
857b6c7
Make allowed protocols per connector, allow all in base connector
bdraco Aug 7, 2024
64292c0
Update aiohttp/connector.py
bdraco Aug 7, 2024
1d79ef2
Make allowed protocols per connector, allow all in base connector
bdraco Aug 7, 2024
d713232
ensure base connector allows all protocols by default
bdraco Aug 7, 2024
83fd6f0
ensure base connector allows all protocols by default
bdraco Aug 7, 2024
9d2fe1a
fixes
bdraco Aug 7, 2024
883c2f6
fixes
bdraco Aug 7, 2024
1b65087
fixes
bdraco Aug 7, 2024
fba214f
type
bdraco Aug 7, 2024
af4cc59
fix tcp:// as well
bdraco Aug 7, 2024
ce827f1
Update CHANGES/8632.bugfix.rst
bdraco Aug 7, 2024
709f824
fix tcp:// as well
bdraco Aug 7, 2024
a0aeb7c
Merge remote-tracking branch 'upstream/fix_unix_tcp' into fix_unix_tcp
bdraco Aug 7, 2024
5967e5a
rename to BASE_PROTOCOL_SCHEMA_SET
bdraco Aug 7, 2024
412e12f
Merge remote-tracking branch 'upstream/master' into fix_unix_tcp
bdraco Aug 7, 2024
71bf4ae
drop now
bdraco Aug 7, 2024
9c95012
drop now
bdraco Aug 7, 2024
ad51c59
Update aiohttp/client.py
Dreamsorcerer Aug 7, 2024
2256884
drop ones that are no longer in the base connector
bdraco Aug 7, 2024
cde13a6
Update aiohttp/connector.py
bdraco Aug 7, 2024
374821b
kiss
bdraco Aug 7, 2024
232d377
preen
bdraco Aug 7, 2024
f7996bf
Update tests/test_client_session.py
bdraco Aug 7, 2024
258eae9
explict unix test
bdraco Aug 7, 2024
4652d6b
Merge remote-tracking branch 'upstream/fix_unix_tcp' into fix_unix_tcp
bdraco Aug 7, 2024
f5d3bb4
Update tests/test_client_session.py
bdraco Aug 7, 2024
83db4a1
fix type
bdraco Aug 7, 2024
5b22c58
Merge remote-tracking branch 'upstream/fix_unix_tcp' into fix_unix_tcp
bdraco Aug 7, 2024
3bbf448
fix type
bdraco Aug 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/8632.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed connecting to ``npipe://``, ``tcp://``, and ``unix://`` urls -- by :user:`bdraco`.
17 changes: 10 additions & 7 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,13 @@
ClientWebSocketResponse,
ClientWSTimeout,
)
from .connector import BaseConnector, NamedPipeConnector, TCPConnector, UnixConnector
from .connector import (
HTTP_AND_EMPTY_SCHEMA_SET,
BaseConnector,
NamedPipeConnector,
TCPConnector,
UnixConnector,
)
from .cookiejar import CookieJar
from .helpers import (
_SENTINEL,
Expand Down Expand Up @@ -210,9 +216,6 @@ class ClientTimeout:

# https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2
IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"})
HTTP_SCHEMA_SET = frozenset({"http", "https", ""})
WS_SCHEMA_SET = frozenset({"ws", "wss"})
ALLOWED_PROTOCOL_SCHEMA_SET = HTTP_SCHEMA_SET | WS_SCHEMA_SET

Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
_RetType = TypeVar("_RetType")
_CharsetResolver = Callable[[ClientResponse, bytes], str]
Expand Down Expand Up @@ -466,7 +469,8 @@ async def _request(
except ValueError as e:
raise InvalidUrlClientError(str_or_url) from e

if url.scheme not in ALLOWED_PROTOCOL_SCHEMA_SET:
assert self._connector is not None
if url.scheme not in self._connector.allowed_protocol_schema_set:
raise NonHttpUrlClientError(url)

skip_headers = set(self._skip_auto_headers)
Expand Down Expand Up @@ -597,7 +601,6 @@ async def _request(
real_timeout.connect,
ceil_threshold=real_timeout.ceil_threshold,
):
assert self._connector is not None
conn = await self._connector.connect(
req, traces=traces, timeout=real_timeout
)
Expand Down Expand Up @@ -693,7 +696,7 @@ async def _request(
) from e

scheme = parsed_redirect_url.scheme
if scheme not in HTTP_SCHEMA_SET:
if scheme not in HTTP_AND_EMPTY_SCHEMA_SET:
resp.close()
raise NonHttpUrlRedirectClientError(r_url)
elif not scheme:
Expand Down
16 changes: 16 additions & 0 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@
SSLContext = object # type: ignore[misc,assignment]


EMPTY_SCHEMA_SET = frozenset({""})
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
HTTP_SCHEMA_SET = frozenset({"http", "https"})
WS_SCHEMA_SET = frozenset({"ws", "wss"})

HTTP_AND_EMPTY_SCHEMA_SET = HTTP_SCHEMA_SET | EMPTY_SCHEMA_SET
HIGH_LEVEL_SCHEMA_SET = HTTP_AND_EMPTY_SCHEMA_SET | WS_SCHEMA_SET


__all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector")


Expand Down Expand Up @@ -190,6 +198,8 @@ class BaseConnector:
# abort transport after 2 seconds (cleanup broken connections)
_cleanup_closed_period = 2.0

allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET

def __init__(
self,
*,
Expand Down Expand Up @@ -741,6 +751,8 @@ class TCPConnector(BaseConnector):
loop - Optional event loop.
"""

allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"})

def __init__(
self,
*,
Expand Down Expand Up @@ -1342,6 +1354,8 @@ class UnixConnector(BaseConnector):
loop - Optional event loop.
"""

allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"unix"})

def __init__(
self,
path: str,
Expand Down Expand Up @@ -1396,6 +1410,8 @@ class NamedPipeConnector(BaseConnector):
loop - Optional event loop.
"""

allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"npipe"})

def __init__(
self,
path: str,
Expand Down
71 changes: 67 additions & 4 deletions tests/test_client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from aiohttp.client import ClientSession
from aiohttp.client_proto import ResponseHandler
from aiohttp.client_reqrep import ClientRequest, ConnectionKey
from aiohttp.connector import BaseConnector, Connection, TCPConnector
from aiohttp.connector import BaseConnector, Connection, TCPConnector, UnixConnector
from aiohttp.pytest_plugin import AiohttpClient
from aiohttp.test_utils import make_mocked_coro
from aiohttp.tracing import Trace
Expand Down Expand Up @@ -536,15 +536,78 @@ async def test_ws_connect_allowed_protocols(
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.url = URL(f"{protocol}://example.com")
resp.url = URL(f"{protocol}://example")
resp.cookies = SimpleCookie()
resp.start = mock.AsyncMock()

req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True)
req_factory = mock.Mock(return_value=req)
req.send = mock.AsyncMock(return_value=resp)
# BaseConnector allows all high level protocols by default
connector = BaseConnector()

session = await create_session(request_class=req_factory)
session = await create_session(connector=connector, request_class=req_factory)

connections = []
assert session._connector is not None
original_connect = session._connector.connect

async def connect(
req: ClientRequest, traces: List[Trace], timeout: aiohttp.ClientTimeout
) -> Connection:
conn = await original_connect(req, traces, timeout)
connections.append(conn)
return conn

async def create_connection(
req: object, traces: object, timeout: object
) -> ResponseHandler:
return create_mocked_conn()

connector = session._connector
with mock.patch.object(connector, "connect", connect), mock.patch.object(
connector, "_create_connection", create_connection
), mock.patch.object(connector, "_release"), mock.patch(
"aiohttp.client.os"
) as m_os:
m_os.urandom.return_value = key_data
await session.ws_connect(f"{protocol}://example")

# normally called during garbage collection. triggers an exception
# if the connection wasn't already closed
for c in connections:
c.close()
c.__del__()
Dismissed Show dismissed Hide dismissed

await session.close()


@pytest.mark.parametrize("protocol", ["http", "https", "ws", "wss", "unix"])
async def test_ws_connect_unix_socket_allowed_protocols(
create_session: Callable[..., Awaitable[ClientSession]],
create_mocked_conn: Callable[[], ResponseHandler],
protocol: str,
ws_key: bytes,
key_data: bytes,
) -> None:
resp = mock.create_autospec(aiohttp.ClientResponse)
resp.status = 101
resp.headers = {
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.url = URL(f"{protocol}://example")
resp.cookies = SimpleCookie()
resp.start = mock.AsyncMock()

req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True)
req_factory = mock.Mock(return_value=req)
req.send = mock.AsyncMock(return_value=resp)
# UnixConnector allows all high level protocols by default and unix sockets
session = await create_session(
connector=UnixConnector(path=""), request_class=req_factory
)

connections = []
assert session._connector is not None
Expand All @@ -569,7 +632,7 @@ async def create_connection(
"aiohttp.client.os"
) as m_os:
m_os.urandom.return_value = key_data
await session.ws_connect(f"{protocol}://example.com")
await session.ws_connect(f"{protocol}://example")

# normally called during garbage collection. triggers an exception
# if the connection wasn't already closed
Expand Down
34 changes: 34 additions & 0 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1636,6 +1636,11 @@ async def test_tcp_connector_ctor(loop: asyncio.AbstractEventLoop) -> None:
assert conn.family == 0


async def test_tcp_connector_allowed_protocols(loop: asyncio.AbstractEventLoop) -> None:
conn = aiohttp.TCPConnector()
assert conn.allowed_protocol_schema_set == {"", "tcp", "http", "https", "ws", "wss"}


async def test_invalid_ssl_param() -> None:
with pytest.raises(TypeError):
aiohttp.TCPConnector(ssl=object()) # type: ignore[arg-type]
Expand Down Expand Up @@ -1819,6 +1824,19 @@ async def test_ctor_with_default_loop(loop: asyncio.AbstractEventLoop) -> None:
assert loop is conn._loop


async def test_base_connector_allows_high_level_protocols(
loop: asyncio.AbstractEventLoop,
) -> None:
conn = aiohttp.BaseConnector()
assert conn.allowed_protocol_schema_set == {
"",
"http",
"https",
"ws",
"wss",
}


async def test_connect_with_limit(
loop: asyncio.AbstractEventLoop, key: ConnectionKey
) -> None:
Expand Down Expand Up @@ -2621,6 +2639,14 @@ async def handler(request: web.Request) -> web.Response:

connector = aiohttp.UnixConnector(unix_sockname)
assert unix_sockname == connector.path
assert connector.allowed_protocol_schema_set == {
"",
"http",
"https",
"ws",
"wss",
"unix",
}

session = ClientSession(connector=connector)
r = await session.get(url)
Expand Down Expand Up @@ -2648,6 +2674,14 @@ async def handler(request: web.Request) -> web.Response:

connector = aiohttp.NamedPipeConnector(pipe_name)
assert pipe_name == connector.path
assert connector.allowed_protocol_schema_set == {
"",
"http",
"https",
"ws",
"wss",
"npipe",
}

session = ClientSession(connector=connector)
r = await session.get(url)
Expand Down
Loading