From c734388ef907fb8910fc922377203de31c9c3e73 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Wed, 9 Mar 2022 18:27:28 +0000 Subject: [PATCH 1/5] Add delay_cancellation utility function `delay_cancellation` behaves like `stop_cancellation`, except it delays `CancelledError`s until the original `Deferred` resolves. This is handy for unifying cleanup paths and ensuring that uncancelled coroutines don't use finished logcontexts. Signed-off-by: Sean Quah --- synapse/util/async_helpers.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index a9f67dcbac6a..7e0e847dbd5f 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -695,3 +695,31 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": new_deferred: defer.Deferred[T] = defer.Deferred() deferred.chainDeferred(new_deferred) return new_deferred + + +def delay_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": + """Delay cancellation of a `Deferred` until it resolves. + Has the same effect as `stop_cancellation`, but the returned `Deferred` will not + resolve with a `CancelledError` until the original `Deferred` resolves. + Args: + deferred: The `Deferred` to protect against cancellation. May optionally follow + the Synapse logcontext rules. + Returns: + A new `Deferred`, which will contain the result of the original `Deferred`. + The new `Deferred` will not propagate cancellation through to the original. + When cancelled, the new `Deferred` will wait until the original `Deferred` + resolves before failing with a `CancelledError`. + The new `Deferred` will follow the Synapse logcontext rules if `deferred` + follows the Synapse logcontext rules. Otherwise the new `Deferred` should be + wrapped with `make_deferred_yieldable`. + """ + + def handle_cancel(new_deferred: "defer.Deferred[T]") -> None: + new_deferred.pause() + new_deferred.errback(Failure(CancelledError())) + + deferred.addBoth(lambda _: new_deferred.unpause()) + + new_deferred: "defer.Deferred[T]" = defer.Deferred(handle_cancel) + deferred.chainDeferred(new_deferred) + return new_deferred From ea9fb47ef968bb71b23369bfd3ebcf63317a0983 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Wed, 9 Mar 2022 16:40:16 +0000 Subject: [PATCH 2/5] Add tests for database transaction callbacks Signed-off-by: Sean Quah --- changelog.d/12198.misc | 1 + tests/storage/test_database.py | 113 ++++++++++++++++++++++++++++++++- 2 files changed, 113 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12198.misc diff --git a/changelog.d/12198.misc b/changelog.d/12198.misc new file mode 100644 index 000000000000..6b184a9053f8 --- /dev/null +++ b/changelog.d/12198.misc @@ -0,0 +1 @@ +Add tests for database transaction callbacks. diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py index 85978675634b..a3821b41081b 100644 --- a/tests/storage/test_database.py +++ b/tests/storage/test_database.py @@ -12,7 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.database import make_tuple_comparison_clause +from typing import Callable, NoReturn, Tuple +from unittest.mock import Mock, call + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer +from synapse.storage.database import ( + DatabasePool, + LoggingTransaction, + make_tuple_comparison_clause, +) +from synapse.util import Clock from tests import unittest @@ -22,3 +33,103 @@ def test_native_tuple_comparison(self): clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)]) self.assertEqual(clause, "(a,b) > (?,?)") self.assertEqual(args, [1, 2]) + + +class CallbacksTestCase(unittest.HomeserverTestCase): + """Tests for transaction callbacks.""" + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + + def _run_interaction( + self, func: Callable[[LoggingTransaction], object] + ) -> Tuple[Mock, Mock]: + """Run the given function in a database transaction, with callbacks registered. + + Args: + func: The function to be run in a transaction. The transaction will be + retried if `func` raises an `OperationalError`. + + Returns: + Two mocks, which were registered as an `after_callback` and an + `exception_callback` respectively, on every transaction attempt. + """ + after_callback = Mock() + exception_callback = Mock() + + def _test_txn(txn: LoggingTransaction) -> None: + txn.call_after(after_callback, 123, 456, extra=789) + txn.call_on_exception(exception_callback, 987, 654, extra=321) + func(txn) + + try: + self.get_success_or_raise( + self.db_pool.runInteraction("test_transaction", _test_txn) + ) + except Exception: + pass + + return after_callback, exception_callback + + def test_after_callback(self) -> None: + """Test that the after callback is called when a transaction succeeds.""" + after_callback, exception_callback = self._run_interaction(lambda txn: None) + + after_callback.assert_called_once_with(123, 456, extra=789) + exception_callback.assert_not_called() + + def test_exception_callback(self) -> None: + """Test that the exception callback is called when a transaction fails.""" + after_callback, exception_callback = self._run_interaction(lambda txn: 1 / 0) + + after_callback.assert_not_called() + exception_callback.assert_called_once_with(987, 654, extra=321) + + def test_failed_retry(self) -> None: + """Test that the exception callback is called for every failed attempt.""" + + def _test_txn(txn: LoggingTransaction) -> NoReturn: + """Simulate a retryable failure on every attempt.""" + raise self.db_pool.engine.module.OperationalError() + + after_callback, exception_callback = self._run_interaction(_test_txn) + + after_callback.assert_not_called() + exception_callback.assert_has_calls( + [ + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + ] + ) + self.assertEqual(exception_callback.call_count, 6) # no additional calls + + def test_successful_retry(self) -> None: + """Test callbacks for a failed transaction followed by a successful attempt.""" + first_attempt = True + + def _test_txn(txn: LoggingTransaction) -> None: + """Simulate a retryable failure on the first attempt only.""" + nonlocal first_attempt + if first_attempt: + first_attempt = False + raise self.db_pool.engine.module.OperationalError() + else: + return None + + after_callback, exception_callback = self._run_interaction(_test_txn) + + # Calling both `after_callback`s when the first attempt failed is rather + # dubious (#12184). But let's document the behaviour in a test. + after_callback.assert_has_calls( + [ + call(123, 456, extra=789), + call(123, 456, extra=789), + ] + ) + self.assertEqual(after_callback.call_count, 2) # no additional calls + exception_callback.assert_not_called() From 5397ae47818169ea0985a82d076e1479faf48036 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Wed, 9 Mar 2022 17:39:52 +0000 Subject: [PATCH 3/5] Handle cancellation in `DatabasePool.runInteraction()` To handle cancellation, we ensure that `after_callback`s and `exception_callback`s are always run, since the transaction will complete on another thread regardless of cancellation. We also wait until everything is done before releasing the `CancelledError`, so that logging contexts won't get used after they have been finished. Signed-off-by: Sean Quah --- synapse/storage/database.py | 61 ++++++++++++++++++++++--------------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 99802228c9f7..9749f0c06e86 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -41,6 +41,7 @@ from typing_extensions import Literal from twisted.enterprise import adbapi +from twisted.internet import defer from synapse.api.errors import StoreError from synapse.config.database import DatabaseConnectionConfig @@ -55,6 +56,7 @@ from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor +from synapse.util.async_helpers import delay_cancellation from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -732,34 +734,45 @@ async def runInteraction( Returns: The result of func """ - after_callbacks: List[_CallbackListEntry] = [] - exception_callbacks: List[_CallbackListEntry] = [] - if not current_context(): - logger.warning("Starting db txn '%s' from sentinel context", desc) + async def _runInteraction() -> R: + after_callbacks: List[_CallbackListEntry] = [] + exception_callbacks: List[_CallbackListEntry] = [] - try: - with opentracing.start_active_span(f"db.{desc}"): - result = await self.runWithConnection( - self.new_transaction, - desc, - after_callbacks, - exception_callbacks, - func, - *args, - db_autocommit=db_autocommit, - isolation_level=isolation_level, - **kwargs, - ) + if not current_context(): + logger.warning("Starting db txn '%s' from sentinel context", desc) - for after_callback, after_args, after_kwargs in after_callbacks: - after_callback(*after_args, **after_kwargs) - except Exception: - for after_callback, after_args, after_kwargs in exception_callbacks: - after_callback(*after_args, **after_kwargs) - raise + try: + with opentracing.start_active_span(f"db.{desc}"): + result = await self.runWithConnection( + self.new_transaction, + desc, + after_callbacks, + exception_callbacks, + func, + *args, + db_autocommit=db_autocommit, + isolation_level=isolation_level, + **kwargs, + ) - return cast(R, result) + for after_callback, after_args, after_kwargs in after_callbacks: + after_callback(*after_args, **after_kwargs) + + return cast(R, result) + except Exception: + for after_callback, after_args, after_kwargs in exception_callbacks: + after_callback(*after_args, **after_kwargs) + raise + + # To handle cancellation, we ensure that `after_callback`s and + # `exception_callback`s are always run, since the transaction will complete + # on another thread regardless of cancellation. + # + # We also wait until everything above is done before releasing the + # `CancelledError`, so that logging contexts won't get used after they have been + # finished. + return await delay_cancellation(defer.ensureDeferred(_runInteraction())) async def runWithConnection( self, From f2029228a078dba664953c86c26423fcf55a95ab Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Wed, 9 Mar 2022 16:55:22 +0000 Subject: [PATCH 4/5] Add tests for database callbacks after cancellation Signed-off-by: Sean Quah --- tests/storage/test_database.py | 58 ++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py index a3821b41081b..15a578d45eb4 100644 --- a/tests/storage/test_database.py +++ b/tests/storage/test_database.py @@ -15,6 +15,8 @@ from typing import Callable, NoReturn, Tuple from unittest.mock import Mock, call +from twisted.internet import defer +from twisted.internet.defer import CancelledError, Deferred from twisted.test.proto_helpers import MemoryReactor from synapse.server import HomeServer @@ -133,3 +135,59 @@ def _test_txn(txn: LoggingTransaction) -> None: ) self.assertEqual(after_callback.call_count, 2) # no additional calls exception_callback.assert_not_called() + + +class CancellationTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + + def test_after_callback(self) -> None: + """Test that the after callback is called when a transaction succeeds.""" + d: "Deferred[None]" + after_callback = Mock() + exception_callback = Mock() + + def _test_txn(txn: LoggingTransaction) -> None: + txn.call_after(after_callback, 123, 456, extra=789) + txn.call_on_exception(exception_callback, 987, 654, extra=321) + d.cancel() + + d = defer.ensureDeferred( + self.db_pool.runInteraction("test_transaction", _test_txn) + ) + self.get_failure(d, CancelledError) + + after_callback.assert_called_once_with(123, 456, extra=789) + exception_callback.assert_not_called() + + def test_exception_callback(self) -> None: + """Test that the exception callback is called when a transaction fails.""" + d: "Deferred[None]" + after_callback = Mock() + exception_callback = Mock() + + def _test_txn(txn: LoggingTransaction) -> None: + txn.call_after(after_callback, 123, 456, extra=789) + txn.call_on_exception(exception_callback, 987, 654, extra=321) + d.cancel() + # Simulate a retryable failure on every attempt. + raise self.db_pool.engine.module.OperationalError() + + d = defer.ensureDeferred( + self.db_pool.runInteraction("test_transaction", _test_txn) + ) + self.get_failure(d, CancelledError) + + after_callback.assert_not_called() + exception_callback.assert_has_calls( + [ + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + ] + ) + self.assertEqual(exception_callback.call_count, 6) # no additional calls From 1df9d429d30564e7f2243333de1861cc991d286f Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Wed, 9 Mar 2022 18:21:26 +0000 Subject: [PATCH 5/5] Add newsfile --- changelog.d/12199.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/12199.misc diff --git a/changelog.d/12199.misc b/changelog.d/12199.misc new file mode 100644 index 000000000000..16dec1d26d24 --- /dev/null +++ b/changelog.d/12199.misc @@ -0,0 +1 @@ +Handle cancellation in `DatabasePool.runInteraction()`.