diff --git a/changelog.d/6840.misc b/changelog.d/6840.misc new file mode 100644 index 000000000000..0496f12de801 --- /dev/null +++ b/changelog.d/6840.misc @@ -0,0 +1 @@ +Port much of `synapse.handlers.federation` to async/await. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index f99d17a7de96..3a840e068bd2 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -17,7 +17,18 @@ import copy import itertools import logging -from typing import Dict, Iterable +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + Tuple, + TypeVar, +) from prometheus_client import Counter @@ -35,12 +46,14 @@ from synapse.api.room_versions import ( KNOWN_ROOM_VERSIONS, EventFormatVersions, + RoomVersion, RoomVersions, ) -from synapse.events import builder, room_version_to_event_format +from synapse.events import EventBase, builder, room_version_to_event_format from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.logging.context import make_deferred_yieldable from synapse.logging.utils import log_function +from synapse.types import JsonDict from synapse.util import unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination @@ -52,6 +65,8 @@ PDU_RETRY_TIME_MS = 1 * 60 * 1000 +T = TypeVar("T") + class InvalidResponseError(RuntimeError): """Helper for _try_destination_list: indicates that the server returned a response @@ -170,21 +185,17 @@ def claim_client_keys(self, destination, content, timeout): sent_queries_counter.labels("client_one_time_keys").inc() return self.transport_layer.claim_client_keys(destination, content, timeout) - @defer.inlineCallbacks - @log_function - def backfill(self, dest, room_id, limit, extremities): - """Requests some more historic PDUs for the given context from the + async def backfill( + self, dest: str, room_id: str, limit: int, extremities: Iterable[str] + ) -> List[EventBase]: + """Requests some more historic PDUs for the given room from the given destination server. Args: dest (str): The remote homeserver to ask. room_id (str): The room_id to backfill. - limit (int): The maximum number of PDUs to return. - extremities (list): List of PDU id and origins of the first pdus - we have seen from the context - - Returns: - Deferred: Results in the received PDUs. + limit (int): The maximum number of events to return. + extremities (list): our current backwards extremities, to backfill from """ logger.debug("backfill extrem=%s", extremities) @@ -192,13 +203,13 @@ def backfill(self, dest, room_id, limit, extremities): if not extremities: return - transaction_data = yield self.transport_layer.backfill( + transaction_data = await self.transport_layer.backfill( dest, room_id, extremities, limit ) logger.debug("backfill transaction_data=%r", transaction_data) - room_version = yield self.store.get_room_version_id(room_id) + room_version = await self.store.get_room_version_id(room_id) format_ver = room_version_to_event_format(room_version) pdus = [ @@ -207,7 +218,7 @@ def backfill(self, dest, room_id, limit, extremities): ] # FIXME: We should handle signature failures more gracefully. - pdus[:] = yield make_deferred_yieldable( + pdus[:] = await make_deferred_yieldable( defer.gatherResults( self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True ).addErrback(unwrapFirstError) @@ -215,11 +226,14 @@ def backfill(self, dest, room_id, limit, extremities): return pdus - @defer.inlineCallbacks - @log_function - def get_pdu( - self, destinations, event_id, room_version, outlier=False, timeout=None - ): + async def get_pdu( + self, + destinations: Iterable[str], + event_id: str, + room_version: str, + outlier: bool = False, + timeout: Optional[int] = None, + ) -> Optional[EventBase]: """Requests the PDU with given origin and ID from the remote home servers. @@ -227,18 +241,17 @@ def get_pdu( one succeeds. Args: - destinations (list): Which homeservers to query - event_id (str): event to fetch - room_version (str): version of the room - outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if + destinations: Which homeservers to query + event_id: event to fetch + room_version: version of the room + outlier: Indicates whether the PDU is an `outlier`, i.e. if it's from an arbitary point in the context as opposed to part of the current block of PDUs. Defaults to `False` - timeout (int): How long to try (in ms) each destination for before + timeout: How long to try (in ms) each destination for before moving to the next destination. None indicates no timeout. Returns: - Deferred: Results in the requested PDU, or None if we were unable to find - it. + The requested PDU, or None if we were unable to find it. """ # TODO: Rate limit the number of times we try and get the same event. @@ -259,7 +272,7 @@ def get_pdu( continue try: - transaction_data = yield self.transport_layer.get_event( + transaction_data = await self.transport_layer.get_event( destination, event_id, timeout=timeout ) @@ -279,7 +292,7 @@ def get_pdu( pdu = pdu_list[0] # Check signatures are correct. - signed_pdu = yield self._check_sigs_and_hash(room_version, pdu) + signed_pdu = await self._check_sigs_and_hash(room_version, pdu) break @@ -309,15 +322,16 @@ def get_pdu( return signed_pdu - @defer.inlineCallbacks - def get_room_state_ids(self, destination: str, room_id: str, event_id: str): + async def get_room_state_ids( + self, destination: str, room_id: str, event_id: str + ) -> Tuple[List[str], List[str]]: """Calls the /state_ids endpoint to fetch the state at a particular point in the room, and the auth events for the given event Returns: - Tuple[List[str], List[str]]: a tuple of (state event_ids, auth event_ids) + a tuple of (state event_ids, auth event_ids) """ - result = yield self.transport_layer.get_room_state_ids( + result = await self.transport_layer.get_room_state_ids( destination, room_id, event_id=event_id ) @@ -331,19 +345,17 @@ def get_room_state_ids(self, destination: str, room_id: str, event_id: str): return state_event_ids, auth_event_ids - @defer.inlineCallbacks - @log_function - def get_event_auth(self, destination, room_id, event_id): - res = yield self.transport_layer.get_event_auth(destination, room_id, event_id) + async def get_event_auth(self, destination, room_id, event_id): + res = await self.transport_layer.get_event_auth(destination, room_id, event_id) - room_version = yield self.store.get_room_version_id(room_id) + room_version = await self.store.get_room_version_id(room_id) format_ver = room_version_to_event_format(room_version) auth_chain = [ event_from_pdu_json(p, format_ver, outlier=True) for p in res["auth_chain"] ] - signed_auth = yield self._check_sigs_and_hash_and_fetch( + signed_auth = await self._check_sigs_and_hash_and_fetch( destination, auth_chain, outlier=True, room_version=room_version ) @@ -351,17 +363,21 @@ def get_event_auth(self, destination, room_id, event_id): return signed_auth - @defer.inlineCallbacks - def _try_destination_list(self, description, destinations, callback): + async def _try_destination_list( + self, + description: str, + destinations: Iterable[str], + callback: Callable[[str], Awaitable[T]], + ) -> T: """Try an operation on a series of servers, until it succeeds Args: - description (unicode): description of the operation we're doing, for logging + description: description of the operation we're doing, for logging - destinations (Iterable[unicode]): list of server_names to try + destinations: list of server_names to try - callback (callable): Function to run for each server. Passed a single - argument: the server_name to try. May return a deferred. + callback: Function to run for each server. Passed a single + argument: the server_name to try. If the callback raises a CodeMessageException with a 300/400 code, attempts to perform the operation stop immediately and the exception is @@ -372,7 +388,7 @@ def _try_destination_list(self, description, destinations, callback): suppressed if the exception is an InvalidResponseError. Returns: - The [Deferred] result of callback, if it succeeds + The result of callback, if it succeeds Raises: SynapseError if the chosen remote server returns a 300/400 code, or @@ -383,7 +399,7 @@ def _try_destination_list(self, description, destinations, callback): continue try: - res = yield callback(destination) + res = await callback(destination) return res except InvalidResponseError as e: logger.warning("Failed to %s via %s: %s", description, destination, e) @@ -402,12 +418,12 @@ def _try_destination_list(self, description, destinations, callback): ) except Exception: logger.warning( - "Failed to %s via %s", description, destination, exc_info=1 + "Failed to %s via %s", description, destination, exc_info=True ) raise SynapseError(502, "Failed to %s via any server" % (description,)) - def make_membership_event( + async def make_membership_event( self, destinations: Iterable[str], room_id: str, @@ -415,7 +431,7 @@ def make_membership_event( membership: str, content: dict, params: Dict[str, str], - ): + ) -> Tuple[str, EventBase, RoomVersion]: """ Creates an m.room.member event, with context, without participating in the room. @@ -436,19 +452,19 @@ def make_membership_event( content: Any additional data to put into the content field of the event. params: Query parameters to include in the request. - Return: - Deferred[Tuple[str, FrozenEvent, RoomVersion]]: resolves to a tuple of + + Returns: `(origin, event, room_version)` where origin is the remote homeserver which generated the event, and room_version is the version of the room. - Fails with a `UnsupportedRoomVersionError` if remote responds with - a room version we don't understand. + Raises: + UnsupportedRoomVersionError: if remote responds with + a room version we don't understand. - Fails with a ``SynapseError`` if the chosen remote server - returns a 300/400 code. + SynapseError: if the chosen remote server returns a 300/400 code. - Fails with a ``RuntimeError`` if no servers were reachable. + RuntimeError: if no servers were reachable. """ valid_memberships = {Membership.JOIN, Membership.LEAVE} if membership not in valid_memberships: @@ -457,9 +473,8 @@ def make_membership_event( % (membership, ",".join(valid_memberships)) ) - @defer.inlineCallbacks - def send_request(destination): - ret = yield self.transport_layer.make_membership_event( + async def send_request(destination: str) -> Tuple[str, EventBase, RoomVersion]: + ret = await self.transport_layer.make_membership_event( destination, room_id, user_id, membership, params ) @@ -492,33 +507,35 @@ def send_request(destination): event_dict=pdu_dict, ) - return (destination, ev, room_version) + return destination, ev, room_version - return self._try_destination_list( + return await self._try_destination_list( "make_" + membership, destinations, send_request ) - def send_join(self, destinations, pdu, event_format_version): + async def send_join( + self, destinations: Iterable[str], pdu: EventBase, event_format_version: int + ) -> Dict[str, Any]: """Sends a join event to one of a list of homeservers. Doing so will cause the remote server to add the event to the graph, and send the event out to the rest of the federation. Args: - destinations (str): Candidate homeservers which are probably + destinations: Candidate homeservers which are probably participating in the room. - pdu (BaseEvent): event to be sent - event_format_version (int): The event format version + pdu: event to be sent + event_format_version: The event format version - Return: - Deferred: resolves to a dict with members ``origin`` (a string - giving the serer the event was sent to, ``state`` (?) and + Returns: + a dict with members ``origin`` (a string + giving the server the event was sent to, ``state`` (?) and ``auth_chain``. - Fails with a ``SynapseError`` if the chosen remote server - returns a 300/400 code. + Raises: + SynapseError: if the chosen remote server returns a 300/400 code. - Fails with a ``RuntimeError`` if no servers were reachable. + RuntimeError: if no servers were reachable. """ def check_authchain_validity(signed_auth_chain): @@ -538,9 +555,8 @@ def check_authchain_validity(signed_auth_chain): "room appears to have unsupported version %s" % (room_version,) ) - @defer.inlineCallbacks - def send_request(destination): - content = yield self._do_send_join(destination, pdu) + async def send_request(destination) -> Dict[str, Any]: + content = await self._do_send_join(destination, pdu) logger.debug("Got content: %s", content) @@ -569,7 +585,7 @@ def send_request(destination): # invalid, and it would fail auth checks anyway. raise SynapseError(400, "No create event in state") - valid_pdus = yield self._check_sigs_and_hash_and_fetch( + valid_pdus = await self._check_sigs_and_hash_and_fetch( destination, list(pdus.values()), outlier=True, @@ -605,14 +621,13 @@ def send_request(destination): "origin": destination, } - return self._try_destination_list("send_join", destinations, send_request) + return await self._try_destination_list("send_join", destinations, send_request) - @defer.inlineCallbacks - def _do_send_join(self, destination, pdu): + async def _do_send_join(self, destination: str, pdu: EventBase): time_now = self._clock.time_msec() try: - content = yield self.transport_layer.send_join_v2( + content = await self.transport_layer.send_join_v2( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, @@ -634,7 +649,7 @@ def _do_send_join(self, destination, pdu): logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API") - resp = yield self.transport_layer.send_join_v1( + resp = await self.transport_layer.send_join_v1( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, @@ -645,45 +660,42 @@ def _do_send_join(self, destination, pdu): # content. return resp[1] - @defer.inlineCallbacks - def send_invite(self, destination, room_id, event_id, pdu): - room_version = yield self.store.get_room_version_id(room_id) + async def send_invite( + self, destination: str, room_id: str, event_id: str, pdu: EventBase, + ) -> EventBase: + room_version = await self.store.get_room_version_id(room_id) - content = yield self._do_send_invite(destination, pdu, room_version) + content = await self._do_send_invite(destination, pdu, room_version) pdu_dict = content["event"] logger.debug("Got response to send_invite: %s", pdu_dict) - room_version = yield self.store.get_room_version_id(room_id) + room_version = await self.store.get_room_version_id(room_id) format_ver = room_version_to_event_format(room_version) pdu = event_from_pdu_json(pdu_dict, format_ver) # Check signatures are correct. - pdu = yield self._check_sigs_and_hash(room_version, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) # FIXME: We should handle signature failures more gracefully. return pdu - @defer.inlineCallbacks - def _do_send_invite(self, destination, pdu, room_version): + async def _do_send_invite( + self, destination: str, pdu: EventBase, room_version: str + ) -> JsonDict: """Actually sends the invite, first trying v2 API and falling back to v1 API if necessary. - Args: - destination (str): Target server - pdu (FrozenEvent) - room_version (str) - Returns: - dict: The event as a dict as returned by the remote server + The event as a dict as returned by the remote server """ time_now = self._clock.time_msec() try: - content = yield self.transport_layer.send_invite_v2( + content = await self.transport_layer.send_invite_v2( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, @@ -722,7 +734,7 @@ def _do_send_invite(self, destination, pdu, room_version): # Didn't work, try v1 API. # Note the v1 API returns a tuple of `(200, content)` - _, content = yield self.transport_layer.send_invite_v1( + _, content = await self.transport_layer.send_invite_v1( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, @@ -730,7 +742,7 @@ def _do_send_invite(self, destination, pdu, room_version): ) return content - def send_leave(self, destinations, pdu): + async def send_leave(self, destinations: Iterable[str], pdu: EventBase) -> None: """Sends a leave event to one of a list of homeservers. Doing so will cause the remote server to add the event to the graph, @@ -739,34 +751,29 @@ def send_leave(self, destinations, pdu): This is mostly useful to reject received invites. Args: - destinations (str): Candidate homeservers which are probably + destinations: Candidate homeservers which are probably participating in the room. - pdu (BaseEvent): event to be sent + pdu: event to be sent - Return: - Deferred: resolves to None. - - Fails with a ``SynapseError`` if the chosen remote server - returns a 300/400 code. + Raises: + SynapseError if the chosen remote server returns a 300/400 code. - Fails with a ``RuntimeError`` if no servers were reachable. + RuntimeError if no servers were reachable. """ - @defer.inlineCallbacks - def send_request(destination): - content = yield self._do_send_leave(destination, pdu) - + async def send_request(destination: str) -> None: + content = await self._do_send_leave(destination, pdu) logger.debug("Got content: %s", content) - return None - return self._try_destination_list("send_leave", destinations, send_request) + return await self._try_destination_list( + "send_leave", destinations, send_request + ) - @defer.inlineCallbacks - def _do_send_leave(self, destination, pdu): + async def _do_send_leave(self, destination, pdu): time_now = self._clock.time_msec() try: - content = yield self.transport_layer.send_leave_v2( + content = await self.transport_layer.send_leave_v2( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, @@ -788,7 +795,7 @@ def _do_send_leave(self, destination, pdu): logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API") - resp = yield self.transport_layer.send_leave_v1( + resp = await self.transport_layer.send_leave_v1( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, @@ -820,34 +827,33 @@ def get_public_rooms( third_party_instance_id=third_party_instance_id, ) - @defer.inlineCallbacks - def get_missing_events( + async def get_missing_events( self, - destination, - room_id, - earliest_events_ids, - latest_events, - limit, - min_depth, - timeout, - ): + destination: str, + room_id: str, + earliest_events_ids: Sequence[str], + latest_events: Iterable[EventBase], + limit: int, + min_depth: int, + timeout: int, + ) -> List[EventBase]: """Tries to fetch events we are missing. This is called when we receive an event without having received all of its ancestors. Args: - destination (str) - room_id (str) - earliest_events_ids (list): List of event ids. Effectively the + destination + room_id + earliest_events_ids: List of event ids. Effectively the events we expected to receive, but haven't. `get_missing_events` should only return events that didn't happen before these. - latest_events (list): List of events we have received that we don't + latest_events: List of events we have received that we don't have all previous events for. - limit (int): Maximum number of events to return. - min_depth (int): Minimum depth of events tor return. - timeout (int): Max time to wait in ms + limit: Maximum number of events to return. + min_depth: Minimum depth of events to return. + timeout: Max time to wait in ms """ try: - content = yield self.transport_layer.get_missing_events( + content = await self.transport_layer.get_missing_events( destination=destination, room_id=room_id, earliest_events=earliest_events_ids, @@ -857,14 +863,14 @@ def get_missing_events( timeout=timeout, ) - room_version = yield self.store.get_room_version_id(room_id) + room_version = await self.store.get_room_version_id(room_id) format_ver = room_version_to_event_format(room_version) events = [ event_from_pdu_json(e, format_ver) for e in content.get("events", []) ] - signed_events = yield self._check_sigs_and_hash_and_fetch( + signed_events = await self._check_sigs_and_hash_and_fetch( destination, events, outlier=False, room_version=room_version ) except HttpResponseException as e: