From 94503b5ef56a115779c2e4de9de3325ad2cd34cf Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 1 Nov 2021 18:05:56 +0000 Subject: [PATCH 1/3] ObservableDeferred: run observers in order --- changelog.d/11229.misc | 1 + synapse/util/async_helpers.py | 34 ++++----- tests/util/caches/test_deferred_cache.py | 4 +- ...t_async_utils.py => test_async_helpers.py} | 69 ++++++++++++++++++- 4 files changed, 88 insertions(+), 20 deletions(-) create mode 100644 changelog.d/11229.misc rename tests/util/{test_async_utils.py => test_async_helpers.py} (64%) diff --git a/changelog.d/11229.misc b/changelog.d/11229.misc new file mode 100644 index 000000000000..7bb01cf0796e --- /dev/null +++ b/changelog.d/11229.misc @@ -0,0 +1 @@ +`ObservableDeferred`: run registered observers in order. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 5df80ea8e7b4..c68b1dad3e9e 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -17,6 +17,7 @@ import inspect import itertools import logging +from collections import Collection from contextlib import contextmanager from typing import ( Any, @@ -26,7 +27,6 @@ Generic, Hashable, Iterable, - List, Optional, Set, TypeVar, @@ -76,12 +76,17 @@ class ObservableDeferred(Generic[_T]): def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False): object.__setattr__(self, "_deferred", deferred) object.__setattr__(self, "_result", None) - object.__setattr__(self, "_observers", set()) + object.__setattr__(self, "_observers", list()) def callback(r): object.__setattr__(self, "_result", (True, r)) - while self._observers: - observer = self._observers.pop() + + # once we have set _result, no more entries will be added to _observers, + # so it's safe to replace it with the empty tuple. + observers = self._observers + object.__setattr__(self, "_observers", tuple()) + + for observer in observers: try: observer.callback(r) except Exception as e: @@ -95,12 +100,16 @@ def callback(r): def errback(f): object.__setattr__(self, "_result", (False, f)) - while self._observers: + + # once we have set _result, no more entries will be added to _observers, + # so it's safe to replace it with the empty tuple. + observers = self._observers + object.__setattr__(self, "_observers", tuple()) + + for observer in observers: # This is a little bit of magic to correctly propagate stack # traces when we `await` on one of the observer deferreds. f.value.__failure__ = f - - observer = self._observers.pop() try: observer.errback(f) except Exception as e: @@ -127,20 +136,13 @@ def observe(self) -> "defer.Deferred[_T]": """ if not self._result: d: "defer.Deferred[_T]" = defer.Deferred() - - def remove(r): - self._observers.discard(d) - return r - - d.addBoth(remove) - - self._observers.add(d) + self._observers.append(d) return d else: success, res = self._result return defer.succeed(res) if success else defer.fail(res) - def observers(self) -> "List[defer.Deferred[_T]]": + def observers(self) -> "Collection[defer.Deferred[_T]]": return self._observers def has_called(self) -> bool: diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py index 54a88a83255b..c613ce3f1055 100644 --- a/tests/util/caches/test_deferred_cache.py +++ b/tests/util/caches/test_deferred_cache.py @@ -47,9 +47,7 @@ def check1(r): self.assertTrue(set_d.called) return r - # TODO: Actually ObservableDeferred *doesn't* run its tests in order on py3.8. - # maybe we should fix that? - # get_d.addCallback(check1) + get_d.addCallback(check1) # now fire off all the deferreds origin_d.callback(99) diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_helpers.py similarity index 64% rename from tests/util/test_async_utils.py rename to tests/util/test_async_helpers.py index 069f875962f5..ab89cab81256 100644 --- a/tests/util/test_async_utils.py +++ b/tests/util/test_async_helpers.py @@ -21,11 +21,78 @@ PreserveLoggingContext, current_context, ) -from synapse.util.async_helpers import timeout_deferred +from synapse.util.async_helpers import ObservableDeferred, timeout_deferred from tests.unittest import TestCase +class ObservableDeferredTest(TestCase): + def test_succeed(self): + origin_d = Deferred() + observable = ObservableDeferred(origin_d) + + observer1 = observable.observe() + observer2 = observable.observe() + + self.assertFalse(observer1.called) + self.assertFalse(observer2.called) + + # check the first observer is called first + def check_called_first(res): + self.assertFalse(observer2.called) + return res + + observer1.addBoth(check_called_first) + + # store the results + results = [None, None] + + def check_val(res, idx): + results[idx] = res + return res + + observer1.addCallback(check_val, 0) + observer2.addCallback(check_val, 1) + + origin_d.callback(123) + self.assertEqual(results[0], 123, "observer 1 callback result") + self.assertEqual(results[1], 123, "observer 2 callback result") + + def test_failure(self): + origin_d = Deferred() + observable = ObservableDeferred(origin_d, consumeErrors=True) + + observer1 = observable.observe() + observer2 = observable.observe() + + self.assertFalse(observer1.called) + self.assertFalse(observer2.called) + + # check the first observer is called first + def check_called_first(res): + self.assertFalse(observer2.called) + return res + + observer1.addBoth(check_called_first) + + # store the results + results = [None, None] + + def check_val(res, idx): + results[idx] = res + return None + + observer1.addErrback(check_val, 0) + observer2.addErrback(check_val, 1) + + try: + raise Exception("gah!") + except Exception as e: + origin_d.errback(e) + self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result") + self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result") + + class TimeoutDeferredTest(TestCase): def setUp(self): self.clock = Clock() From 20956da2b2ec54a024311c45a86ffe96895b6b3e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 1 Nov 2021 18:19:41 +0000 Subject: [PATCH 2/3] fix imports --- synapse/util/async_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index c68b1dad3e9e..fcfa92884931 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -17,12 +17,12 @@ import inspect import itertools import logging -from collections import Collection from contextlib import contextmanager from typing import ( Any, Awaitable, Callable, + Collection, Dict, Generic, Hashable, From a3330726402ff4c28d97f6bda1aa93d891d39577 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 1 Nov 2021 18:40:04 +0000 Subject: [PATCH 3/3] fix lint --- synapse/util/async_helpers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index fcfa92884931..96efc5f3e38e 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -76,7 +76,7 @@ class ObservableDeferred(Generic[_T]): def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False): object.__setattr__(self, "_deferred", deferred) object.__setattr__(self, "_result", None) - object.__setattr__(self, "_observers", list()) + object.__setattr__(self, "_observers", []) def callback(r): object.__setattr__(self, "_result", (True, r)) @@ -84,7 +84,7 @@ def callback(r): # once we have set _result, no more entries will be added to _observers, # so it's safe to replace it with the empty tuple. observers = self._observers - object.__setattr__(self, "_observers", tuple()) + object.__setattr__(self, "_observers", ()) for observer in observers: try: @@ -104,7 +104,7 @@ def errback(f): # once we have set _result, no more entries will be added to _observers, # so it's safe to replace it with the empty tuple. observers = self._observers - object.__setattr__(self, "_observers", tuple()) + object.__setattr__(self, "_observers", ()) for observer in observers: # This is a little bit of magic to correctly propagate stack