From 7943ae191744ab2b9abb75a4069ab654327c2677 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 9 Jun 2021 18:07:07 +0100 Subject: [PATCH 1/6] Make `ReplicationEndpoint._check_auth_and_handle` async We don't really need to do this, but it's the only place ResponseCache.wrap is called without an await, so we might as well clean it up. --- synapse/replication/http/_base.py | 6 +++--- synapse/replication/http/membership.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 2a13026e9a16..f13a7c23b4a6 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -285,7 +285,7 @@ def register(self, http_server): self.__class__.__name__, ) - def _check_auth_and_handle(self, request, **kwargs): + async def _check_auth_and_handle(self, request, **kwargs): """Called on new incoming requests when caching is enabled. Checks if there is a cached response for the request and returns that, otherwise calls `_handle_request` and caches its response. @@ -300,8 +300,8 @@ def _check_auth_and_handle(self, request, **kwargs): if self.CACHE: txn_id = kwargs.pop("txn_id") - return self.response_cache.wrap( + return await self.response_cache.wrap( txn_id, self._handle_request, request, **kwargs ) - return self._handle_request(request, **kwargs) + return await self._handle_request(request, **kwargs) diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 289a397d6885..0990296b7044 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -206,7 +206,7 @@ async def _serialize_payload( # type: ignore return {} - def _handle_request( # type: ignore + async def _handle_request( # type: ignore self, request: Request, room_id: str, user_id: str, change: str ) -> Tuple[int, JsonDict]: logger.info("user membership change: %s in %s", user_id, room_id) From 9ea41d41d64fcb320d9e275d31afbeb85dd35bd6 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 9 Jun 2021 18:20:42 +0100 Subject: [PATCH 2/6] Make `ResponseCache.wrap` a regular async function ... and add afew type annotations too. --- synapse/util/caches/response_cache.py | 26 +++++++++++-------- ...esponsecache.py => test_response_cache.py} | 25 +++++++++++++----- 2 files changed, 34 insertions(+), 17 deletions(-) rename tests/util/caches/{test_responsecache.py => test_response_cache.py} (86%) diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index 25ea1bcc915e..29a523541f34 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Callable, Dict, Generic, Optional, TypeVar +from typing import Any, Awaitable, Callable, Dict, Generic, Optional, TypeVar from twisted.internet import defer @@ -23,10 +23,14 @@ logger = logging.getLogger(__name__) -T = TypeVar("T") +# the type of the key in the cache +KV = TypeVar("KV") +# the type of the result from the operation +RV = TypeVar("RV") -class ResponseCache(Generic[T]): + +class ResponseCache(Generic[KV]): """ This caches a deferred response. Until the deferred completes it will be returned from the cache. This means that if the client retries the request @@ -36,7 +40,7 @@ class ResponseCache(Generic[T]): def __init__(self, clock: Clock, name: str, timeout_ms: float = 0): # Requests that haven't finished yet. - self.pending_result_cache = {} # type: Dict[T, ObservableDeferred] + self.pending_result_cache = {} # type: Dict[KV, ObservableDeferred] self.clock = clock self.timeout_sec = timeout_ms / 1000.0 @@ -50,7 +54,7 @@ def size(self) -> int: def __len__(self) -> int: return self.size() - def get(self, key: T) -> Optional[defer.Deferred]: + def get(self, key: KV) -> Optional[defer.Deferred]: """Look up the given key. Can return either a new Deferred (which also doesn't follow the synapse @@ -76,7 +80,7 @@ def get(self, key: T) -> Optional[defer.Deferred]: self._metrics.inc_misses() return None - def set(self, key: T, deferred: defer.Deferred) -> defer.Deferred: + def set(self, key: KV, deferred: defer.Deferred) -> defer.Deferred: """Set the entry for the given key to the given deferred. *deferred* should run its callbacks in the sentinel logcontext (ie, @@ -109,9 +113,9 @@ def remove(r): result.addBoth(remove) return result.observe() - def wrap( - self, key: T, callback: Callable[..., Any], *args: Any, **kwargs: Any - ) -> defer.Deferred: + async def wrap( + self, key: KV, callback: Callable[..., Awaitable[RV]], *args: Any, **kwargs: Any + ) -> RV: """Wrap together a *get* and *set* call, taking care of logcontexts First looks up the key in the cache, and if it is present makes it @@ -143,7 +147,7 @@ async def handle_request(request): **kwargs: named parameters to pass to the callback, if it is used Returns: - Deferred which resolves to the result + The result of the callback (from the cache, or otherwise) """ result = self.get(key) if not result: @@ -158,4 +162,4 @@ async def handle_request(request): logger.info( "[%s]: using incomplete cached result for [%s]", self._name, key ) - return make_deferred_yieldable(result) + return await make_deferred_yieldable(result) diff --git a/tests/util/caches/test_responsecache.py b/tests/util/caches/test_response_cache.py similarity index 86% rename from tests/util/caches/test_responsecache.py rename to tests/util/caches/test_response_cache.py index f9a187b8defc..d2f3c2c7fa82 100644 --- a/tests/util/caches/test_responsecache.py +++ b/tests/util/caches/test_response_cache.py @@ -12,13 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + from synapse.util.caches.response_cache import ResponseCache from tests.server import get_clock from tests.unittest import TestCase -class DeferredCacheTestCase(TestCase): +class ResponseCacheTestCase(TestCase): """ A TestCase class for ResponseCache. @@ -48,7 +50,9 @@ def test_cache_hit(self): expected_result = "howdy" - wrap_d = cache.wrap(0, self.instant_return, expected_result) + wrap_d = defer.ensureDeferred( + cache.wrap(0, self.instant_return, expected_result) + ) self.assertEqual( expected_result, @@ -66,7 +70,9 @@ def test_cache_miss(self): expected_result = "howdy" - wrap_d = cache.wrap(0, self.instant_return, expected_result) + wrap_d = defer.ensureDeferred( + cache.wrap(0, self.instant_return, expected_result) + ) self.assertEqual( expected_result, @@ -80,7 +86,9 @@ def test_cache_expire(self): expected_result = "howdy" - wrap_d = cache.wrap(0, self.instant_return, expected_result) + wrap_d = defer.ensureDeferred( + cache.wrap(0, self.instant_return, expected_result) + ) self.assertEqual(expected_result, self.successResultOf(wrap_d)) self.assertEqual( @@ -99,7 +107,10 @@ def test_cache_wait_hit(self): expected_result = "howdy" - wrap_d = cache.wrap(0, self.delayed_return, expected_result) + wrap_d = defer.ensureDeferred( + cache.wrap(0, self.delayed_return, expected_result) + ) + self.assertNoResult(wrap_d) # function wakes up, returns result @@ -112,7 +123,9 @@ def test_cache_wait_expire(self): expected_result = "howdy" - wrap_d = cache.wrap(0, self.delayed_return, expected_result) + wrap_d = defer.ensureDeferred( + cache.wrap(0, self.delayed_return, expected_result) + ) self.assertNoResult(wrap_d) # stop at 1 second to callback cache eviction callLater at that time, then another to set time at 2 From f48542e6ec45b190bafeb5035cd1ad96fcc2d4ba Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 9 Jun 2021 19:24:46 +0100 Subject: [PATCH 3/6] update some comments and docstrings Just fix a few misleading things --- synapse/util/caches/response_cache.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index 29a523541f34..f820544d89ca 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -39,7 +39,9 @@ class ResponseCache(Generic[KV]): """ def __init__(self, clock: Clock, name: str, timeout_ms: float = 0): - # Requests that haven't finished yet. + # This is poorly-named: it includes both complete and incomplete results. + # We keep complete results rather than switching to absolute values because + # that makes it easier to cache Failure results. self.pending_result_cache = {} # type: Dict[KV, ObservableDeferred] self.clock = clock @@ -57,13 +59,10 @@ def __len__(self) -> int: def get(self, key: KV) -> Optional[defer.Deferred]: """Look up the given key. - Can return either a new Deferred (which also doesn't follow the synapse - logcontext rules), or, if the request has completed, the actual - result. You will probably want to make_deferred_yieldable the result. + Returns a new Deferred (which also doesn't follow the synapse + logcontext rules). You will probably want to make_deferred_yieldable the result. - If there is no entry for the key, returns None. It is worth noting that - this means there is no way to distinguish a completed result of None - from an absent cache entry. + If there is no entry for the key, returns None. Args: key: key to get/set in the cache @@ -87,9 +86,8 @@ def set(self, key: KV, deferred: defer.Deferred) -> defer.Deferred: you should wrap normal synapse deferreds with synapse.logging.context.run_in_background). - Can return either a new Deferred (which also doesn't follow the synapse - logcontext rules), or, if *deferred* was already complete, the actual - result. You will probably want to make_deferred_yieldable the result. + Returns a new Deferred (which also doesn't follow the synapse logcontext rules). + You will probably want to make_deferred_yieldable the result. Args: key: key to get/set in the cache @@ -101,7 +99,7 @@ def set(self, key: KV, deferred: defer.Deferred) -> defer.Deferred: result = ObservableDeferred(deferred, consumeErrors=True) self.pending_result_cache[key] = result - def remove(r): + def on_complete(r): if self.timeout_sec: self.clock.call_later( self.timeout_sec, self.pending_result_cache.pop, key, None @@ -110,7 +108,10 @@ def remove(r): self.pending_result_cache.pop(key, None) return r - result.addBoth(remove) + # make sure we do this *after* adding the entry to pending_result_cache, + # in case the result is already complete (in which case flipping the order would + # leave us with a stuck entry in the cache). + result.addBoth(on_complete) return result.observe() async def wrap( From 63e6c2042b30d26194aa6afd021d6efbf2eabf36 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 9 Jun 2021 22:52:21 +0100 Subject: [PATCH 4/6] Extend ResponseCache to pass a `context` object to the callback ... allowing the callback to specify whether or not the result should be cached. --- synapse/util/caches/response_cache.py | 52 +++++++++++++++++++++--- tests/util/caches/test_response_cache.py | 49 +++++++++++++++++++++- 2 files changed, 95 insertions(+), 6 deletions(-) diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index f820544d89ca..34c662c4dbd7 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -14,6 +14,8 @@ import logging from typing import Any, Awaitable, Callable, Dict, Generic, Optional, TypeVar +import attr + from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -30,6 +32,28 @@ RV = TypeVar("RV") +@attr.s(auto_attribs=True) +class ResponseCacheContext(Generic[KV]): + """Information about a missed ResponseCache hit + + This object can be passed into the callback for additional feedback + """ + + cache_key: KV + """The cache key that caused the cache miss + + This should be considered read-only. + + TODO: in attrs 20.1, make it frozen with an on_setattr. + """ + + should_cache: bool = True + """Whether the result should be cached once the request completes. + + This can be modified by the callback if it decides its result should not be cached. + """ + + class ResponseCache(Generic[KV]): """ This caches a deferred response. Until the deferred completes it will be @@ -79,7 +103,9 @@ def get(self, key: KV) -> Optional[defer.Deferred]: self._metrics.inc_misses() return None - def set(self, key: KV, deferred: defer.Deferred) -> defer.Deferred: + def _set( + self, context: ResponseCacheContext[KV], deferred: defer.Deferred + ) -> defer.Deferred: """Set the entry for the given key to the given deferred. *deferred* should run its callbacks in the sentinel logcontext (ie, @@ -90,21 +116,26 @@ def set(self, key: KV, deferred: defer.Deferred) -> defer.Deferred: You will probably want to make_deferred_yieldable the result. Args: - key: key to get/set in the cache + context: Information about the cache miss deferred: The deferred which resolves to the result. Returns: A new deferred which resolves to the actual result. """ result = ObservableDeferred(deferred, consumeErrors=True) + key = context.cache_key self.pending_result_cache[key] = result def on_complete(r): - if self.timeout_sec: + # if this cache has a non-zero timeout, and the callback has not cleared + # the should_cache bit, we leave it in the cache for now and schedule + # its removal later. + if self.timeout_sec and context.should_cache: self.clock.call_later( self.timeout_sec, self.pending_result_cache.pop, key, None ) else: + # otherwise, remove the result immediately. self.pending_result_cache.pop(key, None) return r @@ -115,7 +146,12 @@ def on_complete(r): return result.observe() async def wrap( - self, key: KV, callback: Callable[..., Awaitable[RV]], *args: Any, **kwargs: Any + self, + key: KV, + callback: Callable[..., Awaitable[RV]], + *args: Any, + cache_context: bool = False, + **kwargs: Any, ) -> RV: """Wrap together a *get* and *set* call, taking care of logcontexts @@ -145,6 +181,9 @@ async def handle_request(request): *args: positional parameters to pass to the callback, if it is used + cache_context: if set, the callback will be given a `cache_context` kw arg, + which will be a ResponseCacheContext object. + **kwargs: named parameters to pass to the callback, if it is used Returns: @@ -155,8 +194,11 @@ async def handle_request(request): logger.debug( "[%s]: no cached result for [%s], calculating new one", self._name, key ) + context = ResponseCacheContext(cache_key=key) + if cache_context: + kwargs["cache_context"] = context d = run_in_background(callback, *args, **kwargs) - result = self.set(key, d) + result = self._set(context, d) elif not isinstance(result, defer.Deferred) or result.called: logger.info("[%s]: using completed cached result for [%s]", self._name, key) else: diff --git a/tests/util/caches/test_response_cache.py b/tests/util/caches/test_response_cache.py index d2f3c2c7fa82..f69419766fac 100644 --- a/tests/util/caches/test_response_cache.py +++ b/tests/util/caches/test_response_cache.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from parameterized import parameterized from twisted.internet import defer -from synapse.util.caches.response_cache import ResponseCache +from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext from tests.server import get_clock from tests.unittest import TestCase @@ -142,3 +143,49 @@ def test_cache_wait_expire(self): self.reactor.pump((2,)) self.assertIsNone(cache.get(0), "cache should not have the result now") + + @parameterized.expand([(True,), (False,)]) + def test_cache_context_nocache(self, should_cache: bool): + """If the callback clears the should_cache bit, the result should not be cached""" + cache = self.with_cache("medium_cache", ms=3000) + + expected_result = "howdy" + + call_count = [0] + + async def non_caching(o: str, cache_context: ResponseCacheContext[int]): + call_count[0] += 1 + await self.clock.sleep(1) + cache_context.should_cache = should_cache + return o + + wrap_d = defer.ensureDeferred( + cache.wrap(0, non_caching, expected_result, cache_context=True) + ) + # there should be no result to start with + self.assertNoResult(wrap_d) + + # a second call should also return a pending deferred + wrap2_d = defer.ensureDeferred( + cache.wrap(0, non_caching, expected_result, cache_context=True) + ) + self.assertNoResult(wrap2_d) + + # and there should have been exactly one call + self.assertEqual(call_count[0], 1) + + # let the call complete + self.reactor.advance(1) + + # both results should have completed + self.assertEqual(expected_result, self.successResultOf(wrap_d)) + self.assertEqual(expected_result, self.successResultOf(wrap2_d)) + + if should_cache: + self.assertEqual( + expected_result, + self.successResultOf(cache.get(0)), + "cache should still have the result", + ) + else: + self.assertIsNone(cache.get(0), "cache should not have the result") From d4798f7a70b71674b09d29d05cf91a6afc348559 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 9 Jun 2021 23:00:29 +0100 Subject: [PATCH 5/6] changelog --- changelog.d/10157.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/10157.misc diff --git a/changelog.d/10157.misc b/changelog.d/10157.misc new file mode 100644 index 000000000000..6c1d0e6e5933 --- /dev/null +++ b/changelog.d/10157.misc @@ -0,0 +1 @@ +Extend `ResponseCache` to pass a context object into the callback. From 117668ee5543956350f5e7208acde38277ad23e5 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 10 Jun 2021 17:02:49 +0100 Subject: [PATCH 6/6] use to avoid listy hack --- tests/util/caches/test_response_cache.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/util/caches/test_response_cache.py b/tests/util/caches/test_response_cache.py index f69419766fac..1e83ef2f33d5 100644 --- a/tests/util/caches/test_response_cache.py +++ b/tests/util/caches/test_response_cache.py @@ -151,10 +151,11 @@ def test_cache_context_nocache(self, should_cache: bool): expected_result = "howdy" - call_count = [0] + call_count = 0 async def non_caching(o: str, cache_context: ResponseCacheContext[int]): - call_count[0] += 1 + nonlocal call_count + call_count += 1 await self.clock.sleep(1) cache_context.should_cache = should_cache return o @@ -172,7 +173,7 @@ async def non_caching(o: str, cache_context: ResponseCacheContext[int]): self.assertNoResult(wrap2_d) # and there should have been exactly one call - self.assertEqual(call_count[0], 1) + self.assertEqual(call_count, 1) # let the call complete self.reactor.advance(1)