Skip to content

Commit

Permalink
[PR #9600/e6187f6 backport][3.11] Avoid starting connection timeout w…
Browse files Browse the repository at this point in the history
…hen a connection is already available (#9607)
  • Loading branch information
bdraco authored Nov 1, 2024
1 parent 3d01146 commit f892979
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 52 deletions.
3 changes: 3 additions & 0 deletions CHANGES/9600.breaking.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Improved performance of the connector when a connection can be reused -- by :user:`bdraco`.

If ``BaseConnector.connect`` has been subclassed and replaced with custom logic, the ``ceil_timeout`` must be added.
11 changes: 3 additions & 8 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@
DEBUG,
BasicAuth,
TimeoutHandle,
ceil_timeout,
get_env_proxy_for_url,
method_must_be_empty_body,
sentinel,
Expand Down Expand Up @@ -692,13 +691,9 @@ async def _request(

# connection timeout
try:
async with ceil_timeout(
real_timeout.connect,
ceil_threshold=real_timeout.ceil_threshold,
):
conn = await self._connector.connect(
req, traces=traces, timeout=real_timeout
)
conn = await self._connector.connect(
req, traces=traces, timeout=real_timeout
)
except asyncio.TimeoutError as exc:
raise ConnectionTimeoutError(
f"Connection timeout to host {url}"
Expand Down
109 changes: 65 additions & 44 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,41 +512,20 @@ async def connect(
"""Get from pool or create new connection."""
key = req.connection_key
available = self._available_connections(key)
wait_for_conn = available <= 0 or key in self._waiters
if not wait_for_conn and (proto := self._get(key)) is not None:
# If we do not have to wait and we can get a connection from the pool
# we can avoid the timeout ceil logic and directly return the connection
return await self._reused_connection(key, proto, traces)

async with ceil_timeout(timeout.connect, timeout.ceil_threshold):
# Wait if there are no available connections or if there are/were
# waiters (i.e. don't steal connection from a waiter about to wake up)
if wait_for_conn:
await self._wait_for_available_connection(key, traces)
if (proto := self._get(key)) is not None:
return await self._reused_connection(key, proto, traces)

# Wait if there are no available connections or if there are/were
# waiters (i.e. don't steal connection from a waiter about to wake up)
if available <= 0 or key in self._waiters:
fut: asyncio.Future[None] = self._loop.create_future()

# This connection will now count towards the limit.
self._waiters[key].append(fut)

if traces:
for trace in traces:
await trace.send_connection_queued_start()

try:
await fut
except BaseException as e:
if key in self._waiters:
# remove a waiter even if it was cancelled, normally it's
# removed when it's notified
try:
self._waiters[key].remove(fut)
except ValueError: # fut may no longer be in list
pass

raise e
finally:
if key in self._waiters and not self._waiters[key]:
del self._waiters[key]

if traces:
for trace in traces:
await trace.send_connection_queued_end()

proto = self._get(key)
if proto is None:
placeholder = cast(ResponseHandler, _TransportPlaceholder())
self._acquired.add(placeholder)
self._acquired_per_host[key].add(placeholder)
Expand Down Expand Up @@ -574,21 +553,63 @@ async def connect(
if traces:
for trace in traces:
await trace.send_connection_create_end()
else:
if traces:
# Acquire the connection to prevent race conditions with limits
placeholder = cast(ResponseHandler, _TransportPlaceholder())
self._acquired.add(placeholder)
self._acquired_per_host[key].add(placeholder)
for trace in traces:
await trace.send_connection_reuseconn()
self._acquired.remove(placeholder)
self._drop_acquired_per_host(key, placeholder)

return self._acquired_connection(proto, key)

async def _reused_connection(
self, key: "ConnectionKey", proto: ResponseHandler, traces: List["Trace"]
) -> Connection:
if traces:
# Acquire the connection to prevent race conditions with limits
placeholder = cast(ResponseHandler, _TransportPlaceholder())
self._acquired.add(placeholder)
self._acquired_per_host[key].add(placeholder)
for trace in traces:
await trace.send_connection_reuseconn()
self._acquired.remove(placeholder)
self._drop_acquired_per_host(key, placeholder)
return self._acquired_connection(proto, key)

def _acquired_connection(
self, proto: ResponseHandler, key: "ConnectionKey"
) -> Connection:
"""Mark proto as acquired and wrap it in a Connection object."""
self._acquired.add(proto)
self._acquired_per_host[key].add(proto)
return Connection(self, key, proto, self._loop)

async def _wait_for_available_connection(
self, key: "ConnectionKey", traces: List["Trace"]
) -> None:
"""Wait until there is an available connection."""
fut: asyncio.Future[None] = self._loop.create_future()

# This connection will now count towards the limit.
self._waiters[key].append(fut)

if traces:
for trace in traces:
await trace.send_connection_queued_start()

try:
await fut
except BaseException as e:
if key in self._waiters:
# remove a waiter even if it was cancelled, normally it's
# removed when it's notified
with suppress(ValueError):
# fut may no longer be in list
self._waiters[key].remove(fut)

raise e
finally:
if key in self._waiters and not self._waiters[key]:
del self._waiters[key]

if traces:
for trace in traces:
await trace.send_connection_queued_end()

def _get(self, key: "ConnectionKey") -> Optional[ResponseHandler]:
try:
conns = self._conns[key]
Expand Down
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ ssl
SSLContext
startup
subapplication
subclassed
subclasses
subdirectory
submodules
Expand Down

0 comments on commit f892979

Please sign in to comment.