Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 90b2327

Browse files
authored
Add delay_cancellation utility function (#12180)
`delay_cancellation` behaves like `stop_cancellation`, except it delays `CancelledError`s until the original `Deferred` resolves. This is handy for unifying cleanup paths and ensuring that uncancelled coroutines don't use finished logcontexts. Signed-off-by: Sean Quah <seanq@element.io>
1 parent 54f674f commit 90b2327

File tree

3 files changed

+161
-12
lines changed

3 files changed

+161
-12
lines changed

changelog.d/12180.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add `delay_cancellation` utility function, which behaves like `stop_cancellation` but waits until the original `Deferred` resolves before raising a `CancelledError`.

synapse/util/async_helpers.py

+42-6
Original file line numberDiff line numberDiff line change
@@ -686,12 +686,48 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
686686
Synapse logcontext rules.
687687
688688
Returns:
689-
A new `Deferred`, which will contain the result of the original `Deferred`,
690-
but will not propagate cancellation through to the original. When cancelled,
691-
the new `Deferred` will fail with a `CancelledError` and will not follow the
692-
Synapse logcontext rules. `make_deferred_yieldable` should be used to wrap
693-
the new `Deferred`.
689+
A new `Deferred`, which will contain the result of the original `Deferred`.
690+
The new `Deferred` will not propagate cancellation through to the original.
691+
When cancelled, the new `Deferred` will fail with a `CancelledError`.
692+
693+
The new `Deferred` will not follow the Synapse logcontext rules and should be
694+
wrapped with `make_deferred_yieldable`.
695+
"""
696+
new_deferred: "defer.Deferred[T]" = defer.Deferred()
697+
deferred.chainDeferred(new_deferred)
698+
return new_deferred
699+
700+
701+
def delay_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
702+
"""Delay cancellation of a `Deferred` until it resolves.
703+
704+
Has the same effect as `stop_cancellation`, but the returned `Deferred` will not
705+
resolve with a `CancelledError` until the original `Deferred` resolves.
706+
707+
Args:
708+
deferred: The `Deferred` to protect against cancellation. May optionally follow
709+
the Synapse logcontext rules.
710+
711+
Returns:
712+
A new `Deferred`, which will contain the result of the original `Deferred`.
713+
The new `Deferred` will not propagate cancellation through to the original.
714+
When cancelled, the new `Deferred` will wait until the original `Deferred`
715+
resolves before failing with a `CancelledError`.
716+
717+
The new `Deferred` will follow the Synapse logcontext rules if `deferred`
718+
follows the Synapse logcontext rules. Otherwise the new `Deferred` should be
719+
wrapped with `make_deferred_yieldable`.
694720
"""
695-
new_deferred: defer.Deferred[T] = defer.Deferred()
721+
722+
def handle_cancel(new_deferred: "defer.Deferred[T]") -> None:
723+
# before the new deferred is cancelled, we `pause` it to stop the cancellation
724+
# propagating. we then `unpause` it once the wrapped deferred completes, to
725+
# propagate the exception.
726+
new_deferred.pause()
727+
new_deferred.errback(Failure(CancelledError()))
728+
729+
deferred.addBoth(lambda _: new_deferred.unpause())
730+
731+
new_deferred: "defer.Deferred[T]" = defer.Deferred(handle_cancel)
696732
deferred.chainDeferred(new_deferred)
697733
return new_deferred

tests/util/test_async_helpers.py

+118-6
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414
import traceback
1515

16+
from parameterized import parameterized_class
17+
1618
from twisted.internet import defer
1719
from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
1820
from twisted.internet.task import Clock
@@ -23,10 +25,12 @@
2325
LoggingContext,
2426
PreserveLoggingContext,
2527
current_context,
28+
make_deferred_yieldable,
2629
)
2730
from synapse.util.async_helpers import (
2831
ObservableDeferred,
2932
concurrently_execute,
33+
delay_cancellation,
3034
stop_cancellation,
3135
timeout_deferred,
3236
)
@@ -313,13 +317,27 @@ async def caller():
313317
self.successResultOf(d2)
314318

315319

316-
class StopCancellationTests(TestCase):
317-
"""Tests for the `stop_cancellation` function."""
320+
@parameterized_class(
321+
("wrapper",),
322+
[("stop_cancellation",), ("delay_cancellation",)],
323+
)
324+
class CancellationWrapperTests(TestCase):
325+
"""Common tests for the `stop_cancellation` and `delay_cancellation` functions."""
326+
327+
wrapper: str
328+
329+
def wrap_deferred(self, deferred: "Deferred[str]") -> "Deferred[str]":
330+
if self.wrapper == "stop_cancellation":
331+
return stop_cancellation(deferred)
332+
elif self.wrapper == "delay_cancellation":
333+
return delay_cancellation(deferred)
334+
else:
335+
raise ValueError(f"Unsupported wrapper type: {self.wrapper}")
318336

319337
def test_succeed(self):
320338
"""Test that the new `Deferred` receives the result."""
321339
deferred: "Deferred[str]" = Deferred()
322-
wrapper_deferred = stop_cancellation(deferred)
340+
wrapper_deferred = self.wrap_deferred(deferred)
323341

324342
# Success should propagate through.
325343
deferred.callback("success")
@@ -329,14 +347,18 @@ def test_succeed(self):
329347
def test_failure(self):
330348
"""Test that the new `Deferred` receives the `Failure`."""
331349
deferred: "Deferred[str]" = Deferred()
332-
wrapper_deferred = stop_cancellation(deferred)
350+
wrapper_deferred = self.wrap_deferred(deferred)
333351

334352
# Failure should propagate through.
335353
deferred.errback(ValueError("abc"))
336354
self.assertTrue(wrapper_deferred.called)
337355
self.failureResultOf(wrapper_deferred, ValueError)
338356
self.assertIsNone(deferred.result, "`Failure` was not consumed")
339357

358+
359+
class StopCancellationTests(TestCase):
360+
"""Tests for the `stop_cancellation` function."""
361+
340362
def test_cancellation(self):
341363
"""Test that cancellation of the new `Deferred` leaves the original running."""
342364
deferred: "Deferred[str]" = Deferred()
@@ -347,11 +369,101 @@ def test_cancellation(self):
347369
self.assertTrue(wrapper_deferred.called)
348370
self.failureResultOf(wrapper_deferred, CancelledError)
349371
self.assertFalse(
350-
deferred.called, "Original `Deferred` was unexpectedly cancelled."
372+
deferred.called, "Original `Deferred` was unexpectedly cancelled"
373+
)
374+
375+
# Now make the original `Deferred` fail.
376+
# The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
377+
# in logs.
378+
deferred.errback(ValueError("abc"))
379+
self.assertIsNone(deferred.result, "`Failure` was not consumed")
380+
381+
382+
class DelayCancellationTests(TestCase):
383+
"""Tests for the `delay_cancellation` function."""
384+
385+
def test_cancellation(self):
386+
"""Test that cancellation of the new `Deferred` waits for the original."""
387+
deferred: "Deferred[str]" = Deferred()
388+
wrapper_deferred = delay_cancellation(deferred)
389+
390+
# Cancel the new `Deferred`.
391+
wrapper_deferred.cancel()
392+
self.assertNoResult(wrapper_deferred)
393+
self.assertFalse(
394+
deferred.called, "Original `Deferred` was unexpectedly cancelled"
395+
)
396+
397+
# Now make the original `Deferred` fail.
398+
# The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
399+
# in logs.
400+
deferred.errback(ValueError("abc"))
401+
self.assertIsNone(deferred.result, "`Failure` was not consumed")
402+
403+
# Now that the original `Deferred` has failed, we should get a `CancelledError`.
404+
self.failureResultOf(wrapper_deferred, CancelledError)
405+
406+
def test_suppresses_second_cancellation(self):
407+
"""Test that a second cancellation is suppressed.
408+
409+
Identical to `test_cancellation` except the new `Deferred` is cancelled twice.
410+
"""
411+
deferred: "Deferred[str]" = Deferred()
412+
wrapper_deferred = delay_cancellation(deferred)
413+
414+
# Cancel the new `Deferred`, twice.
415+
wrapper_deferred.cancel()
416+
wrapper_deferred.cancel()
417+
self.assertNoResult(wrapper_deferred)
418+
self.assertFalse(
419+
deferred.called, "Original `Deferred` was unexpectedly cancelled"
351420
)
352421

353-
# Now make the inner `Deferred` fail.
422+
# Now make the original `Deferred` fail.
354423
# The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
355424
# in logs.
356425
deferred.errback(ValueError("abc"))
357426
self.assertIsNone(deferred.result, "`Failure` was not consumed")
427+
428+
# Now that the original `Deferred` has failed, we should get a `CancelledError`.
429+
self.failureResultOf(wrapper_deferred, CancelledError)
430+
431+
def test_propagates_cancelled_error(self):
432+
"""Test that a `CancelledError` from the original `Deferred` gets propagated."""
433+
deferred: "Deferred[str]" = Deferred()
434+
wrapper_deferred = delay_cancellation(deferred)
435+
436+
# Fail the original `Deferred` with a `CancelledError`.
437+
cancelled_error = CancelledError()
438+
deferred.errback(cancelled_error)
439+
440+
# The new `Deferred` should fail with exactly the same `CancelledError`.
441+
self.assertTrue(wrapper_deferred.called)
442+
self.assertIs(cancelled_error, self.failureResultOf(wrapper_deferred).value)
443+
444+
def test_preserves_logcontext(self):
445+
"""Test that logging contexts are preserved."""
446+
blocking_d: "Deferred[None]" = Deferred()
447+
448+
async def inner():
449+
await make_deferred_yieldable(blocking_d)
450+
451+
async def outer():
452+
with LoggingContext("c") as c:
453+
try:
454+
await delay_cancellation(defer.ensureDeferred(inner()))
455+
self.fail("`CancelledError` was not raised")
456+
except CancelledError:
457+
self.assertEqual(c, current_context())
458+
# Succeed with no error, unless the logging context is wrong.
459+
460+
# Run and block inside `inner()`.
461+
d = defer.ensureDeferred(outer())
462+
self.assertEqual(SENTINEL_CONTEXT, current_context())
463+
464+
d.cancel()
465+
466+
# Now unblock. `outer()` will consume the `CancelledError` and check the
467+
# logging context.
468+
blocking_d.callback(None)
469+
self.successResultOf(d)

0 commit comments

Comments
 (0)