From 01a37c603d55605e0d6f21b7d43828f4738c7a3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 21 Sep 2024 13:15:59 +0300 Subject: [PATCH] Fixed TaskGroup and CancelScope exit issues on asyncio (#774) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ganden Schaffner --- .github/workflows/test.yml | 6 +- docs/versionhistory.rst | 16 +++ pyproject.toml | 5 +- src/anyio/_backends/_asyncio.py | 220 +++++++++++++++++++------------ src/anyio/_backends/_trio.py | 21 +-- src/anyio/_core/_fileio.py | 3 +- src/anyio/_core/_signals.py | 6 +- src/anyio/_core/_streams.py | 4 +- src/anyio/_core/_subprocesses.py | 21 +-- src/anyio/abc/_eventloop.py | 8 +- src/anyio/abc/_sockets.py | 8 +- src/anyio/from_thread.py | 16 ++- src/anyio/pytest_plugin.py | 4 +- src/anyio/streams/tls.py | 6 +- tests/streams/test_stapled.py | 3 +- tests/streams/test_tls.py | 6 +- tests/test_from_thread.py | 4 +- tests/test_signals.py | 2 +- tests/test_sockets.py | 10 +- tests/test_subprocesses.py | 53 ++++---- tests/test_taskgroups.py | 211 +++++++++++++++++++++++++---- tests/test_typedattr.py | 3 +- 22 files changed, 427 insertions(+), 209 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6e3783e6..9c54c39f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -61,14 +61,14 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13", pypy-3.10] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", pypy-3.10] include: - os: macos-latest - python-version: "3.8" + python-version: "3.9" - os: macos-latest python-version: "3.12" - os: windows-latest - python-version: "3.8" + python-version: "3.9" - os: windows-latest python-version: "3.12" runs-on: ${{ matrix.os }} diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 81da27ee..ef7ac30c 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -3,6 +3,22 @@ Version history This library adheres to `Semantic Versioning 2.0 `_. +**UNRELEASED** + +- Dropped support for Python 3.8 + (as `#698 `_ cannot be resolved + without cancel message support) +- Fixed 100% CPU use on asyncio while waiting for an exiting task group to finish while + said task group is within a cancelled cancel scope + (`#695 `_) +- Fixed cancel scopes on asyncio not reraising ``CancelledError`` on exit while the + enclosing cancel scope has been effectively cancelled + (`#698 `_) +- Fixed asyncio task groups not yielding control to the event loop at exit if there were + no child tasks to wait on +- Fixed inconsistent task uncancellation with asyncio cancel scopes belonging to a + task group when said task group has child tasks running + **4.5.0** - Improved the performance of ``anyio.Lock`` and ``anyio.Semaphore`` on asyncio (even up diff --git a/pyproject.toml b/pyproject.toml index 3bea40c1..4e726f4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,14 +19,13 @@ classifiers = [ "Typing :: Typed", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", ] -requires-python = ">= 3.8" +requires-python = ">= 3.9" dependencies = [ "exceptiongroup >= 1.0.2; python_version < '3.11'", "idna >= 2.8", @@ -128,7 +127,7 @@ show_missing = true [tool.tox] legacy_tox_ini = """ [tox] -envlist = pre-commit, py38, py39, py310, py311, py312, py313, pypy3 +envlist = pre-commit, py39, py310, py311, py312, py313, pypy3 skip_missing_interpreters = true minversion = 4.0.0 diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 0d4cdf65..9342fab8 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -20,9 +20,18 @@ ) from asyncio.base_events import _run_until_complete_cb # type: ignore[attr-defined] from collections import OrderedDict, deque -from collections.abc import AsyncIterator, Iterable +from collections.abc import ( + AsyncGenerator, + AsyncIterator, + Awaitable, + Callable, + Collection, + Coroutine, + Iterable, + Sequence, +) from concurrent.futures import Future -from contextlib import suppress +from contextlib import AbstractContextManager, suppress from contextvars import Context, copy_context from dataclasses import dataclass from functools import partial, wraps @@ -42,15 +51,7 @@ from typing import ( IO, Any, - AsyncGenerator, - Awaitable, - Callable, - Collection, - ContextManager, - Coroutine, Optional, - Sequence, - Tuple, TypeVar, cast, ) @@ -358,6 +359,14 @@ def _task_started(task: asyncio.Task) -> bool: # +def is_anyio_cancellation(exc: CancelledError) -> bool: + return ( + bool(exc.args) + and isinstance(exc.args[0], str) + and exc.args[0].startswith("Cancelled by cancel scope ") + ) + + class CancelScope(BaseCancelScope): def __new__( cls, *, deadline: float = math.inf, shield: bool = False @@ -444,35 +453,77 @@ def __exit__( host_task_state.cancel_scope = self._parent_scope - # Restart the cancellation effort in the closest directly cancelled parent - # scope if this one was shielded - self._restart_cancellation_in_parent() + # Undo all cancellations done by this scope + if self._cancelling is not None: + while self._cancel_calls: + self._cancel_calls -= 1 + if self._host_task.uncancel() <= self._cancelling: + break - if self._cancel_called and exc_val is not None: + # We only swallow the exception iff it was an AnyIO CancelledError, either + # directly as exc_val or inside an exception group and there are no cancelled + # parent cancel scopes visible to us here + not_swallowed_exceptions = 0 + swallow_exception = False + if exc_val is not None: for exc in iterate_exceptions(exc_val): - if isinstance(exc, CancelledError): - self._cancelled_caught = self._uncancel(exc) - if self._cancelled_caught: - break + if self._cancel_called and isinstance(exc, CancelledError): + if not (swallow_exception := self._uncancel(exc)): + not_swallowed_exceptions += 1 + else: + not_swallowed_exceptions += 1 + + # Restart the cancellation effort in the closest visible, cancelled parent + # scope if necessary + self._restart_cancellation_in_parent() + return swallow_exception and not not_swallowed_exceptions - return self._cancelled_caught + @property + def _effectively_cancelled(self) -> bool: + cancel_scope: CancelScope | None = self + while cancel_scope is not None: + if cancel_scope._cancel_called: + return True - return None + if cancel_scope.shield: + return False + + cancel_scope = cancel_scope._parent_scope + + return False + + @property + def _parent_cancellation_is_visible_to_us(self) -> bool: + return ( + self._parent_scope is not None + and not self.shield + and self._parent_scope._effectively_cancelled + ) def _uncancel(self, cancelled_exc: CancelledError) -> bool: - if sys.version_info < (3, 9) or self._host_task is None: + if self._host_task is None: self._cancel_calls = 0 return True - # Undo all cancellations done by this scope - if self._cancelling is not None: - while self._cancel_calls: - self._cancel_calls -= 1 - if self._host_task.uncancel() <= self._cancelling: - return True + while True: + if is_anyio_cancellation(cancelled_exc): + # Only swallow the cancellation exception if it's an AnyIO cancel + # exception and there are no other cancel scopes down the line pending + # cancellation + self._cancelled_caught = ( + self._effectively_cancelled + and not self._parent_cancellation_is_visible_to_us + ) + return self._cancelled_caught - self._cancel_calls = 0 - return f"Cancelled by cancel scope {id(self):x}" in cancelled_exc.args + # Sometimes third party frameworks catch a CancelledError and raise a new + # one, so as a workaround we have to look at the previous ones in + # __context__ too for a matching cancel message + if isinstance(cancelled_exc.__context__, CancelledError): + cancelled_exc = cancelled_exc.__context__ + continue + + return False def _timeout(self) -> None: if self._deadline != math.inf: @@ -496,19 +547,17 @@ def _deliver_cancellation(self, origin: CancelScope) -> bool: should_retry = False current = current_task() for task in self._tasks: + should_retry = True if task._must_cancel: # type: ignore[attr-defined] continue # The task is eligible for cancellation if it has started - should_retry = True if task is not current and (task is self._host_task or _task_started(task)): waiter = task._fut_waiter # type: ignore[attr-defined] if not isinstance(waiter, asyncio.Future) or not waiter.done(): - origin._cancel_calls += 1 - if sys.version_info >= (3, 9): - task.cancel(f"Cancelled by cancel scope {id(origin):x}") - else: - task.cancel() + task.cancel(f"Cancelled by cancel scope {id(origin):x}") + if task is origin._host_task: + origin._cancel_calls += 1 # Deliver cancellation to child scopes that aren't shielded or running their own # cancellation callbacks @@ -546,17 +595,6 @@ def _restart_cancellation_in_parent(self) -> None: scope = scope._parent_scope - def _parent_cancelled(self) -> bool: - # Check whether any parent has been cancelled - cancel_scope = self._parent_scope - while cancel_scope is not None and not cancel_scope._shield: - if cancel_scope._cancel_called: - return True - else: - cancel_scope = cancel_scope._parent_scope - - return False - def cancel(self) -> None: if not self._cancel_called: if self._timeout_handle: @@ -663,38 +701,50 @@ async def __aexit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> bool | None: - ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb) if exc_val is not None: self.cancel_scope.cancel() if not isinstance(exc_val, CancelledError): self._exceptions.append(exc_val) - cancelled_exc_while_waiting_tasks: CancelledError | None = None - while self._tasks: - try: - await asyncio.wait(self._tasks) - except CancelledError as exc: - # This task was cancelled natively; reraise the CancelledError later - # unless this task was already interrupted by another exception - self.cancel_scope.cancel() - if cancelled_exc_while_waiting_tasks is None: - cancelled_exc_while_waiting_tasks = exc + try: + if self._tasks: + with CancelScope() as wait_scope: + while self._tasks: + try: + await asyncio.wait(self._tasks) + except CancelledError as exc: + # Shield the scope against further cancellation attempts, + # as they're not productive (#695) + wait_scope.shield = True + self.cancel_scope.cancel() + + # Set exc_val from the cancellation exception if it was + # previously unset. However, we should not replace a native + # cancellation exception with one raise by a cancel scope. + if exc_val is None or ( + isinstance(exc_val, CancelledError) + and not is_anyio_cancellation(exc) + ): + exc_val = exc + else: + # If there are no child tasks to wait on, run at least one checkpoint + # anyway + await AsyncIOBackend.cancel_shielded_checkpoint() - self._active = False - if self._exceptions: - raise BaseExceptionGroup( - "unhandled errors in a TaskGroup", self._exceptions - ) + self._active = False + if self._exceptions: + raise BaseExceptionGroup( + "unhandled errors in a TaskGroup", self._exceptions + ) + elif exc_val: + raise exc_val + except BaseException as exc: + if self.cancel_scope.__exit__(type(exc), exc, exc.__traceback__): + return True - # Raise the CancelledError received while waiting for child tasks to exit, - # unless the context manager itself was previously exited with another - # exception, or if any of the child tasks raised an exception other than - # CancelledError - if cancelled_exc_while_waiting_tasks: - if exc_val is None or ignore_exception: - raise cancelled_exc_while_waiting_tasks + raise - return ignore_exception + return self.cancel_scope.__exit__(exc_type, exc_val, exc_tb) def _spawn( self, @@ -730,7 +780,7 @@ def task_done(_task: asyncio.Task) -> None: if not isinstance(exc, CancelledError): self._exceptions.append(exc) - if not self.cancel_scope._parent_cancelled(): + if not self.cancel_scope._effectively_cancelled: self.cancel_scope.cancel() else: task_status_future.set_exception(exc) @@ -806,7 +856,7 @@ async def start( # Threads # -_Retval_Queue_Type = Tuple[Optional[T_Retval], Optional[BaseException]] +_Retval_Queue_Type = tuple[Optional[T_Retval], Optional[BaseException]] class WorkerThread(Thread): @@ -955,7 +1005,7 @@ class Process(abc.Process): _stderr: StreamReaderWrapper | None async def aclose(self) -> None: - with CancelScope(shield=True): + with CancelScope(shield=True) as scope: if self._stdin: await self._stdin.aclose() if self._stdout: @@ -963,14 +1013,14 @@ async def aclose(self) -> None: if self._stderr: await self._stderr.aclose() - try: - await self.wait() - except BaseException: - self.kill() - with CancelScope(shield=True): + scope.shield = False + try: await self.wait() - - raise + except BaseException: + scope.shield = True + self.kill() + await self.wait() + raise async def wait(self) -> int: return await self._process.wait() @@ -2015,9 +2065,7 @@ def has_pending_cancellation(self) -> bool: if task_state := _task_states.get(task): if cancel_scope := task_state.cancel_scope: - return cancel_scope.cancel_called or ( - not cancel_scope.shield and cancel_scope._parent_cancelled() - ) + return cancel_scope._effectively_cancelled return False @@ -2111,7 +2159,7 @@ async def _call_in_runner_task( ) -> T_Retval: if not self._runner_task: self._send_stream, receive_stream = create_memory_object_stream[ - Tuple[Awaitable[Any], asyncio.Future] + tuple[Awaitable[Any], asyncio.Future] ](1) self._runner_task = self.get_loop().create_task( self._run_tests_and_fixtures(receive_stream) @@ -2473,7 +2521,7 @@ async def connect_tcp( cls, host: str, port: int, local_address: IPSockAddrType | None = None ) -> abc.SocketStream: transport, protocol = cast( - Tuple[asyncio.Transport, StreamProtocol], + tuple[asyncio.Transport, StreamProtocol], await get_running_loop().create_connection( StreamProtocol, host, port, local_addr=local_address ), @@ -2652,7 +2700,7 @@ def current_default_thread_limiter(cls) -> CapacityLimiter: @classmethod def open_signal_receiver( cls, *signals: Signals - ) -> ContextManager[AsyncIterator[Signals]]: + ) -> AbstractContextManager[AsyncIterator[Signals]]: return _SignalReceiver(signals) @classmethod diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index 9b8369d4..de2189ce 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -7,8 +7,18 @@ import sys import types import weakref -from collections.abc import AsyncIterator, Iterable +from collections.abc import ( + AsyncGenerator, + AsyncIterator, + Awaitable, + Callable, + Collection, + Coroutine, + Iterable, + Sequence, +) from concurrent.futures import Future +from contextlib import AbstractContextManager from dataclasses import dataclass from functools import partial from io import IOBase @@ -19,15 +29,8 @@ from typing import ( IO, Any, - AsyncGenerator, - Awaitable, - Callable, - Collection, - ContextManager, - Coroutine, Generic, NoReturn, - Sequence, TypeVar, cast, overload, @@ -1273,7 +1276,7 @@ def current_default_thread_limiter(cls) -> CapacityLimiter: @classmethod def open_signal_receiver( cls, *signals: Signals - ) -> ContextManager[AsyncIterator[Signals]]: + ) -> AbstractContextManager[AsyncIterator[Signals]]: return _SignalReceiver(signals) @classmethod diff --git a/src/anyio/_core/_fileio.py b/src/anyio/_core/_fileio.py index 9503d944..23ccb0d6 100644 --- a/src/anyio/_core/_fileio.py +++ b/src/anyio/_core/_fileio.py @@ -3,7 +3,7 @@ import os import pathlib import sys -from collections.abc import Callable, Iterable, Iterator, Sequence +from collections.abc import AsyncIterator, Callable, Iterable, Iterator, Sequence from dataclasses import dataclass from functools import partial from os import PathLike @@ -12,7 +12,6 @@ TYPE_CHECKING, Any, AnyStr, - AsyncIterator, Final, Generic, overload, diff --git a/src/anyio/_core/_signals.py b/src/anyio/_core/_signals.py index 115c749b..f3451d30 100644 --- a/src/anyio/_core/_signals.py +++ b/src/anyio/_core/_signals.py @@ -1,13 +1,15 @@ from __future__ import annotations from collections.abc import AsyncIterator +from contextlib import AbstractContextManager from signal import Signals -from typing import ContextManager from ._eventloop import get_async_backend -def open_signal_receiver(*signals: Signals) -> ContextManager[AsyncIterator[Signals]]: +def open_signal_receiver( + *signals: Signals, +) -> AbstractContextManager[AsyncIterator[Signals]]: """ Start receiving operating system signals. diff --git a/src/anyio/_core/_streams.py b/src/anyio/_core/_streams.py index aa6b0c22..6a9814e5 100644 --- a/src/anyio/_core/_streams.py +++ b/src/anyio/_core/_streams.py @@ -1,7 +1,7 @@ from __future__ import annotations import math -from typing import Tuple, TypeVar +from typing import TypeVar from warnings import warn from ..streams.memory import ( @@ -14,7 +14,7 @@ class create_memory_object_stream( - Tuple[MemoryObjectSendStream[T_Item], MemoryObjectReceiveStream[T_Item]], + tuple[MemoryObjectSendStream[T_Item], MemoryObjectReceiveStream[T_Item]], ): """ Create a memory object stream. diff --git a/src/anyio/_core/_subprocesses.py b/src/anyio/_core/_subprocesses.py index 1ac2d549..7ba41a5b 100644 --- a/src/anyio/_core/_subprocesses.py +++ b/src/anyio/_core/_subprocesses.py @@ -160,38 +160,25 @@ async def open_process( child process prior to the execution of the subprocess. (POSIX only) :param pass_fds: sequence of file descriptors to keep open between the parent and child processes. (POSIX only) - :param user: effective user to run the process as (Python >= 3.9; POSIX only) - :param group: effective group to run the process as (Python >= 3.9; POSIX only) - :param extra_groups: supplementary groups to set in the subprocess (Python >= 3.9; - POSIX only) + :param user: effective user to run the process as (POSIX only) + :param group: effective group to run the process as (POSIX only) + :param extra_groups: supplementary groups to set in the subprocess (POSIX only) :param umask: if not negative, this umask is applied in the child process before - running the given command (Python >= 3.9; POSIX only) + running the given command (POSIX only) :return: an asynchronous process object """ kwargs: dict[str, Any] = {} if user is not None: - if sys.version_info < (3, 9): - raise TypeError("the 'user' argument requires Python 3.9 or later") - kwargs["user"] = user if group is not None: - if sys.version_info < (3, 9): - raise TypeError("the 'group' argument requires Python 3.9 or later") - kwargs["group"] = group if extra_groups is not None: - if sys.version_info < (3, 9): - raise TypeError("the 'extra_groups' argument requires Python 3.9 or later") - kwargs["extra_groups"] = group if umask >= 0: - if sys.version_info < (3, 9): - raise TypeError("the 'umask' argument requires Python 3.9 or later") - kwargs["umask"] = umask return await get_async_backend().open_process( diff --git a/src/anyio/abc/_eventloop.py b/src/anyio/abc/_eventloop.py index 2c73bb9f..93d0e9d2 100644 --- a/src/anyio/abc/_eventloop.py +++ b/src/anyio/abc/_eventloop.py @@ -3,7 +3,8 @@ import math import sys from abc import ABCMeta, abstractmethod -from collections.abc import AsyncIterator, Awaitable +from collections.abc import AsyncIterator, Awaitable, Callable, Sequence +from contextlib import AbstractContextManager from os import PathLike from signal import Signals from socket import AddressFamily, SocketKind, socket @@ -11,9 +12,6 @@ IO, TYPE_CHECKING, Any, - Callable, - ContextManager, - Sequence, TypeVar, Union, overload, @@ -352,7 +350,7 @@ def current_default_thread_limiter(cls) -> CapacityLimiter: @abstractmethod def open_signal_receiver( cls, *signals: Signals - ) -> ContextManager[AsyncIterator[Signals]]: + ) -> AbstractContextManager[AsyncIterator[Signals]]: pass @classmethod diff --git a/src/anyio/abc/_sockets.py b/src/anyio/abc/_sockets.py index b321225a..1c6a450c 100644 --- a/src/anyio/abc/_sockets.py +++ b/src/anyio/abc/_sockets.py @@ -8,7 +8,7 @@ from ipaddress import IPv4Address, IPv6Address from socket import AddressFamily from types import TracebackType -from typing import Any, Tuple, TypeVar, Union +from typing import Any, TypeVar, Union from .._core._typedattr import ( TypedAttributeProvider, @@ -19,10 +19,10 @@ from ._tasks import TaskGroup IPAddressType = Union[str, IPv4Address, IPv6Address] -IPSockAddrType = Tuple[str, int] +IPSockAddrType = tuple[str, int] SockAddrType = Union[IPSockAddrType, str] -UDPPacketType = Tuple[bytes, IPSockAddrType] -UNIXDatagramPacketType = Tuple[bytes, str] +UDPPacketType = tuple[bytes, IPSockAddrType] +UNIXDatagramPacketType = tuple[bytes, str] T_Retval = TypeVar("T_Retval") diff --git a/src/anyio/from_thread.py b/src/anyio/from_thread.py index b8785845..93a4cfe8 100644 --- a/src/anyio/from_thread.py +++ b/src/anyio/from_thread.py @@ -3,15 +3,17 @@ import sys from collections.abc import Awaitable, Callable, Generator from concurrent.futures import Future -from contextlib import AbstractContextManager, contextmanager +from contextlib import ( + AbstractAsyncContextManager, + AbstractContextManager, + contextmanager, +) from dataclasses import dataclass, field from inspect import isawaitable from threading import Lock, Thread, get_ident from types import TracebackType from typing import ( Any, - AsyncContextManager, - ContextManager, Generic, TypeVar, cast, @@ -87,7 +89,9 @@ class _BlockingAsyncContextManager(Generic[T_co], AbstractContextManager): type[BaseException] | None, BaseException | None, TracebackType | None ] = (None, None, None) - def __init__(self, async_cm: AsyncContextManager[T_co], portal: BlockingPortal): + def __init__( + self, async_cm: AbstractAsyncContextManager[T_co], portal: BlockingPortal + ): self._async_cm = async_cm self._portal = portal @@ -374,8 +378,8 @@ def task_done(future: Future[T_Retval]) -> None: return f, task_status_future.result() def wrap_async_context_manager( - self, cm: AsyncContextManager[T_co] - ) -> ContextManager[T_co]: + self, cm: AbstractAsyncContextManager[T_co] + ) -> AbstractContextManager[T_co]: """ Wrap an async context manager as a synchronous context manager via this portal. diff --git a/src/anyio/pytest_plugin.py b/src/anyio/pytest_plugin.py index 558c72ec..c9fe1bde 100644 --- a/src/anyio/pytest_plugin.py +++ b/src/anyio/pytest_plugin.py @@ -4,7 +4,7 @@ from collections.abc import Iterator from contextlib import ExitStack, contextmanager from inspect import isasyncgenfunction, iscoroutinefunction -from typing import Any, Dict, Tuple, cast +from typing import Any, cast import pytest import sniffio @@ -27,7 +27,7 @@ def extract_backend_and_options(backend: object) -> tuple[str, dict[str, Any]]: return backend, {} elif isinstance(backend, tuple) and len(backend) == 2: if isinstance(backend[0], str) and isinstance(backend[1], dict): - return cast(Tuple[str, Dict[str, Any]], backend) + return cast(tuple[str, dict[str, Any]], backend) raise TypeError("anyio_backend must be either a string or tuple of (string, dict)") diff --git a/src/anyio/streams/tls.py b/src/anyio/streams/tls.py index e913eedb..83240b4d 100644 --- a/src/anyio/streams/tls.py +++ b/src/anyio/streams/tls.py @@ -7,7 +7,7 @@ from collections.abc import Callable, Mapping from dataclasses import dataclass from functools import wraps -from typing import Any, Tuple, TypeVar +from typing import Any, TypeVar from .. import ( BrokenResourceError, @@ -25,8 +25,8 @@ T_Retval = TypeVar("T_Retval") PosArgsT = TypeVarTuple("PosArgsT") -_PCTRTT = Tuple[Tuple[str, str], ...] -_PCTRTTT = Tuple[_PCTRTT, ...] +_PCTRTT = tuple[tuple[str, str], ...] +_PCTRTTT = tuple[_PCTRTT, ...] class TLSAttribute(TypedAttributeSet): diff --git a/tests/streams/test_stapled.py b/tests/streams/test_stapled.py index d7614314..b032e215 100644 --- a/tests/streams/test_stapled.py +++ b/tests/streams/test_stapled.py @@ -1,8 +1,9 @@ from __future__ import annotations from collections import deque +from collections.abc import Iterable from dataclasses import InitVar, dataclass, field -from typing import Iterable, TypeVar +from typing import TypeVar import pytest diff --git a/tests/streams/test_tls.py b/tests/streams/test_tls.py index 9846e0c1..90307657 100644 --- a/tests/streams/test_tls.py +++ b/tests/streams/test_tls.py @@ -2,9 +2,9 @@ import socket import ssl -from contextlib import ExitStack +from contextlib import AbstractContextManager, ExitStack from threading import Thread -from typing import ContextManager, NoReturn +from typing import NoReturn import pytest from pytest_mock import MockerFixture @@ -210,7 +210,7 @@ def serve_sync() -> None: finally: conn.close() - client_cm: ContextManager = ExitStack() + client_cm: AbstractContextManager = ExitStack() if client_compatible and not server_compatible: client_cm = pytest.raises(BrokenResourceError) diff --git a/tests/test_from_thread.py b/tests/test_from_thread.py index f69f513d..c37614e7 100644 --- a/tests/test_from_thread.py +++ b/tests/test_from_thread.py @@ -4,12 +4,12 @@ import sys import threading import time -from collections.abc import Awaitable, Callable +from collections.abc import AsyncGenerator, Awaitable, Callable from concurrent import futures from concurrent.futures import CancelledError, Future from contextlib import asynccontextmanager, suppress from contextvars import ContextVar -from typing import Any, AsyncGenerator, Literal, NoReturn, TypeVar +from typing import Any, Literal, NoReturn, TypeVar import pytest import sniffio diff --git a/tests/test_signals.py b/tests/test_signals.py index 16861b82..161633d2 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -3,7 +3,7 @@ import os import signal import sys -from typing import AsyncIterable +from collections.abc import AsyncIterable import pytest diff --git a/tests/test_sockets.py b/tests/test_sockets.py index 832ae6bc..42937a36 100644 --- a/tests/test_sockets.py +++ b/tests/test_sockets.py @@ -10,12 +10,13 @@ import tempfile import threading import time +from collections.abc import Generator, Iterable, Iterator from contextlib import suppress from pathlib import Path from socket import AddressFamily from ssl import SSLContext, SSLError from threading import Thread -from typing import Any, Generator, Iterable, Iterator, NoReturn, TypeVar, cast +from typing import Any, NoReturn, TypeVar, cast import psutil import pytest @@ -1158,9 +1159,10 @@ async def handle(stream: SocketStream) -> None: async with stream: await stream.send(b"Hello\n") - async with await create_unix_listener( - socket_path - ) as listener, create_task_group() as tg: + async with ( + await create_unix_listener(socket_path) as listener, + create_task_group() as tg, + ): tg.start_soon(listener.serve, handle) await wait_all_tasks_blocked() diff --git a/tests/test_subprocesses.py b/tests/test_subprocesses.py index b1ff553d..adf029a3 100644 --- a/tests/test_subprocesses.py +++ b/tests/test_subprocesses.py @@ -4,7 +4,6 @@ import platform import sys from collections.abc import Callable -from contextlib import ExitStack from pathlib import Path from subprocess import CalledProcessError from textwrap import dedent @@ -135,9 +134,11 @@ async def test_run_process_connect_to_file(tmp_path: Path) -> None: stdinfile.write_text("Hello, process!\n") stdoutfile = tmp_path / "stdout" stderrfile = tmp_path / "stderr" - with stdinfile.open("rb") as fin, stdoutfile.open("wb") as fout, stderrfile.open( - "wb" - ) as ferr: + with ( + stdinfile.open("rb") as fin, + stdoutfile.open("wb") as fout, + stderrfile.open("wb") as ferr, + ): async with await open_process( [ sys.executable, @@ -271,30 +272,21 @@ async def test_py39_arguments( anyio_backend_name: str, anyio_backend_options: dict[str, Any], ) -> None: - with ExitStack() as stack: - if sys.version_info < (3, 9): - stack.enter_context( - pytest.raises( - TypeError, - match=rf"the {argname!r} argument requires Python 3.9 or later", - ) - ) - - try: - await run_process( - [sys.executable, "-c", "print('hello')"], - **{argname: argvalue_factory()}, - ) - except ValueError as exc: - if ( - "unexpected kwargs" in str(exc) - and anyio_backend_name == "asyncio" - and anyio_backend_options["loop_factory"] - and anyio_backend_options["loop_factory"].__module__ == "uvloop" - ): - pytest.skip(f"the {argname!r} argument is not supported by uvloop yet") + try: + await run_process( + [sys.executable, "-c", "print('hello')"], + **{argname: argvalue_factory()}, + ) + except ValueError as exc: + if ( + "unexpected kwargs" in str(exc) + and anyio_backend_name == "asyncio" + and anyio_backend_options["loop_factory"] + and anyio_backend_options["loop_factory"].__module__ == "uvloop" + ): + pytest.skip(f"the {argname!r} argument is not supported by uvloop yet") - raise + raise async def test_close_early() -> None: @@ -316,9 +308,10 @@ async def test_close_while_reading() -> None: time.sleep(3) """) - async with await open_process( - [sys.executable, "-c", code] - ) as process, create_task_group() as tg: + async with ( + await open_process([sys.executable, "-c", code]) as process, + create_task_group() as tg, + ): assert process.stdout tg.start_soon(process.stdout.aclose) with pytest.raises(ClosedResourceError): diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py index a4603612..31490572 100644 --- a/tests/test_taskgroups.py +++ b/tests/test_taskgroups.py @@ -4,11 +4,13 @@ import math import sys import time +from asyncio import CancelledError from collections.abc import AsyncGenerator, Coroutine, Generator from typing import Any, NoReturn, cast import pytest from exceptiongroup import catch +from pytest_mock import MockerFixture import anyio from anyio import ( @@ -257,6 +259,36 @@ async def taskfunc() -> None: await task +@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +async def test_cancel_with_nested_task_groups(mocker: MockerFixture) -> None: + """Regression test for #695.""" + + async def shield_task() -> None: + with CancelScope(shield=True) as scope: + shielded_cancel_spy = mocker.spy(scope, "_deliver_cancellation") + await sleep(0.5) + + assert len(outer_cancel_spy.call_args_list) < 10 + shielded_cancel_spy.assert_not_called() + + async def middle_task() -> None: + try: + async with create_task_group() as tg: + middle_cancel_spy = mocker.spy(tg.cancel_scope, "_deliver_cancellation") + tg.start_soon(shield_task, name="shield task") + finally: + assert len(middle_cancel_spy.call_args_list) < 10 + assert len(outer_cancel_spy.call_args_list) < 10 + + async with create_task_group() as tg: + outer_cancel_spy = mocker.spy(tg.cancel_scope, "_deliver_cancellation") + tg.start_soon(middle_task, name="middle task") + await wait_all_tasks_blocked() + tg.cancel_scope.cancel() + + assert len(outer_cancel_spy.call_args_list) < 10 + + async def test_start_exception_delivery(anyio_backend_name: str) -> None: def task_fn(*, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None: task_status.started("hello") @@ -394,7 +426,7 @@ async def g() -> NoReturn: async with create_task_group(): await sleep(1) - assert False + pytest.fail("Execution should not reach this point") async with create_task_group() as tg: tg.start_soon(g) @@ -455,19 +487,6 @@ async def test_cancel_before_entering_scope() -> None: pytest.fail("execution should not reach this point") -@pytest.mark.xfail( - sys.version_info < (3, 11), reason="Requires asyncio.Task.cancelling()" -) -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) -async def test_cancel_counter_nested_scopes() -> None: - with CancelScope() as root_scope: - with CancelScope(): - root_scope.cancel() - await sleep(0.5) - - assert not cast(asyncio.Task, asyncio.current_task()).cancelling() - - async def test_exception_group_children() -> None: with pytest.raises(BaseExceptionGroup) as exc: async with create_task_group() as tg: @@ -660,17 +679,92 @@ async def test_cancelled_not_caught() -> None: assert not scope.cancelled_caught +async def test_cancelled_scope_based_checkpoint() -> None: + """Regression test closely related to #698.""" + with CancelScope() as outer_scope: + outer_scope.cancel() + + # The following three lines are a way to implement a checkpoint function. + # See also https://github.com/python-trio/trio/issues/860. + with CancelScope() as inner_scope: + inner_scope.cancel() + await sleep_forever() + + pytest.fail("checkpoint should have raised") + + assert not inner_scope.cancelled_caught + assert outer_scope.cancelled_caught + + +async def test_cancelled_raises_beyond_origin_unshielded() -> None: + with CancelScope() as outer_scope: + with CancelScope() as inner_scope: + inner_scope.cancel() + try: + await checkpoint() + finally: + outer_scope.cancel() + + pytest.fail("checkpoint should have raised") + + pytest.fail("exiting the inner scope should've raised a cancellation error") + + # Here, the outer scope is responsible for the cancellation, so the inner scope + # won't catch the cancellation exception, but the outer scope will + assert not inner_scope.cancelled_caught + assert outer_scope.cancelled_caught + + +async def test_cancelled_raises_beyond_origin_shielded() -> None: + code_between_scopes_was_run = False + with CancelScope() as outer_scope: + with CancelScope(shield=True) as inner_scope: + inner_scope.cancel() + try: + await checkpoint() + finally: + outer_scope.cancel() + + pytest.fail("checkpoint should have raised") + + code_between_scopes_was_run = True + + # Here, the inner scope is the one responsible for cancellation, and given that the + # outer scope was also cancelled, it is not considered to have "caught" the + # cancellation, even though it swallows it, because the inner scope triggered it + assert code_between_scopes_was_run + assert inner_scope.cancelled_caught + assert not outer_scope.cancelled_caught + + +async def test_empty_taskgroup_contains_yield_point() -> None: + """ + Test that a task group yields at exit at least once, even with no child tasks to + wait on. + + """ + outer_task_ran = False + + async def outer_task() -> None: + nonlocal outer_task_ran + outer_task_ran = True + + async with create_task_group() as tg_outer: + for _ in range(2): # this is to make sure Trio actually schedules outer_task() + async with create_task_group(): + tg_outer.start_soon(outer_task) + + assert outer_task_ran + + @pytest.mark.parametrize("anyio_backend", ["asyncio"]) async def test_cancel_host_asyncgen() -> None: done = False async def host_task() -> None: nonlocal done - async with create_task_group() as tg: - with CancelScope(shield=True) as inner_scope: - assert inner_scope.shield - tg.cancel_scope.cancel() - await checkpoint() + with CancelScope() as inner_scope: + inner_scope.cancel() with pytest.raises(get_cancelled_exc_class()): await checkpoint() @@ -780,12 +874,12 @@ async def child(fail: bool) -> None: async def test_cancel_cascade() -> None: async def do_something() -> NoReturn: async with create_task_group() as tg2: - tg2.start_soon(sleep, 1) + tg2.start_soon(sleep, 1, name="sleep") - raise Exception("foo") + pytest.fail("Execution should not reach this point") async with create_task_group() as tg: - tg.start_soon(do_something) + tg.start_soon(do_something, name="do_something") await wait_all_tasks_blocked() tg.cancel_scope.cancel() @@ -970,7 +1064,7 @@ async def g() -> NoReturn: tg2.start_soon(anyio.sleep, 10) await anyio.sleep(1) - assert False + pytest.fail("Execution should not have reached this line") async with anyio.create_task_group() as tg: tg.start_soon(g) @@ -1321,6 +1415,77 @@ async def test_cancel_message_replaced(self) -> None: except asyncio.CancelledError: pytest.fail("Should have swallowed the CancelledError") + async def test_cancel_counter_nested_scopes(self) -> None: + with CancelScope() as root_scope: + with CancelScope(): + root_scope.cancel() + await checkpoint() + + assert not cast(asyncio.Task, asyncio.current_task()).cancelling() + + async def test_uncancel_after_taskgroup_cancelled(self) -> None: + """ + Test that a cancel scope only uncancels the host task as many times as it has + cancelled that specific task, and won't count child task cancellations towards + that amount. + """ + + async def child_task(task_status: TaskStatus[None]) -> None: + async with create_task_group() as tg: + tg.start_soon(sleep, 3) + await wait_all_tasks_blocked() + task_status.started() + + task = asyncio.current_task() + assert task + with pytest.raises(CancelledError): + async with create_task_group() as tg: + await tg.start(child_task) + task.cancel() + + assert task.cancelling() == 1 + + async def test_uncancel_after_group_aexit_native_cancel(self) -> None: + """Closely related to #695.""" + done = anyio.Event() + + async def shield_task() -> None: + with CancelScope(shield=True): + await done.wait() + + async def middle_task() -> None: + async with create_task_group() as tg: + tg.start_soon(shield_task) + + task = asyncio.get_running_loop().create_task(middle_task()) + try: + await wait_all_tasks_blocked() + task.cancel("native 1") + await sleep(0.1) + task.cancel("native 2") + finally: + done.set() + + with pytest.raises(asyncio.CancelledError) as exc: + await task + + # Neither native cancellation should have been uncancelled, and the latest + # cancellation message should be the one coming out of the task group. + assert task.cancelling() == 2 + assert str(exc.value) == "native 2" + + async def test_uncancel_after_child_task_failed(self) -> None: + async def taskfunc() -> None: + raise Exception("dummy error") + + with pytest.raises(ExceptionGroup) as exc_info: + async with create_task_group() as tg: + tg.start_soon(taskfunc) + + assert len(exc_info.value.exceptions) == 1 + assert str(exc_info.value.exceptions[0]) == "dummy error" + assert not cast(asyncio.Task, asyncio.current_task()).cancelling() + async def test_cancel_before_entering_task_group() -> None: with CancelScope() as scope: diff --git a/tests/test_typedattr.py b/tests/test_typedattr.py index 9930996a..48e175d5 100644 --- a/tests/test_typedattr.py +++ b/tests/test_typedattr.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Any, Callable, Mapping +from collections.abc import Mapping +from typing import Any, Callable import pytest