diff --git a/sydent/hs_federation/verifier.py b/sydent/hs_federation/verifier.py index a708d407..48b92c63 100644 --- a/sydent/hs_federation/verifier.py +++ b/sydent/hs_federation/verifier.py @@ -25,7 +25,7 @@ from signedjson.sign import SignatureVerifyException from sydent.http.httpclient import FederationHttpClient -from sydent.util.stringutils import is_valid_hostname +from sydent.util.stringutils import is_valid_matrix_server_name logger = logging.getLogger(__name__) @@ -205,8 +205,8 @@ def strip_quotes(value): if not json_request["signatures"]: raise NoAuthenticationError("Missing X-Matrix Authorization header") - if not is_valid_hostname(json_request["origin"]): - raise InvalidServerName("X-Matrix header's origin parameter must be a valid hostname") + if not is_valid_matrix_server_name(json_request["origin"]): + raise InvalidServerName("X-Matrix header's origin parameter must be a valid Matrix server name") yield self.verifyServerSignedJson(json_request, [origin]) diff --git a/sydent/http/servlets/registerservlet.py b/sydent/http/servlets/registerservlet.py index 7267efc2..9950b671 100644 --- a/sydent/http/servlets/registerservlet.py +++ b/sydent/http/servlets/registerservlet.py @@ -25,7 +25,7 @@ from sydent.http.servlets import get_args, jsonwrap, deferjsonwrap, send_cors from sydent.http.httpclient import FederationHttpClient from sydent.users.tokens import issueToken -from sydent.util.stringutils import is_valid_hostname +from sydent.util.stringutils import is_valid_matrix_server_name logger = logging.getLogger(__name__) @@ -49,11 +49,11 @@ def render_POST(self, request): matrix_server = args['matrix_server_name'].lower() - if not is_valid_hostname(matrix_server): + if not is_valid_matrix_server_name(matrix_server): request.setResponseCode(400) return { 'errcode': 'M_INVALID_PARAM', - 'error': 'matrix_server_name must be a valid hostname' + 'error': 'matrix_server_name must be a valid Matrix server name (IP address or hostname)' } result = yield self.client.get_json( @@ -89,7 +89,7 @@ def render_POST(self, request): user_id_server = user_id_components[1] - if not is_valid_hostname(user_id_server): + if not is_valid_matrix_server_name(user_id_server): request.setResponseCode(500) return { 'errcode': 'M_UNKNOWN', diff --git a/sydent/threepid/bind.py b/sydent/threepid/bind.py index 84160bfe..6326c666 100644 --- a/sydent/threepid/bind.py +++ b/sydent/threepid/bind.py @@ -32,7 +32,7 @@ from sydent.threepid import ThreepidAssociation -from sydent.util.stringutils import is_valid_hostname +from sydent.util.stringutils import is_valid_matrix_server_name from twisted.internet import defer @@ -143,9 +143,9 @@ def _notify(self, assoc, attempt): matrix_server = mxid_parts[1] - if not is_valid_hostname(matrix_server): + if not is_valid_matrix_server_name(matrix_server): logger.error( - "MXID server part '%s' not a valid hostname. Not retrying.", + "MXID server part '%s' not a valid Matrix server name. Not retrying.", matrix_server, ) return @@ -184,7 +184,7 @@ def _notify(self, assoc, attempt): "Successfully deleted invite for %s from the store", assoc["address"], ) - except Exception as e: + except Exception: logger.exception( "Couldn't remove invite for %s from the store", assoc["address"], diff --git a/sydent/util/stringutils.py b/sydent/util/stringutils.py index 3f1a91e8..fdf24009 100644 --- a/sydent/util/stringutils.py +++ b/sydent/util/stringutils.py @@ -13,11 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import re +from typing import Optional, Tuple + +from twisted.internet.abstract import isIPAddress, isIPv6Address # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$") -# hostname/domain name + optional port +# hostname/domain name # https://regex101.com/r/OyN1lg/2 hostname_regex = re.compile( r"^(?:[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?)(?:\.[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?)*$", @@ -37,12 +40,10 @@ def is_valid_client_secret(client_secret): def is_valid_hostname(string: str) -> bool: - """Validate that a given string is a valid hostname or domain name, with an - optional port number. + """Validate that a given string is a valid hostname or domain name. For domain names, this only validates that the form is right (for - instance, it doesn't check that the TLD is valid). If a port is - specified, it has to be a valid port number. + instance, it doesn't check that the TLD is valid). :param string: The string to validate :type string: str @@ -51,20 +52,68 @@ def is_valid_hostname(string: str) -> bool: :rtype: bool """ - host_parts = string.split(":", 1) + return hostname_regex.match(string) is not None + + +def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]: + """Split a server name into host/port parts. - if len(host_parts) == 1: - return hostname_regex.match(string) is not None - else: - host, port = host_parts - valid_hostname = hostname_regex.match(host) is not None + No validation is done on the host part. The port part is validated to be + a valid port number. - try: + Args: + server_name: server name to parse + + Returns: + host/port parts. + + Raises: + ValueError if the server name could not be parsed. + """ + try: + if server_name[-1] == "]": + # ipv6 literal, hopefully + return server_name, None + + host_port = server_name.rsplit(":", 1) + host = host_port[0] + port = host_port[1] if host_port[1:] else None + + if port: port_num = int(port) - valid_port = ( - port == str(port_num) # exclude things like '08090' or ' 8090' - and 1 <= port_num < 65536) - except ValueError: - valid_port = False - return valid_hostname and valid_port + # exclude things like '08090' or ' 8090' + if port != str(port_num) or not (1 <= port_num < 65536): + raise ValueError("Invalid port") + + return host, port + except Exception: + raise ValueError("Invalid server name '%s'" % server_name) + + +def is_valid_matrix_server_name(string: str) -> bool: + """Validate that the given string is a valid Matrix server name. + + A string is a valid Matrix server name if it is one of the following, plus + an optional port: + + a. IPv4 address + b. IPv6 literal (`[IPV6_ADDRESS]`) + c. A valid hostname + + :param string: The string to validate + :type string: str + + :return: Whether the input is a valid Matrix server name + :rtype: bool + """ + + try: + host, port = parse_server_name(string) + except ValueError: + return False + + valid_ipv4_addr = isIPAddress(host) + valid_ipv6_literal = host[0] == "[" and host[-1] == "]" and isIPv6Address(host[1:-1]) + + return valid_ipv4_addr or valid_ipv6_literal or is_valid_hostname(host) diff --git a/tests/test_register.py b/tests/test_register.py index abc8c16c..843961e1 100644 --- a/tests/test_register.py +++ b/tests/test_register.py @@ -21,6 +21,7 @@ class RegisterTestCase(unittest.TestCase): """Tests Sydent's register servlet""" + def setUp(self): # Create a new sydent self.sydent = make_sydent() diff --git a/tests/test_util.py b/tests/test_util.py index 7c9a011e..bcf0a3ea 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,26 +1,37 @@ from twisted.trial import unittest -from sydent.util.stringutils import is_valid_hostname +from sydent.util.stringutils import is_valid_matrix_server_name class UtilTests(unittest.TestCase): """Tests Sydent utility functions.""" - def test_is_valid_hostname(self): - """Tests that the is_valid_hostname function accepts only valid - hostnames (or domain names), with optional port number. + + def test_is_valid_matrix_server_name(self): + """Tests that the is_valid_matrix_server_name function accepts only + valid hostnames (or domain names), with optional port number. """ + self.assertTrue(is_valid_matrix_server_name("9.9.9.9")) + self.assertTrue(is_valid_matrix_server_name("9.9.9.9:4242")) + self.assertTrue(is_valid_matrix_server_name("[::]")) + self.assertTrue(is_valid_matrix_server_name("[::]:4242")) + self.assertTrue(is_valid_matrix_server_name("[a:b:c::]:4242")) + + self.assertTrue(is_valid_matrix_server_name("example.com")) + self.assertTrue(is_valid_matrix_server_name("EXAMPLE.COM")) + self.assertTrue(is_valid_matrix_server_name("ExAmPlE.CoM")) + self.assertTrue(is_valid_matrix_server_name("example.com:4242")) + self.assertTrue(is_valid_matrix_server_name("localhost")) + self.assertTrue(is_valid_matrix_server_name("localhost:9000")) + self.assertTrue(is_valid_matrix_server_name("a.b.c.d:1234")) - self.assertTrue(is_valid_hostname("example.com")) - self.assertTrue(is_valid_hostname("EXAMPLE.COM")) - self.assertTrue(is_valid_hostname("ExAmPlE.CoM")) - self.assertTrue(is_valid_hostname("example.com:4242")) - self.assertTrue(is_valid_hostname("localhost")) - self.assertTrue(is_valid_hostname("localhost:9000")) - self.assertTrue(is_valid_hostname("a.b:1234")) + self.assertFalse(is_valid_matrix_server_name("[:::]")) + self.assertFalse(is_valid_matrix_server_name("a:b:c::")) - self.assertFalse(is_valid_hostname("example.com:65536")) - self.assertFalse(is_valid_hostname("example.com:0")) - self.assertFalse(is_valid_hostname("example.com:a")) - self.assertFalse(is_valid_hostname("example.com:04242")) - self.assertFalse(is_valid_hostname("example.com: 4242")) - self.assertFalse(is_valid_hostname("example.com/example.com")) - self.assertFalse(is_valid_hostname("example.com#example.com")) + self.assertFalse(is_valid_matrix_server_name("example.com:65536")) + self.assertFalse(is_valid_matrix_server_name("example.com:0")) + self.assertFalse(is_valid_matrix_server_name("example.com:-1")) + self.assertFalse(is_valid_matrix_server_name("example.com:a")) + self.assertFalse(is_valid_matrix_server_name("example.com: ")) + self.assertFalse(is_valid_matrix_server_name("example.com:04242")) + self.assertFalse(is_valid_matrix_server_name("example.com: 4242")) + self.assertFalse(is_valid_matrix_server_name("example.com/example.com")) + self.assertFalse(is_valid_matrix_server_name("example.com#example.com"))