diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index c9eb2b44..1485a7d3 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -18,6 +18,12 @@ This library adheres to `Semantic Versioning 2.0 `_. - Fixed the return type annotations of ``readinto()`` and ``readinto1()`` methods in the ``anyio.AsyncFile`` class (`#825 `_) +- Fixed ``TaskInfo.has_pending_cancellation()`` on asyncio returning false positives in + cleanup code on Python >= 3.11 + (`#832 `_; PR by @gschaffner) +- Fixed cancelled cancel scopes on asyncio calling ``asyncio.Task.uncancel`` when + propagating a ``CancelledError`` on exit to a cancelled parent scope + (`#790 `_; PR by @gschaffner) **4.6.2** diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index c1fd0d1e..0b7479d2 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -372,11 +372,22 @@ def _task_started(task: asyncio.Task) -> bool: def is_anyio_cancellation(exc: CancelledError) -> bool: - return ( - bool(exc.args) - and isinstance(exc.args[0], str) - and exc.args[0].startswith("Cancelled by cancel scope ") - ) + # Sometimes third party frameworks catch a CancelledError and raise a new one, so as + # a workaround we have to look at the previous ones in __context__ too for a + # matching cancel message + while True: + if ( + exc.args + and isinstance(exc.args[0], str) + and exc.args[0].startswith("Cancelled by cancel scope ") + ): + return True + + if isinstance(exc.__context__, CancelledError): + exc = exc.__context__ + continue + + return False class CancelScope(BaseCancelScope): @@ -397,8 +408,10 @@ def __init__(self, deadline: float = math.inf, shield: bool = False): self._cancel_handle: asyncio.Handle | None = None self._tasks: set[asyncio.Task] = set() self._host_task: asyncio.Task | None = None - self._cancel_calls: int = 0 - self._cancelling: int | None = None + if sys.version_info >= (3, 11): + self._pending_uncancellations: int | None = 0 + else: + self._pending_uncancellations = None def __enter__(self) -> CancelScope: if self._active: @@ -424,8 +437,6 @@ def __enter__(self) -> CancelScope: self._timeout() self._active = True - if sys.version_info >= (3, 11): - self._cancelling = self._host_task.cancelling() # Start cancelling the host task if the scope was cancelled before entering if self._cancel_called: @@ -470,30 +481,41 @@ def __exit__( host_task_state.cancel_scope = self._parent_scope - # Undo all cancellations done by this scope - if self._cancelling is not None: - while self._cancel_calls: - self._cancel_calls -= 1 - if self._host_task.uncancel() <= self._cancelling: - break + # Restart the cancellation effort in the closest visible, cancelled parent + # scope if necessary + self._restart_cancellation_in_parent() # We only swallow the exception iff it was an AnyIO CancelledError, either # directly as exc_val or inside an exception group and there are no cancelled # parent cancel scopes visible to us here - not_swallowed_exceptions = 0 - swallow_exception = False - if exc_val is not None: - for exc in iterate_exceptions(exc_val): - if self._cancel_called and isinstance(exc, CancelledError): - if not (swallow_exception := self._uncancel(exc)): - not_swallowed_exceptions += 1 - else: - not_swallowed_exceptions += 1 + if self._cancel_called and not self._parent_cancellation_is_visible_to_us: + # For each level-cancel() call made on the host task, call uncancel() + while self._pending_uncancellations: + self._host_task.uncancel() + self._pending_uncancellations -= 1 + + # Update cancelled_caught and check for exceptions we must not swallow + cannot_swallow_exc_val = False + if exc_val is not None: + for exc in iterate_exceptions(exc_val): + if isinstance(exc, CancelledError) and is_anyio_cancellation( + exc + ): + self._cancelled_caught = True + else: + cannot_swallow_exc_val = True - # Restart the cancellation effort in the closest visible, cancelled parent - # scope if necessary - self._restart_cancellation_in_parent() - return swallow_exception and not not_swallowed_exceptions + return self._cancelled_caught and not cannot_swallow_exc_val + else: + if self._pending_uncancellations: + assert self._parent_scope is not None + assert self._parent_scope._pending_uncancellations is not None + self._parent_scope._pending_uncancellations += ( + self._pending_uncancellations + ) + self._pending_uncancellations = 0 + + return False finally: self._host_task = None del exc_val @@ -520,31 +542,6 @@ def _parent_cancellation_is_visible_to_us(self) -> bool: and self._parent_scope._effectively_cancelled ) - def _uncancel(self, cancelled_exc: CancelledError) -> bool: - if self._host_task is None: - self._cancel_calls = 0 - return True - - while True: - if is_anyio_cancellation(cancelled_exc): - # Only swallow the cancellation exception if it's an AnyIO cancel - # exception and there are no other cancel scopes down the line pending - # cancellation - self._cancelled_caught = ( - self._effectively_cancelled - and not self._parent_cancellation_is_visible_to_us - ) - return self._cancelled_caught - - # Sometimes third party frameworks catch a CancelledError and raise a new - # one, so as a workaround we have to look at the previous ones in - # __context__ too for a matching cancel message - if isinstance(cancelled_exc.__context__, CancelledError): - cancelled_exc = cancelled_exc.__context__ - continue - - return False - def _timeout(self) -> None: if self._deadline != math.inf: loop = get_running_loop() @@ -576,8 +573,11 @@ def _deliver_cancellation(self, origin: CancelScope) -> bool: waiter = task._fut_waiter # type: ignore[attr-defined] if not isinstance(waiter, asyncio.Future) or not waiter.done(): task.cancel(f"Cancelled by cancel scope {id(origin):x}") - if task is origin._host_task: - origin._cancel_calls += 1 + if ( + task is origin._host_task + and origin._pending_uncancellations is not None + ): + origin._pending_uncancellations += 1 # Deliver cancellation to child scopes that aren't shielded or running their own # cancellation callbacks @@ -2154,12 +2154,11 @@ def has_pending_cancellation(self) -> bool: # If the task isn't around anymore, it won't have a pending cancellation return False - if sys.version_info >= (3, 11): - if task.cancelling(): - return True + if task._must_cancel: # type: ignore[attr-defined] + return True elif ( - isinstance(task._fut_waiter, asyncio.Future) - and task._fut_waiter.cancelled() + isinstance(task._fut_waiter, asyncio.Future) # type: ignore[attr-defined] + and task._fut_waiter.cancelled() # type: ignore[attr-defined] ): return True diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py index 84101e47..1f536940 100644 --- a/tests/test_taskgroups.py +++ b/tests/test_taskgroups.py @@ -673,6 +673,38 @@ async def test_cancel_shielded_scope() -> None: await checkpoint() +async def test_shielded_cleanup_after_cancel() -> None: + """Regression test for #832.""" + with CancelScope() as outer_scope: + outer_scope.cancel() + try: + await checkpoint() + finally: + assert current_effective_deadline() == -math.inf + assert get_current_task().has_pending_cancellation() + + with CancelScope(shield=True): # noqa: ASYNC100 + assert current_effective_deadline() == math.inf + assert not get_current_task().has_pending_cancellation() + + assert current_effective_deadline() == -math.inf + assert get_current_task().has_pending_cancellation() + + +@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +async def test_cleanup_after_native_cancel() -> None: + """Regression test for #832.""" + # See also https://github.com/python/cpython/pull/102815. + task = asyncio.current_task() + assert task + task.cancel() + with pytest.raises(asyncio.CancelledError): + try: + await checkpoint() + finally: + assert not get_current_task().has_pending_cancellation() + + async def test_cancelled_not_caught() -> None: with CancelScope() as scope: # noqa: ASYNC100 scope.cancel() @@ -1488,6 +1520,26 @@ async def taskfunc() -> None: assert str(exc_info.value.exceptions[0]) == "dummy error" assert not cast(asyncio.Task, asyncio.current_task()).cancelling() + async def test_uncancel_cancelled_scope_based_checkpoint(self) -> None: + """See also test_cancelled_scope_based_checkpoint.""" + task = asyncio.current_task() + assert task + + with CancelScope() as outer_scope: + outer_scope.cancel() + + try: + # The following three lines are a way to implement a checkpoint + # function. See also https://github.com/python-trio/trio/issues/860. + with CancelScope() as inner_scope: + inner_scope.cancel() + await sleep_forever() + finally: + assert isinstance(sys.exc_info()[1], asyncio.CancelledError) + assert task.cancelling() + + assert not task.cancelling() + async def test_cancel_before_entering_task_group() -> None: with CancelScope() as scope: