From ca7c38257ce26603bc825a7e6faefc7ef22b1a50 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 30 Dec 2020 11:20:45 -0500 Subject: [PATCH 1/4] Add more type hints to FederationClient. --- synapse/federation/federation_client.py | 88 ++++++++++++++----------- 1 file changed, 50 insertions(+), 38 deletions(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 302b2f69bcdd..6851a98a5ad6 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -18,6 +18,7 @@ import itertools import logging from typing import ( + TYPE_CHECKING, Any, Awaitable, Callable, @@ -26,7 +27,6 @@ List, Mapping, Optional, - Sequence, Tuple, TypeVar, Union, @@ -61,6 +61,9 @@ from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"]) @@ -80,10 +83,10 @@ class InvalidResponseError(RuntimeError): class FederationClient(FederationBase): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.pdu_destination_tried = {} + self.pdu_destination_tried = {} # type: Dict[str, Dict[str, int]] self._clock.looping_call(self._clear_tried_cache, 60 * 1000) self.state = hs.get_state_handler() self.transport_layer = hs.get_federation_transport_client() @@ -118,26 +121,26 @@ def _clear_tried_cache(self): @log_function def make_query( self, - destination, - query_type, - args, - retry_on_dns_fail=False, - ignore_backoff=False, - ): + destination: str, + query_type: str, + args: dict, + retry_on_dns_fail: bool = False, + ignore_backoff: bool = False, + ) -> Awaitable[JsonDict]: """Sends a federation Query to a remote homeserver of the given type and arguments. Args: - destination (str): Domain name of the remote homeserver - query_type (str): Category of the query type; should match the + destination: Domain name of the remote homeserver + query_type: Category of the query type; should match the handler name used in register_query_handler(). - args (dict): Mapping of strings to strings containing the details + args: Mapping of strings to strings containing the details of the query request. - ignore_backoff (bool): true to ignore the historical backoff data + ignore_backoff: true to ignore the historical backoff data and try the request anyway. Returns: - a Awaitable which will eventually yield a JSON object from the + An Awaitable which will eventually yield a JSON object from the response """ sent_queries_counter.labels(query_type).inc() @@ -151,22 +154,26 @@ def make_query( ) @log_function - def query_client_keys(self, destination, content, timeout): + def query_client_keys( + self, destination: str, content: JsonDict, timeout: int + ) -> Awaitable[JsonDict]: """Query device keys for a device hosted on a remote server. Args: - destination (str): Domain name of the remote homeserver - content (dict): The query content. + destination: Domain name of the remote homeserver + content: The query content. Returns: - an Awaitable which will eventually yield a JSON object from the + An Awaitable which will eventually yield a JSON object from the response """ sent_queries_counter.labels("client_device_keys").inc() return self.transport_layer.query_client_keys(destination, content, timeout) @log_function - def query_user_devices(self, destination, user_id, timeout=30000): + def query_user_devices( + self, destination: str, user_id: str, timeout: int = 30000 + ) -> Awaitable[JsonDict]: """Query the device keys for a list of user ids hosted on a remote server. """ @@ -174,15 +181,17 @@ def query_user_devices(self, destination, user_id, timeout=30000): return self.transport_layer.query_user_devices(destination, user_id, timeout) @log_function - def claim_client_keys(self, destination, content, timeout): + def claim_client_keys( + self, destination: str, content: JsonDict, timeout: int + ) -> Awaitable[JsonDict]: """Claims one-time keys for a device hosted on a remote server. Args: - destination (str): Domain name of the remote homeserver - content (dict): The query content. + destination: Domain name of the remote homeserver + content: The query content. Returns: - an Awaitable which will eventually yield a JSON object from the + An Awaitable which will eventually yield a JSON object from the response """ sent_queries_counter.labels("client_one_time_keys").inc() @@ -195,10 +204,10 @@ async def backfill( given destination server. Args: - dest (str): The remote homeserver to ask. - room_id (str): The room_id to backfill. - limit (int): The maximum number of events to return. - extremities (list): our current backwards extremities, to backfill from + dest: The remote homeserver to ask. + room_id: The room_id to backfill. + limit: The maximum number of events to return. + extremities: our current backwards extremities, to backfill from """ logger.debug("backfill extrem=%s", extremities) @@ -370,7 +379,7 @@ async def _check_sigs_and_hash_and_fetch( for events that have failed their checks Returns: - Deferred : A list of PDUs that have valid signatures and hashes. + A list of PDUs that have valid signatures and hashes. """ deferreds = self._check_sigs_and_hashes(room_version, pdus) @@ -418,7 +427,9 @@ async def handle_check_result(pdu: EventBase, deferred: Deferred): else: return [p for p in valid_pdus if p] - async def get_event_auth(self, destination, room_id, event_id): + async def get_event_auth( + self, destination: str, room_id: str, event_id: str + ) -> List[EventBase]: res = await self.transport_layer.get_event_auth(destination, room_id, event_id) room_version = await self.store.get_room_version(room_id) @@ -700,7 +711,7 @@ async def send_request(destination) -> Dict[str, Any]: return await self._try_destination_list("send_join", destinations, send_request) - async def _do_send_join(self, destination: str, pdu: EventBase): + async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict: time_now = self._clock.time_msec() try: @@ -842,7 +853,7 @@ async def send_request(destination: str) -> None: "send_leave", destinations, send_request ) - async def _do_send_leave(self, destination, pdu): + async def _do_send_leave(self, destination: str, pdu: EventBase) -> JsonDict: time_now = self._clock.time_msec() try: @@ -887,7 +898,7 @@ def get_public_rooms( search_filter: Optional[Dict] = None, include_all_networks: bool = False, third_party_instance_id: Optional[str] = None, - ): + ) -> Awaitable[JsonDict]: """Get the list of public rooms from a remote homeserver Args: @@ -901,8 +912,7 @@ def get_public_rooms( party instance Returns: - Awaitable[Dict[str, Any]]: The response from the remote server, or None if - `remote_server` is the same as the local server_name + The response from the remote server. Raises: HttpResponseException: There was an exception returned from the remote server @@ -923,7 +933,7 @@ async def get_missing_events( self, destination: str, room_id: str, - earliest_events_ids: Sequence[str], + earliest_events_ids: Iterable[str], latest_events: Iterable[EventBase], limit: int, min_depth: int, @@ -974,7 +984,9 @@ async def get_missing_events( return signed_events - async def forward_third_party_invite(self, destinations, room_id, event_dict): + async def forward_third_party_invite( + self, destinations: Iterable[str], room_id: str, event_dict: JsonDict + ) -> None: for destination in destinations: if destination == self.server_name: continue @@ -983,7 +995,7 @@ async def forward_third_party_invite(self, destinations, room_id, event_dict): await self.transport_layer.exchange_third_party_invite( destination=destination, room_id=room_id, event_dict=event_dict ) - return None + return except CodeMessageException: raise except Exception as e: @@ -995,7 +1007,7 @@ async def forward_third_party_invite(self, destinations, room_id, event_dict): async def get_room_complexity( self, destination: str, room_id: str - ) -> Optional[dict]: + ) -> Optional[JsonDict]: """ Fetch the complexity of a remote room from another server. From b8f2522e9f813a3d0311b1e4b1450b1b73508e50 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 30 Dec 2020 11:23:36 -0500 Subject: [PATCH 2/4] Use async functions instead of returning awaitables. --- synapse/federation/federation_client.py | 45 +++++++++++++------------ 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 6851a98a5ad6..891f1d9919c4 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -119,14 +119,14 @@ def _clear_tried_cache(self): self.pdu_destination_tried[event_id] = destination_dict @log_function - def make_query( + async def make_query( self, destination: str, query_type: str, args: dict, retry_on_dns_fail: bool = False, ignore_backoff: bool = False, - ) -> Awaitable[JsonDict]: + ) -> JsonDict: """Sends a federation Query to a remote homeserver of the given type and arguments. @@ -140,12 +140,11 @@ def make_query( and try the request anyway. Returns: - An Awaitable which will eventually yield a JSON object from the - response + The JSON object from the response """ sent_queries_counter.labels(query_type).inc() - return self.transport_layer.make_query( + return await self.transport_layer.make_query( destination, query_type, args, @@ -154,9 +153,9 @@ def make_query( ) @log_function - def query_client_keys( + async def query_client_keys( self, destination: str, content: JsonDict, timeout: int - ) -> Awaitable[JsonDict]: + ) -> JsonDict: """Query device keys for a device hosted on a remote server. Args: @@ -164,26 +163,29 @@ def query_client_keys( content: The query content. Returns: - An Awaitable which will eventually yield a JSON object from the - response + The JSON object from the response """ sent_queries_counter.labels("client_device_keys").inc() - return self.transport_layer.query_client_keys(destination, content, timeout) + return await self.transport_layer.query_client_keys( + destination, content, timeout + ) @log_function - def query_user_devices( + async def query_user_devices( self, destination: str, user_id: str, timeout: int = 30000 - ) -> Awaitable[JsonDict]: + ) -> JsonDict: """Query the device keys for a list of user ids hosted on a remote server. """ sent_queries_counter.labels("user_devices").inc() - return self.transport_layer.query_user_devices(destination, user_id, timeout) + return await self.transport_layer.query_user_devices( + destination, user_id, timeout + ) @log_function - def claim_client_keys( + async def claim_client_keys( self, destination: str, content: JsonDict, timeout: int - ) -> Awaitable[JsonDict]: + ) -> JsonDict: """Claims one-time keys for a device hosted on a remote server. Args: @@ -191,11 +193,12 @@ def claim_client_keys( content: The query content. Returns: - An Awaitable which will eventually yield a JSON object from the - response + The JSON object from the response """ sent_queries_counter.labels("client_one_time_keys").inc() - return self.transport_layer.claim_client_keys(destination, content, timeout) + return await self.transport_layer.claim_client_keys( + destination, content, timeout + ) async def backfill( self, dest: str, room_id: str, limit: int, extremities: Iterable[str] @@ -890,7 +893,7 @@ async def _do_send_leave(self, destination: str, pdu: EventBase) -> JsonDict: # content. return resp[1] - def get_public_rooms( + async def get_public_rooms( self, remote_server: str, limit: Optional[int] = None, @@ -898,7 +901,7 @@ def get_public_rooms( search_filter: Optional[Dict] = None, include_all_networks: bool = False, third_party_instance_id: Optional[str] = None, - ) -> Awaitable[JsonDict]: + ) -> JsonDict: """Get the list of public rooms from a remote homeserver Args: @@ -920,7 +923,7 @@ def get_public_rooms( requests over federation """ - return self.transport_layer.get_public_rooms( + return await self.transport_layer.get_public_rooms( remote_server, limit, since_token, From 82230e163ddcf8f426df00ec600e6ab887eac28c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 30 Dec 2020 11:59:09 -0500 Subject: [PATCH 3/4] Remove some unnecessary variables. --- synapse/federation/federation_client.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 891f1d9919c4..d330ae5dbc57 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -718,14 +718,12 @@ async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict: time_now = self._clock.time_msec() try: - content = await self.transport_layer.send_join_v2( + return await self.transport_layer.send_join_v2( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, content=pdu.get_pdu_json(time_now), ) - - return content except HttpResponseException as e: if e.code in [400, 404]: err = e.to_synapse_error() @@ -783,7 +781,7 @@ async def _do_send_invite( time_now = self._clock.time_msec() try: - content = await self.transport_layer.send_invite_v2( + return await self.transport_layer.send_invite_v2( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, @@ -793,7 +791,6 @@ async def _do_send_invite( "invite_room_state": pdu.unsigned.get("invite_room_state", []), }, ) - return content except HttpResponseException as e: if e.code in [400, 404]: err = e.to_synapse_error() @@ -860,14 +857,12 @@ async def _do_send_leave(self, destination: str, pdu: EventBase) -> JsonDict: time_now = self._clock.time_msec() try: - content = await self.transport_layer.send_leave_v2( + return await self.transport_layer.send_leave_v2( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, content=pdu.get_pdu_json(time_now), ) - - return content except HttpResponseException as e: if e.code in [400, 404]: err = e.to_synapse_error() @@ -1023,10 +1018,9 @@ async def get_room_complexity( could not fetch the complexity. """ try: - complexity = await self.transport_layer.get_room_complexity( + return await self.transport_layer.get_room_complexity( destination=destination, room_id=room_id ) - return complexity except CodeMessageException as e: # We didn't manage to get it -- probably a 404. We are okay if other # servers don't give it to us. From ec1310984d9ac2ad266ecdce46a2f8d86d55b916 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 15 Jan 2021 14:10:42 -0500 Subject: [PATCH 4/4] Newsfragment. --- changelog.d/9129.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/9129.misc diff --git a/changelog.d/9129.misc b/changelog.d/9129.misc new file mode 100644 index 000000000000..7800be3e7ed6 --- /dev/null +++ b/changelog.d/9129.misc @@ -0,0 +1 @@ +Various improvements to the federation client.