Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert inline callbacks to async/await #368

Merged
merged 12 commits into from
Jun 24, 2021
1 change: 1 addition & 0 deletions changelog.d/368.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert inlineCallbacks to async/await.
22 changes: 8 additions & 14 deletions sydent/hs_federation/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@

import logging
import time
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

import signedjson.key # type: ignore
import signedjson.sign # type: ignore
from signedjson.sign import SignatureVerifyException
from twisted.internet import defer
from twisted.web.server import Request
from unpaddedbase64 import decode_base64

Expand Down Expand Up @@ -63,8 +62,7 @@ def __init__(self, sydent: "Sydent") -> None:
# server_name: <result from keys query>,
}

@defer.inlineCallbacks
def _getKeysForServer(self, server_name: str) -> Generator:
async def _getKeysForServer(self, server_name: str):
"""Get the signing key data from a homeserver.

:param server_name: The name of the server to request the keys from.
Expand All @@ -81,7 +79,7 @@ def _getKeysForServer(self, server_name: str) -> Generator:
return self.cache[server_name]["verify_keys"]

client = FederationHttpClient(self.sydent)
result = yield client.get_json(
result = await client.get_json(
"matrix://%s/_matrix/key/v2/server/" % server_name, 1024 * 50
)

Expand All @@ -106,12 +104,11 @@ def _getKeysForServer(self, server_name: str) -> Generator:

return result["verify_keys"]

@defer.inlineCallbacks
def verifyServerSignedJson(
async def verifyServerSignedJson(
self,
signed_json: Dict[str, Any],
acceptable_server_names: Optional[List[str]] = None,
) -> Generator:
):
"""Given a signed json object, try to verify any one
of the signatures on it

