From 9f941060e57d30d17e29f19cb6dab0059dd424fc Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 22 Aug 2022 16:22:26 +0100 Subject: [PATCH 1/9] Refactor DeferredCache This is mostly around making `CacheEntry` a generic class and moving some of the callbacks out into methods so that we can reuse them. --- synapse/util/caches/deferred_cache.py | 195 ++++++++++++++------------ synapse/util/caches/treecache.py | 3 + 2 files changed, 109 insertions(+), 89 deletions(-) diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 1d6ec22191a0..cdfb4797c76f 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -14,14 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import enum import threading from typing import ( + Any, Callable, + Collection, + Dict, Generic, Iterable, MutableMapping, Optional, + Set, Sized, TypeVar, Union, @@ -31,7 +36,6 @@ from prometheus_client import Gauge from twisted.internet import defer -from twisted.python import failure from twisted.python.failure import Failure from synapse.util.async_helpers import ObservableDeferred @@ -159,15 +163,16 @@ def get( Raises: KeyError if the key is not found in the cache """ - callbacks = [callback] if callback else [] val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) if val is not _Sentinel.sentinel: - val.callbacks.update(callbacks) + val.add_callback(key, callback) if update_metrics: m = self.cache.metrics assert m # we always have a name, so should always have metrics m.inc_hits() - return val.deferred.observe() + return val.deferred(key) + + callbacks = (callback,) if callback else () val2 = self.cache.get( key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics @@ -218,84 +223,70 @@ def set( value: a deferred which will complete with a result to add to the cache callback: An optional callback to be called when the entry is invalidated """ - if not isinstance(value, defer.Deferred): - raise TypeError("not a Deferred") - - callbacks = [callback] if callback else [] self.check_thread() - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry: - existing_entry.invalidate() + self._pending_deferred_cache.pop(key, None) # XXX: why don't we invalidate the entry in `self.cache` yet? - # we can save a whole load of effort if the deferred is ready. - if value.called: - result = value.result - if not isinstance(result, failure.Failure): - self.cache.set(key, cast(VT, result), callbacks) - return value - # otherwise, we'll add an entry to the _pending_deferred_cache for now, # and add callbacks to add it to the cache properly later. + entry = CacheEntrySingle[KT, VT](value) + entry.add_callback(key, callback) + self._pending_deferred_cache[key] = entry + deferred = entry.deferred(key).addCallbacks( + self._set_completed_callback, + self._error_callback, + callbackArgs=(entry, key), + errbackArgs=(entry, key), + ) - observable = ObservableDeferred(value, consumeErrors=True) - observer = observable.observe() - entry = CacheEntry(deferred=observable, callbacks=callbacks) + # we return a new Deferred which will be called before any subsequent observers. + return deferred + + def _set_completed_callback( + self, value: VT, entry: "CacheEntry[KT, VT]", key: KT + ) -> VT: + """Called when a deferred is completed.""" + # We check if the current entry matches the entry associated with the + # deferred. If they don't match then it got invalidated. + current_entry = self._pending_deferred_cache.pop(key, None) + if current_entry is not entry: + if current_entry: + self._pending_deferred_cache[key] = current_entry + return value - self._pending_deferred_cache[key] = entry + self.cache.set(key, value, entry.get_callbacks(key)) - def compare_and_pop() -> bool: - """Check if our entry is still the one in _pending_deferred_cache, and - if so, pop it. - - Returns true if the entries matched. - """ - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry is entry: - return True - - # oops, the _pending_deferred_cache has been updated since - # we started our query, so we are out of date. - # - # Better put back whatever we took out. (We do it this way - # round, rather than peeking into the _pending_deferred_cache - # and then removing on a match, to make the common case faster) - if existing_entry is not None: - self._pending_deferred_cache[key] = existing_entry - - return False - - def cb(result: VT) -> None: - if compare_and_pop(): - self.cache.set(key, result, entry.callbacks) - else: - # we're not going to put this entry into the cache, so need - # to make sure that the invalidation callbacks are called. - # That was probably done when _pending_deferred_cache was - # updated, but it's possible that `set` was called without - # `invalidate` being previously called, in which case it may - # not have been. Either way, let's double-check now. - entry.invalidate() - - def eb(_fail: Failure) -> None: - compare_and_pop() - entry.invalidate() - - # once the deferred completes, we can move the entry from the - # _pending_deferred_cache to the real cache. - # - observer.addCallbacks(cb, eb) + return value - # we return a new Deferred which will be called before any subsequent observers. - return observable.observe() + def _error_callback( + self, + failure: Failure, + entry: "CacheEntry[KT, VT]", + key: KT, + ) -> Failure: + """Called when a deferred errors.""" + + # We check if the current entry matches the entry associated with the + # deferred. If they don't match then it got invalidated. + current_entry = self._pending_deferred_cache.pop(key, None) + if current_entry is not entry: + if current_entry: + self._pending_deferred_cache[key] = current_entry + return failure + + for cb in entry.get_callbacks(key): + cb() + + return failure def prefill( self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None ) -> None: - callbacks = [callback] if callback else [] + callbacks = (callback,) if callback else () self.cache.set(key, value, callbacks=callbacks) + self._pending_deferred_cache.pop(key, None) def invalidate(self, key: KT) -> None: """Delete a key, or tree of entries @@ -311,41 +302,67 @@ def invalidate(self, key: KT) -> None: self.cache.del_multi(key) # if we have a pending lookup for this key, remove it from the - # _pending_deferred_cache, which will (a) stop it being returned - # for future queries and (b) stop it being persisted as a proper entry + # _pending_deferred_cache, which will (a) stop it being returned for + # future queries and (b) stop it being persisted as a proper entry # in self.cache. entry = self._pending_deferred_cache.pop(key, None) - - # run the invalidation callbacks now, rather than waiting for the - # deferred to resolve. if entry: # _pending_deferred_cache.pop should either return a CacheEntry, or, in the # case of a TreeCache, a dict of keys to cache entries. Either way calling # iterate_tree_cache_entry on it will do the right thing. for entry in iterate_tree_cache_entry(entry): - entry.invalidate() + for cb in entry.get_callbacks(key): + cb() def invalidate_all(self) -> None: self.check_thread() self.cache.clear() - for entry in self._pending_deferred_cache.values(): - entry.invalidate() + for key, entry in self._pending_deferred_cache.items(): + for cb in entry.get_callbacks(key): + cb() + self._pending_deferred_cache.clear() -class CacheEntry: - __slots__ = ["deferred", "callbacks", "invalidated"] +class CacheEntry(Generic[KT, VT], metaclass=abc.ABCMeta): + """Abstract class for entries in `DeferredCache[KT, VT]`""" - def __init__( - self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]] - ): - self.deferred = deferred - self.callbacks = set(callbacks) - self.invalidated = False - - def invalidate(self) -> None: - if not self.invalidated: - self.invalidated = True - for callback in self.callbacks: - callback() - self.callbacks.clear() + @abc.abstractmethod + def deferred(self, key: KT) -> "defer.Deferred[VT]": + """Get a deferred that a caller can wait on to get the value at the + given key""" + ... + + @abc.abstractmethod + def add_callback(self, key: KT, callback: Optional[Callable[[], None]]) -> None: + """Add an invalidation callback""" + ... + + @abc.abstractmethod + def get_callbacks(self, key: KT) -> Collection[Callable[[], None]]: + """Get all invalidation callbacks""" + ... + + +class CacheEntrySingle(CacheEntry[KT, VT]): + """An implementation of `CacheEntry` wrapping a deferred that results in a + single cache entry. + """ + + __slots__ = ["_deferred", "_callbacks"] + + def __init__(self, deferred: "defer.Deferred[VT]") -> None: + self._deferred = ObservableDeferred(deferred, consumeErrors=True) + self._callbacks: Set[Callable[[], None]] = set() + + def deferred(self, key: KT) -> "defer.Deferred[VT]": + return self._deferred.observe() + + def add_callback(self, key: KT, callback: Optional[Callable[[], None]]) -> None: + if callback is None: + return + + self._callbacks.add(callback) + + def get_callbacks(self, key: KT) -> Collection[Callable[[], None]]: + return self._callbacks diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index c1b8ec0c73eb..fec31da2b6d4 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -135,6 +135,9 @@ def pop(self, key, default=None): def values(self): return iterate_tree_cache_entry(self.root) + def items(self): + return iterate_tree_cache_items((), self.root) + def __len__(self) -> int: return self.size From 445613401b00af6ed3da4c762d3da0ad28d1ddab Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 22 Aug 2022 16:41:31 +0100 Subject: [PATCH 2/9] Bulk set API. This allows us to avoid create a deferred per-key up front when doing bulk fetch operations from the DB. --- synapse/util/caches/deferred_cache.py | 74 +++++++++++++++++++++++++++ synapse/util/caches/descriptors.py | 42 +++++---------- 2 files changed, 87 insertions(+), 29 deletions(-) diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index cdfb4797c76f..28aec20eb1cd 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -38,6 +38,7 @@ from twisted.internet import defer from twisted.python.failure import Failure +from synapse.logging.context import PreserveLoggingContext from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry @@ -244,6 +245,25 @@ def set( # we return a new Deferred which will be called before any subsequent observers. return deferred + def set_bulk( + self, + keys: Collection[KT], + callback: Optional[Callable[[], None]] = None, + ) -> "CacheMultipleEntries[KT, VT]": + """Bulk set API for use when fetching multiple keys at once from the DB. + + Called *before* starting the fetch from the DB, and the caller *must* + call either `complete_bulk(..)` or `error_bulk(..)` on the return value. + """ + + entry = CacheMultipleEntries[KT, VT]() + entry.add_global_callback(callback) + + for key in keys: + self._pending_deferred_cache[key] = entry + + return entry + def _set_completed_callback( self, value: VT, entry: "CacheEntry[KT, VT]", key: KT ) -> VT: @@ -366,3 +386,57 @@ def add_callback(self, key: KT, callback: Optional[Callable[[], None]]) -> None: def get_callbacks(self, key: KT) -> Collection[Callable[[], None]]: return self._callbacks + + +class CacheMultipleEntries(CacheEntry[KT, VT]): + """Cache entry that is used for bulk lookups and insertions.""" + + __slots__ = ["_deferred", "_callbacks", "_global_callbacks"] + + def __init__(self) -> None: + self._deferred: Optional[ObservableDeferred[Dict[KT, VT]]] = None + self._callbacks: Dict[KT, Set[Callable[[], None]]] = {} + self._global_callbacks: Set[Callable[[], None]] = set() + + def deferred(self, key: KT) -> "defer.Deferred[VT]": + if not self._deferred: + self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) + return self._deferred.observe().addCallback(lambda res: res.get(key)) + + def add_callback(self, key: KT, callback: Optional[Callable[[], None]]) -> None: + if callback is None: + return + + self._callbacks.setdefault(key, set()).add(callback) + + def get_callbacks(self, key: KT) -> Collection[Callable[[], None]]: + return self._callbacks.get(key, set()) | self._global_callbacks + + def add_global_callback(self, callback: Optional[Callable[[], None]]) -> None: + """Add a callback for when any keys get invalidated.""" + if callback is None: + return + + self._global_callbacks.add(callback) + + def complete_bulk( + self, + cache: DeferredCache[KT, VT], + result: Dict[KT, VT], + ) -> None: + """Called when there is a result""" + for key, value in result.items(): + cache._set_completed_callback(value, self, key) + + if self._deferred: + self._deferred.callback(result) + + def error_bulk( + self, cache: DeferredCache[KT, VT], keys: Collection[KT], failure: Failure + ) -> None: + """Called when bulk lookup failed.""" + for key in keys: + cache._error_callback(failure, self, key) + + if self._deferred: + self._deferred.errback(failure) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 867f315b2ace..653a7e15fe62 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -471,46 +471,30 @@ def arg_to_cache_key(arg: Hashable) -> Hashable: missing.add(arg) if missing: - # we need a deferred for each entry in the list, - # which we put in the cache. Each deferred resolves with the - # relevant result for that key. - deferreds_map = {} - for arg in missing: - deferred: "defer.Deferred[Any]" = defer.Deferred() - deferreds_map[arg] = deferred - key = arg_to_cache_key(arg) - cached_defers.append( - cache.set(key, deferred, callback=invalidate_callback) - ) + cache_keys = [arg_to_cache_key(key) for key in missing] + cache_entry = cache.set_bulk(cache_keys, callback=invalidate_callback) def complete_all(res: Dict[Hashable, Any]) -> None: - # the wrapped function has completed. It returns a dict. - # We can now update our own result map, and then resolve the - # observable deferreds in the cache. - for e, d1 in deferreds_map.items(): - val = res.get(e, None) - # make sure we update the results map before running the - # deferreds, because as soon as we run the last deferred, the - # gatherResults() below will complete and return the result - # dict to our caller. - results[e] = val - d1.callback(val) + missing_results = {} + for key in missing: + val = res.get(key, None) + + results[key] = val + missing_results[arg_to_cache_key(key)] = val + + cache_entry.complete_bulk(cache, missing_results) def errback_all(f: Failure) -> None: - # the wrapped function has failed. Propagate the failure into - # the cache, which will invalidate the entry, and cause the - # relevant cached_deferreds to fail, which will propagate the - # failure to our caller. - for d1 in deferreds_map.values(): - d1.errback(f) + cache_entry.error_bulk(cache, cache_keys, f) args_to_call = dict(arg_dict) args_to_call[self.list_name] = missing # dispatch the call, and attach the two handlers - defer.maybeDeferred( + missing_d = defer.maybeDeferred( preserve_fn(self.orig), **args_to_call ).addCallbacks(complete_all, errback_all) + cached_defers.append(missing_d) if cached_defers: d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks( From 1626c62180aca9105dea3e51b0c98852ed33488d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 22 Aug 2022 17:19:37 +0100 Subject: [PATCH 3/9] Add bulk lookup API to DeferredCache. --- synapse/util/caches/deferred_cache.py | 73 +++++++++++++++++++++++++-- synapse/util/caches/descriptors.py | 59 ++++++++++++---------- 2 files changed, 101 insertions(+), 31 deletions(-) diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 28aec20eb1cd..061bab841b68 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -18,16 +18,15 @@ import enum import threading from typing import ( - Any, Callable, Collection, Dict, Generic, - Iterable, MutableMapping, Optional, Set, Sized, + Tuple, TypeVar, Union, cast, @@ -38,7 +37,6 @@ from twisted.internet import defer from twisted.python.failure import Failure -from synapse.logging.context import PreserveLoggingContext from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry @@ -99,7 +97,7 @@ def __init__( # _pending_deferred_cache maps from the key value to a `CacheEntry` object. self._pending_deferred_cache: Union[ - TreeCache, "MutableMapping[KT, CacheEntry]" + TreeCache, "MutableMapping[KT, CacheEntry[KT, VT]]" ] = cache_type() def metrics_cb() -> None: @@ -183,6 +181,73 @@ def get( else: return defer.succeed(val2) + def get_bulk( + self, + keys: Collection[KT], + callback: Optional[Callable[[], None]] = None, + ) -> Tuple[Dict[KT, VT], Optional["defer.Deferred[Dict[KT, VT]]"], Collection[KT]]: + """Bulk lookup of items in the cache. + + Returns: + A 3-tuple of: + 1. a dict of key/value of items already cached; + 2. a deferred that resolves to a dict of key/value of items + we're already fetching; and + 3. a collection of keys that don't appear in the previous two. + """ + + # The cached results + cached = {} + + # List of pending deferreds + pending = [] + + # Dict that gets filled out when the pending deferreds complete + pending_results = {} + + # List of keys that aren't in either cache + missing = [] + + callbacks = (callback,) if callback else () + + for key in keys: + # Check if its in the main cache. + immediate_value = self.cache.get( + key, + _Sentinel.sentinel, + callbacks=callbacks, + ) + if immediate_value is not _Sentinel.sentinel: + cached[key] = immediate_value + continue + + # Check if its in the pending cache + pending_value = self._pending_deferred_cache.get(key, _Sentinel.sentinel) + if pending_value is not _Sentinel.sentinel: + pending_value.add_callback(key, callback) + + def completed_cb(value: VT, key: KT) -> VT: + pending_results[key] = value + return value + + # Add a callback to fill out `pending_results` when that completes + d = pending_value.deferred(key).addCallback(completed_cb, key) + pending.append(d) + continue + + # Not in either cache + missing.append(key) + + # If we've got pending deferreds, squash them into a single one that + # returns `pending_results`. + pending_deferred = None + if pending: + pending_deferred = defer.gatherResults( + pending, consumeErrors=True + ).addCallback(lambda _: pending_results) + + return (cached, pending_deferred, missing) + def get_immediate( self, key: KT, default: T, update_metrics: bool = True ) -> Union[VT, T]: diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 653a7e15fe62..010abf20a2ed 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -25,6 +25,7 @@ Generic, Hashable, Iterable, + List, Mapping, Optional, Sequence, @@ -435,16 +436,6 @@ def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]": keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] list_args = arg_dict[self.list_name] - results = {} - - def update_results_dict(res: Any, arg: Hashable) -> None: - results[arg] = res - - # list of deferreds to wait for - cached_defers = [] - - missing = set() - # If the cache takes a single arg then that is used as the key, # otherwise a tuple is used. if num_args == 1: @@ -452,6 +443,9 @@ def update_results_dict(res: Any, arg: Hashable) -> None: def arg_to_cache_key(arg: Hashable) -> Hashable: return arg + def cache_key_to_arg(key: tuple) -> Hashable: + return key + else: keylist = list(keyargs) @@ -459,36 +453,47 @@ def arg_to_cache_key(arg: Hashable) -> Hashable: keylist[self.list_pos] = arg return tuple(keylist) - for arg in list_args: - try: - res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback) - if not res.called: - res.addCallback(update_results_dict, arg) - cached_defers.append(res) - else: - results[arg] = res.result - except KeyError: - missing.add(arg) + def cache_key_to_arg(key: tuple) -> Hashable: + return key[self.list_pos] + + cache_keys = [arg_to_cache_key(arg) for arg in list_args] + immediate_results, pending_deferred, missing = cache.get_bulk( + cache_keys, callback=invalidate_callback + ) + + results = {cache_key_to_arg(key): v for key, v in immediate_results.items()} + + cached_defers: List["defer.Deferred[Any]"] = [] + if pending_deferred: + + def update_results(r: Dict) -> None: + for k, v in r.items(): + results[cache_key_to_arg(k)] = v + + pending_deferred.addCallback(update_results) + cached_defers.append(pending_deferred) if missing: - cache_keys = [arg_to_cache_key(key) for key in missing] - cache_entry = cache.set_bulk(cache_keys, callback=invalidate_callback) + cache_entry = cache.set_bulk(missing, invalidate_callback) def complete_all(res: Dict[Hashable, Any]) -> None: missing_results = {} for key in missing: - val = res.get(key, None) + arg = cache_key_to_arg(key) + val = res.get(arg, None) - results[key] = val - missing_results[arg_to_cache_key(key)] = val + results[arg] = val + missing_results[key] = val cache_entry.complete_bulk(cache, missing_results) def errback_all(f: Failure) -> None: - cache_entry.error_bulk(cache, cache_keys, f) + cache_entry.error_bulk(cache, missing, f) args_to_call = dict(arg_dict) - args_to_call[self.list_name] = missing + args_to_call[self.list_name] = { + cache_key_to_arg(key) for key in missing + } # dispatch the call, and attach the two handlers missing_d = defer.maybeDeferred( From ab035d23c75031806cbd2115c43941f049c3d439 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 23 Aug 2022 08:43:21 +0100 Subject: [PATCH 4/9] Newsfile --- changelog.d/13591.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/13591.misc diff --git a/changelog.d/13591.misc b/changelog.d/13591.misc new file mode 100644 index 000000000000..080e865e5575 --- /dev/null +++ b/changelog.d/13591.misc @@ -0,0 +1 @@ +Improve performance of `@cachedList`. From 7f3b52bc995a1c39e5df8365aa6b833df3ff8c9e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 23 Aug 2022 13:22:10 +0100 Subject: [PATCH 5/9] Rename get_callbacks --- synapse/util/caches/deferred_cache.py | 30 ++++++++++++++++----------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 061bab841b68..19f3b9ce5819 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -164,7 +164,7 @@ def get( """ val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) if val is not _Sentinel.sentinel: - val.add_callback(key, callback) + val.add_invalidation_callback(key, callback) if update_metrics: m = self.cache.metrics assert m # we always have a name, so should always have metrics @@ -224,7 +224,7 @@ def get_bulk( # Check if its in the pending cache pending_value = self._pending_deferred_cache.get(key, _Sentinel.sentinel) if pending_value is not _Sentinel.sentinel: - pending_value.add_callback(key, callback) + pending_value.add_invalidation_callback(key, callback) def completed_cb(value: VT, key: KT) -> VT: pending_results[key] = value @@ -298,7 +298,7 @@ def set( # otherwise, we'll add an entry to the _pending_deferred_cache for now, # and add callbacks to add it to the cache properly later. entry = CacheEntrySingle[KT, VT](value) - entry.add_callback(key, callback) + entry.add_invalidation_callback(key, callback) self._pending_deferred_cache[key] = entry deferred = entry.deferred(key).addCallbacks( self._set_completed_callback, @@ -341,7 +341,7 @@ def _set_completed_callback( self._pending_deferred_cache[key] = current_entry return value - self.cache.set(key, value, entry.get_callbacks(key)) + self.cache.set(key, value, entry.get_invalidation_callbacks(key)) return value @@ -361,7 +361,7 @@ def _error_callback( self._pending_deferred_cache[key] = current_entry return failure - for cb in entry.get_callbacks(key): + for cb in entry.get_invalidation_callbacks(key): cb() return failure @@ -403,7 +403,7 @@ def invalidate_all(self) -> None: self.check_thread() self.cache.clear() for key, entry in self._pending_deferred_cache.items(): - for cb in entry.get_callbacks(key): + for cb in entry.get_invalidation_callbacks(key): cb() self._pending_deferred_cache.clear() @@ -419,12 +419,14 @@ def deferred(self, key: KT) -> "defer.Deferred[VT]": ... @abc.abstractmethod - def add_callback(self, key: KT, callback: Optional[Callable[[], None]]) -> None: + def add_invalidation_callback( + self, key: KT, callback: Optional[Callable[[], None]] + ) -> None: """Add an invalidation callback""" ... @abc.abstractmethod - def get_callbacks(self, key: KT) -> Collection[Callable[[], None]]: + def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]: """Get all invalidation callbacks""" ... @@ -443,13 +445,15 @@ def __init__(self, deferred: "defer.Deferred[VT]") -> None: def deferred(self, key: KT) -> "defer.Deferred[VT]": return self._deferred.observe() - def add_callback(self, key: KT, callback: Optional[Callable[[], None]]) -> None: + def add_invalidation_callback( + self, key: KT, callback: Optional[Callable[[], None]] + ) -> None: if callback is None: return self._callbacks.add(callback) - def get_callbacks(self, key: KT) -> Collection[Callable[[], None]]: + def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]: return self._callbacks @@ -468,13 +472,15 @@ def deferred(self, key: KT) -> "defer.Deferred[VT]": self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) return self._deferred.observe().addCallback(lambda res: res.get(key)) - def add_callback(self, key: KT, callback: Optional[Callable[[], None]]) -> None: + def add_invalidation_callback( + self, key: KT, callback: Optional[Callable[[], None]] + ) -> None: if callback is None: return self._callbacks.setdefault(key, set()).add(callback) - def get_callbacks(self, key: KT) -> Collection[Callable[[], None]]: + def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]: return self._callbacks.get(key, set()) | self._global_callbacks def add_global_callback(self, callback: Optional[Callable[[], None]]) -> None: From 4222bfa3e8460af955edac1884ea280db10cac66 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 23 Aug 2022 13:23:04 +0100 Subject: [PATCH 6/9] Rename _set_completed_callback --- synapse/util/caches/deferred_cache.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 19f3b9ce5819..dbae25165a78 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -301,7 +301,7 @@ def set( entry.add_invalidation_callback(key, callback) self._pending_deferred_cache[key] = entry deferred = entry.deferred(key).addCallbacks( - self._set_completed_callback, + self._completed_callback, self._error_callback, callbackArgs=(entry, key), errbackArgs=(entry, key), @@ -329,7 +329,7 @@ def set_bulk( return entry - def _set_completed_callback( + def _completed_callback( self, value: VT, entry: "CacheEntry[KT, VT]", key: KT ) -> VT: """Called when a deferred is completed.""" @@ -497,7 +497,7 @@ def complete_bulk( ) -> None: """Called when there is a result""" for key, value in result.items(): - cache._set_completed_callback(value, self, key) + cache._completed_callback(value, self, key) if self._deferred: self._deferred.callback(result) From 45ab2f1d8aa4b59fb1e8f2d0eb08b17f04f8c8c1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 23 Aug 2022 13:23:54 +0100 Subject: [PATCH 7/9] fixup! Rename get_callbacks --- synapse/util/caches/deferred_cache.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index dbae25165a78..80dab0ea2d83 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -322,7 +322,7 @@ def set_bulk( """ entry = CacheMultipleEntries[KT, VT]() - entry.add_global_callback(callback) + entry.add_global_invalidation_callback(callback) for key in keys: self._pending_deferred_cache[key] = entry @@ -483,7 +483,9 @@ def add_invalidation_callback( def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]: return self._callbacks.get(key, set()) | self._global_callbacks - def add_global_callback(self, callback: Optional[Callable[[], None]]) -> None: + def add_global_invalidation_callback( + self, callback: Optional[Callable[[], None]] + ) -> None: """Add a callback for when any keys get invalidated.""" if callback is None: return From 486b0b1c3c85819620f499931b9265f3f47021d5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 23 Aug 2022 13:24:44 +0100 Subject: [PATCH 8/9] Rename set_bulk --- synapse/util/caches/deferred_cache.py | 2 +- synapse/util/caches/descriptors.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 80dab0ea2d83..0fcea16407dd 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -310,7 +310,7 @@ def set( # we return a new Deferred which will be called before any subsequent observers. return deferred - def set_bulk( + def start_bulk_input( self, keys: Collection[KT], callback: Optional[Callable[[], None]] = None, diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 010abf20a2ed..33ccc86de6c1 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -474,7 +474,7 @@ def update_results(r: Dict) -> None: cached_defers.append(pending_deferred) if missing: - cache_entry = cache.set_bulk(missing, invalidate_callback) + cache_entry = cache.start_bulk_input(missing, invalidate_callback) def complete_all(res: Dict[Hashable, Any]) -> None: missing_results = {} From 6913f4c37f6e2be9df42bbbaa568bdec82274eff Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 23 Aug 2022 14:37:12 +0100 Subject: [PATCH 9/9] Fix method missed during rename --- synapse/util/caches/deferred_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 0fcea16407dd..6425f851eaa4 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -396,7 +396,7 @@ def invalidate(self, key: KT) -> None: # case of a TreeCache, a dict of keys to cache entries. Either way calling # iterate_tree_cache_entry on it will do the right thing. for entry in iterate_tree_cache_entry(entry): - for cb in entry.get_callbacks(key): + for cb in entry.get_invalidation_callbacks(key): cb() def invalidate_all(self) -> None: