diff --git a/changes/44.feature.md b/changes/44.feature.md new file mode 100644 index 0000000..fcb31a5 --- /dev/null +++ b/changes/44.feature.md @@ -0,0 +1 @@ +Propagate task results and exceptions via separate future instances if they are `await`-ed by the caller of `create_task()` in `PersistentTaskGroup`, in addition to invocation of task group exception handler. Note that `await`-ing those futures hangs indefinitely in Python 3.6 but we don't fix it since Python 3.6 is EoL as of December 2021. diff --git a/docs/aiotools.taskgroup.rst b/docs/aiotools.taskgroup.rst index 536307d..2af3c9b 100644 --- a/docs/aiotools.taskgroup.rst +++ b/docs/aiotools.taskgroup.rst @@ -98,8 +98,21 @@ Task Group .. method:: create_task(coro, *, name=None) Spawns a new task inside the taskgroup and returns the reference to - the task. Setting the name of tasks is supported in Python 3.8 or - later only and ignored in older versions. + a :class:`future ` describing the task result. + Setting the name of tasks is supported in Python 3.8 or later only + and ignored in older versions. + + You may ``await`` the retuned future to take the task's return value + or get notified with the exception from it, while the exception + handler is still invoked. Since it is just a *secondary* future, + you cannot cancel the task explicitly using it. To cancel the + task(s), use :meth:`shutdown()` or exit the task group context. + + .. warning:: + + In Python 3.6, ``await``-ing the returned future hangs + indefinitely. We do not fix this issue because Python 3.6 is now + EoL (end-of-life) as of December 2021. .. method:: get_name() diff --git a/src/aiotools/taskgroup/persistent.py b/src/aiotools/taskgroup/persistent.py index dbdffbc..9c960a6 100644 --- a/src/aiotools/taskgroup/persistent.py +++ b/src/aiotools/taskgroup/persistent.py @@ -7,6 +7,7 @@ from types import TracebackType from typing import ( Any, + Awaitable, Callable, Coroutine, List, @@ -85,7 +86,7 @@ def create_task( coro: Coroutine[Any, Any, Any], *, name: str = None, - ) -> "asyncio.Task": + ) -> Awaitable[Any]: if not self._entered: # When used as object attribute, auto-enter. self._entered = True @@ -99,14 +100,18 @@ def _create_task_with_name( *, name: str = None, cb: Callable[[asyncio.Task], Any], - ) -> "asyncio.Task": + ) -> Awaitable[Any]: loop = compat.get_running_loop() - child_task = loop.create_task(self._task_wrapper(coro), name=name) + result_future = loop.create_future() + child_task = loop.create_task( + self._task_wrapper(coro, weakref.ref(result_future)), + name=name, + ) _log.debug("%r is spawned in %r.", child_task, self) self._unfinished_tasks += 1 child_task.add_done_callback(cb) self._tasks.add(child_task) - return child_task + return result_future def _is_base_error(self, exc: BaseException) -> bool: assert isinstance(exc, BaseException) @@ -141,12 +146,24 @@ async def shutdown(self) -> None: self._trigger_shutdown() await self._wait_completion() - async def _task_wrapper(self, coro: Coroutine) -> Any: + async def _task_wrapper( + self, + coro: Coroutine, + result_future: weakref.ref[asyncio.Future], + ) -> Any: loop = compat.get_running_loop() task = compat.current_task() + fut = result_future() try: - return await coro - except Exception: + ret = await coro + if fut is not None: + fut.set_result(ret) + return ret + except asyncio.CancelledError: + if fut is not None: + fut.cancel() + raise + except Exception as e: # Swallow unhandled exceptions by our own and # prevent abortion of the task group bu them. # Wrapping corotuines directly has advantage for @@ -154,6 +171,8 @@ async def _task_wrapper(self, coro: Coroutine) -> Any: # and there is no need to implement separate # mechanism to wait for exception handler tasks. try: + if fut is not None: + fut.set_exception(e) await self._exc_handler(*sys.exc_info()) except Exception as exc: # If there are exceptions inside the exception handler @@ -166,6 +185,8 @@ async def _task_wrapper(self, coro: Coroutine) -> Any: 'exception': exc, 'task': task, }) + finally: + del fut def _on_task_done(self, task: asyncio.Task) -> None: self._unfinished_tasks -= 1 diff --git a/src/aiotools/taskgroup/persistent_compat.py b/src/aiotools/taskgroup/persistent_compat.py index 2592799..357580f 100644 --- a/src/aiotools/taskgroup/persistent_compat.py +++ b/src/aiotools/taskgroup/persistent_compat.py @@ -11,6 +11,7 @@ from types import TracebackType from typing import ( Any, + Awaitable, Callable, Coroutine, List, @@ -90,7 +91,7 @@ def create_task( coro: Coroutine[Any, Any, Any], *, name: str = None, - ) -> "asyncio.Task": + ) -> Awaitable[Any]: if not self._entered: # When used as object attribute, auto-enter. self._entered = True @@ -104,13 +105,18 @@ def _create_task_with_name( *, name: str = None, cb: Callable[[asyncio.Task], Any], - ) -> "asyncio.Task": - child_task = create_task_with_name(self._task_wrapper(coro), name=name) + ) -> Awaitable[Any]: + loop = compat.get_running_loop() + result_future = loop.create_future() + child_task = create_task_with_name( + self._task_wrapper(coro, weakref.ref(result_future)), + name=name, + ) _log.debug("%r is spawned in %r.", child_task, self) self._unfinished_tasks += 1 child_task.add_done_callback(cb) self._tasks.add(child_task) - return child_task + return result_future def _is_base_error(self, exc: BaseException) -> bool: assert isinstance(exc, BaseException) @@ -145,12 +151,24 @@ async def shutdown(self) -> None: self._trigger_shutdown() await self._wait_completion() - async def _task_wrapper(self, coro: Coroutine) -> Any: + async def _task_wrapper( + self, + coro: Coroutine, + result_future: "weakref.ref[asyncio.Future]", + ) -> Any: loop = compat.get_running_loop() task = compat.current_task() + fut = result_future() try: - return await coro - except Exception: + ret = await coro + if fut is not None: + fut.set_result(ret) + return ret + except asyncio.CancelledError: + if fut is not None: + fut.cancel() + raise + except Exception as e: # Swallow unhandled exceptions by our own and # prevent abortion of the task group bu them. # Wrapping corotuines directly has advantage for @@ -158,6 +176,8 @@ async def _task_wrapper(self, coro: Coroutine) -> Any: # and there is no need to implement separate # mechanism to wait for exception handler tasks. try: + if fut is not None: + fut.set_exception(e) await self._exc_handler(*sys.exc_info()) except Exception as exc: # If there are exceptions inside the exception handler @@ -170,6 +190,8 @@ async def _task_wrapper(self, coro: Coroutine) -> Any: 'exception': exc, 'task': task, }) + finally: + del fut def _on_task_done(self, task: asyncio.Task) -> None: self._unfinished_tasks -= 1 diff --git a/tests/test_ptaskgroup.py b/tests/test_ptaskgroup.py index 8d89eb5..5f5bf38 100644 --- a/tests/test_ptaskgroup.py +++ b/tests/test_ptaskgroup.py @@ -42,7 +42,6 @@ async def subtask(): async with aiotools.PersistentTaskGroup() as tg: for idx in range(10): tg.create_task(subtask()) - assert len(tg._tasks) == 10 assert tg._unfinished_tasks == 10 # wait until all is done await asyncio.sleep(0.2) @@ -84,7 +83,6 @@ async def aclose(self): obj = LongLivedObject() for idx in range(10): await obj.work() - assert len(obj.tg._tasks) == 10 assert obj.tg._unfinished_tasks == 10 # shutdown after all done @@ -99,7 +97,6 @@ async def aclose(self): obj = LongLivedObject() for idx in range(10): await obj.work() - assert len(obj.tg._tasks) == 10 assert obj.tg._unfinished_tasks == 10 # shutdown immediately @@ -175,7 +172,6 @@ async def subtask(): for _ in range(10): tg.create_task(subtask()) await asyncio.sleep(0) - assert len(tg._tasks) == 10 # shutdown after exit (all done) is no-op. assert done_count == 10 @@ -202,7 +198,6 @@ async def subtask(): async with aiotools.PersistentTaskGroup() as tg: for _ in range(10): tg.create_task(subtask()) - assert len(tg._tasks) == 10 # let's abort immediately. await tg.shutdown() @@ -210,6 +205,111 @@ async def subtask(): assert len(tg._tasks) == 0 +@pytest.mark.skipif( + sys.version_info < (3, 7, 0), + reason='Requires Python 3.7 or higher', + # In Python 3.6, this test hangs indefinitely. + # We don't fix this -- 3.6 is EoL as of December 2021. +) +@pytest.mark.asyncio +async def test_ptaskgroup_await_result(): + + done_count = 0 + + async def subtask(): + nonlocal done_count + await asyncio.sleep(0.1) + done_count += 1 + return "a" + + vclock = aiotools.VirtualClock() + with vclock.patch_loop(): + + results = [] + + async with aiotools.PersistentTaskGroup() as tg: + + ret = await tg.create_task(subtask()) + results.append(ret) + + ret = await asyncio.shield(tg.create_task(subtask())) + results.append(ret) + + a = tg.create_task(subtask()) + try: + ret = await a + results.append(ret) + finally: + del a + + a = asyncio.shield(tg.create_task(subtask())) + try: + ret = await a + results.append(ret) + finally: + del a + + assert results == ["a", "a", "a", "a"] + assert done_count == 4 + assert tg._unfinished_tasks == 0 + assert len(tg._tasks) == 0 + + +@pytest.mark.skipif( + sys.version_info < (3, 7, 0), + reason='Requires Python 3.7 or higher', + # In Python 3.6, this test hangs indefinitely. + # We don't fix this -- 3.6 is EoL as of December 2021. +) +@pytest.mark.asyncio +async def test_ptaskgroup_await_exception(): + + done_count = 0 + error_count = 0 + + async def subtask(): + nonlocal done_count + await asyncio.sleep(0.1) + 1 / 0 + done_count += 1 + + async def handler(exc_type, exc_obj, exc_tb): + nonlocal error_count + assert issubclass(exc_type, ZeroDivisionError) + error_count += 1 + + vclock = aiotools.VirtualClock() + with vclock.patch_loop(): + + async with aiotools.PersistentTaskGroup(exception_handler=handler) as tg: + + with pytest.raises(ZeroDivisionError): + await tg.create_task(subtask()) + + with pytest.raises(ZeroDivisionError): + await asyncio.shield(tg.create_task(subtask())) + + with pytest.raises(ZeroDivisionError): + a = tg.create_task(subtask()) + try: + await a + finally: + del a + + with pytest.raises(ZeroDivisionError): + # WARNING: This pattern leaks the reference to the task. + a = asyncio.shield(tg.create_task(subtask())) + try: + await a + finally: + del a + + assert done_count == 0 + assert error_count == 4 + assert tg._unfinished_tasks == 0 + assert len(tg._tasks) == 1 + + @pytest.mark.asyncio async def test_ptaskgroup_exc_handler_swallow(): @@ -234,7 +334,6 @@ async def handler(exc_type, exc_obj, exc_tb): async with aiotools.PersistentTaskGroup(exception_handler=handler) as tg: for _ in range(10): tg.create_task(subtask()) - assert len(tg._tasks) == 10 except ExceptionGroup as eg: # All non-base exceptions must be swallowed by # our exception handler. @@ -348,7 +447,6 @@ async def subtask(): async with aiotools.PersistentTaskGroup() as tg: for _ in range(10): tg.create_task(subtask()) - assert len(tg._tasks) == 10 # Shutdown just after starting child tasks. # Even in this case, awaits in the tasks' cancellation blocks # should be executed until their completion.