Skip to content

Commit

Permalink
fix: Missing TaskGroup naming support in Python 3.11 (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
achimnol authored Mar 6, 2022
1 parent eb74cf3 commit 4d4d0c5
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 6 deletions.
1 change: 1 addition & 0 deletions changes/39.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix missing naming support of `TaskGroup` in Python 3.11
14 changes: 14 additions & 0 deletions src/aiotools/taskgroup/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from contextvars import ContextVar
import itertools

from .types import TaskGroupError

Expand All @@ -13,6 +14,16 @@

class TaskGroup(asyncio.TaskGroup):

def __init__(self, *, name=None):
super().__init__()
if name is None:
self._name = f"tg-{_name_counter()}"
else:
self._name = str(name)

def get_name(self):
return self._name

async def __aenter__(self):
self._current_taskgroup_token = current_taskgroup.set(self)
return await super().__aenter__()
Expand All @@ -29,3 +40,6 @@ async def __aexit__(self, et, exc, tb):
raise TaskGroupError(eg.message, eg.exceptions) from None
finally:
current_taskgroup.reset(self._current_taskgroup_token)


_name_counter = itertools.count(1).__next__
2 changes: 1 addition & 1 deletion src/aiotools/taskgroup/base_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class TaskGroup:

def __init__(self, *, name=None):
if name is None:
self._name = f'tg-{_name_counter()}'
self._name = f"tg-{_name_counter()}"
else:
self._name = str(name)

Expand Down
7 changes: 2 additions & 5 deletions tests/test_ptaskgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@
# being logged explicitly by pytest.


@pytest.mark.skipif(
sys.version_info < (3, 8, 0),
reason='Requires Python 3.8 or higher',
)
@pytest.mark.asyncio
async def test_ptaskgroup_naming():

Expand All @@ -26,7 +22,8 @@ async def subtask():
async with aiotools.PersistentTaskGroup(name="XYZ") as tg:
t = tg.create_task(subtask(), name="ABC")
assert tg.get_name() == "XYZ"
assert t.get_name() == "ABC"
if hasattr(t, 'get_name'):
assert t.get_name() == "ABC"


@pytest.mark.asyncio
Expand Down
13 changes: 13 additions & 0 deletions tests/test_taskgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@
)


@pytest.mark.asyncio
async def test_taskgroup_naming():

async def subtask():
pass

async with TaskGroup(name="XYZ") as tg:
t = tg.create_task(subtask(), name="ABC")
assert tg.get_name() == "XYZ"
if hasattr(t, 'get_name'):
assert t.get_name() == "ABC"


@pytest.mark.asyncio
async def test_delayed_subtasks():
with VirtualClock().patch_loop():
Expand Down

0 comments on commit 4d4d0c5

Please sign in to comment.