diff --git a/docs/tasks.rst b/docs/tasks.rst
index 9833acaa..0f8b0126 100644
--- a/docs/tasks.rst
+++ b/docs/tasks.rst
@@ -142,3 +142,21 @@ host task that will be copied, but the context of the task that calls
support for this only landed in v3.7.
.. _context: https://docs.python.org/3/library/contextvars.html
+
+Differences with asyncio.TaskGroup
+----------------------------------
+
+The :class:`asyncio.TaskGroup` class, added in Python 3.11, is very similar in design to
+the AnyIO :class:`~TaskGroup` class. The asyncio counterpart has some important
+differences in its semantics, however:
+
+* Tasks are spawned solely through :meth:`~asyncio.TaskGroup.create_task`; there is no
+ ``start()`` or ``start_soon()`` method
+* The :meth:`~asyncio.TaskGroup.create_task` method returns a task object which can be
+ awaited on (or cancelled)
+* Tasks can only be cancelled individually (there is no ``cancel()`` method or similar
+ in the task group)
+* When a task is cancelled before its coroutine has started running, it will not get a
+ chance to handle the cancellation exception
+* New tasks cannot be started after an exception in one of the tasks has triggered a
+ shutdown
diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst
index dccd1b5b..4e8a4b37 100644
--- a/docs/versionhistory.rst
+++ b/docs/versionhistory.rst
@@ -7,6 +7,15 @@ This library adheres to `Semantic Versioning 2.0 `_.
- **BACKWARDS INCOMPATIBLE** Replaced AnyIO's own ``ExceptionGroup`` class with the PEP
654 ``BaseExceptionGroup`` and ``ExceptionGroup``
+- **BACKWARDS INCOMPATIBLE** Changes to cancellation semantics:
+
+ - Any exceptions raising out of a task groups are now nested inside an
+ ``ExceptionGroup`` (or ``BaseExceptionGroup`` if one or more ``BaseException`` were
+ included), except when all the exceptions are cancellation exceptions. In that case,
+ a single cancellation exception is raised instead.
+ - ``CancelScope`` now un-cancels its host task on Python 3.11 + asyncio when
+ appropriate, for compatibility with ``asyncio.timeout`` and other context managers
+ that swallow exceptions
- Bumped minimum version of trio to v0.22
- Added ``create_unix_datagram_socket`` and ``create_connected_unix_datagram_socket`` to
create UNIX datagram sockets (PR by Jean Hominal)
@@ -42,6 +51,8 @@ This library adheres to `Semantic Versioning 2.0 `_.
the event loop to be closed
- Fixed ``current_effective_deadline()`` not returning ``-inf`` on asyncio when the
currently active cancel scope has been cancelled (PR by Ganden Schaffner)
+- Fixed task group not raising a cancellation exception on asyncio at exit if no child
+ tasks were spawned and an outer cancellation scope had been cancelled before
**3.6.1**
diff --git a/pyproject.toml b/pyproject.toml
index 48fb8d04..c1e22706 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -26,7 +26,7 @@ classifiers = [
]
requires-python = ">= 3.7"
dependencies = [
- "exceptiongroup; python_version < '3.11'",
+ "exceptiongroup >= 1.0.2; python_version < '3.11'",
"idna >= 2.8",
"sniffio >= 1.1",
"typing_extensions; python_version < '3.8'",
diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py
index 1c085277..c4bb6a5a 100644
--- a/src/anyio/_backends/_asyncio.py
+++ b/src/anyio/_backends/_asyncio.py
@@ -30,19 +30,19 @@
from signal import Signals
from socket import AddressFamily, SocketKind
from threading import Thread
-from types import TracebackType
+from types import MethodType, TracebackType
from typing import (
IO,
Any,
AsyncGenerator,
Awaitable,
Callable,
+ ClassVar,
Collection,
ContextManager,
Coroutine,
Deque,
Generator,
- Iterator,
Mapping,
Optional,
Sequence,
@@ -78,7 +78,7 @@
from ..lowlevel import RunVar
if sys.version_info < (3, 11):
- from exceptiongroup import BaseExceptionGroup, ExceptionGroup
+ from exceptiongroup import BaseExceptionGroup
if sys.version_info >= (3, 8):
@@ -119,7 +119,7 @@ def find_root_task() -> asyncio.Task:
# Look up the topmost task in the AnyIO task tree, if possible
task = cast(asyncio.Task, current_task())
- state = _task_states.get(task)
+ state = TaskState.get(task)
if state:
cancel_scope = state.cancel_scope
while cancel_scope and cancel_scope._parent_scope is not None:
@@ -206,18 +206,14 @@ def __enter__(self) -> CancelScope:
"Each CancelScope may only be used for a single 'with' block"
)
- self._host_task = host_task = cast(asyncio.Task, current_task())
- self._tasks.add(host_task)
- try:
- task_state = _task_states[host_task]
- except KeyError:
- task_name = host_task.get_name() if _native_task_names else None
- task_state = TaskState(None, task_name, self)
- _task_states[host_task] = task_state
- else:
- self._parent_scope = task_state.cancel_scope
- task_state.cancel_scope = self
+ host_task = current_task()
+ if host_task is None:
+ raise RuntimeError("Must be in a task to enter a cancel scope")
+ self._host_task = host_task
+ task_state = TaskState.get(self._host_task)
+ self._parent_scope = task_state.cancel_scope if task_state else None
+ self._add_task(host_task)
self._timeout()
self._active = True
@@ -235,49 +231,41 @@ def __exit__(
) -> bool | None:
if not self._active:
raise RuntimeError("This cancel scope is not active")
- if current_task() is not self._host_task:
+
+ host_task = current_task()
+ if host_task is None:
+ raise RuntimeError("Must be in a task to exit a cancel scope")
+ elif host_task is not self._host_task:
raise RuntimeError(
"Attempted to exit cancel scope in a different task than it was "
"entered in"
)
- assert self._host_task is not None
- host_task_state = _task_states.get(self._host_task)
- if host_task_state is None or host_task_state.cancel_scope is not self:
+ task_state = TaskState.get(host_task)
+ if task_state is None or task_state.cancel_scope is not self:
raise RuntimeError(
"Attempted to exit a cancel scope that isn't the current tasks's "
"current cancel scope"
)
+ ignore_exc = isinstance(exc_val, CancelledError) and task_state.uncancel(self)
+
+ # TaskGroup removes the host task from _tasks, so it wouldn't try to wait on
+ # itself
+ if self._host_task in self._tasks:
+ self._remove_task(self._host_task)
+
self._active = False
if self._timeout_handle:
self._timeout_handle.cancel()
self._timeout_handle = None
- self._tasks.remove(self._host_task)
-
- host_task_state.cancel_scope = self._parent_scope
-
# Restart the cancellation effort in the farthest directly cancelled parent
# scope if this one was shielded
if self._shield:
self._deliver_cancellation_to_parent()
- if exc_val is not None:
- exceptions = (
- exc_val.exceptions if isinstance(exc_val, ExceptionGroup) else [exc_val]
- )
- if all(isinstance(exc, CancelledError) for exc in exceptions):
- if self._timeout_expired:
- return True
- elif not self._cancel_called:
- # Task was cancelled natively
- return None
- elif not self._parent_cancelled():
- # This scope was directly cancelled
- return True
-
- return None
+ return ignore_exc
def _timeout(self) -> None:
if self._deadline != math.inf:
@@ -303,7 +291,8 @@ def _deliver_cancellation(self) -> None:
# The task is eligible for cancellation if it has started and is not in a
# cancel scope shielded from this one
- cancel_scope = _task_states[task].cancel_scope
+ task_state = TaskState.get(task)
+ cancel_scope = task_state.cancel_scope if task_state else None
while cancel_scope is not self:
if cancel_scope is None or cancel_scope._shield:
break
@@ -314,7 +303,7 @@ def _deliver_cancellation(self) -> None:
if task is not current and (
task is self._host_task or _task_started(task)
):
- task.cancel()
+ task.cancel(scope=self) # type: ignore[call-arg]
# Schedule another callback if there are still tasks left
if should_retry:
@@ -352,6 +341,16 @@ def _parent_cancelled(self) -> bool:
return False
+ def _add_task(
+ self, task: asyncio.Task, parent_task: asyncio.Task | None = None
+ ) -> None:
+ self._tasks.add(task)
+ TaskState.add(task, scope=self, parent_task=parent_task)
+
+ def _remove_task(self, task: asyncio.Task) -> None:
+ self._tasks.remove(task)
+ TaskState.remove(task, self)
+
def cancel(self) -> None:
if not self._cancel_called:
if self._timeout_handle:
@@ -403,17 +402,114 @@ class TaskState:
itself because there are no guarantees about its implementation.
"""
- __slots__ = "parent_id", "name", "cancel_scope"
+ __slots__ = (
+ "parent_id",
+ "name",
+ "cancel_scope",
+ "_cancelled_by_scopes",
+ "_native_cancellations",
+ )
+
+ _task_states: ClassVar[
+ WeakKeyDictionary[asyncio.Task, TaskState]
+ ] = WeakKeyDictionary()
def __init__(
- self, parent_id: int | None, name: str | None, cancel_scope: CancelScope | None
+ self,
+ parent_task: asyncio.Task | None,
+ name: str | None,
+ cancel_scope: CancelScope | None,
):
- self.parent_id = parent_id
+ self.parent_id = id(parent_task) if parent_task else None
self.name = name
self.cancel_scope = cancel_scope
+ self._cancelled_by_scopes: set[CancelScope] = set()
+ self._native_cancellations = 0
+
+ @classmethod
+ def add(
+ cls,
+ task: asyncio.Task,
+ *,
+ scope: CancelScope | None = None,
+ parent_task: asyncio.Task | None = None,
+ name: object = None,
+ ) -> TaskState:
+ state = cls._task_states.get(task)
+ if state is None:
+ task_name: str | None = None
+ if name is not None:
+ task_name = str(name)
+ elif _native_task_names:
+ task_name = task.get_name()
+
+ state = cls(parent_task, task_name, scope)
+ cls._task_states[task] = state
+ task.cancel = MethodType( # type: ignore[assignment]
+ partial(TaskState._patched_cancel, original=task.cancel), task
+ )
+ if sys.version_info >= (3, 11):
+ task.uncancel = MethodType(
+ partial(TaskState._patched_uncancel, original=task.uncancel), task
+ )
+ else:
+ state.cancel_scope = scope
+
+ assert state.cancel_scope is scope
+ return state
+
+ @classmethod
+ def get(cls, task: asyncio.Task) -> TaskState | None:
+ return cls._task_states.get(task)
+
+ @classmethod
+ def remove(cls, task: asyncio.Task, scope: CancelScope | None) -> None:
+ state = cls._task_states[task]
+ if state.cancel_scope is None:
+ patched_method = cast(partial, task.cancel)
+ task.cancel = patched_method.keywords[ # type: ignore[assignment]
+ "original"
+ ]
+ del cls._task_states[task]
+ else:
+ state.cancel_scope = state.cancel_scope._parent_scope
+
+ @staticmethod
+ def _patched_cancel(
+ self: asyncio.Task,
+ *args: object,
+ original: Callable,
+ scope: CancelScope | None = None,
+ ) -> None:
+ state = TaskState._task_states.get(self)
+ if state is not None:
+ if scope is None:
+ state._native_cancellations += 1
+ else:
+ state._cancelled_by_scopes.add(scope)
+
+ original(*args)
+
+ @staticmethod
+ def _patched_uncancel(
+ self: asyncio.Task, *, scope: CancelScope | None = None, original: Callable
+ ) -> int:
+ state = TaskState._task_states.get(self)
+ if scope is None and state is not None and state._native_cancellations:
+ state._native_cancellations -= 1
+
+ return original()
+ def uncancel(self, scope: CancelScope) -> bool:
+ self._cancelled_by_scopes.discard(scope)
+ if sys.version_info >= (3, 11) and scope._host_task is not None:
+ scope._host_task.uncancel(scope=scope)
-_task_states = WeakKeyDictionary() # type: WeakKeyDictionary[asyncio.Task, TaskState]
+ return not self.cancelled
+
+ @property
+ def cancelled(self) -> bool:
+ return bool(self._native_cancellations) or bool(self._cancelled_by_scopes)
#
@@ -434,38 +530,13 @@ def started(self, value: object = None) -> None:
"called 'started' twice on the same task status"
) from None
- task = cast(asyncio.Task, current_task())
- _task_states[task].parent_id = self._parent_id
-
-
-def collapse_exception_group(excgroup: BaseExceptionGroup) -> BaseException:
- exceptions = list(excgroup.exceptions)
- modified = False
- for i, exc in enumerate(exceptions):
- if isinstance(exc, BaseExceptionGroup):
- new_exc = collapse_exception_group(exc)
- if new_exc is not exc:
- modified = True
- exceptions[i] = new_exc
-
- if len(exceptions) == 1:
- return exceptions[0]
- elif modified:
- return excgroup.derive(exceptions)
- else:
- return excgroup
-
-
-def walk_exception_group(excgroup: BaseExceptionGroup) -> Iterator[BaseException]:
- for exc in excgroup.exceptions:
- if isinstance(exc, BaseExceptionGroup):
- yield from walk_exception_group(exc)
- else:
- yield exc
-
+ task = current_task()
+ if not task:
+ raise RuntimeError("called 'started' outside of a task")
-def is_anyio_cancelled_exc(exc: BaseException) -> bool:
- return isinstance(exc, CancelledError) and not exc.args
+ task_state = TaskState.get(task)
+ if task_state:
+ task_state.parent_id = self._parent_id
class TaskGroup(abc.TaskGroup):
@@ -485,36 +556,42 @@ async def __aexit__(
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
+ host_task = cast(asyncio.Task, self.cancel_scope._host_task)
ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
if exc_val is not None:
self.cancel_scope.cancel()
self._exceptions.append(exc_val)
- while self.cancel_scope._tasks:
+ if self.cancel_scope._tasks:
+ while self.cancel_scope._tasks:
+ try:
+ await asyncio.wait(self.cancel_scope._tasks)
+ except asyncio.CancelledError as e:
+ self.cancel_scope.cancel()
+ self._exceptions.append(e)
+ else:
+ # No tasks to wait on, but we still need to check for cancellation here
try:
- await asyncio.wait(self.cancel_scope._tasks)
- except asyncio.CancelledError:
- self.cancel_scope.cancel()
+ await AsyncIOBackend.checkpoint()
+ except CancelledError as e:
+ self._exceptions.append(e)
self._active = False
if self._exceptions:
exc: BaseException | None
- group = BaseExceptionGroup("multiple tasks failed", self._exceptions)
- if not self.cancel_scope._parent_cancelled():
- # If any exceptions other than AnyIO cancellation exceptions have been
- # received, raise those
- _, exc = group.split(is_anyio_cancelled_exc)
- elif all(is_anyio_cancelled_exc(e) for e in walk_exception_group(group)):
- # All tasks were cancelled by AnyIO
- exc = CancelledError()
- else:
- exc = group
-
- if isinstance(exc, BaseExceptionGroup):
- exc = collapse_exception_group(exc)
+ group = BaseExceptionGroup(
+ "one or more errors occurred in a task group", self._exceptions
+ )
+ matched, unmatched = group.split(CancelledError)
+ if unmatched:
+ # If there are exceptions other than CancelledError, always raise those
+ raise unmatched
- if exc is not None and exc is not exc_val:
- raise exc
+ # If the host task was natively cancelled, or a parent cancel scope was
+ # cancelled, raise a new CancelledError
+ task_state = TaskState.get(host_task)
+ if task_state and task_state.cancelled:
+ raise CancelledError from None
return ignore_exception
@@ -539,9 +616,7 @@ async def _run_wrapped_task(
RuntimeError("Child exited without calling task_status.started()")
)
finally:
- if task in self.cancel_scope._tasks:
- self.cancel_scope._tasks.remove(task)
- del _task_states[task]
+ self.cancel_scope._remove_task(task)
def _spawn(
self,
@@ -552,10 +627,7 @@ def _spawn(
) -> asyncio.Task:
def task_done(_task: asyncio.Task) -> None:
# This is the code path for Python 3.8+
- assert _task in self.cancel_scope._tasks
- self.cancel_scope._tasks.remove(_task)
- del _task_states[_task]
-
+ self.cancel_scope._remove_task(task)
try:
exc = _task.exception()
except CancelledError as e:
@@ -587,12 +659,13 @@ def task_done(_task: asyncio.Task) -> None:
kwargs = {}
if task_status_future:
- parent_id = id(current_task())
+ parent_task = cast(asyncio.Task, current_task())
kwargs["task_status"] = _AsyncioTaskStatus(
task_status_future, id(self.cancel_scope._host_task)
)
else:
- parent_id = id(self.cancel_scope._host_task)
+ assert self.cancel_scope._host_task
+ parent_task = self.cancel_scope._host_task
coro = func(*args, **kwargs)
if not asyncio.iscoroutine(coro):
@@ -609,10 +682,8 @@ def task_done(_task: asyncio.Task) -> None:
task.add_done_callback(task_done)
# Make the spawned task inherit the task group's cancel scope
- _task_states[task] = TaskState(
- parent_id=parent_id, name=name, cancel_scope=self.cancel_scope
- )
- self.cancel_scope._tasks.add(task)
+ TaskState.add(task, parent_task=parent_task, name=name)
+ self.cancel_scope._add_task(task, parent_task)
return task
def start_soon(
@@ -1662,7 +1733,7 @@ async def __anext__(self) -> Signals:
def _create_task_info(task: asyncio.Task) -> TaskInfo:
- task_state = _task_states.get(task)
+ task_state = TaskState.get(task)
if task_state is None:
name = task.get_name() if _native_task_names else None
parent_id = None
@@ -1798,15 +1869,14 @@ def run(
@wraps(func)
async def wrapper() -> T_Retval:
task = cast(asyncio.Task, current_task())
- task_state = TaskState(None, get_callable_name(func), None)
- _task_states[task] = task_state
+ task_state = TaskState.add(task, name=get_callable_name(func))
if _native_task_names:
task.set_name(task_state.name)
try:
return await func(*args)
finally:
- del _task_states[task]
+ TaskState.remove(task, None)
debug = options.get("debug", False)
policy = options.get("policy", None)
@@ -1836,11 +1906,8 @@ async def checkpoint_if_cancelled(cls) -> None:
if task is None:
return
- try:
- cancel_scope = _task_states[task].cancel_scope
- except KeyError:
- return
-
+ task_state = TaskState.get(task)
+ cancel_scope = task_state.cancel_scope if task_state else None
while cancel_scope:
if cancel_scope.cancel_called:
await sleep(0)
@@ -1866,13 +1933,9 @@ def create_cancel_scope(
@classmethod
def current_effective_deadline(cls) -> float:
- try:
- cancel_scope = _task_states[
- current_task() # type: ignore[index]
- ].cancel_scope
- except KeyError:
- return math.inf
-
+ task = cast(asyncio.Task, current_task())
+ task_state = TaskState.get(task)
+ cancel_scope = task_state.cancel_scope if task_state else None
deadline = math.inf
while cancel_scope:
deadline = min(deadline, cancel_scope.deadline)
diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py
index a5a21c73..d9f82ce3 100644
--- a/src/anyio/_backends/_trio.py
+++ b/src/anyio/_backends/_trio.py
@@ -3,6 +3,7 @@
import array
import math
import socket
+import sys
from collections.abc import AsyncIterator, Iterable
from concurrent.futures import Future
from dataclasses import dataclass
@@ -61,6 +62,9 @@
from ..abc import IPSockAddrType, UDPPacketType, UNIXDatagramPacketType
from ..abc._eventloop import AsyncBackend
+if sys.version_info < (3, 11):
+ from exceptiongroup import BaseExceptionGroup
+
if TYPE_CHECKING:
from trio_typing import TaskStatus
@@ -134,7 +138,9 @@ def shield(self, value: bool) -> None:
class TaskGroup(abc.TaskGroup):
def __init__(self) -> None:
self._active = False
- self._nursery_manager = trio.open_nursery()
+ self._nursery_manager = trio.open_nursery(
+ strict_exception_groups=True # type: ignore[call-arg]
+ )
self.cancel_scope = None # type: ignore[assignment]
async def __aenter__(self) -> TaskGroup:
@@ -151,6 +157,14 @@ async def __aexit__(
) -> bool | None:
try:
return await self._nursery_manager.__aexit__(exc_type, exc_val, exc_tb)
+ except BaseExceptionGroup as excgrp:
+ matched, unmatched = excgrp.split(trio.Cancelled)
+ if not unmatched:
+ raise trio.Cancelled._create() from None # type: ignore[attr-defined]
+ elif unmatched:
+ raise unmatched from None
+ else:
+ raise
finally:
self._active = False
@@ -719,7 +733,9 @@ def __init__(self, **options: Any) -> None:
async def _trio_main(self) -> None:
self._stop_event = trio.Event()
- async with trio.open_nursery() as self._nursery:
+ async with trio.open_nursery(
+ strict_exception_groups=True # type: ignore[call-arg]
+ ) as self._nursery:
await self._stop_event.wait()
async def _call_func(
diff --git a/src/anyio/from_thread.py b/src/anyio/from_thread.py
index f5ddc2e3..da8923e8 100644
--- a/src/anyio/from_thread.py
+++ b/src/anyio/from_thread.py
@@ -148,6 +148,10 @@ async def __aexit__(
await self.stop()
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
+ @property
+ def _running(self) -> bool:
+ return self._event_loop_thread_id is not None
+
def _check_running(self) -> None:
if self._event_loop_thread_id is None:
raise RuntimeError("This portal is not running")
@@ -202,8 +206,11 @@ def callback(f: Future) -> None:
if not future.cancelled():
future.set_exception(exc)
- # Let base exceptions fall through
+ # Let base exceptions fall through, but mark the portal as not running, so
+ # start_blocking_portal() won't try to stop it since BaseException will
+ # cause that anyway
if not isinstance(exc, Exception):
+ self._event_loop_thread_id = None
raise
else:
if not future.cancelled():
@@ -413,9 +420,7 @@ async def run_portal() -> None:
cancel_remaining_tasks = True
raise
finally:
- try:
+ if not run_future.done() and portal._running:
portal.call(portal.stop, cancel_remaining_tasks)
- except RuntimeError:
- pass
run_future.result()
diff --git a/tests/streams/test_memory.py b/tests/streams/test_memory.py
index f322d9c4..b83169e4 100644
--- a/tests/streams/test_memory.py
+++ b/tests/streams/test_memory.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import sys
+
import pytest
from anyio import (
@@ -16,6 +18,9 @@
from anyio.abc import ObjectReceiveStream, ObjectSendStream
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
+if sys.version_info < (3, 11):
+ from exceptiongroup import ExceptionGroup
+
pytestmark = pytest.mark.anyio
@@ -170,21 +175,27 @@ async def test_clone_closed() -> None:
async def test_close_send_while_receiving() -> None:
send, receive = create_memory_object_stream(1)
- with pytest.raises(EndOfStream):
+ with pytest.raises(ExceptionGroup) as exc:
async with create_task_group() as tg:
tg.start_soon(receive.receive)
await wait_all_tasks_blocked()
await send.aclose()
+ matched, unmatched = exc.value.split(EndOfStream)
+ assert not unmatched
+
async def test_close_receive_while_sending() -> None:
send, receive = create_memory_object_stream(0)
- with pytest.raises(BrokenResourceError):
+ with pytest.raises(ExceptionGroup) as exc:
async with create_task_group() as tg:
tg.start_soon(send.send, "hello")
await wait_all_tasks_blocked()
await receive.aclose()
+ matched, unmatched = exc.value.split(BrokenResourceError)
+ assert not unmatched
+
async def test_receive_after_send_closed() -> None:
send, receive = create_memory_object_stream(1)
diff --git a/tests/streams/test_tls.py b/tests/streams/test_tls.py
index fed2b1da..bc4a63b7 100644
--- a/tests/streams/test_tls.py
+++ b/tests/streams/test_tls.py
@@ -235,7 +235,11 @@ def serve_sync() -> None:
)
with client_cm:
assert await wrapper.receive() == b"hello"
- await wrapper.aclose()
+ try:
+ await wrapper.aclose()
+ except ssl.SSLError:
+ print("OpenSSL version:", ssl.OPENSSL_VERSION)
+ raise
server_thread.join()
server_sock.close()
diff --git a/tests/test_from_thread.py b/tests/test_from_thread.py
index 93886f53..4906cb70 100644
--- a/tests/test_from_thread.py
+++ b/tests/test_from_thread.py
@@ -27,6 +27,9 @@
from anyio.from_thread import BlockingPortal, start_blocking_portal
from anyio.lowlevel import checkpoint
+if sys.version_info < (3, 11):
+ from exceptiongroup import BaseExceptionGroup, ExceptionGroup
+
if sys.version_info >= (3, 8):
from typing import Literal
else:
@@ -390,10 +393,13 @@ async def run_in_context() -> AsyncGenerator[None, None]:
yield
with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal:
- with pytest.raises(ZeroDivisionError):
+ with pytest.raises(ExceptionGroup) as exc:
with portal.wrap_async_context_manager(run_in_context()):
pass
+ _, unmatched = exc.value.split(ZeroDivisionError)
+ assert not unmatched
+
def test_start_no_value(
self, anyio_backend_name: str, anyio_backend_options: dict[str, Any]
) -> None:
@@ -516,7 +522,7 @@ def test_raise_baseexception_from_task(
async def raise_baseexception() -> None:
raise BaseException("fatal error")
- with pytest.raises(BaseException, match="fatal error"):
+ with pytest.raises(BaseExceptionGroup) as outer_exc:
with start_blocking_portal(
anyio_backend_name, anyio_backend_options
) as portal:
@@ -524,3 +530,6 @@ async def raise_baseexception() -> None:
portal.call(raise_baseexception)
assert exc.value.__context__ is None
+
+ assert len(outer_exc.value.exceptions) == 1
+ assert str(outer_exc.value.exceptions[0]) == "fatal error"
diff --git a/tests/test_sockets.py b/tests/test_sockets.py
index acffe920..072d626d 100644
--- a/tests/test_sockets.py
+++ b/tests/test_sockets.py
@@ -657,13 +657,16 @@ async def test_reuse_port(self, family: AnyIPAddressFamily) -> None:
async def test_close_from_other_task(self, family: AnyIPAddressFamily) -> None:
listener = await create_tcp_listener(local_host="localhost", family=family)
- with pytest.raises(ClosedResourceError):
+ with pytest.raises(ExceptionGroup) as exc:
async with create_task_group() as tg:
tg.start_soon(listener.serve, lambda stream: None)
await wait_all_tasks_blocked()
await listener.aclose()
tg.cancel_scope.cancel()
+ _, unmatched = exc.value.split(ClosedResourceError)
+ assert not unmatched
+
async def test_send_after_eof(self, family: AnyIPAddressFamily) -> None:
async def handle(stream: SocketStream) -> None:
async with stream:
@@ -1253,8 +1256,8 @@ async def test_send_receive(self, family: AnyIPAddressFamily) -> None:
async with await create_connected_udp_socket(
host, port, local_host="localhost", family=family
) as udp2:
- host, port = udp2.extra(
- SocketAttribute.local_address # type: ignore[misc]
+ host, port = udp2.extra( # type: ignore[misc]
+ SocketAttribute.local_address
)
await udp2.send(b"blah")
request = await udp1.receive()
diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py
index adf2e79d..96d8a4d1 100644
--- a/tests/test_taskgroups.py
+++ b/tests/test_taskgroups.py
@@ -27,7 +27,7 @@
from anyio.lowlevel import checkpoint
if sys.version_info < (3, 11):
- from exceptiongroup import BaseExceptionGroup
+ from exceptiongroup import BaseExceptionGroup, ExceptionGroup
pytestmark = pytest.mark.anyio
@@ -89,7 +89,7 @@ async def task_func() -> None:
async def test_start_soon_after_error() -> None:
- with pytest.raises(ZeroDivisionError):
+ with pytest.raises(ExceptionGroup):
async with create_task_group() as tg:
a = 1 / 0 # noqa: F841
@@ -147,11 +147,11 @@ async def taskfunc(*, task_status: TaskStatus) -> NoReturn:
task_status.started(2)
raise Exception("foo")
- with pytest.raises(Exception) as exc:
+ with pytest.raises(ExceptionGroup) as exc:
async with create_task_group() as tg:
value = await tg.start(taskfunc)
- exc.match("foo")
+ assert str(exc.value.exceptions[0]) == "foo"
assert value == 2
@@ -183,54 +183,6 @@ async def taskfunc(*, task_status: TaskStatus) -> None:
assert not finished
-@pytest.mark.parametrize("anyio_backend", ["asyncio"])
-async def test_start_native_host_cancelled() -> None:
- started = finished = False
-
- async def taskfunc(*, task_status: TaskStatus) -> None:
- nonlocal started, finished
- started = True
- await sleep(2)
- finished = True
-
- async def start_another() -> None:
- async with create_task_group() as tg:
- await tg.start(taskfunc)
-
- task = asyncio.get_running_loop().create_task(start_another())
- await wait_all_tasks_blocked()
- task.cancel()
- with pytest.raises(asyncio.CancelledError):
- await task
-
- assert started
- assert not finished
-
-
-@pytest.mark.parametrize("anyio_backend", ["asyncio"])
-async def test_start_native_child_cancelled() -> None:
- task = None
- finished = False
-
- async def taskfunc(*, task_status: TaskStatus) -> None:
- nonlocal task, finished
- task = asyncio.current_task()
- await sleep(2)
- finished = True
-
- async def start_another() -> None:
- async with create_task_group() as tg2:
- await tg2.start(taskfunc)
-
- async with create_task_group() as tg:
- tg.start_soon(start_another)
- await wait_all_tasks_blocked()
- assert task is not None
- task.cancel()
-
- assert not finished
-
-
async def test_start_exception_delivery() -> None:
def task_fn(*, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None:
task_status.started("hello")
@@ -248,12 +200,13 @@ async def set_result(value: str) -> None:
await sleep(3)
result = value
- with pytest.raises(Exception) as exc:
+ with pytest.raises(ExceptionGroup) as exc:
async with create_task_group() as tg:
tg.start_soon(set_result, "a")
raise Exception("dummy error")
- exc.match("dummy error")
+ assert len(exc.value.exceptions) == 1
+ assert str(exc.value.exceptions[0]) == "dummy error"
assert result is None
@@ -282,13 +235,14 @@ async def child() -> NoReturn:
raise Exception("foo")
sleep_completed = False
- with pytest.raises(Exception) as exc:
+ with pytest.raises(ExceptionGroup) as exc:
async with create_task_group() as tg:
tg.start_soon(child)
await sleep(0.5)
sleep_completed = True
- exc.match("foo")
+ assert len(exc.value.exceptions) == 1
+ assert str(exc.value.exceptions[0]) == "foo"
assert not sleep_completed
@@ -300,13 +254,14 @@ async def child() -> None:
await sleep(1)
sleep_completed = True
- with pytest.raises(Exception) as exc:
+ with pytest.raises(ExceptionGroup) as exc:
async with create_task_group() as tg:
tg.start_soon(child)
await wait_all_tasks_blocked()
raise Exception("foo")
- exc.match("foo")
+ assert len(exc.value.exceptions) == 1
+ assert str(exc.value.exceptions[0]) == "foo"
assert not sleep_completed
@@ -369,7 +324,7 @@ async def waiter() -> None:
nonlocal cancel_received
try:
await sleep(5)
- finally:
+ except get_cancelled_exc_class():
cancel_received = True
async def subgroup() -> None:
@@ -396,6 +351,48 @@ async def test_cancel_before_entering_scope() -> None:
pytest.fail("execution should not reach this point")
+async def test_cancel_outer_scope_no_tasks() -> None:
+ """
+ Test that a task group raises an exception group containing one cancellation error
+ from __aexit__() if the outer cancel scope was cancelled.
+
+ """
+ with CancelScope() as outer_scope:
+ try:
+ async with anyio.create_task_group():
+ outer_scope.cancel()
+ except BaseException as exc:
+ if not isinstance(exc, get_cancelled_exc_class()):
+ pytest.fail("should have raised a cancellation exception")
+
+ raise
+ else:
+ pytest.fail("should have raised an exception")
+
+
+async def test_cancel_outer_scope_one_task() -> None:
+ """
+ Test that a task group propagates a cancellation error (wrapped in an exception
+ group) from __aexit__() that was not intended for the task group's cancel scope.
+
+ """
+ try:
+ with CancelScope() as outer_scope:
+ try:
+ async with anyio.create_task_group() as tg:
+ tg.start_soon(sleep, 3)
+ outer_scope.cancel()
+ except get_cancelled_exc_class():
+ pass
+ except BaseException as exc:
+ pytest.fail(
+ f"should have raised a cancellation error instead of "
+ f"{exc.__class__}"
+ )
+ except BaseException:
+ pytest.fail("the cancel scope should have swallowed the exceptions")
+
+
async def test_exception_group_children() -> None:
with pytest.raises(BaseExceptionGroup) as exc:
async with create_task_group() as tg:
@@ -549,40 +546,6 @@ async def test_cancel_shielded_scope() -> None:
await sleep(0)
-@pytest.mark.parametrize("anyio_backend", ["asyncio"])
-async def test_cancel_host_asyncgen() -> None:
- done = False
-
- async def host_task() -> None:
- nonlocal done
- async with create_task_group() as tg:
- with CancelScope(shield=True) as inner_scope:
- assert inner_scope.shield
- tg.cancel_scope.cancel()
-
- with pytest.raises(get_cancelled_exc_class()):
- await sleep(0)
-
- with pytest.raises(get_cancelled_exc_class()):
- await sleep(0)
-
- done = True
-
- async def host_agen_fn() -> AsyncGenerator[None, None]:
- await host_task()
- yield
- pytest.fail("host_agen_fn should only be __anext__ed once")
-
- host_agen = host_agen_fn()
- try:
- loop = asyncio.get_running_loop()
- await loop.create_task(host_agen.__anext__()) # type: ignore[arg-type]
- finally:
- await host_agen.aclose()
-
- assert done
-
-
async def test_shielding_immediate_scope_cancelled() -> None:
async def cancel_when_ready() -> None:
await wait_all_tasks_blocked()
@@ -654,13 +617,14 @@ async def child(fail: bool) -> None:
await sleep(1)
sleep_completed = True
- with pytest.raises(Exception) as exc:
+ with pytest.raises(ExceptionGroup) as exc:
async with create_task_group() as tg:
tg.start_soon(child, False)
await wait_all_tasks_blocked()
tg.start_soon(child, True)
- exc.match("foo")
+ assert len(exc.value.exceptions) == 1
+ assert str(exc.value.exceptions[0]) == "foo"
assert not sleep_completed
@@ -745,7 +709,7 @@ async def killer(scope: CancelScope) -> None:
await wait_all_tasks_blocked()
scope.cancel()
- with pytest.raises(TimeoutError):
+ with pytest.raises(ExceptionGroup) as exc:
async with create_task_group() as tg:
with CancelScope() as scope:
with CancelScope(shield=True):
@@ -753,6 +717,9 @@ async def killer(scope: CancelScope) -> None:
with fail_after(0.2):
await sleep(2)
+ _, unmatched = exc.value.split(TimeoutError)
+ assert not unmatched
+
async def test_triple_nested_shield() -> None:
"""Regression test for #370."""
@@ -820,7 +787,8 @@ async def fn() -> None:
assert len(exc.value.exceptions) == 2
assert str(exc.value.exceptions[0]) == "parent task failed"
- assert str(exc.value.exceptions[1]) == "child task failed"
+ assert isinstance(exc.value.exceptions[1], ExceptionGroup)
+ assert str(exc.value.exceptions[1].exceptions[0]) == "child task failed"
async def test_cancel_propagation_with_inner_spawn() -> None:
@@ -849,74 +817,6 @@ async def test_escaping_cancelled_error_from_cancelled_task() -> None:
scope.cancel()
-@pytest.mark.skipif(
- sys.version_info >= (3, 11),
- reason="Generator based coroutines have been removed in Python 3.11",
-)
-@pytest.mark.filterwarnings(
- 'ignore:"@coroutine" decorator is deprecated:DeprecationWarning'
-)
-def test_cancel_generator_based_task() -> None:
- async def native_coro_part() -> None:
- with CancelScope() as scope:
- asyncio.get_running_loop().call_soon(scope.cancel)
- await asyncio.sleep(1)
- pytest.fail("Execution should not have reached this line")
-
- @asyncio.coroutine
- def generator_part() -> Generator[object, BaseException, None]:
- yield from native_coro_part()
-
- anyio.run(generator_part, backend="asyncio")
-
-
-@pytest.mark.parametrize("anyio_backend", ["asyncio"])
-async def test_cancel_native_future_tasks() -> None:
- async def wait_native_future() -> None:
- loop = asyncio.get_running_loop()
- await loop.create_future()
-
- async with anyio.create_task_group() as tg:
- tg.start_soon(wait_native_future)
- tg.cancel_scope.cancel()
-
-
-@pytest.mark.parametrize("anyio_backend", ["asyncio"])
-async def test_cancel_native_future_tasks_cancel_scope() -> None:
- async def wait_native_future() -> None:
- with anyio.CancelScope():
- loop = asyncio.get_running_loop()
- await loop.create_future()
-
- async with anyio.create_task_group() as tg:
- tg.start_soon(wait_native_future)
- tg.cancel_scope.cancel()
-
-
-@pytest.mark.parametrize("anyio_backend", ["asyncio"])
-async def test_cancel_completed_task() -> None:
- loop = asyncio.get_running_loop()
- old_exception_handler = loop.get_exception_handler()
- exceptions = []
-
- def exception_handler(*args: object, **kwargs: object) -> None:
- exceptions.append((args, kwargs))
-
- loop.set_exception_handler(exception_handler)
- try:
-
- async def noop() -> None:
- pass
-
- async with anyio.create_task_group() as tg:
- tg.start_soon(noop)
- tg.cancel_scope.cancel()
-
- assert exceptions == []
- finally:
- loop.set_exception_handler(old_exception_handler)
-
-
async def test_task_in_sync_spawn_callback() -> None:
outer_task_id = anyio.get_current_task().id
inner_task_id = None
@@ -1002,54 +902,12 @@ async def exit_scope(scope: CancelScope) -> None:
async with create_task_group() as tg:
tg.start_soon(enter_scope, scope)
- with pytest.raises(RuntimeError):
+ with pytest.raises(ExceptionGroup) as exc:
async with create_task_group() as tg:
tg.start_soon(exit_scope, scope)
-
-def test_unhandled_exception_group(caplog: pytest.LogCaptureFixture) -> None:
- def crash() -> NoReturn:
- raise KeyboardInterrupt
-
- async def nested() -> None:
- async with anyio.create_task_group() as tg:
- tg.start_soon(anyio.sleep, 5)
- await anyio.sleep(5)
-
- async def main() -> NoReturn:
- async with anyio.create_task_group() as tg:
- tg.start_soon(nested)
- await wait_all_tasks_blocked()
- asyncio.get_running_loop().call_soon(crash)
- await anyio.sleep(5)
-
- pytest.fail("Execution should never reach this point")
-
- with pytest.raises(KeyboardInterrupt):
- anyio.run(main, backend="asyncio")
-
- assert not caplog.messages
-
-
-@pytest.mark.skipif(
- sys.version_info < (3, 9),
- sys.version_info >= (3, 11),
- reason="Cancel messages are only supported on Python 3.9 and 3.10",
-)
-@pytest.mark.parametrize("anyio_backend", ["asyncio"])
-async def test_cancellederror_combination_with_message() -> None:
- async def taskfunc(*, task_status: TaskStatus) -> NoReturn:
- task_status.started(asyncio.current_task())
- await sleep(5)
- pytest.fail("Execution should never reach this point")
-
- with pytest.raises(asyncio.CancelledError, match="blah"):
- async with create_task_group() as tg:
- task = await tg.start(taskfunc)
- tg.start_soon(sleep, 5)
- await wait_all_tasks_blocked()
- assert isinstance(task, asyncio.Task)
- task.cancel("blah")
+ assert len(exc.value.exceptions) == 1
+ assert isinstance(exc.value.exceptions[0], RuntimeError)
async def test_start_soon_parent_id() -> None:
@@ -1092,3 +950,185 @@ async def starter_task() -> None:
assert initial_parent_id != permanent_parent_id
assert initial_parent_id == starter_task_id
assert permanent_parent_id == root_task_id
+
+
+class TestAsyncio:
+ """Contains asyncio specific cancel scope/task group tests."""
+
+ @pytest.mark.parametrize("anyio_backend", ["asyncio"])
+ async def test_cancel_host_asyncgen(self) -> None:
+ done = False
+
+ async def host_task() -> None:
+ nonlocal done
+ async with create_task_group() as tg:
+ with CancelScope(shield=True) as inner_scope:
+ assert inner_scope.shield
+ tg.cancel_scope.cancel()
+
+ with pytest.raises(get_cancelled_exc_class()):
+ await sleep(0)
+
+ with pytest.raises(get_cancelled_exc_class()):
+ await sleep(0)
+
+ done = True
+
+ async def host_agen_fn() -> AsyncGenerator[None, None]:
+ await host_task()
+ yield
+ pytest.fail("host_agen_fn should only be __anext__ed once")
+
+ host_agen = host_agen_fn()
+ try:
+ loop = asyncio.get_running_loop()
+ await loop.create_task(host_agen.__anext__()) # type: ignore[arg-type]
+ finally:
+ await host_agen.aclose()
+
+ assert done
+
+ def test_unhandled_exception_group(self, caplog: pytest.LogCaptureFixture) -> None:
+ def crash() -> NoReturn:
+ raise KeyboardInterrupt
+
+ async def nested() -> None:
+ async with anyio.create_task_group() as tg:
+ tg.start_soon(anyio.sleep, 5)
+ await anyio.sleep(5)
+
+ async def main() -> NoReturn:
+ async with anyio.create_task_group() as tg:
+ tg.start_soon(nested)
+ await wait_all_tasks_blocked()
+ asyncio.get_running_loop().call_soon(crash)
+ await anyio.sleep(5)
+
+ pytest.fail("Execution should never reach this point")
+
+ with pytest.raises(KeyboardInterrupt):
+ anyio.run(main, backend="asyncio")
+
+ assert not caplog.messages
+
+ @pytest.mark.parametrize("anyio_backend", ["asyncio"])
+ async def test_start_native_host_cancelled(self) -> None:
+ started = finished = False
+
+ async def taskfunc(*, task_status: TaskStatus) -> None:
+ nonlocal started, finished
+ started = True
+ await sleep(2)
+ finished = True
+
+ async def start_another() -> None:
+ async with create_task_group() as tg:
+ await tg.start(taskfunc)
+
+ task = asyncio.get_running_loop().create_task(start_another())
+ await wait_all_tasks_blocked()
+ task.cancel()
+ with pytest.raises(asyncio.CancelledError):
+ await task
+
+ assert started
+ assert not finished
+
+ @pytest.mark.parametrize("anyio_backend", ["asyncio"])
+ async def test_start_native_child_cancelled(self) -> None:
+ task = None
+ finished = False
+
+ async def taskfunc(*, task_status: TaskStatus) -> None:
+ nonlocal task, finished
+ task = asyncio.current_task()
+ await sleep(2)
+ finished = True
+
+ async def start_another() -> None:
+ async with create_task_group() as tg2:
+ await tg2.start(taskfunc)
+
+ async with create_task_group() as tg:
+ tg.start_soon(start_another)
+ await wait_all_tasks_blocked()
+ assert task is not None
+ task.cancel()
+
+ assert not finished
+
+ @pytest.mark.skipif(
+ sys.version_info >= (3, 11),
+ reason="Generator based coroutines have been removed in Python 3.11",
+ )
+ @pytest.mark.filterwarnings(
+ 'ignore:"@coroutine" decorator is deprecated:DeprecationWarning'
+ )
+ def test_cancel_generator_based_task(self) -> None:
+ async def native_coro_part() -> None:
+ with CancelScope() as scope:
+ asyncio.get_running_loop().call_soon(scope.cancel)
+ await asyncio.sleep(1)
+ pytest.fail("Execution should not have reached this line")
+
+ @asyncio.coroutine
+ def generator_part() -> Generator[object, BaseException, None]:
+ yield from native_coro_part()
+
+ anyio.run(generator_part, backend="asyncio")
+
+ @pytest.mark.parametrize("anyio_backend", ["asyncio"])
+ async def test_cancel_native_future_tasks(self) -> None:
+ async def wait_native_future() -> None:
+ loop = asyncio.get_running_loop()
+ await loop.create_future()
+
+ async with anyio.create_task_group() as tg:
+ tg.start_soon(wait_native_future)
+ tg.cancel_scope.cancel()
+
+ @pytest.mark.parametrize("anyio_backend", ["asyncio"])
+ async def test_cancel_native_future_tasks_cancel_scope(self) -> None:
+ async def wait_native_future() -> None:
+ with anyio.CancelScope():
+ loop = asyncio.get_running_loop()
+ await loop.create_future()
+
+ async with anyio.create_task_group() as tg:
+ tg.start_soon(wait_native_future)
+ tg.cancel_scope.cancel()
+
+ @pytest.mark.parametrize("anyio_backend", ["asyncio"])
+ async def test_cancel_completed_task(self) -> None:
+ loop = asyncio.get_running_loop()
+ old_exception_handler = loop.get_exception_handler()
+ exceptions = []
+
+ def exception_handler(*args: object, **kwargs: object) -> None:
+ exceptions.append((args, kwargs))
+
+ loop.set_exception_handler(exception_handler)
+ try:
+
+ async def noop() -> None:
+ pass
+
+ async with anyio.create_task_group() as tg:
+ tg.start_soon(noop)
+ tg.cancel_scope.cancel()
+
+ assert exceptions == []
+ finally:
+ loop.set_exception_handler(old_exception_handler)
+
+ @pytest.mark.skipif(sys.version_info < (3, 11), reason="Requires Python >= 3.11")
+ @pytest.mark.parametrize("anyio_backend", ["asyncio"])
+ async def test_asyncio_timeout(self) -> None:
+ """Test that CancelScope.__exit__ un-cancels the task."""
+ with CancelScope() as scope:
+ scope.cancel()
+ await sleep(2)
+ pytest.fail("Execution should not reach this point")
+
+ task = asyncio.current_task()
+ assert not task.cancelling()