From 4d4d0c5f5eb8c9dd37d755659c4d324acb1d7792 Mon Sep 17 00:00:00 2001 From: Joongi Kim Date: Mon, 7 Mar 2022 00:44:29 +0900 Subject: [PATCH] fix: Missing TaskGroup naming support in Python 3.11 (#39) --- changes/39.fix.md | 1 + src/aiotools/taskgroup/base.py | 14 ++++++++++++++ src/aiotools/taskgroup/base_compat.py | 2 +- tests/test_ptaskgroup.py | 7 ++----- tests/test_taskgroup.py | 13 +++++++++++++ 5 files changed, 31 insertions(+), 6 deletions(-) create mode 100644 changes/39.fix.md diff --git a/changes/39.fix.md b/changes/39.fix.md new file mode 100644 index 0000000..4beb3b8 --- /dev/null +++ b/changes/39.fix.md @@ -0,0 +1 @@ +Fix missing naming support of `TaskGroup` in Python 3.11 diff --git a/src/aiotools/taskgroup/base.py b/src/aiotools/taskgroup/base.py index 0ccbc08..31c197c 100644 --- a/src/aiotools/taskgroup/base.py +++ b/src/aiotools/taskgroup/base.py @@ -1,5 +1,6 @@ import asyncio from contextvars import ContextVar +import itertools from .types import TaskGroupError @@ -13,6 +14,16 @@ class TaskGroup(asyncio.TaskGroup): + def __init__(self, *, name=None): + super().__init__() + if name is None: + self._name = f"tg-{_name_counter()}" + else: + self._name = str(name) + + def get_name(self): + return self._name + async def __aenter__(self): self._current_taskgroup_token = current_taskgroup.set(self) return await super().__aenter__() @@ -29,3 +40,6 @@ async def __aexit__(self, et, exc, tb): raise TaskGroupError(eg.message, eg.exceptions) from None finally: current_taskgroup.reset(self._current_taskgroup_token) + + +_name_counter = itertools.count(1).__next__ diff --git a/src/aiotools/taskgroup/base_compat.py b/src/aiotools/taskgroup/base_compat.py index d1fe707..84277b8 100644 --- a/src/aiotools/taskgroup/base_compat.py +++ b/src/aiotools/taskgroup/base_compat.py @@ -50,7 +50,7 @@ class TaskGroup: def __init__(self, *, name=None): if name is None: - self._name = f'tg-{_name_counter()}' + self._name = f"tg-{_name_counter()}" else: self._name = str(name) diff --git a/tests/test_ptaskgroup.py b/tests/test_ptaskgroup.py index a35f4b3..8d89eb5 100644 --- a/tests/test_ptaskgroup.py +++ b/tests/test_ptaskgroup.py @@ -13,10 +13,6 @@ # being logged explicitly by pytest. -@pytest.mark.skipif( - sys.version_info < (3, 8, 0), - reason='Requires Python 3.8 or higher', -) @pytest.mark.asyncio async def test_ptaskgroup_naming(): @@ -26,7 +22,8 @@ async def subtask(): async with aiotools.PersistentTaskGroup(name="XYZ") as tg: t = tg.create_task(subtask(), name="ABC") assert tg.get_name() == "XYZ" - assert t.get_name() == "ABC" + if hasattr(t, 'get_name'): + assert t.get_name() == "ABC" @pytest.mark.asyncio diff --git a/tests/test_taskgroup.py b/tests/test_taskgroup.py index 94a306f..2371f76 100644 --- a/tests/test_taskgroup.py +++ b/tests/test_taskgroup.py @@ -11,6 +11,19 @@ ) +@pytest.mark.asyncio +async def test_taskgroup_naming(): + + async def subtask(): + pass + + async with TaskGroup(name="XYZ") as tg: + t = tg.create_task(subtask(), name="ABC") + assert tg.get_name() == "XYZ" + if hasattr(t, 'get_name'): + assert t.get_name() == "ABC" + + @pytest.mark.asyncio async def test_delayed_subtasks(): with VirtualClock().patch_loop():