Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

ObservableDeferred: run observers in order #11229

Merged
merged 3 commits into from
Nov 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11229.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`ObservableDeferred`: run registered observers in order.
34 changes: 18 additions & 16 deletions synapse/util/async_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
Any,
Awaitable,
Callable,
Collection,
Dict,
Generic,
Hashable,
Iterable,
List,
Optional,
Set,
TypeVar,
Expand Down Expand Up @@ -76,12 +76,17 @@ class ObservableDeferred(Generic[_T]):
def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
object.__setattr__(self, "_deferred", deferred)
object.__setattr__(self, "_result", None)
object.__setattr__(self, "_observers", set())
object.__setattr__(self, "_observers", [])

def callback(r):
object.__setattr__(self, "_result", (True, r))
while self._observers:
observer = self._observers.pop()

# once we have set _result, no more entries will be added to _observers,
# so it's safe to replace it with the empty tuple.
observers = self._observers
object.__setattr__(self, "_observers", ())

for observer in observers:
try:
observer.callback(r)
except Exception as e:
Expand All @@ -95,12 +100,16 @@ def callback(r):

def errback(f):
object.__setattr__(self, "_result", (False, f))
while self._observers:

# once we have set _result, no more entries will be added to _observers,
# so it's safe to replace it with the empty tuple.
observers = self._observers
object.__setattr__(self, "_observers", ())

for observer in observers:
# This is a little bit of magic to correctly propagate stack
# traces when we `await` on one of the observer deferreds.
f.value.__failure__ = f

observer = self._observers.pop()
try:
observer.errback(f)
except Exception as e:
Expand All @@ -127,20 +136,13 @@ def observe(self) -> "defer.Deferred[_T]":
"""
if not self._result:
d: "defer.Deferred[_T]" = defer.Deferred()

def remove(r):
self._observers.discard(d)
return r

d.addBoth(remove)
Comment on lines -130 to -135
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code removed the observer from self._observers when it was run. I think it was redundant before, because the observer was pop()ed from self._observers anyway, but it's doubly-redundant now, since the whole of self._observers is thrown away.

(it was added in #190 - no real clues there as to why.)


self._observers.add(d)
self._observers.append(d)
return d
else:
success, res = self._result
return defer.succeed(res) if success else defer.fail(res)

def observers(self) -> "List[defer.Deferred[_T]]":
def observers(self) -> "Collection[defer.Deferred[_T]]":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit surprised that we're leaking internal state here!

Could only see one use though. Probably fine as it is.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup. Thought about changing it to a method that just returns the number of observers. Ran out of enthusiasm.

return self._observers

def has_called(self) -> bool:
Expand Down
4 changes: 1 addition & 3 deletions tests/util/caches/test_deferred_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ def check1(r):
self.assertTrue(set_d.called)
return r

# TODO: Actually ObservableDeferred *doesn't* run its tests in order on py3.8.
# maybe we should fix that?
# get_d.addCallback(check1)
Comment on lines -50 to -52
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the comment isn't terribly clear, but I think this is what we're fixing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_d.addCallback(check1)

# now fire off all the deferreds
origin_d.callback(99)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,78 @@
PreserveLoggingContext,
current_context,
)
from synapse.util.async_helpers import timeout_deferred
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred

from tests.unittest import TestCase


class ObservableDeferredTest(TestCase):
def test_succeed(self):
origin_d = Deferred()
observable = ObservableDeferred(origin_d)

observer1 = observable.observe()
observer2 = observable.observe()

self.assertFalse(observer1.called)
self.assertFalse(observer2.called)

# check the first observer is called first
def check_called_first(res):
self.assertFalse(observer2.called)
return res

observer1.addBoth(check_called_first)

# store the results
results = [None, None]

def check_val(res, idx):
results[idx] = res
return res

observer1.addCallback(check_val, 0)
observer2.addCallback(check_val, 1)

origin_d.callback(123)
self.assertEqual(results[0], 123, "observer 1 callback result")
self.assertEqual(results[1], 123, "observer 2 callback result")

def test_failure(self):
origin_d = Deferred()
observable = ObservableDeferred(origin_d, consumeErrors=True)

observer1 = observable.observe()
observer2 = observable.observe()

self.assertFalse(observer1.called)
self.assertFalse(observer2.called)

# check the first observer is called first
def check_called_first(res):
self.assertFalse(observer2.called)
return res

observer1.addBoth(check_called_first)

# store the results
results = [None, None]

def check_val(res, idx):
results[idx] = res
return None

observer1.addErrback(check_val, 0)
observer2.addErrback(check_val, 1)

try:
raise Exception("gah!")
except Exception as e:
origin_d.errback(e)
self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")


class TimeoutDeferredTest(TestCase):
def setUp(self):
self.clock = Clock()
Expand Down