diff --git a/CHANGELOG.md b/CHANGELOG.md index 39e11d4b..38b1d8b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,16 +4,18 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - +--> + ## [3.0.1] - 2024-09-07 ### Fixed diff --git a/docs/settings.rst b/docs/settings.rst index 0b76129f..545736cc 100644 --- a/docs/settings.rst +++ b/docs/settings.rst @@ -63,6 +63,37 @@ assigned ports. Note that you may override ``Application.get_allowed_schemes()`` to set this on a per-application basis. +ALLOW_URI_WILDCARDS +~~~~~~~~~~~~~~~~~~~ + +Default: ``False`` + +SECURITY WARNING: Enabling this setting can introduce security vulnerabilities. Only enable +this setting if you understand the risks. https://datatracker.ietf.org/doc/html/rfc6749#section-3.1.2 +states "The redirection endpoint URI MUST be an absolute URI as defined by [RFC3986] Section 4.3." The +intent of the URI restrictions is to prevent open redirects and phishing attacks. If you do enable this +ensure that the wildcards restrict URIs to resources under your control. You are strongly encouragd not +to use this feature in production. + +When set to ``True``, the server will allow wildcard characters in the domains for allowed_origins and +redirect_uris. + +``*`` is the only wildcard character allowed. + +``*`` can only be used as a prefix to a domain, must be the first character in +the domain, and cannot be in the top or second level domain. Matching is done using an +endsWith check. + +For example, +``https://*.example.com`` is allowed, +``https://*-myproject.example.com`` is allowed, +``https://*.sub.example.com`` is not allowed, +``https://*.com`` is not allowed, and +``https://example.*.com`` is not allowed. + +This feature is useful for working with CI service such as cloudflare, netlify, and vercel that offer branch +deployments for development previews and user acceptance testing. + ALLOWED_SCHEMES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index 621ce5b3..0467ddfa 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -213,7 +213,11 @@ def clean(self): if redirect_uris: validator = AllowedURIValidator( - allowed_schemes, name="redirect uri", allow_path=True, allow_query=True + allowed_schemes, + name="redirect uri", + allow_path=True, + allow_query=True, + allow_hostname_wildcard=oauth2_settings.ALLOW_URI_WILDCARDS, ) for uri in redirect_uris: validator(uri) @@ -227,7 +231,11 @@ def clean(self): allowed_origins = self.allowed_origins.strip().split() if allowed_origins: # oauthlib allows only https scheme for CORS - validator = AllowedURIValidator(oauth2_settings.ALLOWED_SCHEMES, "allowed origin") + validator = AllowedURIValidator( + oauth2_settings.ALLOWED_SCHEMES, + "allowed origin", + allow_hostname_wildcard=oauth2_settings.ALLOW_URI_WILDCARDS, + ) for uri in allowed_origins: validator(uri) @@ -777,12 +785,28 @@ def redirect_to_uri_allowed(uri, allowed_uris): :param allowed_uris: A list of URIs that are allowed """ + if not isinstance(allowed_uris, list): + raise ValueError("allowed_uris must be a list") + parsed_uri = urlparse(uri) uqs_set = set(parse_qsl(parsed_uri.query)) for allowed_uri in allowed_uris: parsed_allowed_uri = urlparse(allowed_uri) + if parsed_allowed_uri.scheme != parsed_uri.scheme: + # match failed, continue + continue + + """ check hostname """ + if oauth2_settings.ALLOW_URI_WILDCARDS and parsed_allowed_uri.hostname.startswith("*"): + """ wildcard hostname """ + if not parsed_uri.hostname.endswith(parsed_allowed_uri.hostname[1:]): + continue + elif parsed_allowed_uri.hostname != parsed_uri.hostname: + continue + # From RFC 8252 (Section 7.3) + # https://datatracker.ietf.org/doc/html/rfc8252#section-7.3 # # Loopback redirect URIs use the "http" scheme # [...] @@ -790,26 +814,26 @@ def redirect_to_uri_allowed(uri, allowed_uris): # time of the request for loopback IP redirect URIs, to accommodate # clients that obtain an available ephemeral port from the operating # system at the time of the request. + allowed_uri_is_loopback = parsed_allowed_uri.scheme == "http" and parsed_allowed_uri.hostname in [ + "127.0.0.1", + "::1", + ] + """ check port """ + if not allowed_uri_is_loopback and parsed_allowed_uri.port != parsed_uri.port: + continue + + """ check path """ + if parsed_allowed_uri.path != parsed_uri.path: + continue + + """ check querystring """ + aqs_set = set(parse_qsl(parsed_allowed_uri.query)) + if not aqs_set.issubset(uqs_set): + continue # circuit break - allowed_uri_is_loopback = ( - parsed_allowed_uri.scheme == "http" - and parsed_allowed_uri.hostname in ["127.0.0.1", "::1"] - and parsed_allowed_uri.port is None - ) - if ( - allowed_uri_is_loopback - and parsed_allowed_uri.scheme == parsed_uri.scheme - and parsed_allowed_uri.hostname == parsed_uri.hostname - and parsed_allowed_uri.path == parsed_uri.path - ) or ( - parsed_allowed_uri.scheme == parsed_uri.scheme - and parsed_allowed_uri.netloc == parsed_uri.netloc - and parsed_allowed_uri.path == parsed_uri.path - ): - aqs_set = set(parse_qsl(parsed_allowed_uri.query)) - if aqs_set.issubset(uqs_set): - return True + return True + # if uris matched then it's not allowed return False @@ -833,4 +857,5 @@ def is_origin_allowed(origin, allowed_origins): and parsed_allowed_origin.netloc == parsed_origin.netloc ): return True + return False diff --git a/oauth2_provider/settings.py b/oauth2_provider/settings.py index f5a6a25d..9771aa4e 100644 --- a/oauth2_provider/settings.py +++ b/oauth2_provider/settings.py @@ -71,6 +71,7 @@ "REQUEST_APPROVAL_PROMPT": "force", "ALLOWED_REDIRECT_URI_SCHEMES": ["http", "https"], "ALLOWED_SCHEMES": ["https"], + "ALLOW_URI_WILDCARDS": False, "OIDC_ENABLED": False, "OIDC_ISS_ENDPOINT": "", "OIDC_USERINFO_ENDPOINT": "", diff --git a/oauth2_provider/validators.py b/oauth2_provider/validators.py index b238b12d..b2370cfd 100644 --- a/oauth2_provider/validators.py +++ b/oauth2_provider/validators.py @@ -21,7 +21,15 @@ class URIValidator(URLValidator): class AllowedURIValidator(URIValidator): # TODO: find a way to get these associated with their form fields in place of passing name # TODO: submit PR to get `cause` included in the parent class ValidationError params` - def __init__(self, schemes, name, allow_path=False, allow_query=False, allow_fragments=False): + def __init__( + self, + schemes, + name, + allow_path=False, + allow_query=False, + allow_fragments=False, + allow_hostname_wildcard=False, + ): """ :param schemes: List of allowed schemes. E.g.: ["https"] :param name: Name of the validated URI. It is required for validation message. E.g.: "Origin" @@ -34,6 +42,7 @@ def __init__(self, schemes, name, allow_path=False, allow_query=False, allow_fra self.allow_path = allow_path self.allow_query = allow_query self.allow_fragments = allow_fragments + self.allow_hostname_wildcard = allow_hostname_wildcard def __call__(self, value): value = force_str(value) @@ -68,8 +77,57 @@ def __call__(self, value): params={"name": self.name, "value": value, "cause": "path not allowed"}, ) + if self.allow_hostname_wildcard and "*" in netloc: + domain_parts = netloc.split(".") + if netloc.count("*") > 1: + raise ValidationError( + "%(name)s URI validation error. %(cause)s: %(value)s", + params={ + "name": self.name, + "value": value, + "cause": "only one wildcard is allowed in the hostname", + }, + ) + if not netloc.startswith("*"): + raise ValidationError( + "%(name)s URI validation error. %(cause)s: %(value)s", + params={ + "name": self.name, + "value": value, + "cause": "wildcards must be at the beginning of the hostname", + }, + ) + if len(domain_parts) < 3: + raise ValidationError( + "%(name)s URI validation error. %(cause)s: %(value)s", + params={ + "name": self.name, + "value": value, + "cause": "wildcards cannot be in the top level or second level domain", + }, + ) + + # strip the wildcard from the netloc, we'll reassamble the value later to pass to URI Validator + if netloc.startswith("*."): + netloc = netloc[2:] + else: + netloc = netloc[1:] + + # domains cannot start with a hyphen, but can have them in the middle, so we strip hyphens + # after the wildcard so the final domain is valid and will succeed in URIVAlidator + if netloc.startswith("-"): + netloc = netloc[1:] + + # we stripped the wildcard from the netloc and path if they were allowed and present since they would + # fail validation we'll reassamble the URI to pass to the URIValidator + reassambled_uri = f"{scheme}://{netloc}{path}" + if query: + reassambled_uri += f"?{query}" + if fragment: + reassambled_uri += f"#{fragment}" + try: - super().__call__(value) + super().__call__(reassambled_uri) except ValidationError as e: raise ValidationError( "%(name)s URI validation error. %(cause)s: %(value)s", diff --git a/tests/test_application_views.py b/tests/test_application_views.py index 88617807..d4c7e28a 100644 --- a/tests/test_application_views.py +++ b/tests/test_application_views.py @@ -63,6 +63,156 @@ def test_application_registration_user(self): self.assertEqual(app.algorithm, form_data["algorithm"]) +@pytest.mark.usefixtures("oauth2_settings") +@pytest.mark.oauth2_settings({"ALLOW_URI_WILDCARDS": True}) +class TestApplicationRegistrationViewRedirectURIWithWildcard(BaseTest): + def _test_valid(self, redirect_uri): + self.client.login(username="foo_user", password="123456") + + form_data = { + "name": "Foo app", + "client_id": "client_id", + "client_secret": "client_secret", + "client_type": Application.CLIENT_CONFIDENTIAL, + "redirect_uris": redirect_uri, + "post_logout_redirect_uris": "http://example.com", + "authorization_grant_type": Application.GRANT_AUTHORIZATION_CODE, + "algorithm": "", + } + + response = self.client.post(reverse("oauth2_provider:register"), form_data) + self.assertEqual(response.status_code, 302) + + app = get_application_model().objects.get(name="Foo app") + self.assertEqual(app.user.username, "foo_user") + app = Application.objects.get() + self.assertEqual(app.name, form_data["name"]) + self.assertEqual(app.client_id, form_data["client_id"]) + self.assertEqual(app.redirect_uris, form_data["redirect_uris"]) + self.assertEqual(app.post_logout_redirect_uris, form_data["post_logout_redirect_uris"]) + self.assertEqual(app.client_type, form_data["client_type"]) + self.assertEqual(app.authorization_grant_type, form_data["authorization_grant_type"]) + self.assertEqual(app.algorithm, form_data["algorithm"]) + + def _test_invalid(self, uri, error_message): + self.client.login(username="foo_user", password="123456") + + form_data = { + "name": "Foo app", + "client_id": "client_id", + "client_secret": "client_secret", + "client_type": Application.CLIENT_CONFIDENTIAL, + "redirect_uris": uri, + "post_logout_redirect_uris": "http://example.com", + "authorization_grant_type": Application.GRANT_AUTHORIZATION_CODE, + "algorithm": "", + } + + response = self.client.post(reverse("oauth2_provider:register"), form_data) + self.assertEqual(response.status_code, 200) + self.assertContains(response, error_message) + + def test_application_registration_valid_3ld_wildcard(self): + self._test_valid("https://*.example.com") + + def test_application_registration_valid_3ld_partial_wildcard(self): + self._test_valid("https://*-partial.example.com") + + def test_application_registration_invalid_star(self): + self._test_invalid("*", "invalid_scheme: *") + + def test_application_registration_invalid_tld_wildcard(self): + self._test_invalid("https://*", "wildcards cannot be in the top level or second level domain") + + def test_application_registration_invalid_tld_partial_wildcard(self): + self._test_invalid("https://*-partial", "wildcards cannot be in the top level or second level domain") + + def test_application_registration_invalid_tld_not_startswith_wildcard_tld(self): + self._test_invalid("https://example.*", "wildcards must be at the beginning of the hostname") + + def test_application_registration_invalid_2ld_wildcard(self): + self._test_invalid("https://*.com", "wildcards cannot be in the top level or second level domain") + + def test_application_registration_invalid_2ld_partial_wildcard(self): + self._test_invalid( + "https://*-partial.com", "wildcards cannot be in the top level or second level domain" + ) + + def test_application_registration_invalid_2ld_not_startswith_wildcard_tld(self): + self._test_invalid("https://example.*.com", "wildcards must be at the beginning of the hostname") + + def test_application_registration_invalid_3ld_partial_not_startswith_wildcard_2ld(self): + self._test_invalid( + "https://invalid-*.example.com", "wildcards must be at the beginning of the hostname" + ) + + def test_application_registration_invalid_4ld_not_startswith_wildcard_3ld(self): + self._test_invalid( + "https://invalid.*.invalid.example.com", + "wildcards must be at the beginning of the hostname", + ) + + def test_application_registration_invalid_4ld_partial_not_startswith_wildcard_2ld(self): + self._test_invalid( + "https://invalid-*.invalid.example.com", + "wildcards must be at the beginning of the hostname", + ) + + +@pytest.mark.usefixtures("oauth2_settings") +@pytest.mark.oauth2_settings({"ALLOW_URI_WILDCARDS": True}) +class TestApplicationRegistrationViewAllowedOriginWithWildcard( + TestApplicationRegistrationViewRedirectURIWithWildcard +): + def _test_valid(self, uris): + self.client.login(username="foo_user", password="123456") + + form_data = { + "name": "Foo app", + "client_id": "client_id", + "client_secret": "client_secret", + "client_type": Application.CLIENT_CONFIDENTIAL, + "allowed_origins": uris, + "redirect_uris": "https://example.com", + "post_logout_redirect_uris": "http://example.com", + "authorization_grant_type": Application.GRANT_AUTHORIZATION_CODE, + "algorithm": "", + } + + response = self.client.post(reverse("oauth2_provider:register"), form_data) + self.assertEqual(response.status_code, 302) + + app = get_application_model().objects.get(name="Foo app") + self.assertEqual(app.user.username, "foo_user") + app = Application.objects.get() + self.assertEqual(app.name, form_data["name"]) + self.assertEqual(app.client_id, form_data["client_id"]) + self.assertEqual(app.redirect_uris, form_data["redirect_uris"]) + self.assertEqual(app.post_logout_redirect_uris, form_data["post_logout_redirect_uris"]) + self.assertEqual(app.client_type, form_data["client_type"]) + self.assertEqual(app.authorization_grant_type, form_data["authorization_grant_type"]) + self.assertEqual(app.algorithm, form_data["algorithm"]) + + def _test_invalid(self, uri, error_message): + self.client.login(username="foo_user", password="123456") + + form_data = { + "name": "Foo app", + "client_id": "client_id", + "client_secret": "client_secret", + "client_type": Application.CLIENT_CONFIDENTIAL, + "allowed_origins": uri, + "redirect_uris": "http://example.com", + "post_logout_redirect_uris": "http://example.com", + "authorization_grant_type": Application.GRANT_AUTHORIZATION_CODE, + "algorithm": "", + } + + response = self.client.post(reverse("oauth2_provider:register"), form_data) + self.assertEqual(response.status_code, 200) + self.assertContains(response, error_message) + + class TestApplicationViews(BaseTest): @classmethod def _create_application(cls, name, user): diff --git a/tests/test_models.py b/tests/test_models.py index 123c41b3..32ca0762 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -16,6 +16,7 @@ get_grant_model, get_id_token_model, get_refresh_token_model, + redirect_to_uri_allowed, ) from . import presets @@ -622,6 +623,79 @@ def test_application_clean(oauth2_settings, application): application.clean() +def _test_wildcard_redirect_uris_valid(oauth2_settings, application, uris): + oauth2_settings.ALLOW_URI_WILDCARDS = True + application.redirect_uris = uris + application.clean() + + +def _test_wildcard_redirect_uris_invalid(oauth2_settings, application, uris): + oauth2_settings.ALLOW_URI_WILDCARDS = True + application.redirect_uris = uris + with pytest.raises(ValidationError): + application.clean() + + +@pytest.mark.django_db(databases=retrieve_current_databases()) +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_application_clean_wildcard_redirect_uris_valid_3ld(oauth2_settings, application): + _test_wildcard_redirect_uris_valid(oauth2_settings, application, "https://*.example.com/path") + + +@pytest.mark.django_db(databases=retrieve_current_databases()) +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_application_clean_wildcard_redirect_uris_valid_partial_3ld(oauth2_settings, application): + _test_wildcard_redirect_uris_valid(oauth2_settings, application, "https://*-partial.example.com/path") + + +@pytest.mark.django_db(databases=retrieve_current_databases()) +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_application_clean_wildcard_redirect_uris_invalid_3ld_not_starting_with_wildcard( + oauth2_settings, application +): + _test_wildcard_redirect_uris_invalid(oauth2_settings, application, "https://invalid-*.example.com/path") + + +@pytest.mark.django_db(databases=retrieve_current_databases()) +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_application_clean_wildcard_redirect_uris_invalid_2ld(oauth2_settings, application): + _test_wildcard_redirect_uris_invalid(oauth2_settings, application, "https://*.com/path") + + +@pytest.mark.django_db(databases=retrieve_current_databases()) +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_application_clean_wildcard_redirect_uris_invalid_partial_2ld(oauth2_settings, application): + _test_wildcard_redirect_uris_invalid(oauth2_settings, application, "https://*-partial.com/path") + + +@pytest.mark.django_db(databases=retrieve_current_databases()) +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_application_clean_wildcard_redirect_uris_invalid_2ld_not_starting_with_wildcard( + oauth2_settings, application +): + _test_wildcard_redirect_uris_invalid(oauth2_settings, application, "https://invalid-*.com/path") + + +@pytest.mark.django_db(databases=retrieve_current_databases()) +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_application_clean_wildcard_redirect_uris_invalid_tld(oauth2_settings, application): + _test_wildcard_redirect_uris_invalid(oauth2_settings, application, "https://*/path") + + +@pytest.mark.django_db(databases=retrieve_current_databases()) +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_application_clean_wildcard_redirect_uris_invalid_tld_partial(oauth2_settings, application): + _test_wildcard_redirect_uris_invalid(oauth2_settings, application, "https://*-partial/path") + + +@pytest.mark.django_db(databases=retrieve_current_databases()) +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_application_clean_wildcard_redirect_uris_invalid_tld_not_starting_with_wildcard( + oauth2_settings, application +): + _test_wildcard_redirect_uris_invalid(oauth2_settings, application, "https://invalid-*/path") + + @pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.ALLOWED_SCHEMES_DEFAULT) def test_application_origin_allowed_default_https(oauth2_settings, cors_application): @@ -636,3 +710,35 @@ def test_application_origin_allowed_http(oauth2_settings, cors_application): """Test that http schemes are allowed because http was added to ALLOWED_SCHEMES""" assert cors_application.origin_allowed("https://example.com") assert cors_application.origin_allowed("http://example.com") + + +def test_redirect_to_uri_allowed_expects_allowed_uri_list(): + with pytest.raises(ValueError): + redirect_to_uri_allowed("https://example.com", "https://example.com") + assert redirect_to_uri_allowed("https://example.com", ["https://example.com"]) + + +valid_wildcard_redirect_to_params = [ + ("https://valid.example.com", ["https://*.example.com"]), + ("https://valid.valid.example.com", ["https://*.example.com"]), + ("https://valid-partial.example.com", ["https://*-partial.example.com"]), + ("https://valid.valid-partial.example.com", ["https://*-partial.example.com"]), +] + + +@pytest.mark.parametrize("uri, allowed_uri", valid_wildcard_redirect_to_params) +def test_wildcard_redirect_to_uri_allowed_valid(uri, allowed_uri, oauth2_settings): + oauth2_settings.ALLOW_URI_WILDCARDS = True + assert redirect_to_uri_allowed(uri, allowed_uri) + + +invalid_wildcard_redirect_to_params = [ + ("https://invalid.com", ["https://*.example.com"]), + ("https://invalid.example.com", ["https://*-partial.example.com"]), +] + + +@pytest.mark.parametrize("uri, allowed_uri", invalid_wildcard_redirect_to_params) +def test_wildcard_redirect_to_uri_allowed_invalid(uri, allowed_uri, oauth2_settings): + oauth2_settings.ALLOW_URI_WILDCARDS = True + assert not redirect_to_uri_allowed(uri, allowed_uri) diff --git a/tests/test_validators.py b/tests/test_validators.py index eb382c15..a77a1e16 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -171,3 +171,27 @@ def test_allow_fragment_invalid_urls(self): for uri in bad_uris: with self.assertRaises(ValidationError): validator(uri) + + def test_allow_hostname_wildcard(self): + validator = AllowedURIValidator(["https"], "test", allow_hostname_wildcard=True) + good_uris = [ + "https://*.example.com", + "https://*-partial.example.com", + "https://*.partial.example.com", + "https://*-partial.valid.example.com", + ] + for uri in good_uris: + # Check ValidationError not thrown + validator(uri) + + bad_uris = [ + "https://*/", + "https://*-partial", + "https://*.com", + "https://*-partial.com", + "https://*.*.example.com", + "https://invalid.*.example.com", + ] + for uri in bad_uris: + with self.assertRaises(ValidationError): + validator(uri)