13
13
# limitations under the License.
14
14
import traceback
15
15
16
+ from parameterized import parameterized_class
17
+
16
18
from twisted .internet import defer
17
19
from twisted .internet .defer import CancelledError , Deferred , ensureDeferred
18
20
from twisted .internet .task import Clock
23
25
LoggingContext ,
24
26
PreserveLoggingContext ,
25
27
current_context ,
28
+ make_deferred_yieldable ,
26
29
)
27
30
from synapse .util .async_helpers import (
28
31
ObservableDeferred ,
29
32
concurrently_execute ,
33
+ delay_cancellation ,
30
34
stop_cancellation ,
31
35
timeout_deferred ,
32
36
)
@@ -313,13 +317,27 @@ async def caller():
313
317
self .successResultOf (d2 )
314
318
315
319
316
- class StopCancellationTests (TestCase ):
317
- """Tests for the `stop_cancellation` function."""
320
+ @parameterized_class (
321
+ ("wrapper" ,),
322
+ [("stop_cancellation" ,), ("delay_cancellation" ,)],
323
+ )
324
+ class CancellationWrapperTests (TestCase ):
325
+ """Common tests for the `stop_cancellation` and `delay_cancellation` functions."""
326
+
327
+ wrapper : str
328
+
329
+ def wrap_deferred (self , deferred : "Deferred[str]" ) -> "Deferred[str]" :
330
+ if self .wrapper == "stop_cancellation" :
331
+ return stop_cancellation (deferred )
332
+ elif self .wrapper == "delay_cancellation" :
333
+ return delay_cancellation (deferred )
334
+ else :
335
+ raise ValueError (f"Unsupported wrapper type: { self .wrapper } " )
318
336
319
337
def test_succeed (self ):
320
338
"""Test that the new `Deferred` receives the result."""
321
339
deferred : "Deferred[str]" = Deferred ()
322
- wrapper_deferred = stop_cancellation (deferred )
340
+ wrapper_deferred = self . wrap_deferred (deferred )
323
341
324
342
# Success should propagate through.
325
343
deferred .callback ("success" )
@@ -329,14 +347,18 @@ def test_succeed(self):
329
347
def test_failure (self ):
330
348
"""Test that the new `Deferred` receives the `Failure`."""
331
349
deferred : "Deferred[str]" = Deferred ()
332
- wrapper_deferred = stop_cancellation (deferred )
350
+ wrapper_deferred = self . wrap_deferred (deferred )
333
351
334
352
# Failure should propagate through.
335
353
deferred .errback (ValueError ("abc" ))
336
354
self .assertTrue (wrapper_deferred .called )
337
355
self .failureResultOf (wrapper_deferred , ValueError )
338
356
self .assertIsNone (deferred .result , "`Failure` was not consumed" )
339
357
358
+
359
+ class StopCancellationTests (TestCase ):
360
+ """Tests for the `stop_cancellation` function."""
361
+
340
362
def test_cancellation (self ):
341
363
"""Test that cancellation of the new `Deferred` leaves the original running."""
342
364
deferred : "Deferred[str]" = Deferred ()
@@ -347,11 +369,101 @@ def test_cancellation(self):
347
369
self .assertTrue (wrapper_deferred .called )
348
370
self .failureResultOf (wrapper_deferred , CancelledError )
349
371
self .assertFalse (
350
- deferred .called , "Original `Deferred` was unexpectedly cancelled."
372
+ deferred .called , "Original `Deferred` was unexpectedly cancelled"
373
+ )
374
+
375
+ # Now make the original `Deferred` fail.
376
+ # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
377
+ # in logs.
378
+ deferred .errback (ValueError ("abc" ))
379
+ self .assertIsNone (deferred .result , "`Failure` was not consumed" )
380
+
381
+
382
+ class DelayCancellationTests (TestCase ):
383
+ """Tests for the `delay_cancellation` function."""
384
+
385
+ def test_cancellation (self ):
386
+ """Test that cancellation of the new `Deferred` waits for the original."""
387
+ deferred : "Deferred[str]" = Deferred ()
388
+ wrapper_deferred = delay_cancellation (deferred )
389
+
390
+ # Cancel the new `Deferred`.
391
+ wrapper_deferred .cancel ()
392
+ self .assertNoResult (wrapper_deferred )
393
+ self .assertFalse (
394
+ deferred .called , "Original `Deferred` was unexpectedly cancelled"
395
+ )
396
+
397
+ # Now make the original `Deferred` fail.
398
+ # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
399
+ # in logs.
400
+ deferred .errback (ValueError ("abc" ))
401
+ self .assertIsNone (deferred .result , "`Failure` was not consumed" )
402
+
403
+ # Now that the original `Deferred` has failed, we should get a `CancelledError`.
404
+ self .failureResultOf (wrapper_deferred , CancelledError )
405
+
406
+ def test_suppresses_second_cancellation (self ):
407
+ """Test that a second cancellation is suppressed.
408
+
409
+ Identical to `test_cancellation` except the new `Deferred` is cancelled twice.
410
+ """
411
+ deferred : "Deferred[str]" = Deferred ()
412
+ wrapper_deferred = delay_cancellation (deferred )
413
+
414
+ # Cancel the new `Deferred`, twice.
415
+ wrapper_deferred .cancel ()
416
+ wrapper_deferred .cancel ()
417
+ self .assertNoResult (wrapper_deferred )
418
+ self .assertFalse (
419
+ deferred .called , "Original `Deferred` was unexpectedly cancelled"
351
420
)
352
421
353
- # Now make the inner `Deferred` fail.
422
+ # Now make the original `Deferred` fail.
354
423
# The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
355
424
# in logs.
356
425
deferred .errback (ValueError ("abc" ))
357
426
self .assertIsNone (deferred .result , "`Failure` was not consumed" )
427
+
428
+ # Now that the original `Deferred` has failed, we should get a `CancelledError`.
429
+ self .failureResultOf (wrapper_deferred , CancelledError )
430
+
431
+ def test_propagates_cancelled_error (self ):
432
+ """Test that a `CancelledError` from the original `Deferred` gets propagated."""
433
+ deferred : "Deferred[str]" = Deferred ()
434
+ wrapper_deferred = delay_cancellation (deferred )
435
+
436
+ # Fail the original `Deferred` with a `CancelledError`.
437
+ cancelled_error = CancelledError ()
438
+ deferred .errback (cancelled_error )
439
+
440
+ # The new `Deferred` should fail with exactly the same `CancelledError`.
441
+ self .assertTrue (wrapper_deferred .called )
442
+ self .assertIs (cancelled_error , self .failureResultOf (wrapper_deferred ).value )
443
+
444
+ def test_preserves_logcontext (self ):
445
+ """Test that logging contexts are preserved."""
446
+ blocking_d : "Deferred[None]" = Deferred ()
447
+
448
+ async def inner ():
449
+ await make_deferred_yieldable (blocking_d )
450
+
451
+ async def outer ():
452
+ with LoggingContext ("c" ) as c :
453
+ try :
454
+ await delay_cancellation (defer .ensureDeferred (inner ()))
455
+ self .fail ("`CancelledError` was not raised" )
456
+ except CancelledError :
457
+ self .assertEqual (c , current_context ())
458
+ # Succeed with no error, unless the logging context is wrong.
459
+
460
+ # Run and block inside `inner()`.
461
+ d = defer .ensureDeferred (outer ())
462
+ self .assertEqual (SENTINEL_CONTEXT , current_context ())
463
+
464
+ d .cancel ()
465
+
466
+ # Now unblock. `outer()` will consume the `CancelledError` and check the
467
+ # logging context.
468
+ blocking_d .callback (None )
469
+ self .successResultOf (d )
0 commit comments