Skip to content

Commit

Permalink
Convert sydent.http.srvresolver and associated modules to async/awai…
Browse files Browse the repository at this point in the history
…t. (#365)
  • Loading branch information
H-Shay authored Jun 22, 2021
1 parent 6626623 commit 712b3c6
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 44 deletions.
1 change: 1 addition & 0 deletions changelog.d/365.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert inlineCallbacks to async/await.
4 changes: 3 additions & 1 deletion sydent/http/matrixfederationagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,9 @@ def _route_matrix_uri(

# try a SRV lookup
service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
server_list = yield self._srv_resolver.resolve_service(service_name)
server_list = yield defer.ensureDeferred(
self._srv_resolver.resolve_service(service_name)
)

if not server_list:
target_host = parsed_uri.host
Expand Down
8 changes: 3 additions & 5 deletions sydent/http/srvresolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@
import logging
import random
import time
from typing import Callable, Dict, Generator, List, SupportsInt, Tuple
from typing import Callable, Dict, List, SupportsInt, Tuple

import attr
from twisted.internet import defer
from twisted.internet.error import ConnectError
from twisted.internet.interfaces import IResolver
from twisted.names import client, dns
Expand Down Expand Up @@ -107,8 +106,7 @@ def __init__(
self._cache = cache
self._get_time = get_time

@defer.inlineCallbacks
def resolve_service(self, service_name: bytes) -> "Generator":
async def resolve_service(self, service_name: bytes) -> List["Server"]:
"""Look up a SRV record
:param service_name: The record to look up.
Expand All @@ -129,7 +127,7 @@ def resolve_service(self, service_name: bytes) -> "Generator":
return servers

try:
answers, _, _ = yield self._dns_client.lookupService(service_name)
answers, _, _ = await self._dns_client.lookupService(service_name)
except DNSNameError:
# TODO: cache this. We can get the SOA out of the exception, and use
# the negative-TTL value.
Expand Down
72 changes: 34 additions & 38 deletions tests/test_blacklisting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from unittest.mock import patch

from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from twisted.test.proto_helpers import StringTransport
from twisted.trial.unittest import TestCase
from twisted.web.client import Agent

from sydent.http.blacklisting_reactor import BlacklistingReactorWrapper
from sydent.http.srvresolver import Server
from tests.utils import make_request, make_sydent
from tests.utils import AsyncMock, make_request, make_sydent


class BlacklistingAgentTest(TestCase):
Expand Down Expand Up @@ -92,7 +90,9 @@ def test_reactor(self):
_bindAddress,
) = self.reactor.tcpClients.pop()

@patch("sydent.http.srvresolver.SrvResolver.resolve_service")
@patch(
"sydent.http.srvresolver.SrvResolver.resolve_service", new_callable=AsyncMock
)
def test_federation_client_allowed_ip(self, resolver):
self.sydent.run()

Expand All @@ -108,17 +108,15 @@ def test_federation_client_allowed_ip(self, resolver):
},
)

resolver.return_value = defer.succeed(
[
Server(
host=self.allowed_domain,
port=443,
priority=1,
weight=1,
expires=100,
)
]
)
resolver.return_value = [
Server(
host=self.allowed_domain,
port=443,
priority=1,
weight=1,
expires=100,
)
]

request.render(self.sydent.servlets.registerServlet)

Expand All @@ -144,7 +142,9 @@ def test_federation_client_allowed_ip(self, resolver):

self.assertEqual(channel.code, 200)

@patch("sydent.http.srvresolver.SrvResolver.resolve_service")
@patch(
"sydent.http.srvresolver.SrvResolver.resolve_service", new_callable=AsyncMock
)
def test_federation_client_safe_ip(self, resolver):
self.sydent.run()

Expand All @@ -160,17 +160,15 @@ def test_federation_client_safe_ip(self, resolver):
},
)

resolver.return_value = defer.succeed(
[
Server(
host=self.safe_domain,
port=443,
priority=1,
weight=1,
expires=100,
)
]
)
resolver.return_value = [
Server(
host=self.safe_domain,
port=443,
priority=1,
weight=1,
expires=100,
)
]

request.render(self.sydent.servlets.registerServlet)

Expand Down Expand Up @@ -210,17 +208,15 @@ def test_federation_client_unsafe_ip(self, resolver):
},
)

resolver.return_value = defer.succeed(
[
Server(
host=self.unsafe_domain,
port=443,
priority=1,
weight=1,
expires=100,
)
]
)
resolver.return_value = [
Server(
host=self.unsafe_domain,
port=443,
priority=1,
weight=1,
expires=100,
)
]

request.render(self.sydent.servlets.registerServlet)

Expand Down
6 changes: 6 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from io import BytesIO
from typing import Dict
from unittest.mock import MagicMock

import attr
import twisted.logger
Expand Down Expand Up @@ -298,3 +299,8 @@ def getHostByName(self, name, timeout=None):

def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
raise NotImplementedError()


class AsyncMock(MagicMock):
async def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs)

0 comments on commit 712b3c6

Please sign in to comment.