diff --git a/CHANGES/8641.bugfix.rst b/CHANGES/8641.bugfix.rst new file mode 100644 index 00000000000..9c85ac04419 --- /dev/null +++ b/CHANGES/8641.bugfix.rst @@ -0,0 +1,3 @@ +Fixed WebSocket ping tasks being prematurely garbage collected -- by :user:`bdraco`. + +There was a small risk that WebSocket ping tasks would be prematurely garbage collected because the event loop only holds a weak reference to the task. The garbage collection risk has been fixed by holding a strong reference to the task. Additionally, the task is now scheduled eagerly with Python 3.12+ to increase the chance it can be completed immediately and avoid having to hold any references to the task. diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index 516ad586f70..247f62c758e 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -72,6 +72,7 @@ def __init__( self._exception: Optional[BaseException] = None self._compress = compress self._client_notakeover = client_notakeover + self._ping_task: Optional[asyncio.Task[None]] = None self._reset_heartbeat() @@ -80,6 +81,9 @@ def _cancel_heartbeat(self) -> None: if self._heartbeat_cb is not None: self._heartbeat_cb.cancel() self._heartbeat_cb = None + if self._ping_task is not None: + self._ping_task.cancel() + self._ping_task = None def _cancel_pong_response_cb(self) -> None: if self._pong_response_cb is not None: @@ -118,11 +122,6 @@ def _send_heartbeat(self) -> None: ) return - # fire-and-forget a task is not perfect but maybe ok for - # sending ping. Otherwise we need a long-living heartbeat - # task in the class. - loop.create_task(self._writer.ping()) # type: ignore[unused-awaitable] - conn = self._conn timeout_ceil_threshold = ( conn._connector._timeout_ceil_threshold if conn is not None else 5 @@ -131,6 +130,22 @@ def _send_heartbeat(self) -> None: self._cancel_pong_response_cb() self._pong_response_cb = loop.call_at(when, self._pong_not_received) + if sys.version_info >= (3, 12): + # Optimization for Python 3.12, try to send the ping + # immediately to avoid having to schedule + # the task on the event loop. + ping_task = asyncio.Task(self._writer.ping(), loop=loop, eager_start=True) + else: + ping_task = loop.create_task(self._writer.ping()) + + if not ping_task.done(): + self._ping_task = ping_task + ping_task.add_done_callback(self._ping_task_done) + + def _ping_task_done(self, task: "asyncio.Task[None]") -> None: + """Callback for when the ping task completes.""" + self._ping_task = None + def _pong_not_received(self) -> None: if not self._closed: self._set_closed() diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 9f71d147997..ba3332715a6 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -96,12 +96,16 @@ def __init__( self._pong_response_cb: Optional[asyncio.TimerHandle] = None self._compress = compress self._max_msg_size = max_msg_size + self._ping_task: Optional[asyncio.Task[None]] = None def _cancel_heartbeat(self) -> None: self._cancel_pong_response_cb() if self._heartbeat_cb is not None: self._heartbeat_cb.cancel() self._heartbeat_cb = None + if self._ping_task is not None: + self._ping_task.cancel() + self._ping_task = None def _cancel_pong_response_cb(self) -> None: if self._pong_response_cb is not None: @@ -141,11 +145,6 @@ def _send_heartbeat(self) -> None: ) return - # fire-and-forget a task is not perfect but maybe ok for - # sending ping. Otherwise we need a long-living heartbeat - # task in the class. - loop.create_task(self._writer.ping()) # type: ignore[unused-awaitable] - req = self._req timeout_ceil_threshold = ( req._protocol._timeout_ceil_threshold if req is not None else 5 @@ -154,6 +153,22 @@ def _send_heartbeat(self) -> None: self._cancel_pong_response_cb() self._pong_response_cb = loop.call_at(when, self._pong_not_received) + if sys.version_info >= (3, 12): + # Optimization for Python 3.12, try to send the ping + # immediately to avoid having to schedule + # the task on the event loop. + ping_task = asyncio.Task(self._writer.ping(), loop=loop, eager_start=True) + else: + ping_task = loop.create_task(self._writer.ping()) + + if not ping_task.done(): + self._ping_task = ping_task + ping_task.add_done_callback(self._ping_task_done) + + def _ping_task_done(self, task: "asyncio.Task[None]") -> None: + """Callback for when the ping task completes.""" + self._ping_task = None + def _pong_not_received(self) -> None: if self._req is not None and self._req.transport is not None: self._set_closed() diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index 5abaf0fefbf..907ae232e9a 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -1,6 +1,7 @@ import asyncio import sys from typing import Any, NoReturn +from unittest import mock import pytest @@ -727,8 +728,53 @@ async def handler(request): assert isinstance(msg.data, ServerTimeoutError) -async def test_send_recv_compress(aiohttp_client: Any) -> None: - async def handler(request): +async def test_close_websocket_while_ping_inflight( + aiohttp_client: AiohttpClient, +) -> None: + """Test closing the websocket while a ping is in-flight.""" + ping_received = False + + async def handler(request: web.Request) -> NoReturn: + nonlocal ping_received + ws = web.WebSocketResponse(autoping=False) + await ws.prepare(request) + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.BINARY + msg = await ws.receive() + ping_received = msg.type is aiohttp.WSMsgType.PING + await ws.receive() + assert False + + app = web.Application() + app.router.add_route("GET", "/", handler) + + client = await aiohttp_client(app) + resp = await client.ws_connect("/", heartbeat=0.1) + await resp.send_bytes(b"ask") + + cancelled = False + ping_stated = False + + async def delayed_ping() -> None: + nonlocal cancelled, ping_stated + ping_stated = True + try: + await asyncio.sleep(1) + except asyncio.CancelledError: + cancelled = True + raise + + with mock.patch.object(resp._writer, "ping", delayed_ping): + await asyncio.sleep(0.1) + + await resp.close() + await asyncio.sleep(0) + assert ping_stated is True + assert cancelled is True + + +async def test_send_recv_compress(aiohttp_client: AiohttpClient) -> None: + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request)