Expand All @@ -137,7 +134,7 @@ def verifyServerSignedJson(
if server_name not in acceptable_server_names:
continue

server_keys = yield self._getKeysForServer(server_name)
server_keys = await self._getKeysForServer(server_name)
for key_name, sig in sigs.items():
if key_name in server_keys:
if "key" not in server_keys[key_name]:
Expand Down Expand Up @@ -167,10 +164,7 @@ def verifyServerSignedJson(
)
raise SignatureVerifyException("No matching signature found")

@defer.inlineCallbacks
def authenticate_request(
self, request: "Request", content: Optional[bytes]
) -> Generator:
async def authenticate_request(self, request: "Request", content: Optional[bytes]):
"""Authenticates a Matrix federation request based on the X-Matrix header
XXX: Copied largely from synapse

Expand Down Expand Up @@ -241,7 +235,7 @@ def strip_quotes(value):
"X-Matrix header's origin parameter must be a valid Matrix server name"
)

yield self.verifyServerSignedJson(json_request, [origin])
await self.verifyServerSignedJson(json_request, [origin])

logger.info("Verified request from HS %s", origin)

Expand Down
19 changes: 8 additions & 11 deletions sydent/http/httpclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
import json
import logging
from io import BytesIO
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional

from twisted.internet import defer
from twisted.web.client import Agent, FileBodyProducer
from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent
Expand All @@ -41,8 +40,7 @@ class HTTPClient:

agent: IAgent

@defer.inlineCallbacks
def get_json(self, uri: str, max_size: Optional[int] = None) -> Generator:
async def get_json(self, uri: str, max_size: Optional[int] = None):
"""Make a GET request to an endpoint returning JSON and parse result

:param uri: The URI to make a GET request to.
Expand All @@ -56,11 +54,11 @@ def get_json(self, uri: str, max_size: Optional[int] = None) -> Generator:
"""
logger.debug("HTTP GET %s", uri)

response = yield self.agent.request(
response = await self.agent.request(
b"GET",
uri.encode("utf8"),
)
body = yield read_body_with_max_size(response, max_size)
body = await read_body_with_max_size(response, max_size)
try:
# json.loads doesn't allow bytes in Python 3.5
json_body = json_decoder.decode(body.decode("UTF-8"))
Expand All @@ -69,10 +67,9 @@ def get_json(self, uri: str, max_size: Optional[int] = None) -> Generator:
raise
return json_body

@defer.inlineCallbacks
def post_json_get_nothing(
async def post_json_get_nothing(
self, uri: str, post_json: Dict[Any, Any], opts: Dict[str, Any]
) -> Generator:
):
"""Make a POST request to an endpoint returning JSON and parse result

:param uri: The URI to make a POST request to.
Expand Down Expand Up @@ -102,7 +99,7 @@ def post_json_get_nothing(

logger.debug("HTTP POST %s -> %s", json_bytes, uri)

response = yield self.agent.request(
response = await self.agent.request(
b"POST",
uri.encode("utf8"),
headers,
Expand All @@ -114,7 +111,7 @@ def post_json_get_nothing(
# https://twistedmatrix.com/documents/current/web/howto/client.html
try:
# TODO Will this cause the server to think the request was a failure?
yield read_body_with_max_size(response, 0)
await read_body_with_max_size(response, 0)
except BodyExceededMaxSize:
pass

Expand Down
33 changes: 13 additions & 20 deletions sydent/http/matrixfederationagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import attr
from netaddr import IPAddress # type: ignore
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.interfaces import IStreamClientEndpoint
from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent
Expand Down Expand Up @@ -114,8 +113,7 @@ def __init__(
# `None`: there is no (valid) .well-known here
self._well_known_cache = _well_known_cache

@defer.inlineCallbacks
def request(
async def request(
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
self,
method: bytes,
uri: bytes,
Expand Down Expand Up @@ -146,7 +144,7 @@ def request(
:rtype: Deferred[twisted.web.iweb.IResponse]
"""
parsed_uri = URI.fromBytes(uri, defaultPort=-1)
res = yield self._route_matrix_uri(parsed_uri)
res = await self._route_matrix_uri(parsed_uri)

# set up the TLS connection params
#
Expand Down Expand Up @@ -183,13 +181,12 @@ def endpointForURI(_uri):
return ep

agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool)
res = yield agent.request(method, uri, headers, bodyProducer)
res = await agent.request(method, uri, headers, bodyProducer)
return res

@defer.inlineCallbacks
def _route_matrix_uri(
async def _route_matrix_uri(
self, parsed_uri: "URI", lookup_well_known: bool = True
) -> Generator:
):
"""Helper for `request`: determine the routing for a Matrix URI

:param parsed_uri: uri to route. Note that it should be parsed with
Expand Down Expand Up @@ -232,7 +229,7 @@ def _route_matrix_uri(

if lookup_well_known:
# try a .well-known lookup
well_known_server = yield self._get_well_known(parsed_uri.host)
well_known_server = await self._get_well_known(parsed_uri.host)

if well_known_server:
# if we found a .well-known, start again, but don't do another
Expand Down Expand Up @@ -263,14 +260,12 @@ def _route_matrix_uri(
fragment=parsed_uri.fragment,
)

res = yield self._route_matrix_uri(new_uri, lookup_well_known=False)
res = await self._route_matrix_uri(new_uri, lookup_well_known=False)
return res

# try a SRV lookup
service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
server_list = yield defer.ensureDeferred(
self._srv_resolver.resolve_service(service_name)
)
server_list = await self._srv_resolver.resolve_service(service_name)

if not server_list:
target_host = parsed_uri.host
Expand All @@ -297,8 +292,7 @@ def _route_matrix_uri(
target_port=port,
)

@defer.inlineCallbacks
def _get_well_known(self, server_name: bytes) -> Generator:
async def _get_well_known(self, server_name: bytes):
"""Attempt to fetch and parse a .well-known file for the given server

:param server_name: Name of the server, from the requested url.
Expand All @@ -313,15 +307,14 @@ def _get_well_known(self, server_name: bytes) -> Generator:
except KeyError:
# TODO: should we linearise so that we don't end up doing two .well-known
# requests for the same server in parallel?
result, cache_period = yield self._do_get_well_known(server_name)
result, cache_period = await self._do_get_well_known(server_name)

if cache_period > 0:
self._well_known_cache.set(server_name, result, cache_period)

return result

@defer.inlineCallbacks
def _do_get_well_known(self, server_name: bytes) -> Generator:
async def _do_get_well_known(self, server_name: bytes):
"""Actually fetch and parse a .well-known, without checking the cache

:param server_name: Name of the server, from the requested url
Expand All @@ -337,8 +330,8 @@ def _do_get_well_known(self, server_name: bytes) -> Generator:
uri_str = uri.decode("ascii")
logger.info("Fetching %s", uri_str)
try:
response = yield self._well_known_agent.request(b"GET", uri)
body = yield read_body_with_max_size(response, WELL_KNOWN_MAX_SIZE)
response = await self._well_known_agent.request(b"GET", uri)
body = await read_body_with_max_size(response, WELL_KNOWN_MAX_SIZE)
if response.code != 200:
raise Exception("Non-200 response %s" % (response.code,))

Expand Down
16 changes: 9 additions & 7 deletions sydent/http/servlets/registerservlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,15 @@ def render_POST(self, request: Request) -> Generator:
"error": "matrix_server_name must be a valid Matrix server name (IP address or hostname)",
}

result = yield self.client.get_json(
"matrix://%s/_matrix/federation/v1/openid/userinfo?access_token=%s"
% (
matrix_server,
urllib.parse.quote(args["access_token"]),
),
1024 * 5,
result = yield defer.ensureDeferred(
self.client.get_json(
"matrix://%s/_matrix/federation/v1/openid/userinfo?access_token=%s"
% (
matrix_server,
urllib.parse.quote(args["access_token"]),
),
1024 * 5,
)
)

if "sub" not in result:
Expand Down
7 changes: 3 additions & 4 deletions sydent/threepid/bind.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def addBinding(self, medium: str, address: str, mxid: str) -> Dict[str, Any]:
signer = Signer(self.sydent)
sgassoc = signer.signedThreePidAssociation(assoc)

self._notify(sgassoc, 0)
defer.ensureDeferred(self._notify(sgassoc, 0))

return sgassoc

Expand All @@ -126,8 +126,7 @@ def removeBinding(self, threepid: Dict[str, str], mxid: str) -> None:
localAssocStore.removeAssociation(threepid, mxid)
self.sydent.pusher.doLocalPush()

@defer.inlineCallbacks
def _notify(self, assoc: Dict[str, Any], attempt: int) -> Generator:
async def _notify(self, assoc: Dict[str, Any], attempt: int) -> Generator:
"""
Sends data about a new association (and, if necessary, the associated invites)
to the associated MXID's homeserver.
Expand Down Expand Up @@ -163,7 +162,7 @@ def _notify(self, assoc: Dict[str, Any], attempt: int) -> Generator:
# Make a POST to the chosen Synapse server
http_client = FederationHttpClient(self.sydent)
try:
response = yield http_client.post_json_get_nothing(post_url, assoc, {})
response = await http_client.post_json_get_nothing(post_url, assoc, {})
except Exception as e:
self._notifyErrback(assoc, attempt, e)
return
Expand Down
4 changes: 2 additions & 2 deletions tests/test_invites.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_delete_on_bind(self):
address = "john@example.com"

# Mock post_json_get_nothing so the /onBind call doesn't fail.
def post_json_get_nothing(uri, post_json, opts):
async def post_json_get_nothing(uri, post_json, opts):
return Response((b"HTTP", 1, 1), 200, b"OK", None, None)

FederationHttpClient.post_json_get_nothing = Mock(
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_no_delete_on_bind(self):
address = "john@example.com"

# Mock post_json_get_nothing so the /onBind call doesn't fail.
def post_json_get_nothing(uri, post_json, opts):
async def post_json_get_nothing(uri, post_json, opts):
return Response((b"HTTP", 1, 1), 200, b"OK", None, None)

FederationHttpClient.post_json_get_nothing = Mock(
Expand Down