diff --git a/changelog.d/8063.misc b/changelog.d/8063.misc new file mode 100644 index 000000000000..dfe4c03171d6 --- /dev/null +++ b/changelog.d/8063.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index cf039e7f7dd0..82aac2bbf3a3 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -16,15 +16,16 @@ import abc import logging -from typing import List, Tuple +from typing import List, Optional, Tuple from twisted.internet import defer from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.types import JsonDict from synapse.util import json_encoder -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) @@ -97,13 +98,15 @@ def get_account_data_for_user_txn(txn): "get_account_data_for_user", get_account_data_for_user_txn ) - @cachedInlineCallbacks(num_args=2, max_entries=5000) - def get_global_account_data_by_type_for_user(self, data_type, user_id): + @cached(num_args=2, max_entries=5000) + async def get_global_account_data_by_type_for_user( + self, data_type: str, user_id: str + ) -> Optional[JsonDict]: """ Returns: - Deferred: A dict + The account data. """ - result = yield self.db_pool.simple_select_one_onecol( + result = await self.db_pool.simple_select_one_onecol( table="account_data", keyvalues={"user_id": user_id, "account_data_type": data_type}, retcol="content", @@ -280,9 +283,11 @@ def get_updated_account_data_for_user_txn(txn): "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) - @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) - def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): - ignored_account_data = yield self.get_global_account_data_by_type_for_user( + @cached(num_args=2, cache_context=True, max_entries=5000) + async def is_ignored_by( + self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext + ) -> bool: + ignored_account_data = await self.get_global_account_data_by_type_for_user( "m.ignored_user_list", ignorer_user_id, on_invalidate=cache_context.invalidate, @@ -307,24 +312,27 @@ def __init__(self, database: DatabasePool, db_conn, hs): super(AccountDataStore, self).__init__(database, db_conn, hs) - def get_max_account_data_stream_id(self): + def get_max_account_data_stream_id(self) -> int: """Get the current max stream id for the private user data stream Returns: - A deferred int. + The maximum stream ID. """ return self._account_data_id_gen.get_current_token() - @defer.inlineCallbacks - def add_account_data_to_room(self, user_id, room_id, account_data_type, content): + async def add_account_data_to_room( + self, user_id: str, room_id: str, account_data_type: str, content: JsonDict + ) -> int: """Add some account_data to a room for a user. + Args: - user_id(str): The user to add a tag for. - room_id(str): The room to add a tag for. - account_data_type(str): The type of account_data to add. - content(dict): A json object to associate with the tag. + user_id: The user to add a tag for. + room_id: The room to add a tag for. + account_data_type: The type of account_data to add. + content: A json object to associate with the tag. + Returns: - A deferred that completes once the account_data has been added. + The maximum stream ID. """ content_json = json_encoder.encode(content) @@ -332,7 +340,7 @@ def add_account_data_to_room(self, user_id, room_id, account_data_type, content) # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. - yield self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( desc="add_room_account_data", table="room_account_data", keyvalues={ @@ -350,7 +358,7 @@ def add_account_data_to_room(self, user_id, room_id, account_data_type, content) # doesn't sound any worse than the whole update getting lost, # which is what would happen if we combined the two into one # transaction. - yield self._update_max_stream_id(next_id) + await self._update_max_stream_id(next_id) self._account_data_stream_cache.entity_has_changed(user_id, next_id) self.get_account_data_for_user.invalidate((user_id,)) @@ -359,18 +367,20 @@ def add_account_data_to_room(self, user_id, room_id, account_data_type, content) (user_id, room_id, account_data_type), content ) - result = self._account_data_id_gen.get_current_token() - return result + return self._account_data_id_gen.get_current_token() - @defer.inlineCallbacks - def add_account_data_for_user(self, user_id, account_data_type, content): + async def add_account_data_for_user( + self, user_id: str, account_data_type: str, content: JsonDict + ) -> int: """Add some account_data to a room for a user. + Args: - user_id(str): The user to add a tag for. - account_data_type(str): The type of account_data to add. - content(dict): A json object to associate with the tag. + user_id: The user to add a tag for. + account_data_type: The type of account_data to add. + content: A json object to associate with the tag. + Returns: - A deferred that completes once the account_data has been added. + The maximum stream ID. """ content_json = json_encoder.encode(content) @@ -378,7 +388,7 @@ def add_account_data_for_user(self, user_id, account_data_type, content): # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. - yield self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( desc="add_user_account_data", table="account_data", keyvalues={"user_id": user_id, "account_data_type": account_data_type}, @@ -396,7 +406,7 @@ def add_account_data_for_user(self, user_id, account_data_type, content): # Note: This is only here for backwards compat to allow admins to # roll back to a previous Synapse version. Next time we update the # database version we can remove this table. - yield self._update_max_stream_id(next_id) + await self._update_max_stream_id(next_id) self._account_data_stream_cache.entity_has_changed(user_id, next_id) self.get_account_data_for_user.invalidate((user_id,)) @@ -404,14 +414,13 @@ def add_account_data_for_user(self, user_id, account_data_type, content): (account_data_type, user_id) ) - result = self._account_data_id_gen.get_current_token() - return result + return self._account_data_id_gen.get_current_token() - def _update_max_stream_id(self, next_id): + def _update_max_stream_id(self, next_id: int): """Update the max stream_id Args: - next_id(int): The the revision to advance to. + next_id: The the revision to advance to. """ # Note: This is only here for backwards compat to allow admins to diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index 1de8249563f8..f211ddbaf88e 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -16,8 +16,6 @@ import logging from typing import TYPE_CHECKING -from twisted.internet import defer - from synapse.events.utils import prune_event_dict from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore @@ -148,17 +146,16 @@ def _censor_event_txn(self, txn, event_id, pruned_json): updatevalues={"json": pruned_json}, ) - @defer.inlineCallbacks - def expire_event(self, event_id): + async def expire_event(self, event_id: str) -> None: """Retrieve and expire an event that has expired, and delete its associated expiry timestamp. If the event can't be retrieved, delete its associated timestamp so we don't try to expire it again in the future. Args: - event_id (str): The ID of the event to delete. + event_id: The ID of the event to delete. """ # Try to retrieve the event's content from the database or the event cache. - event = yield self.get_event(event_id) + event = await self.get_event(event_id) def delete_expired_event_txn(txn): # Delete the expiry timestamp associated with this event from the database. @@ -193,7 +190,7 @@ def delete_expired_event_txn(txn): txn, "_get_event_cache", (event.event_id,) ) - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_expired_event", delete_expired_event_txn ) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 76ec954f446d..1f6e995c4fef 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -16,8 +16,6 @@ import logging from typing import List, Tuple -from twisted.internet import defer - from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool @@ -31,24 +29,31 @@ class DeviceInboxWorkerStore(SQLBaseStore): def get_to_device_stream_token(self): return self._device_inbox_id_gen.get_current_token() - def get_new_messages_for_device( - self, user_id, device_id, last_stream_id, current_stream_id, limit=100 - ): + async def get_new_messages_for_device( + self, + user_id: str, + device_id: str, + last_stream_id: int, + current_stream_id: int, + limit: int = 100, + ) -> Tuple[List[dict], int]: """ Args: - user_id(str): The recipient user_id. - device_id(str): The recipient device_id. - current_stream_id(int): The current position of the to device + user_id: The recipient user_id. + device_id: The recipient device_id. + last_stream_id: The last stream ID checked. + current_stream_id: The current position of the to device message stream. + limit: The maximum number of messages to retrieve. + Returns: - Deferred ([dict], int): List of messages for the device and where - in the stream the messages got to. + A list of messages for the device and where in the stream the messages got to. """ has_changed = self._device_inbox_stream_cache.has_entity_changed( user_id, last_stream_id ) if not has_changed: - return defer.succeed(([], current_stream_id)) + return ([], current_stream_id) def get_new_messages_for_device_txn(txn): sql = ( @@ -69,20 +74,22 @@ def get_new_messages_for_device_txn(txn): stream_pos = current_stream_id return messages, stream_pos - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_new_messages_for_device", get_new_messages_for_device_txn ) @trace - @defer.inlineCallbacks - def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): + async def delete_messages_for_device( + self, user_id: str, device_id: str, up_to_stream_id: int + ) -> int: """ Args: - user_id(str): The recipient user_id. - device_id(str): The recipient device_id. - up_to_stream_id(int): Where to delete messages up to. + user_id: The recipient user_id. + device_id: The recipient device_id. + up_to_stream_id: Where to delete messages up to. + Returns: - A deferred that resolves to the number of messages deleted. + The number of messages deleted. """ # If we have cached the last stream id we've deleted up to, we can # check if there is likely to be anything that needs deleting @@ -109,7 +116,7 @@ def delete_messages_for_device_txn(txn): txn.execute(sql, (user_id, device_id, up_to_stream_id)) return txn.rowcount - count = yield self.db_pool.runInteraction( + count = await self.db_pool.runInteraction( "delete_messages_for_device", delete_messages_for_device_txn ) @@ -128,9 +135,9 @@ def delete_messages_for_device_txn(txn): return count @trace - def get_new_device_msgs_for_remote( + async def get_new_device_msgs_for_remote( self, destination, last_stream_id, current_stream_id, limit - ): + ) -> Tuple[List[dict], int]: """ Args: destination(str): The name of the remote server. @@ -139,8 +146,7 @@ def get_new_device_msgs_for_remote( current_stream_id(int|long): The current position of the device message stream. Returns: - Deferred ([dict], int|long): List of messages for the device and where - in the stream the messages got to. + A list of messages for the device and where in the stream the messages got to. """ set_tag("destination", destination) @@ -153,11 +159,11 @@ def get_new_device_msgs_for_remote( ) if not has_changed or last_stream_id == current_stream_id: log_kv({"message": "No new messages in stream"}) - return defer.succeed(([], current_stream_id)) + return ([], current_stream_id) if limit <= 0: # This can happen if we run out of room for EDUs in the transaction. - return defer.succeed(([], last_stream_id)) + return ([], last_stream_id) @trace def get_new_messages_for_remote_destination_txn(txn): @@ -178,7 +184,7 @@ def get_new_messages_for_remote_destination_txn(txn): stream_pos = current_stream_id return messages, stream_pos - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_new_device_msgs_for_remote", get_new_messages_for_remote_destination_txn, ) @@ -290,16 +296,15 @@ def __init__(self, database: DatabasePool, db_conn, hs): self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox ) - @defer.inlineCallbacks - def _background_drop_index_device_inbox(self, progress, batch_size): + async def _background_drop_index_device_inbox(self, progress, batch_size): def reindex_txn(conn): txn = conn.cursor() txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") txn.close() - yield self.db_pool.runWithConnection(reindex_txn) + await self.db_pool.runWithConnection(reindex_txn) - yield self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) + await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) return 1 @@ -320,21 +325,21 @@ def __init__(self, database: DatabasePool, db_conn, hs): ) @trace - @defer.inlineCallbacks - def add_messages_to_device_inbox( - self, local_messages_by_user_then_device, remote_messages_by_destination - ): + async def add_messages_to_device_inbox( + self, + local_messages_by_user_then_device: dict, + remote_messages_by_destination: dict, + ) -> int: """Used to send messages from this server. Args: - sender_user_id(str): The ID of the user sending these messages. - local_messages_by_user_and_device(dict): + local_messages_by_user_and_device: Dictionary of user_id to device_id to message. - remote_messages_by_destination(dict): + remote_messages_by_destination: Dictionary of destination server_name to the EDU JSON to send. + Returns: - A deferred stream_id that resolves when the messages have been - inserted. + The new stream_id. """ def add_messages_txn(txn, now_ms, stream_id): @@ -359,7 +364,7 @@ def add_messages_txn(txn, now_ms, stream_id): with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id ) for user_id in local_messages_by_user_then_device.keys(): @@ -371,10 +376,9 @@ def add_messages_txn(txn, now_ms, stream_id): return self._device_inbox_id_gen.get_current_token() - @defer.inlineCallbacks - def add_messages_from_remote_to_device_inbox( - self, origin, message_id, local_messages_by_user_then_device - ): + async def add_messages_from_remote_to_device_inbox( + self, origin: str, message_id: str, local_messages_by_user_then_device: dict + ) -> int: def add_messages_txn(txn, now_ms, stream_id): # Check if we've already inserted a matching message_id for that # origin. This can happen if the origin doesn't receive our @@ -409,7 +413,7 @@ def add_messages_txn(txn, now_ms, stream_id): with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", add_messages_txn, now_ms, diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index b7d0adb10e26..64ddd8243d9e 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -24,6 +24,7 @@ from synapse.types import UserID from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import override_config from tests.utils import register_federation_servlets @@ -151,7 +152,7 @@ def get_users_in_room(room_id): self.datastore.get_current_state_deltas.return_value = (0, None) self.datastore.get_to_device_stream_token = lambda: 0 - self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: defer.succeed( + self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable( ([], 0) ) self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None