diff --git a/README.md b/README.md index d145284..e0cc87b 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,8 @@ async def my_handler(): # res_1 and res_2 may be instances of exceptions. ``` -The differences to `asyncio.gather` are: +The differences to `asyncio.gather()` are: +- If a child task fails other unfinished tasks will be cancelled, just like in a TaskGroup. - `quattro.gather()` only accepts coroutines and not futures and generators, just like a TaskGroup. - When `return_exceptions` is false (the default), an exception in a child task will cause an ExceptionGroup to bubble out of the top-level `gather()` call, just like in a TaskGroup. - Results are returned as a tuple, not a list. diff --git a/src/quattro/_gather.py b/src/quattro/_gather.py index 4431dd9..8b84d8b 100644 --- a/src/quattro/_gather.py +++ b/src/quattro/_gather.py @@ -194,6 +194,8 @@ async def gather( # type: ignore[misc] Notable differences are: + * If a child task fails other unfinished tasks will be cancelled, just like + in a TaskGroup. * `quattro.gather` only accepts coroutines and not futures and generators, just like a TaskGroup. * When `return_exceptions` is false (the default), an exception in a child task @@ -202,6 +204,8 @@ async def gather( # type: ignore[misc] * Results are returned as a tuple, not a list. (See https://docs.python.org/3/library/asyncio-task.html#asyncio.gather) + + .. versionadded:: 23.1.0 """ if not coros: return () diff --git a/tests/test_gather.py b/tests/test_gather.py index 2fec8c1..7ab0b24 100644 --- a/tests/test_gather.py +++ b/tests/test_gather.py @@ -1,34 +1,39 @@ -from asyncio import CancelledError, sleep +from asyncio import CancelledError, current_task, get_running_loop, sleep +from asyncio import gather as asyncio_gather -from pytest import raises +from pytest import mark, raises -from quattro import gather, move_on_after +from quattro import gather from quattro.taskgroup import ExceptionGroup -async def test_empty(): +@mark.parametrize("gather", [gather, asyncio_gather]) +async def test_empty(gather): """An empty gather works.""" - assert await gather() == () + # asyncio gather returns a list + assert tuple(await gather()) == () -async def test_simple_gather(): +@mark.parametrize("gather", [gather, asyncio_gather]) +async def test_simple_gather(gather): """Simple gather works.""" async def test() -> int: await sleep(0.01) return 1 - assert await gather(test(), test()) == (1, 1) + assert tuple(await gather(test(), test())) == (1, 1) -async def test_gather_with_error(): +@mark.parametrize("gather", [gather, asyncio_gather]) +async def test_gather_with_error(gather): """Gather works if there's an error.""" cancelled = 0 async def test() -> int: nonlocal cancelled try: - await sleep(0.01) + await sleep(0.1) except CancelledError: cancelled += 1 return 1 @@ -37,23 +42,32 @@ async def error() -> None: await sleep(0.005) raise ValueError() - with raises(ExceptionGroup) as exc_info: + with raises((ExceptionGroup, ValueError)) as exc_info: await gather(test(), test(), error()) - assert repr(exc_info.value.exceptions[0]) == "ValueError()" + if gather == asyncio_gather: + assert isinstance(exc_info.value, ValueError) + # default asyncio behavior + assert cancelled == 0 + else: + assert isinstance(exc_info.value, ExceptionGroup) + assert repr(exc_info.value.exceptions[0]) == "ValueError()" + assert cancelled == 2 -async def test_simple_gather_exceptions(): +@mark.parametrize("gather", [gather, asyncio_gather]) +async def test_simple_gather_exceptions(gather): """Simple gather works when collecting exceptions.""" async def test() -> int: await sleep(0.01) return 1 - assert await gather(test(), test(), return_exceptions=True) == (1, 1) + assert tuple(await gather(test(), test(), return_exceptions=True)) == (1, 1) -async def test_with_error_return_excs(): +@mark.parametrize("gather", [gather, asyncio_gather]) +async def test_with_error_return_excs(gather): """Gather works if there's an error and exceptions are returned.""" async def test() -> int: @@ -68,10 +82,11 @@ async def error() -> None: res = await gather(test(), test(), error(), return_exceptions=True) - assert res == (1, 1, err) + assert tuple(res) == (1, 1, err) -async def test_parent_cancelled(): +@mark.parametrize("gather", [gather, asyncio_gather]) +async def test_parent_cancelled(gather): """When the parent is cancelled, the children are also cancelled.""" cancelled = 0 @@ -86,14 +101,20 @@ async def test() -> int: res = None - with move_on_after(0.001): + # We cannot use `move_on` here since asyncio.gather doesn't + # work with it on some versions of 3.9 and 3.10. + current = current_task() + get_running_loop().call_later(0.001, lambda: current.cancel()) + + with raises(CancelledError): res = await gather(test(), test()) assert res is None assert cancelled == 2 -async def test_parent_cancelled_return_excs(): +@mark.parametrize("gather", [gather, asyncio_gather]) +async def test_parent_cancelled_return_excs(gather): """When the parent is cancelled, the children are also cancelled.""" cancelled = 0 @@ -108,7 +129,13 @@ async def test() -> int: res = None - with move_on_after(0.001): + # We cannot use + # `move_on` here since asyncio.gather doesn't + # work with it on some versions of 3.9 and 3.10. + current = current_task() + get_running_loop().call_later(0.001, lambda: current.cancel()) + + with raises(CancelledError): res = await gather(test(), test(), return_exceptions=True) assert res is None