From 04f9ce61b282e5de69aaa3ab5555825c348f64f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sun, 6 Nov 2022 23:08:55 +0200 Subject: [PATCH 01/11] Improved asyncio cancellation semantics The semantics now better match with trio's. --- docs/versionhistory.rst | 2 ++ src/anyio/_backends/_asyncio.py | 30 ++++++++++++------------- src/anyio/from_thread.py | 13 +++++++---- tests/test_taskgroups.py | 39 +++++++++++++++++++++++++++++++++ 4 files changed, 64 insertions(+), 20 deletions(-) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index dccd1b5b..6088aba8 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -42,6 +42,8 @@ This library adheres to `Semantic Versioning 2.0 `_. the event loop to be closed - Fixed ``current_effective_deadline()`` not returning ``-inf`` on asyncio when the currently active cancel scope has been cancelled (PR by Ganden Schaffner) +- Fixed task group not raising a cancellation exception on asyncio at exit if no child + tasks were spawned and an outer cancellation scope had been cancelled before **3.6.1** diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 1c085277..b48350be 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -42,7 +42,6 @@ Coroutine, Deque, Generator, - Iterator, Mapping, Optional, Sequence, @@ -456,14 +455,6 @@ def collapse_exception_group(excgroup: BaseExceptionGroup) -> BaseException: return excgroup -def walk_exception_group(excgroup: BaseExceptionGroup) -> Iterator[BaseException]: - for exc in excgroup.exceptions: - if isinstance(exc, BaseExceptionGroup): - yield from walk_exception_group(exc) - else: - yield exc - - def is_anyio_cancelled_exc(exc: BaseException) -> bool: return isinstance(exc, CancelledError) and not exc.args @@ -490,11 +481,21 @@ async def __aexit__( self.cancel_scope.cancel() self._exceptions.append(exc_val) - while self.cancel_scope._tasks: + if self.cancel_scope._tasks: + while self.cancel_scope._tasks: + try: + await asyncio.wait(self.cancel_scope._tasks) + except asyncio.CancelledError as e: + if exc_val is None: + self.cancel_scope.cancel() + self._exceptions.append(e) + exc_val = e + else: + # No tasks to wait on, but we still need to check for cancellation here try: - await asyncio.wait(self.cancel_scope._tasks) - except asyncio.CancelledError: - self.cancel_scope.cancel() + await AsyncIOBackend.checkpoint() + except CancelledError as e: + self._exceptions.append(e) self._active = False if self._exceptions: @@ -504,9 +505,6 @@ async def __aexit__( # If any exceptions other than AnyIO cancellation exceptions have been # received, raise those _, exc = group.split(is_anyio_cancelled_exc) - elif all(is_anyio_cancelled_exc(e) for e in walk_exception_group(group)): - # All tasks were cancelled by AnyIO - exc = CancelledError() else: exc = group diff --git a/src/anyio/from_thread.py b/src/anyio/from_thread.py index f5ddc2e3..da8923e8 100644 --- a/src/anyio/from_thread.py +++ b/src/anyio/from_thread.py @@ -148,6 +148,10 @@ async def __aexit__( await self.stop() return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + @property + def _running(self) -> bool: + return self._event_loop_thread_id is not None + def _check_running(self) -> None: if self._event_loop_thread_id is None: raise RuntimeError("This portal is not running") @@ -202,8 +206,11 @@ def callback(f: Future) -> None: if not future.cancelled(): future.set_exception(exc) - # Let base exceptions fall through + # Let base exceptions fall through, but mark the portal as not running, so + # start_blocking_portal() won't try to stop it since BaseException will + # cause that anyway if not isinstance(exc, Exception): + self._event_loop_thread_id = None raise else: if not future.cancelled(): @@ -413,9 +420,7 @@ async def run_portal() -> None: cancel_remaining_tasks = True raise finally: - try: + if not run_future.done() and portal._running: portal.call(portal.stop, cancel_remaining_tasks) - except RuntimeError: - pass run_future.result() diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py index adf2e79d..d5213bcb 100644 --- a/tests/test_taskgroups.py +++ b/tests/test_taskgroups.py @@ -396,6 +396,45 @@ async def test_cancel_before_entering_scope() -> None: pytest.fail("execution should not reach this point") +async def test_cancel_outer_scope_no_tasks() -> None: + """ + Test that a task group raises an exception group containing one cancellation error + from __aexit__() if the outer cancel scope was cancelled. + + """ + with CancelScope() as outer_scope: + try: + async with anyio.create_task_group(): + outer_scope.cancel() + except BaseException as exc: + if not isinstance(exc, get_cancelled_exc_class()): + pytest.fail("should have raised a cancellation exception") + + raise + else: + pytest.fail("should have raised an exception") + + +async def test_cancel_outer_scope_one_task() -> None: + """ + Test that a task group propagates a cancellation error (wrapped in an exception + group) from __aexit__() that was not intended for the task group's cancel scope. + + """ + with CancelScope() as outer_scope: + try: + async with anyio.create_task_group() as tg: + tg.start_soon(sleep, 3) + outer_scope.cancel() + except BaseExceptionGroup as excgrp: + assert len(excgrp.exceptions) == 2 + raise + except get_cancelled_exc_class(): + pytest.fail("task group raised a plain cancellation error") + else: + pytest.fail("should have raised an exception group") + + async def test_exception_group_children() -> None: with pytest.raises(BaseExceptionGroup) as exc: async with create_task_group() as tg: From 91a71641d3b1bfa416b50814b0834ef12ee6143e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sun, 6 Nov 2022 23:39:39 +0200 Subject: [PATCH 02/11] Properly detect cancellation in asyncio CancelScope --- src/anyio/_backends/_asyncio.py | 15 ++++++++++----- tests/test_taskgroups.py | 27 +++++++++++++++------------ 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index b48350be..2b21023b 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -77,7 +77,7 @@ from ..lowlevel import RunVar if sys.version_info < (3, 11): - from exceptiongroup import BaseExceptionGroup, ExceptionGroup + from exceptiongroup import BaseExceptionGroup if sys.version_info >= (3, 8): @@ -263,10 +263,15 @@ def __exit__( self._deliver_cancellation_to_parent() if exc_val is not None: - exceptions = ( - exc_val.exceptions if isinstance(exc_val, ExceptionGroup) else [exc_val] - ) - if all(isinstance(exc, CancelledError) for exc in exceptions): + cancelled = False + if isinstance(exc_val, CancelledError): + cancelled = True + elif isinstance(exc_val, BaseExceptionGroup): + matched, unmatched = exc_val.split(CancelledError) + if matched and not unmatched: + cancelled = True + + if cancelled: if self._timeout_expired: return True elif not self._cancel_called: diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py index d5213bcb..d648a2d3 100644 --- a/tests/test_taskgroups.py +++ b/tests/test_taskgroups.py @@ -421,18 +421,21 @@ async def test_cancel_outer_scope_one_task() -> None: group) from __aexit__() that was not intended for the task group's cancel scope. """ - with CancelScope() as outer_scope: - try: - async with anyio.create_task_group() as tg: - tg.start_soon(sleep, 3) - outer_scope.cancel() - except BaseExceptionGroup as excgrp: - assert len(excgrp.exceptions) == 2 - raise - except get_cancelled_exc_class(): - pytest.fail("task group raised a plain cancellation error") - else: - pytest.fail("should have raised an exception group") + try: + with CancelScope() as outer_scope: + try: + async with anyio.create_task_group() as tg: + tg.start_soon(sleep, 3) + outer_scope.cancel() + except BaseExceptionGroup as excgrp: + assert len(excgrp.exceptions) == 2 + raise + except get_cancelled_exc_class(): + pytest.fail("task group raised a plain cancellation error") + else: + pytest.fail("should have raised an exception group") + except BaseException: + pytest.fail("the cancel scope should have swallowed the exceptions") async def test_exception_group_children() -> None: From 4e53d646ca361f363f8644cfac8bbc5b0255143a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Mon, 7 Nov 2022 11:20:55 +0200 Subject: [PATCH 03/11] WIP cancellation semantics change --- docs/versionhistory.rst | 6 ++++ src/anyio/_backends/_asyncio.py | 58 +++++++-------------------------- src/anyio/_backends/_trio.py | 12 +++++-- tests/streams/test_memory.py | 10 ++++-- tests/test_from_thread.py | 5 ++- tests/test_sockets.py | 5 ++- tests/test_taskgroups.py | 22 ++++++------- 7 files changed, 54 insertions(+), 64 deletions(-) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 6088aba8..90af4b89 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -7,6 +7,12 @@ This library adheres to `Semantic Versioning 2.0 `_. - **BACKWARDS INCOMPATIBLE** Replaced AnyIO's own ``ExceptionGroup`` class with the PEP 654 ``BaseExceptionGroup`` and ``ExceptionGroup`` +- **BACKWARDS INCOMPATIBLE** Changes to cancellation semantics: + + - Any exceptions raising out of a task groups are now nested inside an + ``ExceptionGroup`` (or ``BaseExceptionGroup`` if one or more ``BaseException`` were + included), except when all the exceptions are cancellation exceptions. In that case, + a single cancellation exception is raised instead. - Bumped minimum version of trio to v0.22 - Added ``create_unix_datagram_socket`` and ``create_connected_unix_datagram_socket`` to create UNIX datagram sockets (PR by Jean Hominal) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 2b21023b..6637401b 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -262,24 +262,14 @@ def __exit__( if self._shield: self._deliver_cancellation_to_parent() - if exc_val is not None: - cancelled = False - if isinstance(exc_val, CancelledError): - cancelled = True - elif isinstance(exc_val, BaseExceptionGroup): - matched, unmatched = exc_val.split(CancelledError) - if matched and not unmatched: - cancelled = True - - if cancelled: + if exc_val is not None and isinstance(exc_val, CancelledError): + breakpoint() + if not self._parent_cancelled(): if self._timeout_expired: return True elif not self._cancel_called: # Task was cancelled natively return None - elif not self._parent_cancelled(): - # This scope was directly cancelled - return True return None @@ -442,24 +432,6 @@ def started(self, value: object = None) -> None: _task_states[task].parent_id = self._parent_id -def collapse_exception_group(excgroup: BaseExceptionGroup) -> BaseException: - exceptions = list(excgroup.exceptions) - modified = False - for i, exc in enumerate(exceptions): - if isinstance(exc, BaseExceptionGroup): - new_exc = collapse_exception_group(exc) - if new_exc is not exc: - modified = True - exceptions[i] = new_exc - - if len(exceptions) == 1: - return exceptions[0] - elif modified: - return excgroup.derive(exceptions) - else: - return excgroup - - def is_anyio_cancelled_exc(exc: BaseException) -> bool: return isinstance(exc, CancelledError) and not exc.args @@ -491,10 +463,8 @@ async def __aexit__( try: await asyncio.wait(self.cancel_scope._tasks) except asyncio.CancelledError as e: - if exc_val is None: - self.cancel_scope.cancel() - self._exceptions.append(e) - exc_val = e + self.cancel_scope.cancel() + self._exceptions.append(e) else: # No tasks to wait on, but we still need to check for cancellation here try: @@ -504,20 +474,14 @@ async def __aexit__( self._active = False if self._exceptions: + # if self._exceptions and (len(self._exceptions) != 1 or self._exceptions[0] is exc_val): exc: BaseException | None group = BaseExceptionGroup("multiple tasks failed", self._exceptions) - if not self.cancel_scope._parent_cancelled(): - # If any exceptions other than AnyIO cancellation exceptions have been - # received, raise those - _, exc = group.split(is_anyio_cancelled_exc) - else: - exc = group - - if isinstance(exc, BaseExceptionGroup): - exc = collapse_exception_group(exc) - - if exc is not None and exc is not exc_val: - raise exc + matched, unmatched = group.split(CancelledError) + if unmatched: + raise unmatched + elif self._exceptions[0] is not exc_val or not ignore_exception: + raise CancelledError from None return ignore_exception diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index f7880b36..67ceec8b 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -134,7 +134,7 @@ def shield(self, value: bool) -> None: class TaskGroup(abc.TaskGroup): def __init__(self) -> None: self._active = False - self._nursery_manager = trio.open_nursery() + self._nursery_manager = trio.open_nursery(strict_exception_groups=True) self.cancel_scope = None # type: ignore[assignment] async def __aenter__(self) -> TaskGroup: @@ -151,6 +151,14 @@ async def __aexit__( ) -> bool | None: try: return await self._nursery_manager.__aexit__(exc_type, exc_val, exc_tb) + except BaseExceptionGroup as excgrp: + matched, unmatched = excgrp.split(trio.Cancelled) + if not unmatched: + raise trio.Cancelled._create() from None + elif unmatched: + raise unmatched from None + else: + raise finally: self._active = False @@ -719,7 +727,7 @@ def __init__(self, **options: Any) -> None: async def _trio_main(self) -> None: self._stop_event = trio.Event() - async with trio.open_nursery() as self._nursery: + async with trio.open_nursery(strict_exception_groups=True) as self._nursery: await self._stop_event.wait() async def _call_func( diff --git a/tests/streams/test_memory.py b/tests/streams/test_memory.py index f322d9c4..b9d39320 100644 --- a/tests/streams/test_memory.py +++ b/tests/streams/test_memory.py @@ -170,21 +170,27 @@ async def test_clone_closed() -> None: async def test_close_send_while_receiving() -> None: send, receive = create_memory_object_stream(1) - with pytest.raises(EndOfStream): + with pytest.raises(ExceptionGroup) as exc: async with create_task_group() as tg: tg.start_soon(receive.receive) await wait_all_tasks_blocked() await send.aclose() + matched, unmatched = exc.value.split(EndOfStream) + assert not unmatched + async def test_close_receive_while_sending() -> None: send, receive = create_memory_object_stream(0) - with pytest.raises(BrokenResourceError): + with pytest.raises(ExceptionGroup) as exc: async with create_task_group() as tg: tg.start_soon(send.send, "hello") await wait_all_tasks_blocked() await receive.aclose() + matched, unmatched = exc.value.split(BrokenResourceError) + assert not unmatched + async def test_receive_after_send_closed() -> None: send, receive = create_memory_object_stream(1) diff --git a/tests/test_from_thread.py b/tests/test_from_thread.py index 93886f53..21a20220 100644 --- a/tests/test_from_thread.py +++ b/tests/test_from_thread.py @@ -390,10 +390,13 @@ async def run_in_context() -> AsyncGenerator[None, None]: yield with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal: - with pytest.raises(ZeroDivisionError): + with pytest.raises(ExceptionGroup) as exc: with portal.wrap_async_context_manager(run_in_context()): pass + _, unmatched = exc.value.split(ZeroDivisionError) + assert not unmatched + def test_start_no_value( self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] ) -> None: diff --git a/tests/test_sockets.py b/tests/test_sockets.py index acffe920..ded3dc29 100644 --- a/tests/test_sockets.py +++ b/tests/test_sockets.py @@ -657,13 +657,16 @@ async def test_reuse_port(self, family: AnyIPAddressFamily) -> None: async def test_close_from_other_task(self, family: AnyIPAddressFamily) -> None: listener = await create_tcp_listener(local_host="localhost", family=family) - with pytest.raises(ClosedResourceError): + with pytest.raises(ExceptionGroup) as exc: async with create_task_group() as tg: tg.start_soon(listener.serve, lambda stream: None) await wait_all_tasks_blocked() await listener.aclose() tg.cancel_scope.cancel() + _, unmatched = exc.value.split(ClosedResourceError) + assert not unmatched + async def test_send_after_eof(self, family: AnyIPAddressFamily) -> None: async def handle(stream: SocketStream) -> None: async with stream: diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py index d648a2d3..05f43912 100644 --- a/tests/test_taskgroups.py +++ b/tests/test_taskgroups.py @@ -27,7 +27,7 @@ from anyio.lowlevel import checkpoint if sys.version_info < (3, 11): - from exceptiongroup import BaseExceptionGroup + from exceptiongroup import BaseExceptionGroup, ExceptionGroup pytestmark = pytest.mark.anyio @@ -89,7 +89,7 @@ async def task_func() -> None: async def test_start_soon_after_error() -> None: - with pytest.raises(ZeroDivisionError): + with pytest.raises(ExceptionGroup): async with create_task_group() as tg: a = 1 / 0 # noqa: F841 @@ -147,11 +147,11 @@ async def taskfunc(*, task_status: TaskStatus) -> NoReturn: task_status.started(2) raise Exception("foo") - with pytest.raises(Exception) as exc: + with pytest.raises(ExceptionGroup) as exc: async with create_task_group() as tg: value = await tg.start(taskfunc) - exc.match("foo") + assert str(exc.value.exceptions[0]) == "foo" assert value == 2 @@ -427,12 +427,9 @@ async def test_cancel_outer_scope_one_task() -> None: async with anyio.create_task_group() as tg: tg.start_soon(sleep, 3) outer_scope.cancel() - except BaseExceptionGroup as excgrp: - assert len(excgrp.exceptions) == 2 - raise except get_cancelled_exc_class(): - pytest.fail("task group raised a plain cancellation error") - else: + pass + except BaseException: pytest.fail("should have raised an exception group") except BaseException: pytest.fail("the cancel scope should have swallowed the exceptions") @@ -787,7 +784,7 @@ async def killer(scope: CancelScope) -> None: await wait_all_tasks_blocked() scope.cancel() - with pytest.raises(TimeoutError): + with pytest.raises(ExceptionGroup) as exc: async with create_task_group() as tg: with CancelScope() as scope: with CancelScope(shield=True): @@ -795,6 +792,9 @@ async def killer(scope: CancelScope) -> None: with fail_after(0.2): await sleep(2) + _, unmatched = exc.value.split(TimeoutError) + assert not unmatched + async def test_triple_nested_shield() -> None: """Regression test for #370.""" @@ -862,7 +862,7 @@ async def fn() -> None: assert len(exc.value.exceptions) == 2 assert str(exc.value.exceptions[0]) == "parent task failed" - assert str(exc.value.exceptions[1]) == "child task failed" + assert str(exc.value.exceptions[1].exceptions[0]) == "child task failed" async def test_cancel_propagation_with_inner_spawn() -> None: From 22a38adf00bdbc3e3762abc54dd552039b83ff29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Mon, 14 Nov 2022 00:37:36 +0200 Subject: [PATCH 04/11] Fixed mypy errors --- src/anyio/_backends/_trio.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index 9b3d0044..d9f82ce3 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -3,6 +3,7 @@ import array import math import socket +import sys from collections.abc import AsyncIterator, Iterable from concurrent.futures import Future from dataclasses import dataclass @@ -61,6 +62,9 @@ from ..abc import IPSockAddrType, UDPPacketType, UNIXDatagramPacketType from ..abc._eventloop import AsyncBackend +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + if TYPE_CHECKING: from trio_typing import TaskStatus @@ -134,7 +138,9 @@ def shield(self, value: bool) -> None: class TaskGroup(abc.TaskGroup): def __init__(self) -> None: self._active = False - self._nursery_manager = trio.open_nursery(strict_exception_groups=True) + self._nursery_manager = trio.open_nursery( + strict_exception_groups=True # type: ignore[call-arg] + ) self.cancel_scope = None # type: ignore[assignment] async def __aenter__(self) -> TaskGroup: @@ -154,7 +160,7 @@ async def __aexit__( except BaseExceptionGroup as excgrp: matched, unmatched = excgrp.split(trio.Cancelled) if not unmatched: - raise trio.Cancelled._create() from None + raise trio.Cancelled._create() from None # type: ignore[attr-defined] elif unmatched: raise unmatched from None else: @@ -727,7 +733,9 @@ def __init__(self, **options: Any) -> None: async def _trio_main(self) -> None: self._stop_event = trio.Event() - async with trio.open_nursery(strict_exception_groups=True) as self._nursery: + async with trio.open_nursery( + strict_exception_groups=True # type: ignore[call-arg] + ) as self._nursery: await self._stop_event.wait() async def _call_func( From 4662371fa685441ecd0445dfda8486aaff161e17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Mon, 14 Nov 2022 09:00:46 +0200 Subject: [PATCH 05/11] Fixed pre-commit errors --- pyproject.toml | 2 +- src/anyio/_backends/_asyncio.py | 2 -- tests/streams/test_memory.py | 5 +++++ tests/test_from_thread.py | 3 +++ 4 files changed, 9 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 48fb8d04..c1e22706 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ ] requires-python = ">= 3.7" dependencies = [ - "exceptiongroup; python_version < '3.11'", + "exceptiongroup >= 1.0.2; python_version < '3.11'", "idna >= 2.8", "sniffio >= 1.1", "typing_extensions; python_version < '3.8'", diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 6637401b..7292fbd3 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -263,7 +263,6 @@ def __exit__( self._deliver_cancellation_to_parent() if exc_val is not None and isinstance(exc_val, CancelledError): - breakpoint() if not self._parent_cancelled(): if self._timeout_expired: return True @@ -474,7 +473,6 @@ async def __aexit__( self._active = False if self._exceptions: - # if self._exceptions and (len(self._exceptions) != 1 or self._exceptions[0] is exc_val): exc: BaseException | None group = BaseExceptionGroup("multiple tasks failed", self._exceptions) matched, unmatched = group.split(CancelledError) diff --git a/tests/streams/test_memory.py b/tests/streams/test_memory.py index b9d39320..b83169e4 100644 --- a/tests/streams/test_memory.py +++ b/tests/streams/test_memory.py @@ -1,5 +1,7 @@ from __future__ import annotations +import sys + import pytest from anyio import ( @@ -16,6 +18,9 @@ from anyio.abc import ObjectReceiveStream, ObjectSendStream from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup + pytestmark = pytest.mark.anyio diff --git a/tests/test_from_thread.py b/tests/test_from_thread.py index 21a20220..0c56011c 100644 --- a/tests/test_from_thread.py +++ b/tests/test_from_thread.py @@ -27,6 +27,9 @@ from anyio.from_thread import BlockingPortal, start_blocking_portal from anyio.lowlevel import checkpoint +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup + if sys.version_info >= (3, 8): from typing import Literal else: From cc9a32861fa47747862ac278a5dd1121cc90dd67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Mon, 14 Nov 2022 09:24:04 +0200 Subject: [PATCH 06/11] Fixed mypy errors --- tests/test_taskgroups.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py index 05f43912..8b110902 100644 --- a/tests/test_taskgroups.py +++ b/tests/test_taskgroups.py @@ -237,7 +237,7 @@ def task_fn(*, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None: async with anyio.create_task_group() as tg: with pytest.raises(TypeError, match="to be synchronous$"): - await tg.start(task_fn) # type: ignore[arg-type] + await tg.start(task_fn) async def test_host_exception() -> None: @@ -862,6 +862,7 @@ async def fn() -> None: assert len(exc.value.exceptions) == 2 assert str(exc.value.exceptions[0]) == "parent task failed" + assert isinstance(exc.value.exceptions[1], ExceptionGroup) assert str(exc.value.exceptions[1].exceptions[0]) == "child task failed" From 77b74b94c7e4817303c4a57985d6bac744088311 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Wed, 16 Nov 2022 23:01:27 +0200 Subject: [PATCH 07/11] WIP fixes --- src/anyio/_backends/_asyncio.py | 220 ++++++++++++------ tests/test_from_thread.py | 7 +- tests/test_taskgroups.py | 397 +++++++++++++++----------------- 3 files changed, 345 insertions(+), 279 deletions(-) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 7292fbd3..1c71814e 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -30,13 +30,14 @@ from signal import Signals from socket import AddressFamily, SocketKind from threading import Thread -from types import TracebackType +from types import MethodType, TracebackType from typing import ( IO, Any, AsyncGenerator, Awaitable, Callable, + ClassVar, Collection, ContextManager, Coroutine, @@ -118,7 +119,7 @@ def find_root_task() -> asyncio.Task: # Look up the topmost task in the AnyIO task tree, if possible task = cast(asyncio.Task, current_task()) - state = _task_states.get(task) + state = TaskState.get(task) if state: cancel_scope = state.cancel_scope while cancel_scope and cancel_scope._parent_scope is not None: @@ -206,17 +207,9 @@ def __enter__(self) -> CancelScope: ) self._host_task = host_task = cast(asyncio.Task, current_task()) - self._tasks.add(host_task) - try: - task_state = _task_states[host_task] - except KeyError: - task_name = host_task.get_name() if _native_task_names else None - task_state = TaskState(None, task_name, self) - _task_states[host_task] = task_state - else: - self._parent_scope = task_state.cancel_scope - task_state.cancel_scope = self - + task_state = TaskState.get(self._host_task) + self._parent_scope = task_state.cancel_scope if task_state else None + self._add_task(host_task) self._timeout() self._active = True @@ -240,37 +233,31 @@ def __exit__( "entered in" ) - assert self._host_task is not None - host_task_state = _task_states.get(self._host_task) - if host_task_state is None or host_task_state.cancel_scope is not self: + task_state = TaskState.get(self._host_task) + if task_state is None or task_state.cancel_scope is not self: raise RuntimeError( "Attempted to exit a cancel scope that isn't the current tasks's " "current cancel scope" ) + ignore_exc = isinstance(exc_val, CancelledError) and task_state.uncancel(self) + + # TaskGroup removes the host task from _tasks, so it wouldn't try to wait on + # itself + if self._host_task in self._tasks: + self._remove_task(self._host_task) + self._active = False if self._timeout_handle: self._timeout_handle.cancel() self._timeout_handle = None - self._tasks.remove(self._host_task) - - host_task_state.cancel_scope = self._parent_scope - # Restart the cancellation effort in the farthest directly cancelled parent # scope if this one was shielded if self._shield: self._deliver_cancellation_to_parent() - if exc_val is not None and isinstance(exc_val, CancelledError): - if not self._parent_cancelled(): - if self._timeout_expired: - return True - elif not self._cancel_called: - # Task was cancelled natively - return None - - return None + return ignore_exc def _timeout(self) -> None: if self._deadline != math.inf: @@ -296,7 +283,8 @@ def _deliver_cancellation(self) -> None: # The task is eligible for cancellation if it has started and is not in a # cancel scope shielded from this one - cancel_scope = _task_states[task].cancel_scope + task_state = TaskState.get(task) + cancel_scope = task_state.cancel_scope if task_state else None while cancel_scope is not self: if cancel_scope is None or cancel_scope._shield: break @@ -307,7 +295,7 @@ def _deliver_cancellation(self) -> None: if task is not current and ( task is self._host_task or _task_started(task) ): - task.cancel() + task.cancel(scope=self) # type: ignore[call-arg] # Schedule another callback if there are still tasks left if should_retry: @@ -345,6 +333,16 @@ def _parent_cancelled(self) -> bool: return False + def _add_task( + self, task: asyncio.Task, parent_task: asyncio.Task | None = None + ) -> None: + self._tasks.add(task) + TaskState.add(task, scope=self, parent_task=parent_task) + + def _remove_task(self, task: asyncio.Task) -> None: + self._tasks.remove(task) + TaskState.remove(task, self) + def cancel(self) -> None: if not self._cancel_called: if self._timeout_handle: @@ -396,17 +394,109 @@ class TaskState: itself because there are no guarantees about its implementation. """ - __slots__ = "parent_id", "name", "cancel_scope" + __slots__ = ( + "parent_id", + "name", + "cancel_scope", + "_cancelled_by_scopes", + "_native_cancellations", + ) + + _task_states: ClassVar[ + WeakKeyDictionary[asyncio.Task, TaskState] + ] = WeakKeyDictionary() def __init__( - self, parent_id: int | None, name: str | None, cancel_scope: CancelScope | None + self, + parent_task: asyncio.Task | None, + name: str | None, + cancel_scope: CancelScope | None, ): - self.parent_id = parent_id + self.parent_id = id(parent_task) if parent_task else None self.name = name self.cancel_scope = cancel_scope + self._cancelled_by_scopes: set[CancelScope] = set() + self._native_cancellations = 0 + + @classmethod + def add( + cls, + task: asyncio.Task, + *, + scope: CancelScope | None = None, + parent_task: asyncio.Task | None = None, + name: object = None, + ) -> TaskState: + state = cls._task_states.get(task) + if state is None: + task_name: str | None = None + if name is not None: + task_name = str(name) + elif _native_task_names: + task_name = task.get_name() + + state = cls(parent_task, task_name, scope) + cls._task_states[task] = state + task.cancel = MethodType( # type: ignore[assignment] + partial(TaskState._patched_cancel, original=task.cancel), task + ) + if sys.version_info >= (3, 11): + task.uncancel = MethodType( + partial(TaskState._patched_uncancel, original=task.uncancel), task + ) + else: + state.cancel_scope = scope + + assert state.cancel_scope is scope + return state + + @classmethod + def get(cls, task: asyncio.Task) -> TaskState | None: + return cls._task_states.get(task) + + @classmethod + def remove(cls, task: asyncio.Task, scope: CancelScope | None) -> None: + state = cls._task_states[task] + if state.cancel_scope is None: + patched_method = cast(partial, task.cancel) + task.cancel = patched_method.keywords[ # type: ignore[assignment] + "original" + ] + del cls._task_states[task] + else: + state.cancel_scope = state.cancel_scope._parent_scope + + @staticmethod + def _patched_cancel( + self: asyncio.Task, + *args: object, + original: Callable, + scope: CancelScope | None = None, + ) -> None: + state = TaskState._task_states.get(self) + if state is not None: + if scope is None: + state._native_cancellations += 1 + else: + state._cancelled_by_scopes.add(scope) + original(*args) -_task_states = WeakKeyDictionary() # type: WeakKeyDictionary[asyncio.Task, TaskState] + @staticmethod + def _patched_uncancel(self: asyncio.Task, original: Callable) -> int: + state = TaskState._task_states.get(self) + if state is not None: + state._native_cancellations -= 1 + + return original() + + def uncancel(self, scope: CancelScope) -> bool: + self._cancelled_by_scopes.discard(scope) + return not self.cancelled + + @property + def cancelled(self) -> bool: + return bool(self._native_cancellations) or bool(self._cancelled_by_scopes) # @@ -428,11 +518,7 @@ def started(self, value: object = None) -> None: ) from None task = cast(asyncio.Task, current_task()) - _task_states[task].parent_id = self._parent_id - - -def is_anyio_cancelled_exc(exc: BaseException) -> bool: - return isinstance(exc, CancelledError) and not exc.args + TaskState.get(task).parent_id = self._parent_id class TaskGroup(abc.TaskGroup): @@ -474,11 +560,18 @@ async def __aexit__( self._active = False if self._exceptions: exc: BaseException | None - group = BaseExceptionGroup("multiple tasks failed", self._exceptions) + group = BaseExceptionGroup( + "one or more errors occurred in a task group", self._exceptions + ) matched, unmatched = group.split(CancelledError) if unmatched: + # If there are exceptions other than CancelledError, always raise those raise unmatched - elif self._exceptions[0] is not exc_val or not ignore_exception: + + # If the host task was natively cancelled, or a parent cancel scope was + # cancelled, raise a new CancelledError + task_state = TaskState.get(self.cancel_scope._host_task) + if task_state.cancelled: raise CancelledError from None return ignore_exception @@ -504,9 +597,7 @@ async def _run_wrapped_task( RuntimeError("Child exited without calling task_status.started()") ) finally: - if task in self.cancel_scope._tasks: - self.cancel_scope._tasks.remove(task) - del _task_states[task] + self.cancel_scope._remove_task(task) def _spawn( self, @@ -517,10 +608,7 @@ def _spawn( ) -> asyncio.Task: def task_done(_task: asyncio.Task) -> None: # This is the code path for Python 3.8+ - assert _task in self.cancel_scope._tasks - self.cancel_scope._tasks.remove(_task) - del _task_states[_task] - + self.cancel_scope._remove_task(task) try: exc = _task.exception() except CancelledError as e: @@ -552,12 +640,12 @@ def task_done(_task: asyncio.Task) -> None: kwargs = {} if task_status_future: - parent_id = id(current_task()) + parent_task = cast(asyncio.Task, current_task()) kwargs["task_status"] = _AsyncioTaskStatus( task_status_future, id(self.cancel_scope._host_task) ) else: - parent_id = id(self.cancel_scope._host_task) + parent_task = self.cancel_scope._host_task coro = func(*args, **kwargs) if not asyncio.iscoroutine(coro): @@ -574,10 +662,8 @@ def task_done(_task: asyncio.Task) -> None: task.add_done_callback(task_done) # Make the spawned task inherit the task group's cancel scope - _task_states[task] = TaskState( - parent_id=parent_id, name=name, cancel_scope=self.cancel_scope - ) - self.cancel_scope._tasks.add(task) + TaskState.add(task, parent_task=parent_task, name=name) + self.cancel_scope._add_task(task, parent_task) return task def start_soon( @@ -1627,7 +1713,7 @@ async def __anext__(self) -> Signals: def _create_task_info(task: asyncio.Task) -> TaskInfo: - task_state = _task_states.get(task) + task_state = TaskState.get(task) if task_state is None: name = task.get_name() if _native_task_names else None parent_id = None @@ -1763,15 +1849,14 @@ def run( @wraps(func) async def wrapper() -> T_Retval: task = cast(asyncio.Task, current_task()) - task_state = TaskState(None, get_callable_name(func), None) - _task_states[task] = task_state + task_state = TaskState.add(task, name=get_callable_name(func)) if _native_task_names: task.set_name(task_state.name) try: return await func(*args) finally: - del _task_states[task] + TaskState.remove(task, None) debug = options.get("debug", False) policy = options.get("policy", None) @@ -1801,11 +1886,8 @@ async def checkpoint_if_cancelled(cls) -> None: if task is None: return - try: - cancel_scope = _task_states[task].cancel_scope - except KeyError: - return - + task_state = TaskState.get(task) + cancel_scope = task_state.cancel_scope if task_state else None while cancel_scope: if cancel_scope.cancel_called: await sleep(0) @@ -1831,13 +1913,9 @@ def create_cancel_scope( @classmethod def current_effective_deadline(cls) -> float: - try: - cancel_scope = _task_states[ - current_task() # type: ignore[index] - ].cancel_scope - except KeyError: - return math.inf - + task = cast(asyncio.Task, current_task()) + task_state = TaskState.get(task) + cancel_scope = task_state.cancel_scope if task_state else None deadline = math.inf while cancel_scope: deadline = min(deadline, cancel_scope.deadline) diff --git a/tests/test_from_thread.py b/tests/test_from_thread.py index 0c56011c..4906cb70 100644 --- a/tests/test_from_thread.py +++ b/tests/test_from_thread.py @@ -28,7 +28,7 @@ from anyio.lowlevel import checkpoint if sys.version_info < (3, 11): - from exceptiongroup import ExceptionGroup + from exceptiongroup import BaseExceptionGroup, ExceptionGroup if sys.version_info >= (3, 8): from typing import Literal @@ -522,7 +522,7 @@ def test_raise_baseexception_from_task( async def raise_baseexception() -> None: raise BaseException("fatal error") - with pytest.raises(BaseException, match="fatal error"): + with pytest.raises(BaseExceptionGroup) as outer_exc: with start_blocking_portal( anyio_backend_name, anyio_backend_options ) as portal: @@ -530,3 +530,6 @@ async def raise_baseexception() -> None: portal.call(raise_baseexception) assert exc.value.__context__ is None + + assert len(outer_exc.value.exceptions) == 1 + assert str(outer_exc.value.exceptions[0]) == "fatal error" diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py index 8b110902..b3aad88d 100644 --- a/tests/test_taskgroups.py +++ b/tests/test_taskgroups.py @@ -183,54 +183,6 @@ async def taskfunc(*, task_status: TaskStatus) -> None: assert not finished -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) -async def test_start_native_host_cancelled() -> None: - started = finished = False - - async def taskfunc(*, task_status: TaskStatus) -> None: - nonlocal started, finished - started = True - await sleep(2) - finished = True - - async def start_another() -> None: - async with create_task_group() as tg: - await tg.start(taskfunc) - - task = asyncio.get_running_loop().create_task(start_another()) - await wait_all_tasks_blocked() - task.cancel() - with pytest.raises(asyncio.CancelledError): - await task - - assert started - assert not finished - - -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) -async def test_start_native_child_cancelled() -> None: - task = None - finished = False - - async def taskfunc(*, task_status: TaskStatus) -> None: - nonlocal task, finished - task = asyncio.current_task() - await sleep(2) - finished = True - - async def start_another() -> None: - async with create_task_group() as tg2: - await tg2.start(taskfunc) - - async with create_task_group() as tg: - tg.start_soon(start_another) - await wait_all_tasks_blocked() - assert task is not None - task.cancel() - - assert not finished - - async def test_start_exception_delivery() -> None: def task_fn(*, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None: task_status.started("hello") @@ -248,12 +200,13 @@ async def set_result(value: str) -> None: await sleep(3) result = value - with pytest.raises(Exception) as exc: + with pytest.raises(ExceptionGroup) as exc: async with create_task_group() as tg: tg.start_soon(set_result, "a") raise Exception("dummy error") - exc.match("dummy error") + assert len(exc.value.exceptions) == 1 + assert str(exc.value.exceptions[0]) == "dummy error" assert result is None @@ -282,13 +235,14 @@ async def child() -> NoReturn: raise Exception("foo") sleep_completed = False - with pytest.raises(Exception) as exc: + with pytest.raises(ExceptionGroup) as exc: async with create_task_group() as tg: tg.start_soon(child) await sleep(0.5) sleep_completed = True - exc.match("foo") + assert len(exc.value.exceptions) == 1 + assert str(exc.value.exceptions[0]) == "foo" assert not sleep_completed @@ -300,13 +254,14 @@ async def child() -> None: await sleep(1) sleep_completed = True - with pytest.raises(Exception) as exc: + with pytest.raises(ExceptionGroup) as exc: async with create_task_group() as tg: tg.start_soon(child) await wait_all_tasks_blocked() raise Exception("foo") - exc.match("foo") + assert len(exc.value.exceptions) == 1 + assert str(exc.value.exceptions[0]) == "foo" assert not sleep_completed @@ -369,7 +324,7 @@ async def waiter() -> None: nonlocal cancel_received try: await sleep(5) - finally: + except get_cancelled_exc_class(): cancel_received = True async def subgroup() -> None: @@ -429,8 +384,11 @@ async def test_cancel_outer_scope_one_task() -> None: outer_scope.cancel() except get_cancelled_exc_class(): pass - except BaseException: - pytest.fail("should have raised an exception group") + except BaseException as exc: + pytest.fail( + f"should have raised a cancellation error instead of " + f"{exc.__class__}" + ) except BaseException: pytest.fail("the cancel scope should have swallowed the exceptions") @@ -588,40 +546,6 @@ async def test_cancel_shielded_scope() -> None: await sleep(0) -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) -async def test_cancel_host_asyncgen() -> None: - done = False - - async def host_task() -> None: - nonlocal done - async with create_task_group() as tg: - with CancelScope(shield=True) as inner_scope: - assert inner_scope.shield - tg.cancel_scope.cancel() - - with pytest.raises(get_cancelled_exc_class()): - await sleep(0) - - with pytest.raises(get_cancelled_exc_class()): - await sleep(0) - - done = True - - async def host_agen_fn() -> AsyncGenerator[None, None]: - await host_task() - yield - pytest.fail("host_agen_fn should only be __anext__ed once") - - host_agen = host_agen_fn() - try: - loop = asyncio.get_running_loop() - await loop.create_task(host_agen.__anext__()) # type: ignore[arg-type] - finally: - await host_agen.aclose() - - assert done - - async def test_shielding_immediate_scope_cancelled() -> None: async def cancel_when_ready() -> None: await wait_all_tasks_blocked() @@ -693,13 +617,14 @@ async def child(fail: bool) -> None: await sleep(1) sleep_completed = True - with pytest.raises(Exception) as exc: + with pytest.raises(ExceptionGroup) as exc: async with create_task_group() as tg: tg.start_soon(child, False) await wait_all_tasks_blocked() tg.start_soon(child, True) - exc.match("foo") + assert len(exc.value.exceptions) == 1 + assert str(exc.value.exceptions[0]) == "foo" assert not sleep_completed @@ -892,74 +817,6 @@ async def test_escaping_cancelled_error_from_cancelled_task() -> None: scope.cancel() -@pytest.mark.skipif( - sys.version_info >= (3, 11), - reason="Generator based coroutines have been removed in Python 3.11", -) -@pytest.mark.filterwarnings( - 'ignore:"@coroutine" decorator is deprecated:DeprecationWarning' -) -def test_cancel_generator_based_task() -> None: - async def native_coro_part() -> None: - with CancelScope() as scope: - asyncio.get_running_loop().call_soon(scope.cancel) - await asyncio.sleep(1) - pytest.fail("Execution should not have reached this line") - - @asyncio.coroutine - def generator_part() -> Generator[object, BaseException, None]: - yield from native_coro_part() - - anyio.run(generator_part, backend="asyncio") - - -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) -async def test_cancel_native_future_tasks() -> None: - async def wait_native_future() -> None: - loop = asyncio.get_running_loop() - await loop.create_future() - - async with anyio.create_task_group() as tg: - tg.start_soon(wait_native_future) - tg.cancel_scope.cancel() - - -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) -async def test_cancel_native_future_tasks_cancel_scope() -> None: - async def wait_native_future() -> None: - with anyio.CancelScope(): - loop = asyncio.get_running_loop() - await loop.create_future() - - async with anyio.create_task_group() as tg: - tg.start_soon(wait_native_future) - tg.cancel_scope.cancel() - - -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) -async def test_cancel_completed_task() -> None: - loop = asyncio.get_running_loop() - old_exception_handler = loop.get_exception_handler() - exceptions = [] - - def exception_handler(*args: object, **kwargs: object) -> None: - exceptions.append((args, kwargs)) - - loop.set_exception_handler(exception_handler) - try: - - async def noop() -> None: - pass - - async with anyio.create_task_group() as tg: - tg.start_soon(noop) - tg.cancel_scope.cancel() - - assert exceptions == [] - finally: - loop.set_exception_handler(old_exception_handler) - - async def test_task_in_sync_spawn_callback() -> None: outer_task_id = anyio.get_current_task().id inner_task_id = None @@ -1045,54 +902,12 @@ async def exit_scope(scope: CancelScope) -> None: async with create_task_group() as tg: tg.start_soon(enter_scope, scope) - with pytest.raises(RuntimeError): + with pytest.raises(ExceptionGroup) as exc: async with create_task_group() as tg: tg.start_soon(exit_scope, scope) - -def test_unhandled_exception_group(caplog: pytest.LogCaptureFixture) -> None: - def crash() -> NoReturn: - raise KeyboardInterrupt - - async def nested() -> None: - async with anyio.create_task_group() as tg: - tg.start_soon(anyio.sleep, 5) - await anyio.sleep(5) - - async def main() -> NoReturn: - async with anyio.create_task_group() as tg: - tg.start_soon(nested) - await wait_all_tasks_blocked() - asyncio.get_running_loop().call_soon(crash) - await anyio.sleep(5) - - pytest.fail("Execution should never reach this point") - - with pytest.raises(KeyboardInterrupt): - anyio.run(main, backend="asyncio") - - assert not caplog.messages - - -@pytest.mark.skipif( - sys.version_info < (3, 9), - sys.version_info >= (3, 11), - reason="Cancel messages are only supported on Python 3.9 and 3.10", -) -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) -async def test_cancellederror_combination_with_message() -> None: - async def taskfunc(*, task_status: TaskStatus) -> NoReturn: - task_status.started(asyncio.current_task()) - await sleep(5) - pytest.fail("Execution should never reach this point") - - with pytest.raises(asyncio.CancelledError, match="blah"): - async with create_task_group() as tg: - task = await tg.start(taskfunc) - tg.start_soon(sleep, 5) - await wait_all_tasks_blocked() - assert isinstance(task, asyncio.Task) - task.cancel("blah") + assert len(exc.value.exceptions) == 1 + assert isinstance(exc.value.exceptions[0], RuntimeError) async def test_start_soon_parent_id() -> None: @@ -1135,3 +950,173 @@ async def starter_task() -> None: assert initial_parent_id != permanent_parent_id assert initial_parent_id == starter_task_id assert permanent_parent_id == root_task_id + + +class TestAsyncio: + """Contains asyncio specific cancel scope/task group tests.""" + + @pytest.mark.parametrize("anyio_backend", ["asyncio"]) + async def test_cancel_host_asyncgen(self) -> None: + done = False + + async def host_task() -> None: + nonlocal done + async with create_task_group() as tg: + with CancelScope(shield=True) as inner_scope: + assert inner_scope.shield + tg.cancel_scope.cancel() + + with pytest.raises(get_cancelled_exc_class()): + await sleep(0) + + with pytest.raises(get_cancelled_exc_class()): + await sleep(0) + + done = True + + async def host_agen_fn() -> AsyncGenerator[None, None]: + await host_task() + yield + pytest.fail("host_agen_fn should only be __anext__ed once") + + host_agen = host_agen_fn() + try: + loop = asyncio.get_running_loop() + await loop.create_task(host_agen.__anext__()) # type: ignore[arg-type] + finally: + await host_agen.aclose() + + assert done + + def test_unhandled_exception_group(self, caplog: pytest.LogCaptureFixture) -> None: + def crash() -> NoReturn: + raise KeyboardInterrupt + + async def nested() -> None: + async with anyio.create_task_group() as tg: + tg.start_soon(anyio.sleep, 5) + await anyio.sleep(5) + + async def main() -> NoReturn: + async with anyio.create_task_group() as tg: + tg.start_soon(nested) + await wait_all_tasks_blocked() + asyncio.get_running_loop().call_soon(crash) + await anyio.sleep(5) + + pytest.fail("Execution should never reach this point") + + with pytest.raises(KeyboardInterrupt): + anyio.run(main, backend="asyncio") + + assert not caplog.messages + + @pytest.mark.parametrize("anyio_backend", ["asyncio"]) + async def test_start_native_host_cancelled(self) -> None: + started = finished = False + + async def taskfunc(*, task_status: TaskStatus) -> None: + nonlocal started, finished + started = True + await sleep(2) + finished = True + + async def start_another() -> None: + async with create_task_group() as tg: + await tg.start(taskfunc) + + task = asyncio.get_running_loop().create_task(start_another()) + await wait_all_tasks_blocked() + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + assert started + assert not finished + + @pytest.mark.parametrize("anyio_backend", ["asyncio"]) + async def test_start_native_child_cancelled(self) -> None: + task = None + finished = False + + async def taskfunc(*, task_status: TaskStatus) -> None: + nonlocal task, finished + task = asyncio.current_task() + await sleep(2) + finished = True + + async def start_another() -> None: + async with create_task_group() as tg2: + await tg2.start(taskfunc) + + async with create_task_group() as tg: + tg.start_soon(start_another) + await wait_all_tasks_blocked() + assert task is not None + task.cancel() + + assert not finished + + @pytest.mark.skipif( + sys.version_info >= (3, 11), + reason="Generator based coroutines have been removed in Python 3.11", + ) + @pytest.mark.filterwarnings( + 'ignore:"@coroutine" decorator is deprecated:DeprecationWarning' + ) + def test_cancel_generator_based_task(self) -> None: + async def native_coro_part() -> None: + with CancelScope() as scope: + asyncio.get_running_loop().call_soon(scope.cancel) + await asyncio.sleep(1) + pytest.fail("Execution should not have reached this line") + + @asyncio.coroutine + def generator_part() -> Generator[object, BaseException, None]: + yield from native_coro_part() + + anyio.run(generator_part, backend="asyncio") + + @pytest.mark.parametrize("anyio_backend", ["asyncio"]) + async def test_cancel_native_future_tasks(self) -> None: + async def wait_native_future() -> None: + loop = asyncio.get_running_loop() + await loop.create_future() + + async with anyio.create_task_group() as tg: + tg.start_soon(wait_native_future) + tg.cancel_scope.cancel() + + @pytest.mark.parametrize("anyio_backend", ["asyncio"]) + async def test_cancel_native_future_tasks_cancel_scope(self) -> None: + async def wait_native_future() -> None: + with anyio.CancelScope(): + loop = asyncio.get_running_loop() + await loop.create_future() + + async with anyio.create_task_group() as tg: + tg.start_soon(wait_native_future) + tg.cancel_scope.cancel() + + @pytest.mark.parametrize("anyio_backend", ["asyncio"]) + async def test_cancel_completed_task(self) -> None: + loop = asyncio.get_running_loop() + old_exception_handler = loop.get_exception_handler() + exceptions = [] + + def exception_handler(*args: object, **kwargs: object) -> None: + exceptions.append((args, kwargs)) + + loop.set_exception_handler(exception_handler) + try: + + async def noop() -> None: + pass + + async with anyio.create_task_group() as tg: + tg.start_soon(noop) + tg.cancel_scope.cancel() + + assert exceptions == [] + finally: + loop.set_exception_handler(old_exception_handler) From 2b6492ad1153146bd19c346412c10892126bef14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sun, 20 Nov 2022 16:13:01 +0200 Subject: [PATCH 08/11] Fixed mypy errors --- src/anyio/_backends/_asyncio.py | 29 ++++++++++++++++++++++------- tests/test_sockets.py | 4 ++-- tests/test_taskgroups.py | 2 +- 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 1c71814e..053430ea 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -206,7 +206,11 @@ def __enter__(self) -> CancelScope: "Each CancelScope may only be used for a single 'with' block" ) - self._host_task = host_task = cast(asyncio.Task, current_task()) + host_task = current_task() + if host_task is None: + raise RuntimeError("Must be in a task to enter a cancel scope") + + self._host_task = host_task task_state = TaskState.get(self._host_task) self._parent_scope = task_state.cancel_scope if task_state else None self._add_task(host_task) @@ -227,13 +231,17 @@ def __exit__( ) -> bool | None: if not self._active: raise RuntimeError("This cancel scope is not active") - if current_task() is not self._host_task: + + host_task = current_task() + if host_task is None: + raise RuntimeError("Must be in a task to exit a cancel scope") + elif host_task is not self._host_task: raise RuntimeError( "Attempted to exit cancel scope in a different task than it was " "entered in" ) - task_state = TaskState.get(self._host_task) + task_state = TaskState.get(host_task) if task_state is None or task_state.cancel_scope is not self: raise RuntimeError( "Attempted to exit a cancel scope that isn't the current tasks's " @@ -517,8 +525,13 @@ def started(self, value: object = None) -> None: "called 'started' twice on the same task status" ) from None - task = cast(asyncio.Task, current_task()) - TaskState.get(task).parent_id = self._parent_id + task = current_task() + if not task: + raise RuntimeError("called 'started' outside of a task") + + task_state = TaskState.get(task) + if task_state: + task_state.parent_id = self._parent_id class TaskGroup(abc.TaskGroup): @@ -538,6 +551,7 @@ async def __aexit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> bool | None: + host_task = cast(asyncio.Task, self.cancel_scope._host_task) ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb) if exc_val is not None: self.cancel_scope.cancel() @@ -570,8 +584,8 @@ async def __aexit__( # If the host task was natively cancelled, or a parent cancel scope was # cancelled, raise a new CancelledError - task_state = TaskState.get(self.cancel_scope._host_task) - if task_state.cancelled: + task_state = TaskState.get(host_task) + if task_state and task_state.cancelled: raise CancelledError from None return ignore_exception @@ -645,6 +659,7 @@ def task_done(_task: asyncio.Task) -> None: task_status_future, id(self.cancel_scope._host_task) ) else: + assert self.cancel_scope._host_task parent_task = self.cancel_scope._host_task coro = func(*args, **kwargs) diff --git a/tests/test_sockets.py b/tests/test_sockets.py index ded3dc29..072d626d 100644 --- a/tests/test_sockets.py +++ b/tests/test_sockets.py @@ -1256,8 +1256,8 @@ async def test_send_receive(self, family: AnyIPAddressFamily) -> None: async with await create_connected_udp_socket( host, port, local_host="localhost", family=family ) as udp2: - host, port = udp2.extra( - SocketAttribute.local_address # type: ignore[misc] + host, port = udp2.extra( # type: ignore[misc] + SocketAttribute.local_address ) await udp2.send(b"blah") request = await udp1.receive() diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py index b3aad88d..9d3dcffb 100644 --- a/tests/test_taskgroups.py +++ b/tests/test_taskgroups.py @@ -189,7 +189,7 @@ def task_fn(*, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None: async with anyio.create_task_group() as tg: with pytest.raises(TypeError, match="to be synchronous$"): - await tg.start(task_fn) + await tg.start(task_fn) # type: ignore[arg-type] async def test_host_exception() -> None: From da38ca96eb0bcf3ee5d312973b215cf3fbd556ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sun, 20 Nov 2022 16:34:46 +0200 Subject: [PATCH 09/11] Added debugging code --- tests/streams/test_tls.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/streams/test_tls.py b/tests/streams/test_tls.py index 695cf235..9f8c5bf3 100644 --- a/tests/streams/test_tls.py +++ b/tests/streams/test_tls.py @@ -230,7 +230,11 @@ def serve_sync() -> None: ) with client_cm: assert await wrapper.receive() == b"hello" - await wrapper.aclose() + try: + await wrapper.aclose() + except ssl.SSLError: + print("OpenSSL version:", ssl.OPENSSL_VERSION) + raise server_thread.join() server_sock.close() From 221e927e319706cf7548b0bbe828dfdb639d4315 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sun, 20 Nov 2022 23:21:55 +0200 Subject: [PATCH 10/11] Added a comparison with asyncio.TaskGroup --- docs/tasks.rst | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/docs/tasks.rst b/docs/tasks.rst index 9833acaa..0f8b0126 100644 --- a/docs/tasks.rst +++ b/docs/tasks.rst @@ -142,3 +142,21 @@ host task that will be copied, but the context of the task that calls support for this only landed in v3.7. .. _context: https://docs.python.org/3/library/contextvars.html + +Differences with asyncio.TaskGroup +---------------------------------- + +The :class:`asyncio.TaskGroup` class, added in Python 3.11, is very similar in design to +the AnyIO :class:`~TaskGroup` class. The asyncio counterpart has some important +differences in its semantics, however: + +* Tasks are spawned solely through :meth:`~asyncio.TaskGroup.create_task`; there is no + ``start()`` or ``start_soon()`` method +* The :meth:`~asyncio.TaskGroup.create_task` method returns a task object which can be + awaited on (or cancelled) +* Tasks can only be cancelled individually (there is no ``cancel()`` method or similar + in the task group) +* When a task is cancelled before its coroutine has started running, it will not get a + chance to handle the cancellation exception +* New tasks cannot be started after an exception in one of the tasks has triggered a + shutdown From 538e875a761a7e9b9747e13b9e5c040a7229aaad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Tue, 22 Nov 2022 01:01:52 +0200 Subject: [PATCH 11/11] Fixed cancel scopes so that they un-cancel the task when necessary on py3.11 --- docs/versionhistory.rst | 3 +++ src/anyio/_backends/_asyncio.py | 9 +++++++-- tests/test_taskgroups.py | 12 ++++++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 90af4b89..4e8a4b37 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -13,6 +13,9 @@ This library adheres to `Semantic Versioning 2.0 `_. ``ExceptionGroup`` (or ``BaseExceptionGroup`` if one or more ``BaseException`` were included), except when all the exceptions are cancellation exceptions. In that case, a single cancellation exception is raised instead. + - ``CancelScope`` now un-cancels its host task on Python 3.11 + asyncio when + appropriate, for compatibility with ``asyncio.timeout`` and other context managers + that swallow exceptions - Bumped minimum version of trio to v0.22 - Added ``create_unix_datagram_socket`` and ``create_connected_unix_datagram_socket`` to create UNIX datagram sockets (PR by Jean Hominal) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 053430ea..c4bb6a5a 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -491,15 +491,20 @@ def _patched_cancel( original(*args) @staticmethod - def _patched_uncancel(self: asyncio.Task, original: Callable) -> int: + def _patched_uncancel( + self: asyncio.Task, *, scope: CancelScope | None = None, original: Callable + ) -> int: state = TaskState._task_states.get(self) - if state is not None: + if scope is None and state is not None and state._native_cancellations: state._native_cancellations -= 1 return original() def uncancel(self, scope: CancelScope) -> bool: self._cancelled_by_scopes.discard(scope) + if sys.version_info >= (3, 11) and scope._host_task is not None: + scope._host_task.uncancel(scope=scope) + return not self.cancelled @property diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py index 9d3dcffb..96d8a4d1 100644 --- a/tests/test_taskgroups.py +++ b/tests/test_taskgroups.py @@ -1120,3 +1120,15 @@ async def noop() -> None: assert exceptions == [] finally: loop.set_exception_handler(old_exception_handler) + + @pytest.mark.skipif(sys.version_info < (3, 11), reason="Requires Python >= 3.11") + @pytest.mark.parametrize("anyio_backend", ["asyncio"]) + async def test_asyncio_timeout(self) -> None: + """Test that CancelScope.__exit__ un-cancels the task.""" + with CancelScope() as scope: + scope.cancel() + await sleep(2) + pytest.fail("Execution should not reach this point") + + task = asyncio.current_task() + assert not task.cancelling()