diff --git a/changelog.d/8166.misc b/changelog.d/8166.misc new file mode 100644 index 000000000000..dfe4c03171d6 --- /dev/null +++ b/changelog.d/8166.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index d68b4bd670f6..769cd5de28ab 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -21,7 +21,9 @@ import logging +from synapse.federation.units import Transaction from synapse.logging.utils import log_function +from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -49,15 +51,15 @@ def have_responded(self, origin, transaction): return self.store.get_received_txn_response(transaction.transaction_id, origin) @log_function - def set_response(self, origin, transaction, code, response): + async def set_response( + self, origin: str, transaction: Transaction, code: int, response: JsonDict + ) -> None: """ Persist how we responded to a transaction. - - Returns: - Deferred """ - if not transaction.transaction_id: + transaction_id = transaction.transaction_id # type: ignore + if not transaction_id: raise RuntimeError("Cannot persist a transaction with no transaction_id") - return self.store.set_received_txn_response( - transaction.transaction_id, origin, code, response + await self.store.set_received_txn_response( + transaction_id, origin, code, response ) diff --git a/synapse/federation/units.py b/synapse/federation/units.py index 6b32e0dcbfb8..64d98fc8f675 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -107,9 +107,7 @@ def __init__(self, transaction_id=None, pdus=[], **kwargs): if "edus" in kwargs and not kwargs["edus"]: del kwargs["edus"] - super(Transaction, self).__init__( - transaction_id=transaction_id, pdus=pdus, **kwargs - ) + super().__init__(transaction_id=transaction_id, pdus=pdus, **kwargs) @staticmethod def create_new(pdus, **kwargs): diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 77723f7d4dc7..92f56f1602a2 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -161,16 +161,14 @@ async def get_appservice_state(self, service): return result.get("state") return None - def set_appservice_state(self, service, state): + async def set_appservice_state(self, service, state) -> None: """Set the application service state. Args: service(ApplicationService): The service whose state to set. state(ApplicationServiceState): The connectivity state to apply. - Returns: - An Awaitable which resolves when the state was set successfully. """ - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( "application_services_state", {"as_id": service.id}, {"state": state} ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index a811a39eb524..ecd3f3b3108b 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -716,11 +716,11 @@ async def get_user_ids_requiring_device_list_resync( return {row["user_id"] for row in rows} - def mark_remote_user_device_cache_as_stale(self, user_id: str): + async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None: """Records that the server has reason to believe the cache of the devices for the remote users is out of date. """ - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="device_lists_remote_resync", keyvalues={"user_id": user_id}, values={}, diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index c39864f59f8d..f4c668e3a1df 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -741,7 +741,13 @@ def remove_room_from_summary(self, group_id, room_id, category_id): desc="remove_room_from_summary", ) - def upsert_group_category(self, group_id, category_id, profile, is_public): + async def upsert_group_category( + self, + group_id: str, + category_id: str, + profile: Optional[JsonDict], + is_public: Optional[bool], + ) -> None: """Add/update room category for group """ insertion_values = {} @@ -757,7 +763,7 @@ def upsert_group_category(self, group_id, category_id, profile, is_public): else: update_values["is_public"] = is_public - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, values=update_values, @@ -772,7 +778,13 @@ def remove_group_category(self, group_id, category_id): desc="remove_group_category", ) - def upsert_group_role(self, group_id, role_id, profile, is_public): + async def upsert_group_role( + self, + group_id: str, + role_id: str, + profile: Optional[JsonDict], + is_public: Optional[bool], + ) -> None: """Add/remove user role """ insertion_values = {} @@ -788,7 +800,7 @@ def upsert_group_role(self, group_id, role_id, profile, is_public): else: update_values["is_public"] = is_public - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, values=update_values, @@ -937,10 +949,10 @@ def remove_user_from_summary(self, group_id, user_id, role_id): desc="remove_user_from_summary", ) - def add_group_invite(self, group_id, user_id): + async def add_group_invite(self, group_id: str, user_id: str) -> None: """Record that the group server has invited a user """ - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( table="group_invites", values={"group_id": group_id, "user_id": user_id}, desc="add_group_invite", @@ -1043,8 +1055,10 @@ def _remove_user_from_group_txn(txn): "remove_user_from_group", _remove_user_from_group_txn ) - def add_room_to_group(self, group_id, room_id, is_public): - return self.db_pool.simple_insert( + async def add_room_to_group( + self, group_id: str, room_id: str, is_public: bool + ) -> None: + await self.db_pool.simple_insert( table="group_rooms", values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, desc="add_room_to_group", diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index fadcad51e7a1..1c0a049c5548 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -140,22 +140,28 @@ async def store_server_verify_keys( for i in invalidations: invalidate((i,)) - def store_server_keys_json( - self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes - ): + async def store_server_keys_json( + self, + server_name: str, + key_id: str, + from_server: str, + ts_now_ms: int, + ts_expires_ms: int, + key_json_bytes: bytes, + ) -> None: """Stores the JSON bytes for a set of keys from a server The JSON should be signed by the originating server, the intermediate server, and by this server. Updates the value for the (server_name, key_id, from_server) triplet if one already existed. Args: - server_name (str): The name of the server. - key_id (str): The identifer of the key this JSON is for. - from_server (str): The server this JSON was fetched from. - ts_now_ms (int): The time now in milliseconds. - ts_valid_until_ms (int): The time when this json stops being valid. - key_json (bytes): The encoded JSON. + server_name: The name of the server. + key_id: The identifer of the key this JSON is for. + from_server: The server this JSON was fetched from. + ts_now_ms: The time now in milliseconds. + ts_valid_until_ms: The time when this json stops being valid. + key_json_bytes: The encoded JSON. """ - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="server_keys_json", keyvalues={ "server_name": server_name, diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 4ae255ebd8f5..6c151d07cea2 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -60,7 +60,7 @@ async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: desc="get_local_media", ) - def store_local_media( + async def store_local_media( self, media_id, media_type, @@ -69,8 +69,8 @@ def store_local_media( media_length, user_id, url_cache=None, - ): - return self.db_pool.simple_insert( + ) -> None: + await self.db_pool.simple_insert( "local_media_repository", { "media_id": media_id, @@ -141,10 +141,10 @@ def get_url_cache_txn(txn): return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) - def store_url_cache( + async def store_url_cache( self, url, response_code, etag, expires_ts, og, media_id, download_ts ): - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "local_media_repository_url_cache", { "url": url, @@ -172,7 +172,7 @@ def get_local_media_thumbnails(self, media_id): desc="get_local_media_thumbnails", ) - def store_local_thumbnail( + async def store_local_thumbnail( self, media_id, thumbnail_width, @@ -181,7 +181,7 @@ def store_local_thumbnail( thumbnail_method, thumbnail_length, ): - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "local_media_repository_thumbnails", { "media_id": media_id, @@ -212,7 +212,7 @@ async def get_cached_remote_media( desc="get_cached_remote_media", ) - def store_cached_remote_media( + async def store_cached_remote_media( self, origin, media_id, @@ -222,7 +222,7 @@ def store_cached_remote_media( upload_name, filesystem_id, ): - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "remote_media_cache", { "media_origin": origin, @@ -286,7 +286,7 @@ def get_remote_media_thumbnails(self, origin, media_id): desc="get_remote_media_thumbnails", ) - def store_remote_media_thumbnail( + async def store_remote_media_thumbnail( self, origin, media_id, @@ -297,7 +297,7 @@ def store_remote_media_thumbnail( thumbnail_method, thumbnail_length, ): - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "remote_media_cache_thumbnails", { "media_origin": origin, diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py index dcd1ff911a20..4db8949da767 100644 --- a/synapse/storage/databases/main/openid.py +++ b/synapse/storage/databases/main/openid.py @@ -2,8 +2,10 @@ class OpenIdStore(SQLBaseStore): - def insert_open_id_token(self, token, ts_valid_until_ms, user_id): - return self.db_pool.simple_insert( + async def insert_open_id_token( + self, token: str, ts_valid_until_ms: int, user_id: str + ) -> None: + await self.db_pool.simple_insert( table="open_id_tokens", values={ "token": token, diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index b8233c4848ae..94820e3895bb 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -66,8 +66,8 @@ async def get_from_remote_profile_cache( desc="get_from_remote_profile_cache", ) - def create_profile(self, user_localpart): - return self.db_pool.simple_insert( + async def create_profile(self, user_localpart: str) -> None: + await self.db_pool.simple_insert( table="profiles", values={"user_id": user_localpart}, desc="create_profile" ) @@ -89,13 +89,15 @@ def set_profile_avatar_url(self, user_localpart, new_avatar_url): class ProfileStore(ProfileWorkerStore): - def add_remote_profile_cache(self, user_id, displayname, avatar_url): + async def add_remote_profile_cache( + self, user_id: str, displayname: str, avatar_url: str + ) -> None: """Ensure we are caching the remote user's profiles. This should only be called when `is_subscribed_remote_profile_for_user` would return true for the user. """ - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="remote_profile_cache", keyvalues={"user_id": user_id}, values={ diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index eced53d470ce..946d56ff88d8 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -17,7 +17,7 @@ import logging import re -from typing import Any, Awaitable, Dict, List, Optional +from typing import Any, Dict, List, Optional from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError @@ -549,23 +549,22 @@ def user_delete_threepids(self, user_id: str): desc="user_delete_threepids", ) - def add_user_bound_threepid(self, user_id, medium, address, id_server): + async def add_user_bound_threepid( + self, user_id: str, medium: str, address: str, id_server: str + ): """The server proxied a bind request to the given identity server on behalf of the given user. We need to remember this in case the user asks us to unbind the threepid. Args: - user_id (str) - medium (str) - address (str) - id_server (str) - - Returns: - Awaitable + user_id + medium + address + id_server """ # We need to use an upsert, in case they user had already bound the # threepid - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -1081,9 +1080,9 @@ def _register_user( self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - def record_user_external_id( + async def record_user_external_id( self, auth_provider: str, external_id: str, user_id: str - ) -> Awaitable: + ) -> None: """Record a mapping from an external user id to a mxid Args: @@ -1091,7 +1090,7 @@ def record_user_external_id( external_id: id on that system user_id: complete mxid that it is mapped to """ - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( table="user_external_ids", values={ "auth_provider": auth_provider, @@ -1235,12 +1234,12 @@ async def is_guest(self, user_id: str) -> bool: return res if res else False - def add_user_pending_deactivation(self, user_id): + async def add_user_pending_deactivation(self, user_id: str) -> None: """ Adds a user to the table of users who need to be parted from all the rooms they're in """ - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "users_pending_deactivation", values={"user_id": user_id}, desc="add_user_pending_deactivation", diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 97ecdb16e4ec..04ae7d269f7a 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -27,7 +27,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchStore -from synapse.types import ThirdPartyInstanceID +from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -1296,11 +1296,17 @@ def f(txn): return self.db_pool.runInteraction("get_rooms", f) - def add_event_report( - self, room_id, event_id, user_id, reason, content, received_ts - ): + async def add_event_report( + self, + room_id: str, + event_id: str, + user_id: str, + reason: str, + content: JsonDict, + received_ts: int, + ) -> None: next_id = self._event_reports_id_gen.get_next() - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( table="event_reports", values={ "id": next_id, diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 9fe97af56adb..7af2608ca486 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -16,7 +16,7 @@ import logging from itertools import chain -from typing import Tuple +from typing import Any, Dict, Tuple from twisted.internet.defer import DeferredLock @@ -222,11 +222,11 @@ async def get_stats_positions(self) -> int: desc="stats_incremental_position", ) - def update_room_state(self, room_id, fields): + async def update_room_state(self, room_id: str, fields: Dict[str, Any]) -> None: """ Args: - room_id (str) - fields (dict[str:Any]) + room_id + fields """ # For whatever reason some of the fields may contain null bytes, which @@ -244,7 +244,7 @@ def update_room_state(self, room_id, fields): if field and "\0" in field: fields[col] = None - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="room_stats_state", keyvalues={"room_id": room_id}, values=fields, diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 52668dbdf9cf..2efcc0dc66f4 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -21,6 +21,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool +from synapse.types import JsonDict from synapse.util.caches.expiringcache import ExpiringCache db_binary_type = memoryview @@ -98,20 +99,21 @@ def _get_received_txn_response(self, txn, transaction_id, origin): else: return None - def set_received_txn_response(self, transaction_id, origin, code, response_dict): - """Persist the response we returened for an incoming transaction, and + async def set_received_txn_response( + self, transaction_id: str, origin: str, code: int, response_dict: JsonDict + ) -> None: + """Persist the response we returned for an incoming transaction, and should return for subsequent transactions with the same transaction_id and origin. Args: - txn - transaction_id (str) - origin (str) - code (int) - response_json (str) + transaction_id: The incoming transaction ID. + origin: The origin server. + code: The response code. + response_dict: The response, to be encoded into JSON. """ - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( table="received_transactions", values={ "transaction_id": transaction_id,