Skip to content

Commit

Permalink
Convert inlineCallbacks to async/await. (#368)
Browse files Browse the repository at this point in the history
  • Loading branch information
H-Shay authored Jun 24, 2021
1 parent c210458 commit f5a5bbf
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 54 deletions.
1 change: 1 addition & 0 deletions changelog.d/368.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert inlineCallbacks to async/await.
22 changes: 8 additions & 14 deletions sydent/hs_federation/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@

import logging
import time
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

import signedjson.key # type: ignore
import signedjson.sign # type: ignore
from signedjson.sign import SignatureVerifyException
from twisted.internet import defer
from twisted.web.server import Request
from unpaddedbase64 import decode_base64

Expand Down Expand Up @@ -63,8 +62,7 @@ def __init__(self, sydent: "Sydent") -> None:
# server_name: <result from keys query>,
}

@defer.inlineCallbacks
def _getKeysForServer(self, server_name: str) -> Generator:
async def _getKeysForServer(self, server_name: str):
"""Get the signing key data from a homeserver.
:param server_name: The name of the server to request the keys from.
Expand All @@ -80,7 +78,7 @@ def _getKeysForServer(self, server_name: str) -> Generator:
return self.cache[server_name]["verify_keys"]

client = FederationHttpClient(self.sydent)
result = yield client.get_json(
result = await client.get_json(
"matrix://%s/_matrix/key/v2/server/" % server_name, 1024 * 50
)

Expand All @@ -105,12 +103,11 @@ def _getKeysForServer(self, server_name: str) -> Generator:

return result["verify_keys"]

@defer.inlineCallbacks
def verifyServerSignedJson(
async def verifyServerSignedJson(
self,
signed_json: Dict[str, Any],
acceptable_server_names: Optional[List[str]] = None,
) -> Generator:
):
"""Given a signed json object, try to verify any one
of the signatures on it
Expand All @@ -135,7 +132,7 @@ def verifyServerSignedJson(
if server_name not in acceptable_server_names:
continue

server_keys = yield self._getKeysForServer(server_name)
server_keys = await self._getKeysForServer(server_name)
for key_name, sig in sigs.items():
if key_name in server_keys:
if "key" not in server_keys[key_name]:
Expand Down Expand Up @@ -165,10 +162,7 @@ def verifyServerSignedJson(
)
raise SignatureVerifyException("No matching signature found")

@defer.inlineCallbacks
def authenticate_request(
self, request: "Request", content: Optional[bytes]
) -> Generator:
async def authenticate_request(self, request: "Request", content: Optional[bytes]):
"""Authenticates a Matrix federation request based on the X-Matrix header
XXX: Copied largely from synapse
Expand Down Expand Up @@ -235,7 +229,7 @@ def strip_quotes(value):
"X-Matrix header's origin parameter must be a valid Matrix server name"
)

yield self.verifyServerSignedJson(json_request, [origin])
await self.verifyServerSignedJson(json_request, [origin])

logger.info("Verified request from HS %s", origin)

Expand Down
19 changes: 8 additions & 11 deletions sydent/http/httpclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
import json
import logging
from io import BytesIO
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional

from twisted.internet import defer
from twisted.web.client import Agent, FileBodyProducer
from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent
Expand All @@ -41,8 +40,7 @@ class HTTPClient:

agent: IAgent

@defer.inlineCallbacks
def get_json(self, uri: str, max_size: Optional[int] = None) -> Generator:
async def get_json(self, uri: str, max_size: Optional[int] = None):
"""Make a GET request to an endpoint returning JSON and parse result
:param uri: The URI to make a GET request to.
Expand All @@ -54,11 +52,11 @@ def get_json(self, uri: str, max_size: Optional[int] = None) -> Generator:
"""
logger.debug("HTTP GET %s", uri)

response = yield self.agent.request(
response = await self.agent.request(
b"GET",
uri.encode("utf8"),
)
body = yield read_body_with_max_size(response, max_size)
body = await read_body_with_max_size(response, max_size)
try:
# json.loads doesn't allow bytes in Python 3.5
json_body = json_decoder.decode(body.decode("UTF-8"))
Expand All @@ -67,10 +65,9 @@ def get_json(self, uri: str, max_size: Optional[int] = None) -> Generator:
raise
return json_body

@defer.inlineCallbacks
def post_json_get_nothing(
async def post_json_get_nothing(
self, uri: str, post_json: Dict[Any, Any], opts: Dict[str, Any]
) -> Generator:
):
"""Make a POST request to an endpoint returning JSON and parse result
:param uri: The URI to make a POST request to.
Expand All @@ -97,7 +94,7 @@ def post_json_get_nothing(

logger.debug("HTTP POST %s -> %s", json_bytes, uri)

response = yield self.agent.request(
response = await self.agent.request(
b"POST",
uri.encode("utf8"),
headers,
Expand All @@ -109,7 +106,7 @@ def post_json_get_nothing(
# https://twistedmatrix.com/documents/current/web/howto/client.html
try:
# TODO Will this cause the server to think the request was a failure?
yield read_body_with_max_size(response, 0)
await read_body_with_max_size(response, 0)
except BodyExceededMaxSize:
pass

Expand Down
27 changes: 11 additions & 16 deletions sydent/http/matrixfederationagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def request(
:rtype: Deferred[twisted.web.iweb.IResponse]
"""
parsed_uri = URI.fromBytes(uri, defaultPort=-1)
res = yield self._route_matrix_uri(parsed_uri)
res = yield defer.ensureDeferred(self._route_matrix_uri(parsed_uri))

# set up the TLS connection params
#
Expand Down Expand Up @@ -182,10 +182,9 @@ def endpointForURI(_uri):
res = yield agent.request(method, uri, headers, bodyProducer)
return res

@defer.inlineCallbacks
def _route_matrix_uri(
async def _route_matrix_uri(
self, parsed_uri: "URI", lookup_well_known: bool = True
) -> Generator:
):
"""Helper for `request`: determine the routing for a Matrix URI
:param parsed_uri: uri to route. Note that it should be parsed with
Expand Down Expand Up @@ -226,7 +225,7 @@ def _route_matrix_uri(

if lookup_well_known:
# try a .well-known lookup
well_known_server = yield self._get_well_known(parsed_uri.host)
well_known_server = await self._get_well_known(parsed_uri.host)

if well_known_server:
# if we found a .well-known, start again, but don't do another
Expand Down Expand Up @@ -257,14 +256,12 @@ def _route_matrix_uri(
fragment=parsed_uri.fragment,
)

res = yield self._route_matrix_uri(new_uri, lookup_well_known=False)
res = await self._route_matrix_uri(new_uri, lookup_well_known=False)
return res

# try a SRV lookup
service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
server_list = yield defer.ensureDeferred(
self._srv_resolver.resolve_service(service_name)
)
server_list = await self._srv_resolver.resolve_service(service_name)

if not server_list:
target_host = parsed_uri.host
Expand All @@ -291,8 +288,7 @@ def _route_matrix_uri(
target_port=port,
)

@defer.inlineCallbacks
def _get_well_known(self, server_name: bytes) -> Generator:
async def _get_well_known(self, server_name: bytes):
"""Attempt to fetch and parse a .well-known file for the given server
:param server_name: Name of the server, from the requested url.
Expand All @@ -306,15 +302,14 @@ def _get_well_known(self, server_name: bytes) -> Generator:
except KeyError:
# TODO: should we linearise so that we don't end up doing two .well-known
# requests for the same server in parallel?
result, cache_period = yield self._do_get_well_known(server_name)
result, cache_period = await self._do_get_well_known(server_name)

if cache_period > 0:
self._well_known_cache.set(server_name, result, cache_period)

return result

@defer.inlineCallbacks
def _do_get_well_known(self, server_name: bytes) -> Generator:
async def _do_get_well_known(self, server_name: bytes):
"""Actually fetch and parse a .well-known, without checking the cache
:param server_name: Name of the server, from the requested url
Expand All @@ -329,8 +324,8 @@ def _do_get_well_known(self, server_name: bytes) -> Generator:
uri_str = uri.decode("ascii")
logger.info("Fetching %s", uri_str)
try:
response = yield self._well_known_agent.request(b"GET", uri)
body = yield read_body_with_max_size(response, WELL_KNOWN_MAX_SIZE)
response = await self._well_known_agent.request(b"GET", uri)
body = await read_body_with_max_size(response, WELL_KNOWN_MAX_SIZE)
if response.code != 200:
raise Exception("Non-200 response %s" % (response.code,))

Expand Down
16 changes: 9 additions & 7 deletions sydent/http/servlets/registerservlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,15 @@ def render_POST(self, request: Request) -> Generator:
"error": "matrix_server_name must be a valid Matrix server name (IP address or hostname)",
}

result = yield self.client.get_json(
"matrix://%s/_matrix/federation/v1/openid/userinfo?access_token=%s"
% (
matrix_server,
urllib.parse.quote(args["access_token"]),
),
1024 * 5,
result = yield defer.ensureDeferred(
self.client.get_json(
"matrix://%s/_matrix/federation/v1/openid/userinfo?access_token=%s"
% (
matrix_server,
urllib.parse.quote(args["access_token"]),
),
1024 * 5,
)
)

if "sub" not in result:
Expand Down
7 changes: 3 additions & 4 deletions sydent/threepid/bind.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def addBinding(self, medium: str, address: str, mxid: str) -> Dict[str, Any]:
signer = Signer(self.sydent)
sgassoc = signer.signedThreePidAssociation(assoc)

self._notify(sgassoc, 0)
defer.ensureDeferred(self._notify(sgassoc, 0))

return sgassoc

Expand All @@ -120,8 +120,7 @@ def removeBinding(self, threepid: Dict[str, str], mxid: str) -> None:
localAssocStore.removeAssociation(threepid, mxid)
self.sydent.pusher.doLocalPush()

@defer.inlineCallbacks
def _notify(self, assoc: Dict[str, Any], attempt: int) -> Generator:
async def _notify(self, assoc: Dict[str, Any], attempt: int) -> Generator:
"""
Sends data about a new association (and, if necessary, the associated invites)
to the associated MXID's homeserver.
Expand Down Expand Up @@ -155,7 +154,7 @@ def _notify(self, assoc: Dict[str, Any], attempt: int) -> Generator:
# Make a POST to the chosen Synapse server
http_client = FederationHttpClient(self.sydent)
try:
response = yield http_client.post_json_get_nothing(post_url, assoc, {})
response = await http_client.post_json_get_nothing(post_url, assoc, {})
except Exception as e:
self._notifyErrback(assoc, attempt, e)
return
Expand Down
4 changes: 2 additions & 2 deletions tests/test_invites.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_delete_on_bind(self):
address = "john@example.com"

# Mock post_json_get_nothing so the /onBind call doesn't fail.
def post_json_get_nothing(uri, post_json, opts):
async def post_json_get_nothing(uri, post_json, opts):
return Response((b"HTTP", 1, 1), 200, b"OK", None, None)

FederationHttpClient.post_json_get_nothing = Mock(
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_no_delete_on_bind(self):
address = "john@example.com"

# Mock post_json_get_nothing so the /onBind call doesn't fail.
def post_json_get_nothing(uri, post_json, opts):
async def post_json_get_nothing(uri, post_json, opts):
return Response((b"HTTP", 1, 1), 200, b"OK", None, None)

FederationHttpClient.post_json_get_nothing = Mock(
Expand Down

0 comments on commit f5a5bbf

Please sign in to comment.