From 7e34f6e9856f890b2b9db06364747dcb0c0d214e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 21 Jul 2021 19:56:16 +0100 Subject: [PATCH 1/2] improve typing annotations in CachedCall tighten up some of the typing in CachedCall, which is going to be needed when Twisted 21.7 brings better typing on Deferred. --- changelog.d/10450.misc | 1 + synapse/util/caches/cached_call.py | 34 ++++++++++++++++++++---------- 2 files changed, 24 insertions(+), 11 deletions(-) create mode 100644 changelog.d/10450.misc diff --git a/changelog.d/10450.misc b/changelog.d/10450.misc new file mode 100644 index 000000000000..aa646f0841c7 --- /dev/null +++ b/changelog.d/10450.misc @@ -0,0 +1 @@ + Update type annotations to work with forthcoming Twisted 21.7.0 release. diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py index 891bee0b33ae..26683b4513ff 100644 --- a/synapse/util/caches/cached_call.py +++ b/synapse/util/caches/cached_call.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import enum from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union +from twisted.internet import defer from twisted.internet.defer import Deferred from twisted.python.failure import Failure @@ -22,6 +23,10 @@ TV = TypeVar("TV") +class _Sentinel(enum.Enum): + sentinel = object() + + class CachedCall(Generic[TV]): """A wrapper for asynchronous calls whose results should be shared @@ -65,7 +70,7 @@ def __init__(self, f: Callable[[], Awaitable[TV]]): """ self._callable: Optional[Callable[[], Awaitable[TV]]] = f self._deferred: Optional[Deferred] = None - self._result: Union[None, Failure, TV] = None + self._result: Union[_Sentinel, TV, Failure] = _Sentinel.sentinel async def get(self) -> TV: """Kick off the call if necessary, and return the result""" @@ -78,8 +83,9 @@ async def get(self) -> TV: self._callable = None # once the deferred completes, store the result. We cannot simply leave the - # result in the deferred, since if it's a Failure, GCing the deferred - # would then log a critical error about unhandled Failures. + # result in the deferred, since `awaiting` a deferred destroys its result. + # (Also, if it's a Failure, GCing the deferred would log a critical error + # about unhandled Failures) def got_result(r): self._result = r @@ -92,13 +98,19 @@ def got_result(r): # and any eventual exception may not be reported. # we can now await the deferred, and once it completes, return the result. - await make_deferred_yieldable(self._deferred) - - # I *think* this is the easiest way to correctly raise a Failure without having - # to gut-wrench into the implementation of Deferred. - d = Deferred() - d.callback(self._result) - return await d + if isinstance(self._result, _Sentinel): + await make_deferred_yieldable(self._deferred) + assert not isinstance(self._result, _Sentinel) + + if isinstance(self._result, Failure): + # I *think* awaiting a failed Deferred is the easiest way to correctly raise + # the right exception. + d = defer.fail(self._result) + await d + # the `await` should always raise, so this should be unreachable. + raise AssertionError("unexpected return from await on failure") + + return self._result class RetryOnExceptionCachedCall(Generic[TV]): From 79642961447727fb1808fc232b368759518d2449 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 26 Jul 2021 11:50:00 +0100 Subject: [PATCH 2/2] use Failure.raiseException --- synapse/util/caches/cached_call.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py index 26683b4513ff..e58dd91eda7b 100644 --- a/synapse/util/caches/cached_call.py +++ b/synapse/util/caches/cached_call.py @@ -14,7 +14,6 @@ import enum from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union -from twisted.internet import defer from twisted.internet.defer import Deferred from twisted.python.failure import Failure @@ -103,12 +102,8 @@ def got_result(r): assert not isinstance(self._result, _Sentinel) if isinstance(self._result, Failure): - # I *think* awaiting a failed Deferred is the easiest way to correctly raise - # the right exception. - d = defer.fail(self._result) - await d - # the `await` should always raise, so this should be unreachable. - raise AssertionError("unexpected return from await on failure") + self._result.raiseException() + raise AssertionError("unexpected return from Failure.raiseException") return self._result