diff --git a/changelog.d/368.misc b/changelog.d/368.misc new file mode 100644 index 00000000..ce3f6734 --- /dev/null +++ b/changelog.d/368.misc @@ -0,0 +1 @@ +Convert inlineCallbacks to async/await. \ No newline at end of file diff --git a/sydent/hs_federation/verifier.py b/sydent/hs_federation/verifier.py index 47ee6fde..ead33826 100644 --- a/sydent/hs_federation/verifier.py +++ b/sydent/hs_federation/verifier.py @@ -14,12 +14,11 @@ import logging import time -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import signedjson.key # type: ignore import signedjson.sign # type: ignore from signedjson.sign import SignatureVerifyException -from twisted.internet import defer from twisted.web.server import Request from unpaddedbase64 import decode_base64 @@ -63,8 +62,7 @@ def __init__(self, sydent: "Sydent") -> None: # server_name: , } - @defer.inlineCallbacks - def _getKeysForServer(self, server_name: str) -> Generator: + async def _getKeysForServer(self, server_name: str): """Get the signing key data from a homeserver. :param server_name: The name of the server to request the keys from. @@ -81,7 +79,7 @@ def _getKeysForServer(self, server_name: str) -> Generator: return self.cache[server_name]["verify_keys"] client = FederationHttpClient(self.sydent) - result = yield client.get_json( + result = await client.get_json( "matrix://%s/_matrix/key/v2/server/" % server_name, 1024 * 50 ) @@ -106,12 +104,11 @@ def _getKeysForServer(self, server_name: str) -> Generator: return result["verify_keys"] - @defer.inlineCallbacks - def verifyServerSignedJson( + async def verifyServerSignedJson( self, signed_json: Dict[str, Any], acceptable_server_names: Optional[List[str]] = None, - ) -> Generator: + ): """Given a signed json object, try to verify any one of the signatures on it @@ -137,7 +134,7 @@ def verifyServerSignedJson( if server_name not in acceptable_server_names: continue - server_keys = yield self._getKeysForServer(server_name) + server_keys = await self._getKeysForServer(server_name) for key_name, sig in sigs.items(): if key_name in server_keys: if "key" not in server_keys[key_name]: @@ -167,10 +164,7 @@ def verifyServerSignedJson( ) raise SignatureVerifyException("No matching signature found") - @defer.inlineCallbacks - def authenticate_request( - self, request: "Request", content: Optional[bytes] - ) -> Generator: + async def authenticate_request(self, request: "Request", content: Optional[bytes]): """Authenticates a Matrix federation request based on the X-Matrix header XXX: Copied largely from synapse @@ -241,7 +235,7 @@ def strip_quotes(value): "X-Matrix header's origin parameter must be a valid Matrix server name" ) - yield self.verifyServerSignedJson(json_request, [origin]) + await self.verifyServerSignedJson(json_request, [origin]) logger.info("Verified request from HS %s", origin) diff --git a/sydent/http/httpclient.py b/sydent/http/httpclient.py index d69ed5e2..dfff5f77 100644 --- a/sydent/http/httpclient.py +++ b/sydent/http/httpclient.py @@ -15,9 +15,8 @@ import json import logging from io import BytesIO -from typing import TYPE_CHECKING, Any, Dict, Generator, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional -from twisted.internet import defer from twisted.web.client import Agent, FileBodyProducer from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent @@ -41,8 +40,7 @@ class HTTPClient: agent: IAgent - @defer.inlineCallbacks - def get_json(self, uri: str, max_size: Optional[int] = None) -> Generator: + async def get_json(self, uri: str, max_size: Optional[int] = None): """Make a GET request to an endpoint returning JSON and parse result :param uri: The URI to make a GET request to. @@ -56,11 +54,11 @@ def get_json(self, uri: str, max_size: Optional[int] = None) -> Generator: """ logger.debug("HTTP GET %s", uri) - response = yield self.agent.request( + response = await self.agent.request( b"GET", uri.encode("utf8"), ) - body = yield read_body_with_max_size(response, max_size) + body = await read_body_with_max_size(response, max_size) try: # json.loads doesn't allow bytes in Python 3.5 json_body = json_decoder.decode(body.decode("UTF-8")) @@ -69,10 +67,9 @@ def get_json(self, uri: str, max_size: Optional[int] = None) -> Generator: raise return json_body - @defer.inlineCallbacks - def post_json_get_nothing( + async def post_json_get_nothing( self, uri: str, post_json: Dict[Any, Any], opts: Dict[str, Any] - ) -> Generator: + ): """Make a POST request to an endpoint returning JSON and parse result :param uri: The URI to make a POST request to. @@ -102,7 +99,7 @@ def post_json_get_nothing( logger.debug("HTTP POST %s -> %s", json_bytes, uri) - response = yield self.agent.request( + response = await self.agent.request( b"POST", uri.encode("utf8"), headers, @@ -114,7 +111,7 @@ def post_json_get_nothing( # https://twistedmatrix.com/documents/current/web/howto/client.html try: # TODO Will this cause the server to think the request was a failure? - yield read_body_with_max_size(response, 0) + await read_body_with_max_size(response, 0) except BodyExceededMaxSize: pass diff --git a/sydent/http/matrixfederationagent.py b/sydent/http/matrixfederationagent.py index 114b0e90..f9bce0a9 100644 --- a/sydent/http/matrixfederationagent.py +++ b/sydent/http/matrixfederationagent.py @@ -146,7 +146,7 @@ def request( :rtype: Deferred[twisted.web.iweb.IResponse] """ parsed_uri = URI.fromBytes(uri, defaultPort=-1) - res = yield self._route_matrix_uri(parsed_uri) + res = yield defer.ensureDeferred(self._route_matrix_uri(parsed_uri)) # set up the TLS connection params # @@ -186,10 +186,9 @@ def endpointForURI(_uri): res = yield agent.request(method, uri, headers, bodyProducer) return res - @defer.inlineCallbacks - def _route_matrix_uri( + async def _route_matrix_uri( self, parsed_uri: "URI", lookup_well_known: bool = True - ) -> Generator: + ): """Helper for `request`: determine the routing for a Matrix URI :param parsed_uri: uri to route. Note that it should be parsed with @@ -232,7 +231,7 @@ def _route_matrix_uri( if lookup_well_known: # try a .well-known lookup - well_known_server = yield self._get_well_known(parsed_uri.host) + well_known_server = await self._get_well_known(parsed_uri.host) if well_known_server: # if we found a .well-known, start again, but don't do another @@ -263,14 +262,12 @@ def _route_matrix_uri( fragment=parsed_uri.fragment, ) - res = yield self._route_matrix_uri(new_uri, lookup_well_known=False) + res = await self._route_matrix_uri(new_uri, lookup_well_known=False) return res # try a SRV lookup service_name = b"_matrix._tcp.%s" % (parsed_uri.host,) - server_list = yield defer.ensureDeferred( - self._srv_resolver.resolve_service(service_name) - ) + server_list = await self._srv_resolver.resolve_service(service_name) if not server_list: target_host = parsed_uri.host @@ -297,8 +294,7 @@ def _route_matrix_uri( target_port=port, ) - @defer.inlineCallbacks - def _get_well_known(self, server_name: bytes) -> Generator: + async def _get_well_known(self, server_name: bytes): """Attempt to fetch and parse a .well-known file for the given server :param server_name: Name of the server, from the requested url. @@ -313,15 +309,14 @@ def _get_well_known(self, server_name: bytes) -> Generator: except KeyError: # TODO: should we linearise so that we don't end up doing two .well-known # requests for the same server in parallel? - result, cache_period = yield self._do_get_well_known(server_name) + result, cache_period = await self._do_get_well_known(server_name) if cache_period > 0: self._well_known_cache.set(server_name, result, cache_period) return result - @defer.inlineCallbacks - def _do_get_well_known(self, server_name: bytes) -> Generator: + async def _do_get_well_known(self, server_name: bytes): """Actually fetch and parse a .well-known, without checking the cache :param server_name: Name of the server, from the requested url @@ -337,8 +332,8 @@ def _do_get_well_known(self, server_name: bytes) -> Generator: uri_str = uri.decode("ascii") logger.info("Fetching %s", uri_str) try: - response = yield self._well_known_agent.request(b"GET", uri) - body = yield read_body_with_max_size(response, WELL_KNOWN_MAX_SIZE) + response = await self._well_known_agent.request(b"GET", uri) + body = await read_body_with_max_size(response, WELL_KNOWN_MAX_SIZE) if response.code != 200: raise Exception("Non-200 response %s" % (response.code,)) diff --git a/sydent/http/servlets/registerservlet.py b/sydent/http/servlets/registerservlet.py index ecac8d10..c0e92578 100644 --- a/sydent/http/servlets/registerservlet.py +++ b/sydent/http/servlets/registerservlet.py @@ -57,13 +57,15 @@ def render_POST(self, request: Request) -> Generator: "error": "matrix_server_name must be a valid Matrix server name (IP address or hostname)", } - result = yield self.client.get_json( - "matrix://%s/_matrix/federation/v1/openid/userinfo?access_token=%s" - % ( - matrix_server, - urllib.parse.quote(args["access_token"]), - ), - 1024 * 5, + result = yield defer.ensureDeferred( + self.client.get_json( + "matrix://%s/_matrix/federation/v1/openid/userinfo?access_token=%s" + % ( + matrix_server, + urllib.parse.quote(args["access_token"]), + ), + 1024 * 5, + ) ) if "sub" not in result: diff --git a/sydent/threepid/bind.py b/sydent/threepid/bind.py index 7172536c..a8409116 100644 --- a/sydent/threepid/bind.py +++ b/sydent/threepid/bind.py @@ -109,7 +109,7 @@ def addBinding(self, medium: str, address: str, mxid: str) -> Dict[str, Any]: signer = Signer(self.sydent) sgassoc = signer.signedThreePidAssociation(assoc) - self._notify(sgassoc, 0) + defer.ensureDeferred(self._notify(sgassoc, 0)) return sgassoc @@ -126,8 +126,7 @@ def removeBinding(self, threepid: Dict[str, str], mxid: str) -> None: localAssocStore.removeAssociation(threepid, mxid) self.sydent.pusher.doLocalPush() - @defer.inlineCallbacks - def _notify(self, assoc: Dict[str, Any], attempt: int) -> Generator: + async def _notify(self, assoc: Dict[str, Any], attempt: int) -> Generator: """ Sends data about a new association (and, if necessary, the associated invites) to the associated MXID's homeserver. @@ -163,7 +162,7 @@ def _notify(self, assoc: Dict[str, Any], attempt: int) -> Generator: # Make a POST to the chosen Synapse server http_client = FederationHttpClient(self.sydent) try: - response = yield http_client.post_json_get_nothing(post_url, assoc, {}) + response = await http_client.post_json_get_nothing(post_url, assoc, {}) except Exception as e: self._notifyErrback(assoc, attempt, e) return diff --git a/tests/test_invites.py b/tests/test_invites.py index 1d79d960..bca426f2 100644 --- a/tests/test_invites.py +++ b/tests/test_invites.py @@ -34,7 +34,7 @@ def test_delete_on_bind(self): address = "john@example.com" # Mock post_json_get_nothing so the /onBind call doesn't fail. - def post_json_get_nothing(uri, post_json, opts): + async def post_json_get_nothing(uri, post_json, opts): return Response((b"HTTP", 1, 1), 200, b"OK", None, None) FederationHttpClient.post_json_get_nothing = Mock( @@ -119,7 +119,7 @@ def test_no_delete_on_bind(self): address = "john@example.com" # Mock post_json_get_nothing so the /onBind call doesn't fail. - def post_json_get_nothing(uri, post_json, opts): + async def post_json_get_nothing(uri, post_json, opts): return Response((b"HTTP", 1, 1), 200, b"OK", None, None) FederationHttpClient.post_json_get_nothing = Mock(