Skip to content

Commit edbee56

Browse files
authored
Taskgroup tweaks (GH-31559)
Now uses .cancel()/.uncancel(), for even fewer broken edge cases.
1 parent 41ddcd3 commit edbee56

File tree

2 files changed

+42
-34
lines changed

2 files changed

+42
-34
lines changed

Lib/asyncio/taskgroups.py

+23-27
Original file line numberDiff line numberDiff line change
@@ -66,31 +66,28 @@ async def __aexit__(self, et, exc, tb):
6666
self._base_error is None):
6767
self._base_error = exc
6868

69-
if et is exceptions.CancelledError:
70-
if self._parent_cancel_requested:
71-
# Only if we did request task to cancel ourselves
72-
# we mark it as no longer cancelled.
73-
self._parent_task.uncancel()
74-
else:
75-
propagate_cancellation_error = et
76-
77-
if et is not None and not self._aborting:
78-
# Our parent task is being cancelled:
79-
#
80-
# async with TaskGroup() as g:
81-
# g.create_task(...)
82-
# await ... # <- CancelledError
83-
#
69+
if et is not None:
8470
if et is exceptions.CancelledError:
85-
propagate_cancellation_error = et
86-
87-
# or there's an exception in "async with":
88-
#
89-
# async with TaskGroup() as g:
90-
# g.create_task(...)
91-
# 1 / 0
92-
#
93-
self._abort()
71+
if self._parent_cancel_requested and not self._parent_task.uncancel():
72+
# Do nothing, i.e. swallow the error.
73+
pass
74+
else:
75+
propagate_cancellation_error = exc
76+
77+
if not self._aborting:
78+
# Our parent task is being cancelled:
79+
#
80+
# async with TaskGroup() as g:
81+
# g.create_task(...)
82+
# await ... # <- CancelledError
83+
#
84+
# or there's an exception in "async with":
85+
#
86+
# async with TaskGroup() as g:
87+
# g.create_task(...)
88+
# 1 / 0
89+
#
90+
self._abort()
9491

9592
# We use while-loop here because "self._on_completed_fut"
9693
# can be cancelled multiple times if our parent task
@@ -118,7 +115,6 @@ async def __aexit__(self, et, exc, tb):
118115
self._on_completed_fut = None
119116

120117
assert self._unfinished_tasks == 0
121-
self._on_completed_fut = None # no longer needed
122118

123119
if self._base_error is not None:
124120
raise self._base_error
@@ -199,8 +195,7 @@ def _on_task_done(self, task):
199195
})
200196
return
201197

202-
self._abort()
203-
if not self._parent_task.cancelling():
198+
if not self._aborting and not self._parent_cancel_requested:
204199
# If parent task *is not* being cancelled, it means that we want
205200
# to manually cancel it to abort whatever is being run right now
206201
# in the TaskGroup. But we want to mark parent task as
@@ -219,5 +214,6 @@ def _on_task_done(self, task):
219214
# pass
220215
# await something_else # this line has to be called
221216
# # after TaskGroup is finished.
217+
self._abort()
222218
self._parent_cancel_requested = True
223219
self._parent_task.cancel()

Lib/test/test_asyncio/test_taskgroups.py

+19-7
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,11 @@ async def runner():
120120
self.assertTrue(t2_cancel)
121121
self.assertTrue(t2.cancelled())
122122

123-
async def test_taskgroup_05(self):
123+
async def test_cancel_children_on_child_error(self):
124+
"""
125+
When a child task raises an error, the rest of the children
126+
are cancelled and the errors are gathered into an EG.
127+
"""
124128

125129
NUM = 0
126130
t2_cancel = False
@@ -165,7 +169,7 @@ async def runner():
165169
self.assertTrue(t2_cancel)
166170
self.assertTrue(runner_cancel)
167171

168-
async def test_taskgroup_06(self):
172+
async def test_cancellation(self):
169173

170174
NUM = 0
171175

@@ -186,10 +190,12 @@ async def runner():
186190
await asyncio.sleep(0.1)
187191

188192
self.assertFalse(r.done())
189-
r.cancel()
190-
with self.assertRaises(asyncio.CancelledError):
193+
r.cancel("test")
194+
with self.assertRaises(asyncio.CancelledError) as cm:
191195
await r
192196

197+
self.assertEqual(cm.exception.args, ('test',))
198+
193199
self.assertEqual(NUM, 5)
194200

195201
async def test_taskgroup_07(self):
@@ -226,7 +232,7 @@ async def runner():
226232

227233
self.assertEqual(NUM, 15)
228234

229-
async def test_taskgroup_08(self):
235+
async def test_cancellation_in_body(self):
230236

231237
async def foo():
232238
await asyncio.sleep(0.1)
@@ -246,10 +252,12 @@ async def runner():
246252
await asyncio.sleep(0.1)
247253

248254
self.assertFalse(r.done())
249-
r.cancel()
250-
with self.assertRaises(asyncio.CancelledError):
255+
r.cancel("test")
256+
with self.assertRaises(asyncio.CancelledError) as cm:
251257
await r
252258

259+
self.assertEqual(cm.exception.args, ('test',))
260+
253261
async def test_taskgroup_09(self):
254262

255263
t1 = t2 = None
@@ -699,3 +707,7 @@ async def coro():
699707
async with taskgroups.TaskGroup() as g:
700708
t = g.create_task(coro(), name="yolo")
701709
self.assertEqual(t.get_name(), "yolo")
710+
711+
712+
if __name__ == "__main__":
713+
unittest.main()

0 commit comments

Comments
 (0)