Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Propagate results and errors to task awaiters in ptaskgroup #44

Merged
merged 15 commits into from
Mar 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/44.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Propagate task results and exceptions via separate future instances if they are `await`-ed by the caller of `create_task()` in `PersistentTaskGroup`, in addition to invocation of task group exception handler. Note that `await`-ing those futures hangs indefinitely in Python 3.6 but we don't fix it since Python 3.6 is EoL as of December 2021.
17 changes: 15 additions & 2 deletions docs/aiotools.taskgroup.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,21 @@ Task Group
.. method:: create_task(coro, *, name=None)

Spawns a new task inside the taskgroup and returns the reference to
the task. Setting the name of tasks is supported in Python 3.8 or
later only and ignored in older versions.
a :class:`future <asyncio.Future>` describing the task result.
Setting the name of tasks is supported in Python 3.8 or later only
and ignored in older versions.

You may ``await`` the retuned future to take the task's return value
or get notified with the exception from it, while the exception
handler is still invoked. Since it is just a *secondary* future,
you cannot cancel the task explicitly using it. To cancel the
task(s), use :meth:`shutdown()` or exit the task group context.

.. warning::

In Python 3.6, ``await``-ing the returned future hangs
indefinitely. We do not fix this issue because Python 3.6 is now
EoL (end-of-life) as of December 2021.

.. method:: get_name()

Expand Down
35 changes: 28 additions & 7 deletions src/aiotools/taskgroup/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from types import TracebackType
from typing import (
Any,
Awaitable,
Callable,
Coroutine,
List,
Expand Down Expand Up @@ -85,7 +86,7 @@ def create_task(
coro: Coroutine[Any, Any, Any],
*,
name: str = None,
) -> "asyncio.Task":
) -> Awaitable[Any]:
if not self._entered:
# When used as object attribute, auto-enter.
self._entered = True
Expand All @@ -99,14 +100,18 @@ def _create_task_with_name(
*,
name: str = None,
cb: Callable[[asyncio.Task], Any],
) -> "asyncio.Task":
) -> Awaitable[Any]:
loop = compat.get_running_loop()
child_task = loop.create_task(self._task_wrapper(coro), name=name)
result_future = loop.create_future()
child_task = loop.create_task(
self._task_wrapper(coro, weakref.ref(result_future)),
name=name,
)
_log.debug("%r is spawned in %r.", child_task, self)
self._unfinished_tasks += 1
child_task.add_done_callback(cb)
self._tasks.add(child_task)
return child_task
return result_future

def _is_base_error(self, exc: BaseException) -> bool:
assert isinstance(exc, BaseException)
Expand Down Expand Up @@ -141,19 +146,33 @@ async def shutdown(self) -> None:
self._trigger_shutdown()
await self._wait_completion()

async def _task_wrapper(self, coro: Coroutine) -> Any:
async def _task_wrapper(
self,
coro: Coroutine,
result_future: weakref.ref[asyncio.Future],
) -> Any:
loop = compat.get_running_loop()
task = compat.current_task()
fut = result_future()
try:
return await coro
except Exception:
ret = await coro
if fut is not None:
fut.set_result(ret)
return ret
except asyncio.CancelledError:
if fut is not None:
fut.cancel()
raise
except Exception as e:
# Swallow unhandled exceptions by our own and
# prevent abortion of the task group bu them.
# Wrapping corotuines directly has advantage for
# exception handlers to access full traceback
# and there is no need to implement separate
# mechanism to wait for exception handler tasks.
try:
if fut is not None:
fut.set_exception(e)
await self._exc_handler(*sys.exc_info())
except Exception as exc:
# If there are exceptions inside the exception handler
Expand All @@ -166,6 +185,8 @@ async def _task_wrapper(self, coro: Coroutine) -> Any:
'exception': exc,
'task': task,
})
finally:
del fut

def _on_task_done(self, task: asyncio.Task) -> None:
self._unfinished_tasks -= 1
Expand Down
36 changes: 29 additions & 7 deletions src/aiotools/taskgroup/persistent_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from types import TracebackType
from typing import (
Any,
Awaitable,
Callable,
Coroutine,
List,
Expand Down Expand Up @@ -90,7 +91,7 @@ def create_task(
coro: Coroutine[Any, Any, Any],
*,
name: str = None,
) -> "asyncio.Task":
) -> Awaitable[Any]:
if not self._entered:
# When used as object attribute, auto-enter.
self._entered = True
Expand All @@ -104,13 +105,18 @@ def _create_task_with_name(
*,
name: str = None,
cb: Callable[[asyncio.Task], Any],
) -> "asyncio.Task":
child_task = create_task_with_name(self._task_wrapper(coro), name=name)
) -> Awaitable[Any]:
loop = compat.get_running_loop()
result_future = loop.create_future()
child_task = create_task_with_name(
self._task_wrapper(coro, weakref.ref(result_future)),
name=name,
)
_log.debug("%r is spawned in %r.", child_task, self)
self._unfinished_tasks += 1
child_task.add_done_callback(cb)
self._tasks.add(child_task)
return child_task
return result_future

def _is_base_error(self, exc: BaseException) -> bool:
assert isinstance(exc, BaseException)
Expand Down Expand Up @@ -145,19 +151,33 @@ async def shutdown(self) -> None:
self._trigger_shutdown()
await self._wait_completion()

