From 2c2e649be29390061d2bc7d7a7aea1daa32e68f6 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 28 Aug 2020 09:58:17 +0100 Subject: [PATCH 1/3] Move and refactor LoginRestServlet helper methods (#8182) This is split out from https://github.com/matrix-org/synapse/pull/7438, which had gotten rather large. `LoginRestServlet` has a couple helper methods, `login_submission_legacy_convert` and `login_id_thirdparty_from_phone`. They're primarily used for converting legacy user login submissions to "identifier" dicts ([see spec](https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-login)). Identifying information such as usernames or 3PID information used to be top-level in the login body. They're now supposed to be put inside an [identifier](https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types) parameter instead. #7438's purpose is to allow using the new identifier parameter during User-Interactive Authentication, which is currently handled in AuthHandler. That's why I've moved these helper methods there. I also moved the refactoring of these method from #7438 as they're relevant. --- changelog.d/8182.misc | 1 + synapse/handlers/auth.py | 88 ++++++++++++++++++++++++++++++++- synapse/rest/client/v1/login.py | 60 +++------------------- 3 files changed, 94 insertions(+), 55 deletions(-) create mode 100644 changelog.d/8182.misc diff --git a/changelog.d/8182.misc b/changelog.d/8182.misc new file mode 100644 index 000000000000..4fcdf1c45213 --- /dev/null +++ b/changelog.d/8182.misc @@ -0,0 +1 @@ +Refactor some of `LoginRestServlet`'s helper methods, and move them to `AuthHandler` for easier reuse. \ No newline at end of file diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 654f58ddaefe..f0b0a4d76ab7 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -42,8 +42,9 @@ from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi -from synapse.types import Requester, UserID +from synapse.types import JsonDict, Requester, UserID from synapse.util import stringutils as stringutils +from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.threepids import canonicalise_email from ._base import BaseHandler @@ -51,6 +52,91 @@ logger = logging.getLogger(__name__) +def convert_client_dict_legacy_fields_to_identifier( + submission: JsonDict, +) -> Dict[str, str]: + """ + Convert a legacy-formatted login submission to an identifier dict. + + Legacy login submissions (used in both login and user-interactive authentication) + provide user-identifying information at the top-level instead. + + These are now deprecated and replaced with identifiers: + https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types + + Args: + submission: The client dict to convert + + Returns: + The matching identifier dict + + Raises: + SynapseError: If the format of the client dict is invalid + """ + identifier = submission.get("identifier", {}) + + # Generate an m.id.user identifier if "user" parameter is present + user = submission.get("user") + if user: + identifier = {"type": "m.id.user", "user": user} + + # Generate an m.id.thirdparty identifier if "medium" and "address" parameters are present + medium = submission.get("medium") + address = submission.get("address") + if medium and address: + identifier = { + "type": "m.id.thirdparty", + "medium": medium, + "address": address, + } + + # We've converted valid, legacy login submissions to an identifier. If the + # submission still doesn't have an identifier, it's invalid + if not identifier: + raise SynapseError(400, "Invalid login submission", Codes.INVALID_PARAM) + + # Ensure the identifier has a type + if "type" not in identifier: + raise SynapseError( + 400, "'identifier' dict has no key 'type'", errcode=Codes.MISSING_PARAM, + ) + + return identifier + + +def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]: + """ + Convert a phone login identifier type to a generic threepid identifier. + + Args: + identifier: Login identifier dict of type 'm.id.phone' + + Returns: + An equivalent m.id.thirdparty identifier dict + """ + if "country" not in identifier or ( + # The specification requires a "phone" field, while Synapse used to require a "number" + # field. Accept both for backwards compatibility. + "phone" not in identifier + and "number" not in identifier + ): + raise SynapseError( + 400, "Invalid phone-type identifier", errcode=Codes.INVALID_PARAM + ) + + # Accept both "phone" and "number" as valid keys in m.id.phone + phone_number = identifier.get("phone", identifier["number"]) + + # Convert user-provided phone number to a consistent representation + msisdn = phone_number_to_msisdn(identifier["country"], phone_number) + + return { + "type": "m.id.thirdparty", + "medium": "msisdn", + "address": msisdn, + } + + class AuthHandler(BaseHandler): SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 379f668d6f8a..a14618ac84fb 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -18,6 +18,10 @@ from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter +from synapse.handlers.auth import ( + convert_client_dict_legacy_fields_to_identifier, + login_id_phone_to_thirdparty, +) from synapse.http.server import finish_request from synapse.http.servlet import ( RestServlet, @@ -28,56 +32,11 @@ from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder from synapse.types import JsonDict, UserID -from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.threepids import canonicalise_email logger = logging.getLogger(__name__) -def login_submission_legacy_convert(submission): - """ - If the input login submission is an old style object - (ie. with top-level user / medium / address) convert it - to a typed object. - """ - if "user" in submission: - submission["identifier"] = {"type": "m.id.user", "user": submission["user"]} - del submission["user"] - - if "medium" in submission and "address" in submission: - submission["identifier"] = { - "type": "m.id.thirdparty", - "medium": submission["medium"], - "address": submission["address"], - } - del submission["medium"] - del submission["address"] - - -def login_id_thirdparty_from_phone(identifier): - """ - Convert a phone login identifier type to a generic threepid identifier - Args: - identifier(dict): Login identifier dict of type 'm.id.phone' - - Returns: Login identifier dict of type 'm.id.threepid' - """ - if "country" not in identifier or ( - # The specification requires a "phone" field, while Synapse used to require a "number" - # field. Accept both for backwards compatibility. - "phone" not in identifier - and "number" not in identifier - ): - raise SynapseError(400, "Invalid phone-type identifier") - - # Accept both "phone" and "number" as valid keys in m.id.phone - phone_number = identifier.get("phone", identifier["number"]) - - msisdn = phone_number_to_msisdn(identifier["country"], phone_number) - - return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn} - - class LoginRestServlet(RestServlet): PATTERNS = client_patterns("/login$", v1=True) CAS_TYPE = "m.login.cas" @@ -194,18 +153,11 @@ async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]: login_submission.get("address"), login_submission.get("user"), ) - login_submission_legacy_convert(login_submission) - - if "identifier" not in login_submission: - raise SynapseError(400, "Missing param: identifier") - - identifier = login_submission["identifier"] - if "type" not in identifier: - raise SynapseError(400, "Login identifier has no type") + identifier = convert_client_dict_legacy_fields_to_identifier(login_submission) # convert phone type identifiers to generic threepids if identifier["type"] == "m.id.phone": - identifier = login_id_thirdparty_from_phone(identifier) + identifier = login_id_phone_to_thirdparty(identifier) # convert threepid identifiers to user IDs if identifier["type"] == "m.id.thirdparty": From d5e73cb6aa56fbd267ca957e64ad893a9ef28708 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 28 Aug 2020 07:28:53 -0400 Subject: [PATCH 2/3] Define StateMap as immutable and add a MutableStateMap type. (#8183) --- changelog.d/8183.misc | 1 + synapse/handlers/federation.py | 20 ++++++++++++++------ synapse/handlers/room.py | 3 ++- synapse/handlers/sync.py | 5 +++-- synapse/state/__init__.py | 32 ++++++++++++++++++++------------ synapse/state/v1.py | 10 +++++----- synapse/state/v2.py | 6 +++--- synapse/types.py | 7 ++++--- 8 files changed, 52 insertions(+), 32 deletions(-) create mode 100644 changelog.d/8183.misc diff --git a/changelog.d/8183.misc b/changelog.d/8183.misc new file mode 100644 index 000000000000..78d8834328a5 --- /dev/null +++ b/changelog.d/8183.misc @@ -0,0 +1 @@ +Add type hints to `synapse.state`. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index f8b234cee21a..155d0874137d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -72,7 +72,13 @@ from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id +from synapse.types import ( + JsonDict, + MutableStateMap, + StateMap, + UserID, + get_domain_from_id, +) from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.distributor import user_joined_room from synapse.util.retryutils import NotRetryingDestination @@ -96,7 +102,7 @@ class _NewEventInfo: event = attr.ib(type=EventBase) state = attr.ib(type=Optional[Sequence[EventBase]], default=None) - auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None) + auth_events = attr.ib(type=Optional[MutableStateMap[EventBase]], default=None) class FederationHandler(BaseHandler): @@ -2053,7 +2059,7 @@ async def _prep_event( origin: str, event: EventBase, state: Optional[Iterable[EventBase]], - auth_events: Optional[StateMap[EventBase]], + auth_events: Optional[MutableStateMap[EventBase]], backfilled: bool, ) -> EventContext: context = await self.state_handler.compute_event_context(event, old_state=state) @@ -2137,7 +2143,9 @@ async def _check_for_soft_fail( current_states = await self.state_handler.resolve_events( room_version, state_sets, event ) - current_state_ids = {k: e.event_id for k, e in current_states.items()} + current_state_ids = { + k: e.event_id for k, e in current_states.items() + } # type: StateMap[str] else: current_state_ids = await self.state_handler.get_current_state_ids( event.room_id, latest_event_ids=extrem_ids @@ -2223,7 +2231,7 @@ async def do_auth( origin: str, event: EventBase, context: EventContext, - auth_events: StateMap[EventBase], + auth_events: MutableStateMap[EventBase], ) -> EventContext: """ @@ -2274,7 +2282,7 @@ async def _update_auth_events_and_context_for_auth( origin: str, event: EventBase, context: EventContext, - auth_events: StateMap[EventBase], + auth_events: MutableStateMap[EventBase], ) -> EventContext: """Helper for do_auth. See there for docs. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 236a37f777c2..1419d72e9429 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -41,6 +41,7 @@ from synapse.storage.state import StateFilter from synapse.types import ( JsonDict, + MutableStateMap, Requester, RoomAlias, RoomID, @@ -814,7 +815,7 @@ async def _send_events_for_new_room( room_id: str, preset_config: str, invite_list: List[str], - initial_state: StateMap, + initial_state: MutableStateMap, creation_content: JsonDict, room_alias: Optional[RoomAlias] = None, power_level_content_override: Optional[JsonDict] = None, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c42dac18f5f3..9a86eb01c975 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -31,6 +31,7 @@ from synapse.types import ( Collection, JsonDict, + MutableStateMap, RoomStreamToken, StateMap, StreamToken, @@ -588,7 +589,7 @@ async def compute_summary( room_id: str, sync_config: SyncConfig, batch: TimelineBatch, - state: StateMap[EventBase], + state: MutableStateMap[EventBase], now_token: StreamToken, ) -> Optional[JsonDict]: """ Works out a room summary block for this room, summarising the number @@ -736,7 +737,7 @@ async def compute_state_delta( since_token: Optional[StreamToken], now_token: StreamToken, full_state: bool, - ) -> StateMap[EventBase]: + ) -> MutableStateMap[EventBase]: """ Works out the difference in state between the start of the timeline and the previous sync. diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index a601303fa34e..9bf2ec368f0d 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -25,6 +25,7 @@ Sequence, Set, Union, + cast, overload, ) @@ -41,7 +42,7 @@ from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.roommember import ProfileInfo -from synapse.types import Collection, StateMap +from synapse.types import Collection, MutableStateMap, StateMap from synapse.util import Clock from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache @@ -205,7 +206,7 @@ async def get_current_state_ids( logger.debug("calling resolve_state_groups from get_current_state_ids") ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) - return dict(ret.state) + return ret.state async def get_current_users_in_room( self, room_id: str, latest_event_ids: Optional[List[str]] = None @@ -302,7 +303,7 @@ async def compute_event_context( # if we're given the state before the event, then we use that state_ids_before_event = { (s.type, s.state_key): s.event_id for s in old_state - } + } # type: StateMap[str] state_group_before_event = None state_group_before_event_prev_group = None deltas_to_state_group_before_event = None @@ -315,7 +316,7 @@ async def compute_event_context( event.room_id, event.prev_event_ids() ) - state_ids_before_event = dict(entry.state) + state_ids_before_event = entry.state state_group_before_event = entry.state_group state_group_before_event_prev_group = entry.prev_group deltas_to_state_group_before_event = entry.delta_ids @@ -540,7 +541,7 @@ async def resolve_state_groups( # # XXX: is this actually worthwhile, or should we just let # resolve_events_with_store do it? - new_state = {} + new_state = {} # type: MutableStateMap[str] conflicted_state = False for st in state_groups_ids.values(): for key, e_id in st.items(): @@ -554,13 +555,20 @@ async def resolve_state_groups( if conflicted_state: logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): - new_state = await resolve_events_with_store( - self.clock, - room_id, - room_version, - list(state_groups_ids.values()), - event_map=event_map, - state_res_store=state_res_store, + # resolve_events_with_store returns a StateMap, but we can + # treat it as a MutableStateMap as it is above. It isn't + # actually mutated anymore (and is frozen in + # _make_state_cache_entry below). + new_state = cast( + MutableStateMap, + await resolve_events_with_store( + self.clock, + room_id, + room_version, + list(state_groups_ids.values()), + event_map=event_map, + state_res_store=state_res_store, + ), ) # if the new state matches any of the input state groups, we can diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 0eb7fdd9e5d3..a493279cbd2e 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -32,7 +32,7 @@ from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase -from synapse.types import StateMap +from synapse.types import MutableStateMap, StateMap logger = logging.getLogger(__name__) @@ -131,7 +131,7 @@ async def resolve_events_with_store( def _seperate( state_sets: Iterable[StateMap[str]], -) -> Tuple[StateMap[str], StateMap[Set[str]]]: +) -> Tuple[MutableStateMap[str], MutableStateMap[Set[str]]]: """Takes the state_sets and figures out which keys are conflicted and which aren't. i.e., which have multiple different event_ids associated with them in different state sets. @@ -152,7 +152,7 @@ def _seperate( """ state_set_iterator = iter(state_sets) unconflicted_state = dict(next(state_set_iterator)) - conflicted_state = {} # type: StateMap[Set[str]] + conflicted_state = {} # type: MutableStateMap[Set[str]] for state_set in state_set_iterator: for key, value in state_set.items(): @@ -208,7 +208,7 @@ def _create_auth_events_from_maps( def _resolve_with_state( - unconflicted_state_ids: StateMap[str], + unconflicted_state_ids: MutableStateMap[str], conflicted_state_ids: StateMap[Set[str]], auth_event_ids: StateMap[str], state_map: Dict[str, EventBase], @@ -241,7 +241,7 @@ def _resolve_with_state( def _resolve_state_events( - conflicted_state: StateMap[List[EventBase]], auth_events: StateMap[EventBase] + conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase] ) -> StateMap[EventBase]: """ This is where we actually decide which of the conflicted state to use. diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 0e9ffbd6e623..edf94e7ad683 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -38,7 +38,7 @@ from synapse.api.errors import AuthError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase -from synapse.types import StateMap +from synapse.types import MutableStateMap, StateMap from synapse.util import Clock logger = logging.getLogger(__name__) @@ -414,7 +414,7 @@ async def _iterative_auth_checks( base_state: StateMap[str], event_map: Dict[str, EventBase], state_res_store: "synapse.state.StateResolutionStore", -) -> StateMap[str]: +) -> MutableStateMap[str]: """Sequentially apply auth checks to each event in given list, updating the state as it goes along. @@ -430,7 +430,7 @@ async def _iterative_auth_checks( Returns: Returns the final updated state """ - resolved_state = base_state.copy() + resolved_state = dict(base_state) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] for idx, event_id in enumerate(event_ids, start=1): diff --git a/synapse/types.py b/synapse/types.py index bc36cdde308c..f8b9b0385007 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -18,7 +18,7 @@ import string import sys from collections import namedtuple -from typing import Any, Dict, Tuple, Type, TypeVar +from typing import Any, Dict, Mapping, MutableMapping, Tuple, Type, TypeVar import attr from signedjson.key import decode_verify_key_bytes @@ -41,8 +41,9 @@ class Collection(Iterable[T_co], Container[T_co], Sized): # type: ignore # Define a state map type from type/state_key to T (usually an event ID or # event) T = TypeVar("T") -StateMap = Dict[Tuple[str, str], T] - +StateKey = Tuple[str, str] +StateMap = Mapping[StateKey, T] +MutableStateMap = MutableMapping[StateKey, T] # the type of a JSON-serialisable dict. This could be made stronger, but it will # do for now. From 5c03134d0f8dd157ea1800ce1a4bcddbdb73ddf1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 28 Aug 2020 07:54:27 -0400 Subject: [PATCH 3/3] Convert additional database code to async/await. (#8195) --- changelog.d/8195.misc | 1 + synapse/appservice/__init__.py | 19 +- synapse/federation/persistence.py | 19 +- synapse/handlers/federation.py | 4 +- synapse/storage/databases/main/appservice.py | 15 +- synapse/storage/databases/main/deviceinbox.py | 12 +- .../storage/databases/main/e2e_room_keys.py | 30 +-- .../databases/main/event_federation.py | 71 +++---- .../storage/databases/main/group_server.py | 187 +++++++++++------- synapse/storage/databases/main/keys.py | 24 +-- .../storage/databases/main/transactions.py | 39 ++-- 11 files changed, 246 insertions(+), 175 deletions(-) create mode 100644 changelog.d/8195.misc diff --git a/changelog.d/8195.misc b/changelog.d/8195.misc new file mode 100644 index 000000000000..dfe4c03171d6 --- /dev/null +++ b/changelog.d/8195.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 1ffdc1ed9591..69a7182ef4a2 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -14,11 +14,16 @@ # limitations under the License. import logging import re +from typing import TYPE_CHECKING from synapse.api.constants import EventTypes +from synapse.appservice.api import ApplicationServiceApi from synapse.types import GroupID, get_domain_from_id from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.storage.databases.main import DataStore + logger = logging.getLogger(__name__) @@ -35,19 +40,19 @@ def __init__(self, service, id, events): self.id = id self.events = events - def send(self, as_api): + async def send(self, as_api: ApplicationServiceApi) -> bool: """Sends this transaction using the provided AS API interface. Args: - as_api(ApplicationServiceApi): The API to use to send. + as_api: The API to use to send. Returns: - An Awaitable which resolves to True if the transaction was sent. + True if the transaction was sent. """ - return as_api.push_bulk( + return await as_api.push_bulk( service=self.service, events=self.events, txn_id=self.id ) - def complete(self, store): + async def complete(self, store: "DataStore") -> None: """Completes this transaction as successful. Marks this transaction ID on the application service and removes the @@ -55,10 +60,8 @@ def complete(self, store): Args: store: The database store to operate on. - Returns: - A Deferred which resolves to True if the transaction was completed. """ - return store.complete_appservice_txn(service=self.service, txn_id=self.id) + await store.complete_appservice_txn(service=self.service, txn_id=self.id) class ApplicationService(object): diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 769cd5de28ab..de1fe7da3865 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -20,6 +20,7 @@ """ import logging +from typing import Optional, Tuple from synapse.federation.units import Transaction from synapse.logging.utils import log_function @@ -36,25 +37,27 @@ def __init__(self, datastore): self.store = datastore @log_function - def have_responded(self, origin, transaction): - """ Have we already responded to a transaction with the same id and + async def have_responded( + self, origin: str, transaction: Transaction + ) -> Optional[Tuple[int, JsonDict]]: + """Have we already responded to a transaction with the same id and origin? Returns: - Deferred: Results in `None` if we have not previously responded to - this transaction or a 2-tuple of `(int, dict)` representing the - response code and response body. + `None` if we have not previously responded to this transaction or a + 2-tuple of `(int, dict)` representing the response code and response body. """ - if not transaction.transaction_id: + transaction_id = transaction.transaction_id # type: ignore + if not transaction_id: raise RuntimeError("Cannot persist a transaction with no transaction_id") - return self.store.get_received_txn_response(transaction.transaction_id, origin) + return await self.store.get_received_txn_response(transaction_id, origin) @log_function async def set_response( self, origin: str, transaction: Transaction, code: int, response: JsonDict ) -> None: - """ Persist how we responded to a transaction. + """Persist how we responded to a transaction. """ transaction_id = transaction.transaction_id # type: ignore if not transaction_id: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 155d0874137d..16389a0dca4c 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1879,8 +1879,8 @@ async def get_persisted_pdu( else: return None - def get_min_depth_for_context(self, context): - return self.store.get_min_depth(context) + async def get_min_depth_for_context(self, context): + return await self.store.get_min_depth(context) async def _handle_new_event( self, origin, event, state=None, auth_events=None, backfilled=False diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 92f56f1602a2..454c0bc50cb7 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -172,7 +172,7 @@ async def set_appservice_state(self, service, state) -> None: "application_services_state", {"as_id": service.id}, {"state": state} ) - def create_appservice_txn(self, service, events): + async def create_appservice_txn(self, service, events): """Atomically creates a new transaction for this application service with the given list of events. @@ -209,20 +209,17 @@ def _create_appservice_txn(txn): ) return AppServiceTransaction(service=service, id=new_txn_id, events=events) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "create_appservice_txn", _create_appservice_txn ) - def complete_appservice_txn(self, txn_id, service): + async def complete_appservice_txn(self, txn_id, service) -> None: """Completes an application service transaction. Args: txn_id(str): The transaction ID being completed. service(ApplicationService): The application service which was sent this transaction. - Returns: - A Deferred which resolves if this transaction was stored - successfully. """ txn_id = int(txn_id) @@ -258,7 +255,7 @@ def _complete_appservice_txn(txn): {"txn_id": txn_id, "as_id": service.id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "complete_appservice_txn", _complete_appservice_txn ) @@ -312,13 +309,13 @@ def _get_last_txn(self, txn, service_id): else: return int(last_txn_id[0]) # select 'last_txn' col - def set_appservice_last_pos(self, pos): + async def set_appservice_last_pos(self, pos) -> None: def set_appservice_last_pos_txn(txn): txn.execute( "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "set_appservice_last_pos", set_appservice_last_pos_txn ) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index bb85637a95e3..00444331102e 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -190,15 +190,15 @@ def get_new_messages_for_remote_destination_txn(txn): ) @trace - def delete_device_msgs_for_remote(self, destination, up_to_stream_id): + async def delete_device_msgs_for_remote( + self, destination: str, up_to_stream_id: int + ) -> None: """Used to delete messages when the remote destination acknowledges their receipt. Args: - destination(str): The destination server_name - up_to_stream_id(int): Where to delete messages up to. - Returns: - A deferred that resolves when the messages have been deleted. + destination: The destination server_name + up_to_stream_id: Where to delete messages up to. """ def delete_messages_for_remote_destination_txn(txn): @@ -209,7 +209,7 @@ def delete_messages_for_remote_destination_txn(txn): ) txn.execute(sql, (destination, up_to_stream_id)) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 82f9d870fd06..12cecceec2a4 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -151,7 +151,7 @@ async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=Non return sessions - def get_e2e_room_keys_multi(self, user_id, version, room_keys): + async def get_e2e_room_keys_multi(self, user_id, version, room_keys): """Get multiple room keys at a time. The difference between this function and get_e2e_room_keys is that this function can be used to retrieve multiple specific keys at a time, whereas get_e2e_room_keys is used for @@ -166,10 +166,10 @@ def get_e2e_room_keys_multi(self, user_id, version, room_keys): that we want to query Returns: - Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key + dict[str, dict[str, dict]]: a map of room IDs to session IDs to room key """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_e2e_room_keys_multi", self._get_e2e_room_keys_multi_txn, user_id, @@ -283,7 +283,7 @@ def _get_current_version(txn, user_id): raise StoreError(404, "No current backup version") return row[0] - def get_e2e_room_keys_version_info(self, user_id, version=None): + async def get_e2e_room_keys_version_info(self, user_id, version=None): """Get info metadata about a version of our room_keys backup. Args: @@ -293,7 +293,7 @@ def get_e2e_room_keys_version_info(self, user_id, version=None): Raises: StoreError: with code 404 if there are no e2e_room_keys_versions present Returns: - A deferred dict giving the info metadata for this backup version, with + A dict giving the info metadata for this backup version, with fields including: version(str) algorithm(str) @@ -324,12 +324,12 @@ def _get_e2e_room_keys_version_info_txn(txn): result["etag"] = 0 return result - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn ) @trace - def create_e2e_room_keys_version(self, user_id, info): + async def create_e2e_room_keys_version(self, user_id: str, info: dict) -> str: """Atomically creates a new version of this user's e2e_room_keys store with the given version info. @@ -338,7 +338,7 @@ def create_e2e_room_keys_version(self, user_id, info): info(dict): the info about the backup version to be created Returns: - A deferred string for the newly created version ID + The newly created version ID """ def _create_e2e_room_keys_version_txn(txn): @@ -365,7 +365,7 @@ def _create_e2e_room_keys_version_txn(txn): return new_version - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn ) @@ -403,13 +403,15 @@ async def update_e2e_room_keys_version( ) @trace - def delete_e2e_room_keys_version(self, user_id, version=None): + async def delete_e2e_room_keys_version( + self, user_id: str, version: Optional[str] = None + ) -> None: """Delete a given backup version of the user's room keys. Doesn't delete their actual key data. Args: - user_id(str): the user whose backup version we're deleting - version(str): Optional. the version ID of the backup version we're deleting + user_id: the user whose backup version we're deleting + version: Optional. the version ID of the backup version we're deleting If missing, we delete the current backup version info. Raises: StoreError: with code 404 if there are no e2e_room_keys_versions present, @@ -430,13 +432,13 @@ def _delete_e2e_room_keys_version_txn(txn): keyvalues={"user_id": user_id, "version": this_version}, ) - return self.db_pool.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version}, updatevalues={"deleted": 1}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn ) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 6e5761c7b75a..0b69aa6a940a 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -59,7 +59,7 @@ async def get_auth_chain_ids( include_given: include the given events in result Returns: - list of event_ids + An awaitable which resolve to a list of event_ids """ return await self.db_pool.runInteraction( "get_auth_chain_ids", @@ -95,7 +95,7 @@ def _get_auth_chain_ids_txn( return list(results) - def get_auth_chain_difference(self, state_sets: List[Set[str]]): + async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]: """Given sets of state events figure out the auth chain difference (as per state res v2 algorithm). @@ -104,10 +104,10 @@ def get_auth_chain_difference(self, state_sets: List[Set[str]]): chain. Returns: - Deferred[Set[str]] + The set of the difference in auth chains. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_auth_chain_difference", self._get_auth_chain_difference_txn, state_sets, @@ -252,8 +252,8 @@ def _get_auth_chain_difference_txn( # Return all events where not all sets can reach them. return {eid for eid, n in event_to_missing_sets.items() if n} - def get_oldest_events_with_depth_in_room(self, room_id): - return self.db_pool.runInteraction( + async def get_oldest_events_with_depth_in_room(self, room_id): + return await self.db_pool.runInteraction( "get_oldest_events_with_depth_in_room", self.get_oldest_events_with_depth_in_room_txn, room_id, @@ -293,7 +293,7 @@ async def get_max_depth_of(self, event_ids: List[str]) -> int: else: return max(row["depth"] for row in rows) - def get_prev_events_for_room(self, room_id: str): + async def get_prev_events_for_room(self, room_id: str) -> List[str]: """ Gets a subset of the current forward extremities in the given room. @@ -301,14 +301,14 @@ def get_prev_events_for_room(self, room_id: str): events which refer to hundreds of prev_events. Args: - room_id (str): room_id + room_id: room_id Returns: - Deferred[List[str]]: the event ids of the forward extremites + The event ids of the forward extremities. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id ) @@ -328,17 +328,19 @@ def _get_prev_events_for_room_txn(self, txn, room_id: str): return [row[0] for row in txn] - def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter): + async def get_rooms_with_many_extremities( + self, min_count: int, limit: int, room_id_filter: Iterable[str] + ) -> List[str]: """Get the top rooms with at least N extremities. Args: - min_count (int): The minimum number of extremities - limit (int): The maximum number of rooms to return. - room_id_filter (iterable[str]): room_ids to exclude from the results + min_count: The minimum number of extremities + limit: The maximum number of rooms to return. + room_id_filter: room_ids to exclude from the results Returns: - Deferred[list]: At most `limit` room IDs that have at least - `min_count` extremities, sorted by extremity count. + At most `limit` room IDs that have at least `min_count` extremities, + sorted by extremity count. """ def _get_rooms_with_many_extremities_txn(txn): @@ -363,7 +365,7 @@ def _get_rooms_with_many_extremities_txn(txn): txn.execute(sql, query_args) return [room_id for room_id, in txn] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn ) @@ -376,10 +378,10 @@ async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]: desc="get_latest_event_ids_in_room", ) - def get_min_depth(self, room_id): - """ For hte given room, get the minimum depth we have seen for it. + async def get_min_depth(self, room_id: str) -> int: + """For the given room, get the minimum depth we have seen for it. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_min_depth", self._get_min_depth_interaction, room_id ) @@ -394,7 +396,9 @@ def _get_min_depth_interaction(self, txn, room_id): return int(min_depth) if min_depth is not None else None - def get_forward_extremeties_for_room(self, room_id, stream_ordering): + async def get_forward_extremeties_for_room( + self, room_id: str, stream_ordering: int + ) -> List[str]: """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". @@ -402,11 +406,11 @@ def get_forward_extremeties_for_room(self, room_id, stream_ordering): stream_orderings from that point. Args: - room_id (str): - stream_ordering (int): + room_id: + stream_ordering: Returns: - deferred, which resolves to a list of event_ids + A list of event_ids """ # We want to make the cache more effective, so we clamp to the last # change before the given ordering. @@ -422,10 +426,10 @@ def get_forward_extremeties_for_room(self, room_id, stream_ordering): if last_change > self.stream_ordering_month_ago: stream_ordering = min(last_change, stream_ordering) - return self._get_forward_extremeties_for_room(room_id, stream_ordering) + return await self._get_forward_extremeties_for_room(room_id, stream_ordering) @cached(max_entries=5000, num_args=2) - def _get_forward_extremeties_for_room(self, room_id, stream_ordering): + async def _get_forward_extremeties_for_room(self, room_id, stream_ordering): """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". @@ -450,19 +454,18 @@ def get_forward_extremeties_for_room_txn(txn): txn.execute(sql, (stream_ordering, room_id)) return [event_id for event_id, in txn] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) - async def get_backfill_events(self, room_id, event_list, limit): + async def get_backfill_events(self, room_id: str, event_list: list, limit: int): """Get a list of Events for a given topic that occurred before (and including) the events in event_list. Return a list of max size `limit` Args: - txn - room_id (str) - event_list (list) - limit (int) + room_id + event_list + limit """ event_ids = await self.db_pool.runInteraction( "get_backfill_events", @@ -631,8 +634,8 @@ def _delete_old_forward_extrem_cache_txn(txn): _delete_old_forward_extrem_cache_txn, ) - def clean_room_for_join(self, room_id): - return self.db_pool.runInteraction( + async def clean_room_for_join(self, room_id): + return await self.db_pool.runInteraction( "clean_room_for_join", self._clean_room_for_join_txn, room_id ) diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 6c6017188845..ccfbb2135eba 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json @@ -70,7 +70,9 @@ async def get_invited_users_in_group(self, group_id: str) -> List[str]: desc="get_invited_users_in_group", ) - def get_rooms_in_group(self, group_id: str, include_private: bool = False): + async def get_rooms_in_group( + self, group_id: str, include_private: bool = False + ) -> List[Dict[str, Union[str, bool]]]: """Retrieve the rooms that belong to a given group. Does not return rooms that lack members. @@ -79,8 +81,7 @@ def get_rooms_in_group(self, group_id: str, include_private: bool = False): include_private: Whether to return private rooms in results Returns: - Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the - form of: + A list of dictionaries, each in the form of: { "room_id": "!a_room_id:example.com", # The ID of the room @@ -117,13 +118,13 @@ def _get_rooms_in_group_txn(txn): for room_id, is_public in txn ] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_rooms_in_group", _get_rooms_in_group_txn ) - def get_rooms_for_summary_by_category( + async def get_rooms_for_summary_by_category( self, group_id: str, include_private: bool = False, - ): + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: """Get the rooms and categories that should be included in a summary request Args: @@ -131,7 +132,7 @@ def get_rooms_for_summary_by_category( include_private: Whether to return private rooms in results Returns: - Deferred[Tuple[List, Dict]]: A tuple containing: + A tuple containing: * A list of dictionaries with the keys: * "room_id": str, the room ID @@ -207,7 +208,7 @@ def _get_rooms_for_summary_txn(txn): return rooms, categories - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_rooms_for_summary", _get_rooms_for_summary_txn ) @@ -281,10 +282,11 @@ async def get_local_groups_for_room(self, room_id: str) -> List[str]: desc="get_local_groups_for_room", ) - def get_users_for_summary_by_role(self, group_id, include_private=False): + async def get_users_for_summary_by_role(self, group_id, include_private=False): """Get the users and roles that should be included in a summary request - Returns ([users], [roles]) + Returns: + ([users], [roles]) """ def _get_users_for_summary_txn(txn): @@ -338,7 +340,7 @@ def _get_users_for_summary_txn(txn): return users, roles - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_users_for_summary_by_role", _get_users_for_summary_txn ) @@ -376,7 +378,7 @@ async def is_user_invited_to_local_group( allow_none=True, ) - def get_users_membership_info_in_group(self, group_id, user_id): + async def get_users_membership_info_in_group(self, group_id, user_id): """Get a dict describing the membership of a user in a group. Example if joined: @@ -387,7 +389,8 @@ def get_users_membership_info_in_group(self, group_id, user_id): "is_privileged": False, } - Returns an empty dict if the user is not join/invite/etc + Returns: + An empty dict if the user is not join/invite/etc """ def _get_users_membership_in_group_txn(txn): @@ -419,7 +422,7 @@ def _get_users_membership_in_group_txn(txn): return {} - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_users_membership_info_in_group", _get_users_membership_in_group_txn ) @@ -433,7 +436,7 @@ async def get_publicised_groups_for_user(self, user_id: str) -> List[str]: desc="get_publicised_groups_for_user", ) - def get_attestations_need_renewals(self, valid_until_ms): + async def get_attestations_need_renewals(self, valid_until_ms): """Get all attestations that need to be renewed until givent time """ @@ -445,7 +448,7 @@ def _get_attestations_need_renewals_txn(txn): txn.execute(sql, (valid_until_ms,)) return self.db_pool.cursor_to_dict(txn) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_attestations_need_renewals", _get_attestations_need_renewals_txn ) @@ -475,7 +478,7 @@ async def get_joined_groups(self, user_id: str) -> List[str]: desc="get_joined_groups", ) - def get_all_groups_for_user(self, user_id, now_token): + async def get_all_groups_for_user(self, user_id, now_token): def _get_all_groups_for_user_txn(txn): sql = """ SELECT group_id, type, membership, u.content @@ -495,7 +498,7 @@ def _get_all_groups_for_user_txn(txn): for row in txn ] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_all_groups_for_user", _get_all_groups_for_user_txn ) @@ -600,8 +603,27 @@ async def set_group_join_policy(self, group_id: str, join_policy: str) -> None: desc="set_group_join_policy", ) - def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): - return self.db_pool.runInteraction( + async def add_room_to_summary( + self, + group_id: str, + room_id: str, + category_id: str, + order: int, + is_public: Optional[bool], + ) -> None: + """Add (or update) room's entry in summary. + + Args: + group_id + room_id + category_id: If not None then adds the category to the end of + the summary if its not already there. + order: If not None inserts the room at that position, e.g. an order + of 1 will put the room first. Otherwise, the room gets added to + the end. + is_public + """ + await self.db_pool.runInteraction( "add_room_to_summary", self._add_room_to_summary_txn, group_id, @@ -612,18 +634,26 @@ def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): ) def _add_room_to_summary_txn( - self, txn, group_id, room_id, category_id, order, is_public - ): + self, + txn, + group_id: str, + room_id: str, + category_id: str, + order: int, + is_public: Optional[bool], + ) -> None: """Add (or update) room's entry in summary. Args: - group_id (str) - room_id (str) - category_id (str): If not None then adds the category to the end of - the summary if its not already there. [Optional] - order (int): If not None inserts the room at that position, e.g. - an order of 1 will put the room first. Otherwise, the room gets - added to the end. + txn + group_id + room_id + category_id: If not None then adds the category to the end of + the summary if its not already there. + order: If not None inserts the room at that position, e.g. an order + of 1 will put the room first. Otherwise, the room gets added to + the end. + is_public """ room_in_group = self.db_pool.simple_select_one_onecol_txn( txn, @@ -818,8 +848,27 @@ async def remove_group_role(self, group_id: str, role_id: str) -> int: desc="remove_group_role", ) - def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): - return self.db_pool.runInteraction( + async def add_user_to_summary( + self, + group_id: str, + user_id: str, + role_id: str, + order: int, + is_public: Optional[bool], + ) -> None: + """Add (or update) user's entry in summary. + + Args: + group_id + user_id + role_id: If not None then adds the role to the end of the summary if + its not already there. + order: If not None inserts the user at that position, e.g. an order + of 1 will put the user first. Otherwise, the user gets added to + the end. + is_public + """ + await self.db_pool.runInteraction( "add_user_to_summary", self._add_user_to_summary_txn, group_id, @@ -830,18 +879,26 @@ def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): ) def _add_user_to_summary_txn( - self, txn, group_id, user_id, role_id, order, is_public + self, + txn, + group_id: str, + user_id: str, + role_id: str, + order: int, + is_public: Optional[bool], ): """Add (or update) user's entry in summary. Args: - group_id (str) - user_id (str) - role_id (str): If not None then adds the role to the end of - the summary if its not already there. [Optional] - order (int): If not None inserts the user at that position, e.g. - an order of 1 will put the user first. Otherwise, the user gets - added to the end. + txn + group_id + user_id + role_id: If not None then adds the role to the end of the summary if + its not already there. + order: If not None inserts the user at that position, e.g. an order + of 1 will put the user first. Otherwise, the user gets added to + the end. + is_public """ user_in_group = self.db_pool.simple_select_one_onecol_txn( txn, @@ -963,27 +1020,26 @@ async def add_group_invite(self, group_id: str, user_id: str) -> None: desc="add_group_invite", ) - def add_user_to_group( + async def add_user_to_group( self, - group_id, - user_id, - is_admin=False, - is_public=True, - local_attestation=None, - remote_attestation=None, - ): + group_id: str, + user_id: str, + is_admin: bool = False, + is_public: bool = True, + local_attestation: dict = None, + remote_attestation: dict = None, + ) -> None: """Add a user to the group server. Args: - group_id (str) - user_id (str) - is_admin (bool) - is_public (bool) - local_attestation (dict): The attestation the GS created to give - to the remote server. Optional if the user and group are on the - same server - remote_attestation (dict): The attestation given to GS by remote + group_id + user_id + is_admin + is_public + local_attestation: The attestation the GS created to give to the remote server. Optional if the user and group are on the same server + remote_attestation: The attestation given to GS by remote server. + Optional if the user and group are on the same server """ def _add_user_to_group_txn(txn): @@ -1026,9 +1082,9 @@ def _add_user_to_group_txn(txn): }, ) - return self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) + await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) - def remove_user_from_group(self, group_id, user_id): + async def remove_user_from_group(self, group_id: str, user_id: str) -> None: def _remove_user_from_group_txn(txn): self.db_pool.simple_delete_txn( txn, @@ -1056,7 +1112,7 @@ def _remove_user_from_group_txn(txn): keyvalues={"group_id": group_id, "user_id": user_id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "remove_user_from_group", _remove_user_from_group_txn ) @@ -1079,7 +1135,7 @@ async def update_room_in_group_visibility( desc="update_room_in_group_visibility", ) - def remove_room_from_group(self, group_id, room_id): + async def remove_room_from_group(self, group_id: str, room_id: str) -> None: def _remove_room_from_group_txn(txn): self.db_pool.simple_delete_txn( txn, @@ -1093,7 +1149,7 @@ def _remove_room_from_group_txn(txn): keyvalues={"group_id": group_id, "room_id": room_id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "remove_room_from_group", _remove_room_from_group_txn ) @@ -1286,14 +1342,11 @@ async def remove_attestation_renewal(self, group_id: str, user_id: str) -> int: def get_group_stream_token(self): return self._group_updates_id_gen.get_current_token() - def delete_group(self, group_id): + async def delete_group(self, group_id: str) -> None: """Deletes a group fully from the database. Args: - group_id (str) - - Returns: - Deferred + group_id: The group ID to delete. """ def _delete_group_txn(txn): @@ -1317,4 +1370,4 @@ def _delete_group_txn(txn): txn, table=table, keyvalues={"group_id": group_id} ) - return self.db_pool.runInteraction("delete_group", _delete_group_txn) + await self.db_pool.runInteraction("delete_group", _delete_group_txn) diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 1c0a049c5548..ad43bb05abb5 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -16,7 +16,7 @@ import itertools import logging -from typing import Iterable, Tuple +from typing import Dict, Iterable, List, Optional, Tuple from signedjson.key import decode_verify_key_bytes @@ -42,16 +42,17 @@ def _get_server_verify_key(self, server_name_and_key_id): @cachedList( cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids" ) - def get_server_verify_keys(self, server_name_and_key_ids): + async def get_server_verify_keys( + self, server_name_and_key_ids: Iterable[Tuple[str, str]] + ) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]: """ Args: - server_name_and_key_ids (iterable[Tuple[str, str]]): + server_name_and_key_ids: iterable of (server_name, key-id) tuples to fetch keys for Returns: - Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]: - map from (server_name, key_id) -> FetchKeyResult, or None if the key is - unknown + A map from (server_name, key_id) -> FetchKeyResult, or None if the + key is unknown """ keys = {} @@ -87,7 +88,7 @@ def _txn(txn): _get_keys(txn, batch) return keys - return self.db_pool.runInteraction("get_server_verify_keys", _txn) + return await self.db_pool.runInteraction("get_server_verify_keys", _txn) async def store_server_verify_keys( self, @@ -179,7 +180,9 @@ async def store_server_keys_json( desc="store_server_keys_json", ) - def get_server_keys_json(self, server_keys): + async def get_server_keys_json( + self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]] + ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]: """Retrive the key json for a list of server_keys and key ids. If no keys are found for a given server, key_id and source then that server, key_id, and source triplet entry will be an empty list. @@ -188,8 +191,7 @@ def get_server_keys_json(self, server_keys): Args: server_keys (list): List of (server_name, key_id, source) triplets. Returns: - Deferred[dict[Tuple[str, str, str|None], list[dict]]]: - Dict mapping (server_name, key_id, source) triplets to lists of dicts + A mapping from (server_name, key_id, source) triplets to a list of dicts """ def _get_server_keys_json_txn(txn): @@ -215,6 +217,6 @@ def _get_server_keys_json_txn(txn): results[(server_name, key_id, from_server)] = rows return results - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_server_keys_json", _get_server_keys_json_txn ) diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 2efcc0dc66f4..5b31aab700f9 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -15,6 +15,7 @@ import logging from collections import namedtuple +from typing import Optional, Tuple from canonicaljson import encode_canonical_json @@ -56,21 +57,23 @@ def __init__(self, database: DatabasePool, db_conn, hs): expiry_ms=5 * 60 * 1000, ) - def get_received_txn_response(self, transaction_id, origin): + async def get_received_txn_response( + self, transaction_id: str, origin: str + ) -> Optional[Tuple[int, JsonDict]]: """For an incoming transaction from a given origin, check if we have already responded to it. If so, return the response code and response body (as a dict). Args: - transaction_id (str) - origin(str) + transaction_id + origin Returns: - tuple: None if we have not previously responded to - this transaction or a 2-tuple of (int, dict) + None if we have not previously responded to this transaction or a + 2-tuple of (int, dict) """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_received_txn_response", self._get_received_txn_response, transaction_id, @@ -166,21 +169,25 @@ def _get_destination_retry_timings(self, txn, destination): else: return None - def set_destination_retry_timings( - self, destination, failure_ts, retry_last_ts, retry_interval - ): + async def set_destination_retry_timings( + self, + destination: str, + failure_ts: Optional[int], + retry_last_ts: int, + retry_interval: int, + ) -> None: """Sets the current retry timings for a given destination. Both timings should be zero if retrying is no longer occuring. Args: - destination (str) - failure_ts (int|None) - when the server started failing (ms since epoch) - retry_last_ts (int) - time of last retry attempt in unix epoch ms - retry_interval (int) - how long until next retry in ms + destination + failure_ts: when the server started failing (ms since epoch) + retry_last_ts: time of last retry attempt in unix epoch ms + retry_interval: how long until next retry in ms """ self._destination_retry_cache.pop(destination, None) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings, destination, @@ -256,13 +263,13 @@ def _start_cleanup_transactions(self): "cleanup_transactions", self._cleanup_transactions ) - def _cleanup_transactions(self): + async def _cleanup_transactions(self) -> None: now = self._clock.time_msec() month_ago = now - 30 * 24 * 60 * 60 * 1000 def _cleanup_transactions_txn(txn): txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_cleanup_transactions", _cleanup_transactions_txn )