From 7f714304099f1c74b126c7e986bc6ee7bd388239 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 20 Sep 2024 10:59:07 -0400 Subject: [PATCH] Add async support for cancellation --- pymongo/network_layer.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index eb7e8cd4f0..46805ad1cb 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -139,6 +139,8 @@ def _is_ready(fut: Future) -> None: while read < length: try: read += conn.recv_into(mv[read:]) + if read == 0: + raise OSError("connection closed") except BLOCKING_IO_ERRORS as exc: fd = conn.fileno() # Check for closed socket. @@ -195,11 +197,20 @@ def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None: sock.sendall(buf) +async def _poll_cancellation(conn: AsyncConnection) -> None: + while True: + if conn.cancel_context.cancelled: + return + + await asyncio.sleep(_POLL_TIMEOUT) + + async def async_receive_data( conn: AsyncConnection, length: int, deadline: Optional[float] ) -> memoryview: sock = conn.conn sock_timeout = sock.gettimeout() + timeout: Optional[Union[float, int]] if deadline: # When the timeout has expired perform one final check to # see if the socket is readable. This helps avoid spurious @@ -210,14 +221,22 @@ async def async_receive_data( sock.settimeout(0.0) loop = asyncio.get_event_loop() + cancellation_task = asyncio.create_task(_poll_cancellation(conn)) try: if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): - return await asyncio.wait_for(_async_receive_ssl(sock, length, loop), timeout=timeout) + read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type] else: - return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type] - except asyncio.TimeoutError as exc: - # Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands. - raise socket.timeout("timed out") from exc + read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type] + tasks = [read_task, cancellation_task] + result = await asyncio.wait(tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED) + if len(result[1]) == 2: + raise socket.timeout("timed out") + finished = next(iter(result[0])) + next(iter(result[1])).cancel() + if finished == read_task: + return finished.result() # type: ignore[return-value] + else: + raise _OperationCancelled("operation cancelled") finally: sock.settimeout(sock_timeout)