async def _task_wrapper(self, coro: Coroutine) -> Any:
async def _task_wrapper(
self,
coro: Coroutine,
result_future: "weakref.ref[asyncio.Future]",
) -> Any:
loop = compat.get_running_loop()
task = compat.current_task()
fut = result_future()
try:
return await coro
except Exception:
ret = await coro
if fut is not None:
fut.set_result(ret)
return ret
except asyncio.CancelledError:
if fut is not None:
fut.cancel()
raise
except Exception as e:
# Swallow unhandled exceptions by our own and
# prevent abortion of the task group bu them.
# Wrapping corotuines directly has advantage for
# exception handlers to access full traceback
# and there is no need to implement separate
# mechanism to wait for exception handler tasks.
try:
if fut is not None:
fut.set_exception(e)
await self._exc_handler(*sys.exc_info())
except Exception as exc:
# If there are exceptions inside the exception handler
Expand All @@ -170,6 +190,8 @@ async def _task_wrapper(self, coro: Coroutine) -> Any:
'exception': exc,
'task': task,
})
finally:
del fut

def _on_task_done(self, task: asyncio.Task) -> None:
self._unfinished_tasks -= 1
Expand Down
112 changes: 105 additions & 7 deletions tests/test_ptaskgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ async def subtask():
async with aiotools.PersistentTaskGroup() as tg:
for idx in range(10):
tg.create_task(subtask())
assert len(tg._tasks) == 10
assert tg._unfinished_tasks == 10
# wait until all is done
await asyncio.sleep(0.2)
Expand Down Expand Up @@ -84,7 +83,6 @@ async def aclose(self):
obj = LongLivedObject()
for idx in range(10):
await obj.work()
assert len(obj.tg._tasks) == 10
assert obj.tg._unfinished_tasks == 10

# shutdown after all done
Expand All @@ -99,7 +97,6 @@ async def aclose(self):
obj = LongLivedObject()
for idx in range(10):
await obj.work()
assert len(obj.tg._tasks) == 10
assert obj.tg._unfinished_tasks == 10

# shutdown immediately
Expand Down Expand Up @@ -175,7 +172,6 @@ async def subtask():
for _ in range(10):
tg.create_task(subtask())
await asyncio.sleep(0)
assert len(tg._tasks) == 10

# shutdown after exit (all done) is no-op.
assert done_count == 10
Expand All @@ -202,14 +198,118 @@ async def subtask():
async with aiotools.PersistentTaskGroup() as tg:
for _ in range(10):
tg.create_task(subtask())
assert len(tg._tasks) == 10
# let's abort immediately.
await tg.shutdown()

assert done_count == 0
assert len(tg._tasks) == 0


@pytest.mark.skipif(
sys.version_info < (3, 7, 0),
reason='Requires Python 3.7 or higher',
# In Python 3.6, this test hangs indefinitely.
# We don't fix this -- 3.6 is EoL as of December 2021.
)
@pytest.mark.asyncio
async def test_ptaskgroup_await_result():

done_count = 0

async def subtask():
nonlocal done_count
await asyncio.sleep(0.1)
done_count += 1
return "a"

vclock = aiotools.VirtualClock()
with vclock.patch_loop():

results = []

async with aiotools.PersistentTaskGroup() as tg:

ret = await tg.create_task(subtask())
results.append(ret)

ret = await asyncio.shield(tg.create_task(subtask()))
results.append(ret)

a = tg.create_task(subtask())
try:
ret = await a
results.append(ret)
finally:
del a

a = asyncio.shield(tg.create_task(subtask()))
try:
ret = await a
results.append(ret)
finally:
del a

assert results == ["a", "a", "a", "a"]
assert done_count == 4
assert tg._unfinished_tasks == 0
assert len(tg._tasks) == 0


@pytest.mark.skipif(
sys.version_info < (3, 7, 0),
reason='Requires Python 3.7 or higher',
# In Python 3.6, this test hangs indefinitely.
# We don't fix this -- 3.6 is EoL as of December 2021.
)
@pytest.mark.asyncio
async def test_ptaskgroup_await_exception():

done_count = 0
error_count = 0

async def subtask():
nonlocal done_count
await asyncio.sleep(0.1)
1 / 0
done_count += 1

async def handler(exc_type, exc_obj, exc_tb):
nonlocal error_count
assert issubclass(exc_type, ZeroDivisionError)
error_count += 1

vclock = aiotools.VirtualClock()
with vclock.patch_loop():

async with aiotools.PersistentTaskGroup(exception_handler=handler) as tg:

with pytest.raises(ZeroDivisionError):
await tg.create_task(subtask())

with pytest.raises(ZeroDivisionError):
await asyncio.shield(tg.create_task(subtask()))

with pytest.raises(ZeroDivisionError):
a = tg.create_task(subtask())
try:
await a
finally:
del a

with pytest.raises(ZeroDivisionError):
# WARNING: This pattern leaks the reference to the task.
a = asyncio.shield(tg.create_task(subtask()))
try:
await a
finally:
del a

assert done_count == 0
assert error_count == 4
assert tg._unfinished_tasks == 0
assert len(tg._tasks) == 1


@pytest.mark.asyncio
async def test_ptaskgroup_exc_handler_swallow():

Expand All @@ -234,7 +334,6 @@ async def handler(exc_type, exc_obj, exc_tb):
async with aiotools.PersistentTaskGroup(exception_handler=handler) as tg:
for _ in range(10):
tg.create_task(subtask())
assert len(tg._tasks) == 10
except ExceptionGroup as eg:
# All non-base exceptions must be swallowed by
# our exception handler.
Expand Down Expand Up @@ -348,7 +447,6 @@ async def subtask():
async with aiotools.PersistentTaskGroup() as tg:
for _ in range(10):
tg.create_task(subtask())
assert len(tg._tasks) == 10
# Shutdown just after starting child tasks.
# Even in this case, awaits in the tasks' cancellation blocks
# should be executed until their completion.
Expand Down