diff --git a/.gitignore b/.gitignore index af644d1e3..c22ef00fa 100644 --- a/.gitignore +++ b/.gitignore @@ -25,7 +25,7 @@ __pycache__ pip-log.txt # Unit test / coverage reports -.cache +.pytest_cache .coverage .tox .pytest_cache/ diff --git a/oauth2_provider/admin.py b/oauth2_provider/admin.py index c6bbe44b7..d316642e0 100644 --- a/oauth2_provider/admin.py +++ b/oauth2_provider/admin.py @@ -1,8 +1,11 @@ from django.contrib import admin from .models import ( - get_access_token_model, get_application_model, - get_grant_model, get_refresh_token_model + get_access_token_model, + get_application_model, + get_grant_model, + get_refresh_token_model, + get_id_token_model, ) @@ -26,6 +29,11 @@ class AccessTokenAdmin(admin.ModelAdmin): raw_id_fields = ("user", ) +class IDTokenAdmin(admin.ModelAdmin): + list_display = ("token", "user", "application", "expires") + raw_id_fields = ("user", ) + + class RefreshTokenAdmin(admin.ModelAdmin): list_display = ("token", "user", "application") raw_id_fields = ("user", "access_token") @@ -34,9 +42,11 @@ class RefreshTokenAdmin(admin.ModelAdmin): Application = get_application_model() Grant = get_grant_model() AccessToken = get_access_token_model() +IDToken = get_id_token_model() RefreshToken = get_refresh_token_model() admin.site.register(Application, ApplicationAdmin) admin.site.register(Grant, GrantAdmin) admin.site.register(AccessToken, AccessTokenAdmin) +admin.site.register(IDToken, IDTokenAdmin) admin.site.register(RefreshToken, RefreshTokenAdmin) diff --git a/oauth2_provider/forms.py b/oauth2_provider/forms.py index 2e465959a..41129c449 100644 --- a/oauth2_provider/forms.py +++ b/oauth2_provider/forms.py @@ -5,6 +5,7 @@ class AllowForm(forms.Form): allow = forms.BooleanField(required=False) redirect_uri = forms.CharField(widget=forms.HiddenInput()) scope = forms.CharField(widget=forms.HiddenInput()) + nonce = forms.CharField(required=False, widget=forms.HiddenInput()) client_id = forms.CharField(widget=forms.HiddenInput()) state = forms.CharField(required=False, widget=forms.HiddenInput()) response_type = forms.CharField(widget=forms.HiddenInput()) diff --git a/oauth2_provider/migrations/0003_auto_20190413_2007.py b/oauth2_provider/migrations/0003_auto_20190413_2007.py new file mode 100644 index 000000000..472886147 --- /dev/null +++ b/oauth2_provider/migrations/0003_auto_20190413_2007.py @@ -0,0 +1,23 @@ +# Generated by Django 2.2 on 2019-04-13 20:07 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('oauth2_provider', '0002_auto_20190406_1805'), + ] + + operations = [ + migrations.AddField( + model_name='application', + name='algorithm', + field=models.CharField(choices=[('RS256', 'RSA with SHA-2 256'), ('HS256', 'HMAC with SHA-2 256')], default='RS256', max_length=5), + ), + migrations.AlterField( + model_name='application', + name='authorization_grant_type', + field=models.CharField(choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials'), ('openid-hybrid', 'OpenID connect hybrid')], max_length=32), + ), + ] diff --git a/oauth2_provider/migrations/0004_idtoken.py b/oauth2_provider/migrations/0004_idtoken.py new file mode 100644 index 000000000..e0d43b2dc --- /dev/null +++ b/oauth2_provider/migrations/0004_idtoken.py @@ -0,0 +1,33 @@ +# Generated by Django 2.2 on 2019-04-16 14:36 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('oauth2_provider', '0003_auto_20190413_2007'), + ] + + operations = [ + migrations.CreateModel( + name='IDToken', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('token', models.TextField(unique=True)), + ('expires', models.DateTimeField()), + ('scope', models.TextField(blank=True)), + ('created', models.DateTimeField(auto_now_add=True)), + ('updated', models.DateTimeField(auto_now=True)), + ('application', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=settings.OAUTH2_PROVIDER_APPLICATION_MODEL)), + ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='oauth2_provider_idtoken', to=settings.AUTH_USER_MODEL)), + ], + options={ + 'abstract': False, + 'swappable': 'OAUTH2_PROVIDER_ID_TOKEN_MODEL', + }, + ), + ] diff --git a/oauth2_provider/migrations/0005_accesstoken_id_token.py b/oauth2_provider/migrations/0005_accesstoken_id_token.py new file mode 100644 index 000000000..a6ca7dd25 --- /dev/null +++ b/oauth2_provider/migrations/0005_accesstoken_id_token.py @@ -0,0 +1,20 @@ +# Generated by Django 2.2 on 2019-04-16 14:39 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('oauth2_provider', '0004_idtoken'), + ] + + operations = [ + migrations.AddField( + model_name='accesstoken', + name='id_token', + field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='access_token', to=settings.OAUTH2_PROVIDER_ID_TOKEN_MODEL), + ), + ] diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index 1489a8845..fa10c57c3 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -1,7 +1,10 @@ +import json from datetime import timedelta from urllib.parse import parse_qsl, urlparse import logging +from jwcrypto import jwk, jwt + from django.apps import apps from django.conf import settings from django.core.exceptions import ImproperlyConfigured @@ -49,11 +52,20 @@ class AbstractApplication(models.Model): GRANT_IMPLICIT = "implicit" GRANT_PASSWORD = "password" GRANT_CLIENT_CREDENTIALS = "client-credentials" + GRANT_OPENID_HYBRID = "openid-hybrid" GRANT_TYPES = ( (GRANT_AUTHORIZATION_CODE, _("Authorization code")), (GRANT_IMPLICIT, _("Implicit")), (GRANT_PASSWORD, _("Resource owner password-based")), (GRANT_CLIENT_CREDENTIALS, _("Client credentials")), + (GRANT_OPENID_HYBRID, _("OpenID connect hybrid")), + ) + + RS256_ALGORITHM = "RS256" + HS256_ALGORITHM = "HS256" + ALGORITHM_TYPES = ( + (RS256_ALGORITHM, _("RSA with SHA-2 256")), + (HS256_ALGORITHM, _("HMAC with SHA-2 256")), ) id = models.BigAutoField(primary_key=True) @@ -81,6 +93,7 @@ class AbstractApplication(models.Model): created = models.DateTimeField(auto_now_add=True) updated = models.DateTimeField(auto_now=True) + algorithm = models.CharField(max_length=5, choices=ALGORITHM_TYPES, default=RS256_ALGORITHM) class Meta: abstract = True @@ -281,6 +294,10 @@ class AbstractAccessToken(models.Model): related_name="refreshed_access_token" ) token = models.CharField(max_length=255, unique=True, ) + id_token = models.OneToOneField( + oauth2_settings.ID_TOKEN_MODEL, on_delete=models.CASCADE, blank=True, null=True, + related_name="access_token" + ) application = models.ForeignKey( oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE, blank=True, null=True, ) @@ -414,6 +431,104 @@ class Meta(AbstractRefreshToken.Meta): swappable = "OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL" +class AbstractIDToken(models.Model): + """ + An IDToken instance represents the actual token to + access user's resources, as in :openid:`2`. + + Fields: + + * :attr:`user` The Django user representing resources' owner + * :attr:`token` ID token + * :attr:`application` Application instance + * :attr:`expires` Date and time of token expiration, in DateTime format + * :attr:`scope` Allowed scopes + """ + id = models.BigAutoField(primary_key=True) + user = models.ForeignKey( + settings.AUTH_USER_MODEL, on_delete=models.CASCADE, blank=True, null=True, + related_name="%(app_label)s_%(class)s" + ) + token = models.TextField(unique=True) + application = models.ForeignKey( + oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE, blank=True, null=True, + ) + expires = models.DateTimeField() + scope = models.TextField(blank=True) + + created = models.DateTimeField(auto_now_add=True) + updated = models.DateTimeField(auto_now=True) + + def is_valid(self, scopes=None): + """ + Checks if the access token is valid. + + :param scopes: An iterable containing the scopes to check or None + """ + return not self.is_expired() and self.allow_scopes(scopes) + + def is_expired(self): + """ + Check token expiration with timezone awareness + """ + if not self.expires: + return True + + return timezone.now() >= self.expires + + def allow_scopes(self, scopes): + """ + Check if the token allows the provided scopes + + :param scopes: An iterable containing the scopes to check + """ + if not scopes: + return True + + provided_scopes = set(self.scope.split()) + resource_scopes = set(scopes) + + return resource_scopes.issubset(provided_scopes) + + def revoke(self): + """ + Convenience method to uniform tokens' interface, for now + simply remove this token from the database in order to revoke it. + """ + self.delete() + + @property + def scopes(self): + """ + Returns a dictionary of allowed scope names (as keys) with their descriptions (as values) + """ + all_scopes = get_scopes_backend().get_all_scopes() + token_scopes = self.scope.split() + return {name: desc for name, desc in all_scopes.items() if name in token_scopes} + + @property + def claims(self): + key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) + jwt_token = jwt.JWT(key=key, jwt=self.token) + return json.loads(jwt_token.claims) + + def get_claims(self, check_claims=None): + key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) + jwt_token = jwt.JWT(key=key, jwt=self.token, check_claims=check_claims) + return json.loads(jwt_token.claims) + + def __str__(self): + return self.token + + class Meta: + abstract = True + + +class IDToken(AbstractIDToken): + class Meta(AbstractIDToken.Meta): + swappable = "OAUTH2_PROVIDER_ID_TOKEN_MODEL" + + def get_application_model(): """ Return the Application model that is active in this project. """ return apps.get_model(oauth2_settings.APPLICATION_MODEL) @@ -429,6 +544,11 @@ def get_access_token_model(): return apps.get_model(oauth2_settings.ACCESS_TOKEN_MODEL) +def get_id_token_model(): + """ Return the AccessToken model that is active in this project. """ + return apps.get_model(oauth2_settings.ID_TOKEN_MODEL) + + def get_refresh_token_model(): """ Return the RefreshToken model that is active in this project. """ return apps.get_model(oauth2_settings.REFRESH_TOKEN_MODEL) diff --git a/oauth2_provider/oauth2_backends.py b/oauth2_provider/oauth2_backends.py index f71f46e9b..facb10885 100644 --- a/oauth2_provider/oauth2_backends.py +++ b/oauth2_provider/oauth2_backends.py @@ -94,7 +94,7 @@ def validate_authorization_request(self, request): except oauth2.OAuth2Error as error: raise OAuthToolkitError(error=error) - def create_authorization_response(self, request, scopes, credentials, allow): + def create_authorization_response(self, uri, request, scopes, credentials, body, allow): """ A wrapper method that calls create_authorization_response on `server_class` instance. @@ -102,7 +102,8 @@ def create_authorization_response(self, request, scopes, credentials, allow): :param request: The current django.http.HttpRequest object :param scopes: A list of provided scopes :param credentials: Authorization credentials dictionary containing - `client_id`, `state`, `redirect_uri`, `response_type` + `client_id`, `state`, `redirect_uri` and `response_type` + :param body: Other body parameters not used in credentials dictionary :param allow: True if the user authorize the client, otherwise False """ try: @@ -114,10 +115,10 @@ def create_authorization_response(self, request, scopes, credentials, allow): credentials["user"] = request.user headers, body, status = self.server.create_authorization_response( - uri=credentials["redirect_uri"], scopes=scopes, credentials=credentials) - uri = headers.get("Location", None) + uri=uri, scopes=scopes, credentials=credentials, body=body) + redirect_uri = headers.get("Location", None) - return uri, headers, body, status + return redirect_uri, headers, body, status except oauth2.FatalClientError as error: raise FatalClientError(error=error, redirect_uri=credentials["redirect_uri"]) @@ -166,6 +167,28 @@ def verify_request(self, request, scopes): valid, r = self.server.verify_request(uri, http_method, body, headers, scopes=scopes) return valid, r + def validate_userinfo_request(self, request): + """ + """ + try: + self.server.validate_userinfo_request(request) + except oauth2.InvalidTokenError as error: + raise FatalClientError(error=error) + except oauth2.InsufficientScopeError as error: + raise OAuthToolkitError(error=error) + + def create_userinfo_response(self, request): + """ + """ + uri, http_method, body, headers = self._extract_params(request) + + headers, body, status = self.server.create_userinfo_response(uri, http_method, body, + headers) + + uri = headers.get("Location", None) + + return uri, headers, body, status + class JSONOAuthLibCore(OAuthLibCore): """ diff --git a/oauth2_provider/oauth2_validators.py b/oauth2_provider/oauth2_validators.py index 1e80a5cb9..5bbe55d85 100644 --- a/oauth2_provider/oauth2_validators.py +++ b/oauth2_provider/oauth2_validators.py @@ -1,5 +1,7 @@ import base64 import binascii +import json +import hashlib import logging from collections import OrderedDict from datetime import datetime, timedelta @@ -11,15 +13,24 @@ from django.core.exceptions import ObjectDoesNotExist from django.db import transaction from django.db.models import Q -from django.utils import timezone +from django.utils import dateformat, timezone from django.utils.timezone import make_aware from django.utils.translation import ugettext_lazy as _ from oauthlib.oauth2 import RequestValidator +from oauthlib.oauth2.rfc6749 import utils + +from jwcrypto.common import JWException +from jwcrypto import jwk, jwt +from jwcrypto.jwt import JWTExpired from .exceptions import FatalClientError from .models import ( - AbstractApplication, get_access_token_model, - get_application_model, get_grant_model, get_refresh_token_model + AbstractApplication, + get_access_token_model, + get_id_token_model, + get_application_model, + get_grant_model, + get_refresh_token_model, ) from .scopes import get_scopes_backend from .settings import oauth2_settings @@ -28,18 +39,19 @@ log = logging.getLogger("oauth2_provider") GRANT_TYPE_MAPPING = { - "authorization_code": (AbstractApplication.GRANT_AUTHORIZATION_CODE, ), + "authorization_code": (AbstractApplication.GRANT_AUTHORIZATION_CODE, AbstractApplication.GRANT_OPENID_HYBRID), "password": (AbstractApplication.GRANT_PASSWORD, ), "client_credentials": (AbstractApplication.GRANT_CLIENT_CREDENTIALS, ), "refresh_token": ( AbstractApplication.GRANT_AUTHORIZATION_CODE, AbstractApplication.GRANT_PASSWORD, AbstractApplication.GRANT_CLIENT_CREDENTIALS, - ) + ), } Application = get_application_model() AccessToken = get_access_token_model() +IDToken = get_id_token_model() Grant = get_grant_model() RefreshToken = get_refresh_token_model() UserModel = get_user_model() @@ -409,6 +421,16 @@ def validate_response_type(self, client_id, response_type, client, request, *arg return client.allows_grant_type(AbstractApplication.GRANT_AUTHORIZATION_CODE) elif response_type == "token": return client.allows_grant_type(AbstractApplication.GRANT_IMPLICIT) + elif response_type == "id_token": + return client.allows_grant_type(AbstractApplication.GRANT_IMPLICIT) + elif response_type == "id_token token": + return client.allows_grant_type(AbstractApplication.GRANT_IMPLICIT) + elif response_type == "code id_token": + return client.allows_grant_type(AbstractApplication.GRANT_OPENID_HYBRID) + elif response_type == "code token": + return client.allows_grant_type(AbstractApplication.GRANT_OPENID_HYBRID) + elif response_type == "code id_token token": + return client.allows_grant_type(AbstractApplication.GRANT_OPENID_HYBRID) else: return False @@ -460,6 +482,24 @@ def save_authorization_code(self, client_id, code, request, *args, **kwargs): ) g.save() + def get_authorization_code_scopes(self, client_id, code, redirect_uri, request): + scopes = [] + fields = { + "code": code, + } + + if client_id: + fields["application__client_id"] = client_id + + if redirect_uri: + fields["redirect_uri"] = redirect_uri + + grant = Grant.objects.filter(**fields).values() + if grant.exists(): + grant_dict = dict(grant[0]) + scopes = utils.scope_to_list(grant_dict["scope"]) + return scopes + def rotate_refresh_token(self, request): """ Checks if rotate refresh token is enabled @@ -563,11 +603,15 @@ def save_bearer_token(self, token, request, *args, **kwargs): token["expires_in"] = oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS def _create_access_token(self, expires, request, token, source_refresh_token=None): + id_token = token.get('id_token', None) + if id_token: + id_token = IDToken.objects.get(token=id_token) access_token = AccessToken( user=request.user, scope=token["scope"], expires=expires, token=token["access_token"], + id_token=id_token, application=request.client, source_refresh_token=source_refresh_token, ) @@ -647,3 +691,114 @@ def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs # Temporary store RefreshToken instance to be reused by get_original_scopes and save_bearer_token. request.refresh_token_instance = rt return rt.application == client + + @transaction.atomic + def _save_id_token(self, token, request, expires, *args, **kwargs): + + scopes = request.scope or " ".join(request.scopes) + + if request.grant_type == "client_credentials": + request.user = None + + id_token = IDToken.objects.create( + user=request.user, + scope=scopes, + expires=expires, + token=token.serialize(), + application=request.client, + ) + return id_token + + def get_jwt_bearer_token(self, token, token_handler, request): + return self.get_id_token(token, token_handler, request) + + def get_id_token(self, token, token_handler, request): + + key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) + + # TODO: http://openid.net/specs/openid-connect-core-1_0.html#HybridIDToken2 + # Save the id_token on database bound to code when the request come to + # Authorization Endpoint and return the same one when request come to + # Token Endpoint + + # TODO: Check if at this point this request parameters are alredy validated + + expiration_time = timezone.now() + timedelta(seconds=oauth2_settings.ID_TOKEN_EXPIRE_SECONDS) + # Required ID Token claims + claims = { + "iss": oauth2_settings.OIDC_ISS_ENDPOINT, + "sub": str(request.user.id), + "aud": request.client_id, + "exp": int(dateformat.format(expiration_time, "U")), + "iat": int(dateformat.format(datetime.utcnow(), "U")), + "auth_time": int(dateformat.format(request.user.last_login, "U")) + } + + nonce = getattr(request, "nonce", None) + if nonce: + claims["nonce"] = nonce + + # TODO: create a function to check if we should add at_hash + # http://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken + # http://openid.net/specs/openid-connect-core-1_0.html#ImplicitIDToken + # if request.grant_type in 'authorization_code' and 'access_token' in token: + if (request.grant_type is "authorization_code" and "access_token" in token) or request.response_type == "code id_token token" or (request.response_type == "id_token token" and "access_token" in token): + acess_token = token["access_token"] + sha256 = hashlib.sha256(acess_token.encode("ascii")) + bits128 = sha256.hexdigest()[:16] + at_hash = base64.urlsafe_b64encode(bits128.encode("ascii")) + claims['at_hash'] = at_hash.decode("utf8") + + # TODO: create a function to check if we should include c_hash + # http://openid.net/specs/openid-connect-core-1_0.html#HybridIDToken + if request.response_type in ("code id_token", "code id_token token"): + code = token["code"] + sha256 = hashlib.sha256(code.encode("ascii")) + bits256 = sha256.hexdigest()[:32] + c_hash = base64.urlsafe_b64encode(bits256.encode("ascii")) + claims["c_hash"] = c_hash.decode("utf8") + + jwt_token = jwt.JWT(header=json.dumps({"alg": "RS256"}, default=str), claims=json.dumps(claims, default=str)) + jwt_token.make_signed_token(key) + + id_token = self._save_id_token(jwt_token, request, expiration_time) + # this is needed by django rest framework + request.access_token = id_token + request.id_token = id_token + return jwt_token.serialize() + + def validate_jwt_bearer_token(self, token, scopes, request): + return self.validate_id_token(token, scopes, request) + + def validate_id_token(self, token, scopes, request): + """ + When users try to access resources, check that provided id_token is valid + """ + if not token: + return False + + key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) + + try: + jwt_token = jwt.JWT(key=key, jwt=token) + id_token = IDToken.objects.get(token=jwt_token.serialize()) + request.client = id_token.application + request.user = id_token.user + request.scopes = scopes + # this is needed by django rest framework + request.access_token = id_token + return True + except (JWException, JWTExpired): + # TODO: This is the base exception of all jwcrypto + return False + + return False + + def validate_user_match(self, id_token_hint, scopes, claims, request): + # TODO: Fix to validate when necessary acording + # https://github.com/idan/oauthlib/blob/master/oauthlib/oauth2/rfc6749/request_validator.py#L556 + # http://openid.net/specs/openid-connect-core-1_0.html#AuthRequest id_token_hint section + return True + + def get_userinfo_claims(self, request): + return [] \ No newline at end of file diff --git a/oauth2_provider/settings.py b/oauth2_provider/settings.py index 53f163142..978ad215e 100644 --- a/oauth2_provider/settings.py +++ b/oauth2_provider/settings.py @@ -25,6 +25,7 @@ APPLICATION_MODEL = getattr(settings, "OAUTH2_PROVIDER_APPLICATION_MODEL", "oauth2_provider.Application") ACCESS_TOKEN_MODEL = getattr(settings, "OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL", "oauth2_provider.AccessToken") +ID_TOKEN_MODEL = getattr(settings, "OAUTH2_PROVIDER_ID_TOKEN_MODEL", "oauth2_provider.IDToken") GRANT_MODEL = getattr(settings, "OAUTH2_PROVIDER_GRANT_MODEL", "oauth2_provider.Grant") REFRESH_TOKEN_MODEL = getattr(settings, "OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL", "oauth2_provider.RefreshToken") @@ -32,7 +33,7 @@ "CLIENT_ID_GENERATOR_CLASS": "oauth2_provider.generators.ClientIdGenerator", "CLIENT_SECRET_GENERATOR_CLASS": "oauth2_provider.generators.ClientSecretGenerator", "CLIENT_SECRET_GENERATOR_LENGTH": 128, - "OAUTH2_SERVER_CLASS": "oauthlib.oauth2.Server", + "OAUTH2_SERVER_CLASS": "oauthlib.openid.connect.core.endpoints.pre_configured.Server", "OAUTH2_VALIDATOR_CLASS": "oauth2_provider.oauth2_validators.OAuth2Validator", "OAUTH2_BACKEND_CLASS": "oauth2_provider.oauth2_backends.OAuthLibCore", "SCOPES": {"read": "Reading scope", "write": "Writing scope"}, @@ -42,16 +43,33 @@ "WRITE_SCOPE": "write", "AUTHORIZATION_CODE_EXPIRE_SECONDS": 60, "ACCESS_TOKEN_EXPIRE_SECONDS": 36000, + "ID_TOKEN_EXPIRE_SECONDS": 36000, "REFRESH_TOKEN_EXPIRE_SECONDS": None, "REFRESH_TOKEN_GRACE_PERIOD_SECONDS": 0, "ROTATE_REFRESH_TOKEN": True, "ERROR_RESPONSE_WITH_SCOPES": False, "APPLICATION_MODEL": APPLICATION_MODEL, "ACCESS_TOKEN_MODEL": ACCESS_TOKEN_MODEL, + "ID_TOKEN_MODEL": ID_TOKEN_MODEL, "GRANT_MODEL": GRANT_MODEL, "REFRESH_TOKEN_MODEL": REFRESH_TOKEN_MODEL, "REQUEST_APPROVAL_PROMPT": "force", "ALLOWED_REDIRECT_URI_SCHEMES": ["http", "https"], + "OIDC_ISS_ENDPOINT": "", + "OIDC_USERINFO_ENDPOINT": "", + "OIDC_RSA_PRIVATE_KEY": "", + "OIDC_RESPONSE_TYPES_SUPPORTED": [ + "code", + "token", + "id_token", + "id_token token", + "code token", + "code id_token", + "code id_token token", + ], + "OIDC_SUBJECT_TYPES_SUPPORTED": ["public"], + "OIDC_ID_TOKEN_SIGNING_ALG_VALUES_SUPPORTED": ["RS256", "HS256"], + "OIDC_TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED": ["client_secret_post", "client_secret_basic"], # Special settings that will be evaluated at runtime "_SCOPES": [], @@ -76,6 +94,13 @@ "OAUTH2_BACKEND_CLASS", "SCOPES", "ALLOWED_REDIRECT_URI_SCHEMES", + "OIDC_ISS_ENDPOINT", + "OIDC_USERINFO_ENDPOINT", + "OIDC_RSA_PRIVATE_KEY", + "OIDC_RESPONSE_TYPES_SUPPORTED", + "OIDC_SUBJECT_TYPES_SUPPORTED", + "OIDC_ID_TOKEN_SIGNING_ALG_VALUES_SUPPORTED", + "OIDC_TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED" ) # List of settings that may be in string import notation. diff --git a/oauth2_provider/urls.py b/oauth2_provider/urls.py index 86d97d053..333f11933 100644 --- a/oauth2_provider/urls.py +++ b/oauth2_provider/urls.py @@ -27,5 +27,11 @@ name="authorized-token-delete"), ] +oidc_urlpatterns = [ + url(r"^\.well-known/openid-configuration/$", views.ConnectDiscoveryInfoView.as_view(), name="oidc-connect-discovery-info"), + url(r"^jwks/$", views.JwksInfoView.as_view(), name="jwks-info"), + url(r"^userinfo/$", views.UserInfoView.as_view(), name="user-info") +] + -urlpatterns = base_urlpatterns + management_urlpatterns +urlpatterns = base_urlpatterns + management_urlpatterns + oidc_urlpatterns diff --git a/oauth2_provider/views/__init__.py b/oauth2_provider/views/__init__.py index 7bf60cece..9f2ac4ff7 100644 --- a/oauth2_provider/views/__init__.py +++ b/oauth2_provider/views/__init__.py @@ -1,7 +1,13 @@ # flake8: noqa -from .base import AuthorizationView, TokenView, RevokeTokenView -from .application import ApplicationRegistration, ApplicationDetail, ApplicationList, \ - ApplicationDelete, ApplicationUpdate -from .generic import ProtectedResourceView, ScopedProtectedResourceView, ReadWriteScopedResourceView -from .token import AuthorizedTokensListView, AuthorizedTokenDeleteView +from .application import ( + ApplicationDelete, ApplicationDetail, ApplicationList, + ApplicationRegistration, ApplicationUpdate +) +from .base import AuthorizationView, RevokeTokenView, TokenView +from .generic import ( + ProtectedResourceView, ReadWriteScopedResourceView, + ScopedProtectedResourceView +) from .introspect import IntrospectTokenView +from .oidc import ConnectDiscoveryInfoView, JwksInfoView, UserInfoView +from .token import AuthorizedTokenDeleteView, AuthorizedTokensListView diff --git a/oauth2_provider/views/application.py b/oauth2_provider/views/application.py index c925493f5..b38c907ab 100644 --- a/oauth2_provider/views/application.py +++ b/oauth2_provider/views/application.py @@ -32,7 +32,7 @@ def get_form_class(self): get_application_model(), fields=( "name", "client_id", "client_secret", "client_type", - "authorization_grant_type", "redirect_uris" + "authorization_grant_type", "redirect_uris", "algorithm", ) ) @@ -81,6 +81,6 @@ def get_form_class(self): get_application_model(), fields=( "name", "client_id", "client_secret", "client_type", - "authorization_grant_type", "redirect_uris" + "authorization_grant_type", "redirect_uris", "algorithm", ) ) diff --git a/oauth2_provider/views/base.py b/oauth2_provider/views/base.py index 51a1ecccb..abd0ab75e 100644 --- a/oauth2_provider/views/base.py +++ b/oauth2_provider/views/base.py @@ -1,5 +1,6 @@ import json import logging +from urllib import parse from django.contrib.auth.mixins import LoginRequiredMixin from django.http import HttpResponse @@ -95,6 +96,7 @@ def get_initial(self): initial_data = { "redirect_uri": self.oauth2_data.get("redirect_uri", None), "scope": " ".join(scopes), + "nonce": self.oauth2_data.get("nonce", None), "client_id": self.oauth2_data.get("client_id", None), "state": self.oauth2_data.get("state", None), "response_type": self.oauth2_data.get("response_type", None), @@ -116,17 +118,25 @@ def form_valid(self, form): credentials["code_challenge"] = form.cleaned_data.get("code_challenge") if form.cleaned_data.get("code_challenge_method", False): credentials["code_challenge_method"] = form.cleaned_data.get("code_challenge_method") + body = { + "nonce": form.cleaned_data.get("nonce") + } scopes = form.cleaned_data.get("scope") allow = form.cleaned_data.get("allow") try: - uri, headers, body, status = self.create_authorization_response( - request=self.request, scopes=scopes, credentials=credentials, allow=allow + redirect_uri, headers, body, status = self.create_authorization_response( + self.request.get_raw_uri(), + request=self.request, + scopes=scopes, + credentials=credentials, + body=body, + allow=allow ) except OAuthToolkitError as error: return self.error_response(error, application) - self.success_url = uri + self.success_url = redirect_uri log.debug("Success url for the request: {0}".format(self.success_url)) return self.redirect(self.success_url, application) @@ -155,6 +165,9 @@ def get(self, request, *args, **kwargs): # TODO: Cache this! application = get_application_model().objects.get(client_id=credentials["client_id"]) + uri_query = parse.urlparse(self.request.get_raw_uri()).query + uri_query_params = dict(parse.parse_qsl(uri_query, keep_blank_values=True, strict_parsing=True)) + kwargs["application"] = application kwargs["client_id"] = credentials["client_id"] kwargs["redirect_uri"] = credentials["redirect_uri"] @@ -162,6 +175,7 @@ def get(self, request, *args, **kwargs): kwargs["state"] = credentials["state"] kwargs["code_challenge"] = credentials["code_challenge"] kwargs["code_challenge_method"] = credentials["code_challenge_method"] + kwargs["nonce"] = uri_query_params.get('nonce', None) self.oauth2_data = kwargs # following two loc are here only because of https://code.djangoproject.com/ticket/17795 @@ -178,11 +192,14 @@ def get(self, request, *args, **kwargs): # This is useful for in-house applications-> assume an in-house applications # are already approved. if application.skip_authorization: - uri, headers, body, status = self.create_authorization_response( - request=self.request, scopes=" ".join(scopes), - credentials=credentials, allow=True + redirect_uri, headers, body, status = self.create_authorization_response( + self.request.get_raw_uri(), + request=self.request, + scopes=" ".join(scopes), + credentials=credentials, + allow=True ) - return self.redirect(uri, application) + return self.redirect(redirect_uri, application) elif require_approval == "auto": tokens = get_access_token_model().objects.filter( @@ -194,11 +211,14 @@ def get(self, request, *args, **kwargs): # check past authorizations regarded the same scopes as the current one for token in tokens: if token.allow_scopes(scopes): - uri, headers, body, status = self.create_authorization_response( - request=self.request, scopes=" ".join(scopes), - credentials=credentials, allow=True + redirect_uri, headers, body, status = self.create_authorization_response( + self.request.get_raw_uri(), + request=self.request, + scopes=" ".join(scopes), + credentials=credentials, + allow=True ) - return self.redirect(uri, application) + return self.redirect(redirect_uri, application) except OAuthToolkitError as error: return self.error_response(error, application) diff --git a/oauth2_provider/views/mixins.py b/oauth2_provider/views/mixins.py index 00065644a..dbcb292d8 100644 --- a/oauth2_provider/views/mixins.py +++ b/oauth2_provider/views/mixins.py @@ -96,7 +96,7 @@ def validate_authorization_request(self, request): core = self.get_oauthlib_core() return core.validate_authorization_request(request) - def create_authorization_response(self, request, scopes, credentials, allow): + def create_authorization_response(self, uri, request, scopes, credentials, allow, body=None): """ A wrapper method that calls create_authorization_response on `server_class` instance. @@ -104,14 +104,15 @@ def create_authorization_response(self, request, scopes, credentials, allow): :param request: The current django.http.HttpRequest object :param scopes: A space-separated string of provided scopes :param credentials: Authorization credentials dictionary containing - `client_id`, `state`, `redirect_uri`, `response_type` + `client_id`, `state`, `redirect_uri` and `response_type` :param allow: True if the user authorize the client, otherwise False + :param body: Other body parameters not used in credentials dictionary """ # TODO: move this scopes conversion from and to string into a utils function scopes = scopes.split(" ") if scopes else [] core = self.get_oauthlib_core() - return core.create_authorization_response(request, scopes, credentials, allow) + return core.create_authorization_response(uri, request, scopes, credentials, body, allow) def create_token_response(self, request): """ @@ -132,6 +133,18 @@ def create_revocation_response(self, request): core = self.get_oauthlib_core() return core.create_revocation_response(request) + def create_userinfo_response(self, request): + """ + """ + core = self.get_oauthlib_core() + return core.create_userinfo_response(request) + + def validate_userinfo_request(self, request): + """ + """ + core = self.get_oauthlib_core() + return core.validate_userinfo_request(request) + def verify_request(self, request): """ A wrapper method that calls verify_request on `server_class` instance. diff --git a/oauth2_provider/views/oidc.py b/oauth2_provider/views/oidc.py new file mode 100644 index 000000000..839b9124b --- /dev/null +++ b/oauth2_provider/views/oidc.py @@ -0,0 +1,86 @@ +from __future__ import absolute_import, unicode_literals + +import json + +from django.http import JsonResponse +from django.urls import reverse_lazy +from django.http import HttpResponse +from oauthlib.oauth2.rfc6749.errors import ServerError + +from django.views.generic import View + +from rest_framework.views import APIView + +from jwcrypto import jwk + +from .mixins import OAuthLibMixin +from ..settings import oauth2_settings + +class ConnectDiscoveryInfoView(View): + """ + View used to show oidc provider configuration information + """ + def get(self, request, *args, **kwargs): + issuer_url = oauth2_settings.OIDC_ISS_ENDPOINT + data = { + "issuer": issuer_url, + "authorization_endpoint": "{}{}".format(issuer_url, reverse_lazy("oauth2_provider:authorize")), + "token_endpoint": "{}{}".format(issuer_url, reverse_lazy("oauth2_provider:token")), + "userinfo_endpoint": oauth2_settings.OIDC_USERINFO_ENDPOINT, + "jwks_uri": "{}{}".format(issuer_url, reverse_lazy("oauth2_provider:jwks-info")), + "response_types_supported": oauth2_settings.OIDC_RESPONSE_TYPES_SUPPORTED, + "subject_types_supported": oauth2_settings.OIDC_SUBJECT_TYPES_SUPPORTED, + "id_token_signing_alg_values_supported": oauth2_settings.OIDC_ID_TOKEN_SIGNING_ALG_VALUES_SUPPORTED, + "token_endpoint_auth_methods_supported": oauth2_settings.OIDC_TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED, + } + response = JsonResponse(data) + response["Access-Control-Allow-Origin"] = "*" + return response + + +class JwksInfoView(View): + """ + View used to show oidc json web key set document + """ + def get(self, request, *args, **kwargs): + key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) + data = { + "keys": [{ + "alg": "RS256", + "use": "sig", + "kid": key.thumbprint() + }] + } + data["keys"][0].update(json.loads(key.export_public())) + response = JsonResponse(data) + response["Access-Control-Allow-Origin"] = "*" + return response + + +class UserInfoView(OAuthLibMixin, APIView): + server_class = oauth2_settings.OAUTH2_SERVER_CLASS + validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS + oauthlib_backend_class = oauth2_settings.OAUTH2_BACKEND_CLASS + + """ + View used to show Claims about the authenticated End-User + """ + + def get_userinfo_response(self, request, *args, **kwargs): + try: + uri, headers, body, status = self.create_userinfo_response(request) + except ServerError as error: + return HttpResponse(content=error, status=error.status_code) + except Exception as error: + return HttpResponse(content=error, status=500) + + response = HttpResponse(content=body, status=status) + for k, v in headers.items(): + response[k] = v + return response + + def get(self, request, *args, **kwargs): + return self.get_userinfo_response(request, *args, **kwargs) + + def post(self, request, *args, **kwargs): + return self.get_userinfo_response(request, *args, **kwargs) diff --git a/setup.cfg b/setup.cfg index 1901b5e36..1c5fd4036 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,6 +30,7 @@ install_requires = django >= 2.0 requests >= 2.13.0 oauthlib >= 3.0.1 + jwcrypto >= 0.4.2 [options.packages.find] exclude = tests diff --git a/tests/settings.py b/tests/settings.py index 5e145ac3b..f1ad8dd55 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -124,3 +124,14 @@ }, } } + +OAUTH2_PROVIDER = { + "OIDC_ISS_ENDPOINT": "http://localhost", + "OIDC_USERINFO_ENDPOINT": "http://localhost/userinfo/", + "OIDC_RSA_PRIVATE_KEY": "-----BEGIN RSA PRIVATE KEY-----\nMIICXQIBAAKBgQCbCYh5h2NmQuBqVO6G+/CO+cHm9VBzsb0MeA6bbQfDnbhstVOT\nj0hcnZJzDjYc6ajBZZf6gxVP9xrdm9Uh599VI3X5PFXLbMHrmzTAMzCGIyg+/fnP\n0gocYxmCX2+XKyj/Zvt1pUX8VAN2AhrJSfxNDKUHERTVEV9bRBJg4F0C3wIDAQAB\nAoGAP+i4nNw+Ec/8oWh8YSFm4xE6qKG0NdTtSMAOyWwy+KTB+vHuT1QPsLn1vj77\n+IQrX/moogg6F1oV9YdA3vat3U7rwt1sBGsRrLhA+Spp9WEQtglguNo4+QfVo2ju\nYBa2rG+h75qjiA3xnU//F3rvwnAsOWv0NUVdVeguyR+u6okCQQDBUmgWeH2WHmUn\n2nLNCz+9wj28rqhfOr9Ptem2gqk+ywJmuIr4Y5S1OdavOr2UZxOcEwncJ/MLVYQq\nMH+x4V5HAkEAzU2GMR5OdVLcxfVTjzuIC76paoHVWnLibd1cdANpPmE6SM+pf5el\nfVSwuH9Fmlizu8GiPCxbJUoXB/J1tGEKqQJBALhClEU+qOzpoZ6/voYi/6kdN3zc\nuEy0EN6n09AKb8gS9QH1STgAqh+ltjMkeMe3C2DKYK5/QU9/Pc58lWl1FkcCQG67\nZamQgxjcvJ85FvymS1aqW45KwNysIlzHjFo2jMlMf7dN6kobbPMQftDENLJvLWIT\nqoFyGycdsxZiPAIyZSECQQCZFn3Dl6hnJxWZH8Fsa9hj79kZ/WVkIXGmtdgt0fNr\ndTnvCVtA59ne4LEVie/PMH/odQWY0SxVm/76uBZv/1vY\n-----END RSA PRIVATE KEY-----" +} + +OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL = 'oauth2_provider.AccessToken' +OAUTH2_PROVIDER_APPLICATION_MODEL = 'oauth2_provider.Application' +OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL = 'oauth2_provider.RefreshToken' +OAUTH2_PROVIDER_ID_TOKEN_MODEL = 'oauth2_provider.IDToken' diff --git a/tests/test_application_views.py b/tests/test_application_views.py index 6130876ce..64e112da3 100644 --- a/tests/test_application_views.py +++ b/tests/test_application_views.py @@ -50,6 +50,7 @@ def test_application_registration_user(self): "client_type": Application.CLIENT_CONFIDENTIAL, "redirect_uris": "http://example.com", "authorization_grant_type": Application.GRANT_AUTHORIZATION_CODE, + "algorithm": "RS256", } response = self.client.post(reverse("oauth2_provider:register"), form_data) diff --git a/tests/test_authorization_code.py b/tests/test_authorization_code.py index 26788a6e5..f8a75d33c 100644 --- a/tests/test_authorization_code.py +++ b/tests/test_authorization_code.py @@ -53,8 +53,13 @@ def setUp(self): ) self.application.save() - oauth2_settings._SCOPES = ["read", "write"] + oauth2_settings._SCOPES = ["read", "write", "openid"] oauth2_settings._DEFAULT_SCOPES = ["read", "write"] + oauth2_settings.SCOPES = { + "read": "Reading scope", + "write": "Writing scope", + "openid": "OpenID connect" + } def tearDown(self): self.application.delete() @@ -105,6 +110,26 @@ def test_skip_authorization_completely(self): response = self.client.get(url) self.assertEqual(response.status_code, 302) + def test_id_token_skip_authorization_completely(self): + """ + If application.skip_authorization = True, should skip the authorization page. + """ + self.client.login(username="test_user", password="123456") + self.application.skip_authorization = True + self.application.save() + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + def test_pre_auth_invalid_client(self): """ Test error for an invalid client_id with response_type: code @@ -151,6 +176,33 @@ def test_pre_auth_valid_client(self): self.assertEqual(form["scope"].value(), "read write") self.assertEqual(form["client_id"].value(), self.application.client_id) + def test_id_token_pre_auth_valid_client(self): + """ + Test response for a valid client_id with response_type: code + """ + self.client.login(username="test_user", password="123456") + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + # check form is in context and form params are valid + self.assertIn("form", response.context) + + form = response.context["form"] + self.assertEqual(form["redirect_uri"].value(), "http://example.org") + self.assertEqual(form["state"].value(), "random_state_string") + self.assertEqual(form["scope"].value(), "openid") + self.assertEqual(form["client_id"].value(), self.application.client_id) + def test_pre_auth_valid_client_custom_redirect_uri_scheme(self): """ Test response for a valid client_id with response_type: code @@ -317,6 +369,27 @@ def test_code_post_auth_allow(self): self.assertIn("state=random_state_string", response["Location"]) self.assertIn("code=", response["Location"]) + def test_id_token_code_post_auth_allow(self): + """ + Test authorization code is given for an allowed request with response_type: code + """ + self.client.login(username="test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org?", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + def test_code_post_auth_deny(self): """ Test error when resource owner deny access @@ -522,14 +595,14 @@ def test_code_post_auth_fails_when_redirect_uri_path_is_invalid(self): class TestAuthorizationCodeTokenView(BaseTest): - def get_auth(self): + def get_auth(self, scope="read write"): """ Helper method to retrieve a valid authorization code """ authcode_data = { "client_id": self.application.client_id, "state": "random_state_string", - "scope": "read write", + "scope": scope, "redirect_uri": "http://example.org", "response_type": "code", "allow": True, @@ -1017,6 +1090,34 @@ def test_public(self): self.assertEqual(content["scope"], "read write") self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + def test_id_token_public(self): + """ + Request an access token using client_type: public + """ + self.client.login(username="test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + authorization_code = self.get_auth(scope="openid") + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client_id": self.application.client_id, + "scope": "openid", + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "openid") + self.assertIn("access_token", content) + self.assertIn("id_token", content) + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + def test_public_pkce_S256_authorize_get(self): """ Request an access token using client_type: public @@ -1447,6 +1548,45 @@ def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_param self.assertEqual(content["scope"], "read write") self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + def test_id_token_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params(self): + """ + Tests code exchange succeed when redirect uri matches the one used for code request + """ + self.client.login(username="test_user", password="123456") + self.application.redirect_uris = "http://localhost http://example.com?foo=bar" + self.application.save() + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.com?bar=baz&foo=bar", + "response_type": "code", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + query_dict = parse_qs(urlparse(response["Location"]).query) + authorization_code = query_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.com?bar=baz&foo=bar", + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "openid") + self.assertIn("access_token", content) + self.assertIn("id_token", content) + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + class TestAuthorizationCodeProtectedResource(BaseTest): def test_resource_access_allowed(self): @@ -1488,6 +1628,57 @@ def test_resource_access_allowed(self): response = view(request) self.assertEqual(response, "This is a protected resource") + def test_id_token_resource_access_allowed(self): + self.client.login(username="test_user", password="123456") + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + query_dict = parse_qs(urlparse(response["Location"]).query) + authorization_code = query_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + content = json.loads(response.content.decode("utf-8")) + access_token = content["access_token"] + id_token = content["id_token"] + + # use token to access the resource + auth_headers = { + "HTTP_AUTHORIZATION": "Bearer " + access_token, + } + request = self.factory.get("/fake-resource", **auth_headers) + request.user = self.test_user + + view = ResourceView.as_view() + response = view(request) + self.assertEqual(response, "This is a protected resource") + + # use id_token to access the resource + auth_headers = { + "HTTP_AUTHORIZATION": "Bearer " + id_token, + } + request = self.factory.get("/fake-resource", **auth_headers) + request.user = self.test_user + + view = ResourceView.as_view() + response = view(request) + self.assertEqual(response, "This is a protected resource") + def test_resource_access_deny(self): auth_headers = { "HTTP_AUTHORIZATION": "Bearer " + "faketoken", diff --git a/tests/test_hybrid.py b/tests/test_hybrid.py new file mode 100644 index 000000000..da3d0b5e6 --- /dev/null +++ b/tests/test_hybrid.py @@ -0,0 +1,1263 @@ +import base64 +import datetime +import json + +from urllib.parse import parse_qs, urlencode, urlparse + +from django.contrib.auth import get_user_model +from django.test import RequestFactory, TestCase +from django.urls import reverse +from django.utils import timezone +from oauthlib.oauth2.rfc6749 import errors as oauthlib_errors + +from oauth2_provider.models import ( + get_access_token_model, get_application_model, + get_grant_model, get_refresh_token_model +) +from oauth2_provider.settings import oauth2_settings +from oauth2_provider.views import ProtectedResourceView + +from .utils import get_basic_auth_header + + +Application = get_application_model() +AccessToken = get_access_token_model() +Grant = get_grant_model() +RefreshToken = get_refresh_token_model() +UserModel = get_user_model() + + +# mocking a protected resource view +class ResourceView(ProtectedResourceView): + def get(self, request, *args, **kwargs): + return "This is a protected resource" + + +class BaseTest(TestCase): + def setUp(self): + self.factory = RequestFactory() + self.hy_test_user = UserModel.objects.create_user("hy_test_user", "test_hy@example.com", "123456") + self.hy_dev_user = UserModel.objects.create_user("hy_dev_user", "dev_hy@example.com", "123456") + + oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["http", "custom-scheme"] + + self.application = Application( + name="Hybrid Test Application", + redirect_uris=( + "http://localhost http://example.com http://example.org custom-scheme://example.com" + ), + user=self.hy_dev_user, + client_type=Application.CLIENT_CONFIDENTIAL, + authorization_grant_type=Application.GRANT_OPENID_HYBRID, + ) + self.application.save() + + oauth2_settings._SCOPES = ["read", "write", "openid"] + oauth2_settings._DEFAULT_SCOPES = ["read", "write"] + oauth2_settings.SCOPES = { + "read": "Reading scope", + "write": "Writing scope", + "openid": "OpenID connect" + } + + def tearDown(self): + self.application.delete() + self.hy_test_user.delete() + self.hy_dev_user.delete() + + +class TestRegressionIssue315Hybrid(BaseTest): + """ + Test to avoid regression for the issue 315: request object + was being reassigned when getting AuthorizationView + """ + + def test_request_is_not_overwritten_code_token(self): + self.client.login(username="hy_test_user", password="123456") + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code token", + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + assert "request" not in response.context_data + + def test_request_is_not_overwritten_code_id_token(self): + self.client.login(username="hy_test_user", password="123456") + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code id_token", + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org", + "nonce": "nonce", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + assert "request" not in response.context_data + + def test_request_is_not_overwritten_code_id_token_token(self): + self.client.login(username="hy_test_user", password="123456") + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code id_token token", + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org", + "nonce": "nonce", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + assert "request" not in response.context_data + + +class TestHybridView(BaseTest): + def test_skip_authorization_completely(self): + """ + If application.skip_authorization = True, should skip the authorization page. + """ + self.client.login(username="hy_test_user", password="123456") + self.application.skip_authorization = True + self.application.save() + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + + def test_id_token_skip_authorization_completely(self): + """ + If application.skip_authorization = True, should skip the authorization page. + """ + self.client.login(username="hy_test_user", password="123456") + self.application.skip_authorization = True + self.application.save() + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + + def test_pre_auth_invalid_client(self): + """ + Test error for an invalid client_id with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + query_string = urlencode({ + "client_id": "fakeclientid", + "response_type": "code", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 400) + self.assertEqual( + response.context_data["url"], + "?error=invalid_request&error_description=Invalid+client_id+parameter+value." + ) + + def test_pre_auth_valid_client(self): + """ + Test response for a valid client_id with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code id_token", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + # check form is in context and form params are valid + self.assertIn("form", response.context) + + form = response.context["form"] + self.assertEqual(form["redirect_uri"].value(), "http://example.org") + self.assertEqual(form["state"].value(), "random_state_string") + self.assertEqual(form["scope"].value(), "read write") + self.assertEqual(form["client_id"].value(), self.application.client_id) + + def test_id_token_pre_auth_valid_client(self): + """ + Test response for a valid client_id with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code id_token", + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "nonce": "nonce", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + # check form is in context and form params are valid + self.assertIn("form", response.context) + + form = response.context["form"] + self.assertEqual(form["redirect_uri"].value(), "http://example.org") + self.assertEqual(form["state"].value(), "random_state_string") + self.assertEqual(form["scope"].value(), "openid") + self.assertEqual(form["client_id"].value(), self.application.client_id) + + def test_pre_auth_valid_client_custom_redirect_uri_scheme(self): + """ + Test response for a valid client_id with response_type: code + using a non-standard, but allowed, redirect_uri scheme. + """ + self.client.login(username="hy_test_user", password="123456") + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code id_token", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "custom-scheme://example.com", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + # check form is in context and form params are valid + self.assertIn("form", response.context) + + form = response.context["form"] + self.assertEqual(form["redirect_uri"].value(), "custom-scheme://example.com") + self.assertEqual(form["state"].value(), "random_state_string") + self.assertEqual(form["scope"].value(), "read write") + self.assertEqual(form["client_id"].value(), self.application.client_id) + + def test_pre_auth_approval_prompt(self): + tok = AccessToken.objects.create( + user=self.hy_test_user, token="1234567890", + application=self.application, + expires=timezone.now() + datetime.timedelta(days=1), + scope="read write" + ) + self.client.login(username="hy_test_user", password="123456") + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code id_token", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "approval_prompt": "auto", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + # user already authorized the application, but with different scopes: prompt them. + tok.scope = "read" + tok.save() + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + def test_pre_auth_approval_prompt_default(self): + oauth2_settings.REQUEST_APPROVAL_PROMPT = "force" + self.assertEqual(oauth2_settings.REQUEST_APPROVAL_PROMPT, "force") + + AccessToken.objects.create( + user=self.hy_test_user, token="1234567890", + application=self.application, + expires=timezone.now() + datetime.timedelta(days=1), + scope="read write" + ) + self.client.login(username="hy_test_user", password="123456") + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code id_token", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + def test_pre_auth_approval_prompt_default_override(self): + oauth2_settings.REQUEST_APPROVAL_PROMPT = "auto" + + AccessToken.objects.create( + user=self.hy_test_user, token="1234567890", + application=self.application, + expires=timezone.now() + datetime.timedelta(days=1), + scope="read write" + ) + self.client.login(username="hy_test_user", password="123456") + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + + def test_pre_auth_default_redirect(self): + """ + Test for default redirect uri if omitted from query string with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code id_token", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + form = response.context["form"] + self.assertEqual(form["redirect_uri"].value(), "http://localhost") + + def test_pre_auth_forbibben_redirect(self): + """ + Test error when passing a forbidden redirect_uri in query string with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code", + "redirect_uri": "http://forbidden.it", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 400) + + def test_pre_auth_wrong_response_type(self): + """ + Test error when passing a wrong response_type in query string + """ + self.client.login(username="hy_test_user", password="123456") + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "WRONG", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + self.assertIn("error=unsupported_response_type", response["Location"]) + + def test_code_post_auth_allow_code_token(self): + """ + Test authorization code is given for an allowed request with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org", + "response_type": "code token", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("access_token=", response["Location"]) + + def test_code_post_auth_allow_code_id_token(self): + """ + Test authorization code is given for an allowed request with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org", + "response_type": "code id_token", + "allow": True, + "nonce": "nonce", + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + + def test_code_post_auth_allow_code_id_token_token(self): + """ + Test authorization code is given for an allowed request with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org", + "response_type": "code id_token token", + "allow": True, + "nonce": "nonce", + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + self.assertIn("access_token=", response["Location"]) + + def test_id_token_code_post_auth_allow(self): + """ + Test authorization code is given for an allowed request with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "code id_token", + "allow": True, + "nonce": "nonce", + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + + def test_code_post_auth_deny(self): + """ + Test error when resource owner deny access + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": False, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("error=access_denied", response["Location"]) + + def test_code_post_auth_bad_responsetype(self): + """ + Test authorization code is given for an allowed request with a response_type not supported + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "response_type": "UNKNOWN", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org?error", response["Location"]) + + def test_code_post_auth_forbidden_redirect_uri(self): + """ + Test authorization code is given for an allowed request with a forbidden redirect_uri + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://forbidden.it", + "response_type": "code", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 400) + + def test_code_post_auth_malicious_redirect_uri(self): + """ + Test validation of a malicious redirect_uri + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "/../", + "response_type": "code", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 400) + + def test_code_post_auth_allow_custom_redirect_uri_scheme_code_token(self): + """ + Test authorization code is given for an allowed request with response_type: code + using a non-standard, but allowed, redirect_uri scheme. + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "custom-scheme://example.com", + "response_type": "code token", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("custom-scheme://example.com", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("access_token=", response["Location"]) + + def test_code_post_auth_allow_custom_redirect_uri_scheme_code_id_token(self): + """ + Test authorization code is given for an allowed request with response_type: code + using a non-standard, but allowed, redirect_uri scheme. + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "custom-scheme://example.com", + "response_type": "code id_token", + "allow": True, + "nonce": "nonce", + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("custom-scheme://example.com", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + + def test_code_post_auth_allow_custom_redirect_uri_scheme_code_id_token_token(self): + """ + Test authorization code is given for an allowed request with response_type: code + using a non-standard, but allowed, redirect_uri scheme. + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "custom-scheme://example.com", + "response_type": "code id_token token", + "allow": True, + "nonce": "nonce", + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("custom-scheme://example.com", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + self.assertIn("access_token=", response["Location"]) + + def test_code_post_auth_deny_custom_redirect_uri_scheme(self): + """ + Test error when resource owner deny access + using a non-standard, but allowed, redirect_uri scheme. + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "custom-scheme://example.com", + "response_type": "code", + "allow": False, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("custom-scheme://example.com?", response["Location"]) + self.assertIn("error=access_denied", response["Location"]) + + def test_code_post_auth_redirection_uri_with_querystring_code_token(self): + """ + Tests that a redirection uri with query string is allowed + and query string is retained on redirection. + See http://tools.ietf.org/html/rfc6749#section-3.1.2 + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.com?foo=bar", + "response_type": "code token", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.com?foo=bar", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("access_token=", response["Location"]) + + def test_code_post_auth_redirection_uri_with_querystring_code_id_token(self): + """ + Tests that a redirection uri with query string is allowed + and query string is retained on redirection. + See http://tools.ietf.org/html/rfc6749#section-3.1.2 + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.com?foo=bar", + "response_type": "code id_token", + "allow": True, + "nonce": "nonce", + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.com?foo=bar", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + + def test_code_post_auth_redirection_uri_with_querystring_code_id_token_token(self): + """ + Tests that a redirection uri with query string is allowed + and query string is retained on redirection. + See http://tools.ietf.org/html/rfc6749#section-3.1.2 + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.com?foo=bar", + "response_type": "code id_token token", + "allow": True, + "nonce": "nonce", + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.com?foo=bar", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + self.assertIn("access_token=", response["Location"]) + + def test_code_post_auth_failing_redirection_uri_with_querystring(self): + """ + Test that in case of error the querystring of the redirection uri is preserved + + See https://github.com/evonove/django-oauth-toolkit/issues/238 + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.com?foo=bar", + "response_type": "code", + "allow": False, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertEqual("http://example.com?foo=bar&error=access_denied&state=random_state_string", response["Location"]) + + def test_code_post_auth_fails_when_redirect_uri_path_is_invalid(self): + """ + Tests that a redirection uri is matched using scheme + netloc + path + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.com/a?foo=bar", + "response_type": "code", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 400) + + +class TestHybridTokenView(BaseTest): + def get_auth(self, scope="read write"): + """ + Helper method to retrieve a valid authorization code + """ + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": scope, + "redirect_uri": "http://example.org", + "response_type": "code id_token", + "allow": True, + "nonce": "nonce", + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + fragment_dict = parse_qs(urlparse(response["Location"]).fragment) + return fragment_dict["code"].pop() + + def test_basic_auth(self): + """ + Request an access token using basic authentication for client authentication + """ + self.client.login(username="hy_test_user", password="123456") + authorization_code = self.get_auth() + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org" + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "read write") + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + def test_basic_auth_bad_authcode(self): + """ + Request an access token using a bad authorization code + """ + self.client.login(username="hy_test_user", password="123456") + + token_request_data = { + "grant_type": "authorization_code", + "code": "BLAH", + "redirect_uri": "http://example.org" + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 400) + + def test_basic_auth_bad_granttype(self): + """ + Request an access token using a bad grant_type string + """ + self.client.login(username="hy_test_user", password="123456") + + token_request_data = { + "grant_type": "UNKNOWN", + "code": "BLAH", + "redirect_uri": "http://example.org" + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 400) + + def test_basic_auth_grant_expired(self): + """ + Request an access token using an expired grant token + """ + self.client.login(username="hy_test_user", password="123456") + g = Grant( + application=self.application, user=self.hy_test_user, code="BLAH", + expires=timezone.now(), redirect_uri="", scope="") + g.save() + + token_request_data = { + "grant_type": "authorization_code", + "code": "BLAH", + "redirect_uri": "http://example.org" + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 400) + + def test_basic_auth_bad_secret(self): + """ + Request an access token using basic authentication for client authentication + """ + self.client.login(username="hy_test_user", password="123456") + authorization_code = self.get_auth() + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org" + } + auth_headers = get_basic_auth_header(self.application.client_id, "BOOM!") + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 401) + + def test_basic_auth_wrong_auth_type(self): + """ + Request an access token using basic authentication for client authentication + """ + self.client.login(username="hy_test_user", password="123456") + authorization_code = self.get_auth() + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org" + } + + user_pass = "{0}:{1}".format(self.application.client_id, self.application.client_secret) + auth_string = base64.b64encode(user_pass.encode("utf-8")) + auth_headers = { + "HTTP_AUTHORIZATION": "Wrong " + auth_string.decode("utf-8"), + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 401) + + def test_request_body_params(self): + """ + Request an access token using client_type: public + """ + self.client.login(username="hy_test_user", password="123456") + authorization_code = self.get_auth() + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client_id": self.application.client_id, + "client_secret": self.application.client_secret, + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "read write") + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + def test_public(self): + """ + Request an access token using client_type: public + """ + self.client.login(username="hy_test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + authorization_code = self.get_auth() + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client_id": self.application.client_id + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "read write") + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + def test_id_token_public(self): + """ + Request an access token using client_type: public + """ + self.client.login(username="hy_test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + authorization_code = self.get_auth(scope="openid") + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client_id": self.application.client_id, + "scope": "openid", + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "openid") + self.assertIn("access_token", content) + self.assertIn("id_token", content) + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + def test_malicious_redirect_uri(self): + """ + Request an access token using client_type: public and ensure redirect_uri is + properly validated. + """ + self.client.login(username="hy_test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + authorization_code = self.get_auth() + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "/../", + "client_id": self.application.client_id + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 400) + data = response.json() + self.assertEqual(data["error"], "invalid_request") + self.assertEqual(data["error_description"], oauthlib_errors.MismatchingRedirectURIError.description) + + def test_code_exchange_succeed_when_redirect_uri_match(self): + """ + Tests code exchange succeed when redirect uri matches the one used for code request + """ + self.client.login(username="hy_test_user", password="123456") + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org?foo=bar", + "response_type": "code token", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + fragment_dict = parse_qs(urlparse(response["Location"]).fragment) + authorization_code = fragment_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org?foo=bar" + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "openid read write") + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + def test_code_exchange_fails_when_redirect_uri_does_not_match(self): + """ + Tests code exchange fails when redirect uri does not match the one used for code request + """ + self.client.login(username="hy_test_user", password="123456") + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org?foo=bar", + "response_type": "code token", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + query_dict = parse_qs(urlparse(response["Location"]).fragment) + authorization_code = query_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org?foo=baraa" + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 400) + data = response.json() + self.assertEqual(data["error"], "invalid_request") + self.assertEqual(data["error_description"], oauthlib_errors.MismatchingRedirectURIError.description) + + def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params(self): + """ + Tests code exchange succeed when redirect uri matches the one used for code request + """ + self.client.login(username="hy_test_user", password="123456") + self.application.redirect_uris = "http://localhost http://example.com?foo=bar" + self.application.save() + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.com?bar=baz&foo=bar", + "response_type": "code token", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + fragment_dict = parse_qs(urlparse(response["Location"]).fragment) + authorization_code = fragment_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.com?bar=baz&foo=bar" + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "openid read write") + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + def test_id_token_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params(self): + """ + Tests code exchange succeed when redirect uri matches the one used for code request + """ + self.client.login(username="hy_test_user", password="123456") + self.application.redirect_uris = "http://localhost http://example.com?foo=bar" + self.application.save() + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.com?bar=baz&foo=bar", + "response_type": "code token", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + fragment_dict = parse_qs(urlparse(response["Location"]).fragment) + authorization_code = fragment_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.com?bar=baz&foo=bar", + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "openid") + self.assertIn("access_token", content) + self.assertIn("id_token", content) + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + +class TestHybridProtectedResource(BaseTest): + def test_resource_access_allowed(self): + self.client.login(username="hy_test_user", password="123456") + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org", + "response_type": "code token", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + fragment_dict = parse_qs(urlparse(response["Location"]).fragment) + authorization_code = fragment_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org" + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + content = json.loads(response.content.decode("utf-8")) + access_token = content["access_token"] + + # use token to access the resource + auth_headers = { + "HTTP_AUTHORIZATION": "Bearer " + access_token, + } + request = self.factory.get("/fake-resource", **auth_headers) + request.user = self.hy_test_user + + view = ResourceView.as_view() + response = view(request) + self.assertEqual(response, "This is a protected resource") + + def test_id_token_resource_access_allowed(self): + self.client.login(username="hy_test_user", password="123456") + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "code token", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + fragment_dict = parse_qs(urlparse(response["Location"]).fragment) + authorization_code = fragment_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + content = json.loads(response.content.decode("utf-8")) + access_token = content["access_token"] + id_token = content["id_token"] + + # use token to access the resource + auth_headers = { + "HTTP_AUTHORIZATION": "Bearer " + access_token, + } + request = self.factory.get("/fake-resource", **auth_headers) + request.user = self.hy_test_user + + view = ResourceView.as_view() + response = view(request) + self.assertEqual(response, "This is a protected resource") + + # use id_token to access the resource + auth_headers = { + "HTTP_AUTHORIZATION": "Bearer " + id_token, + } + request = self.factory.get("/fake-resource", **auth_headers) + request.user = self.hy_test_user + + view = ResourceView.as_view() + response = view(request) + self.assertEqual(response, "This is a protected resource") + + def test_resource_access_deny(self): + auth_headers = { + "HTTP_AUTHORIZATION": "Bearer " + "faketoken", + } + request = self.factory.get("/fake-resource", **auth_headers) + request.user = self.hy_test_user + + view = ResourceView.as_view() + response = view(request) + self.assertEqual(response.status_code, 403) + + +class TestDefaultScopesHybrid(BaseTest): + + def test_pre_auth_default_scopes(self): + """ + Test response for a valid client_id with response_type: code using default scopes + """ + self.client.login(username="hy_test_user", password="123456") + oauth2_settings._DEFAULT_SCOPES = ["read"] + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code token", + "state": "random_state_string", + "redirect_uri": "http://example.org", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + # check form is in context and form params are valid + self.assertIn("form", response.context) + + form = response.context["form"] + self.assertEqual(form["redirect_uri"].value(), "http://example.org") + self.assertEqual(form["state"].value(), "random_state_string") + self.assertEqual(form["scope"].value(), "read") + self.assertEqual(form["client_id"].value(), self.application.client_id) + oauth2_settings._DEFAULT_SCOPES = ["read", "write"] diff --git a/tests/test_implicit.py b/tests/test_implicit.py index 548592377..e32080122 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -1,9 +1,13 @@ from urllib.parse import parse_qs, urlencode, urlparse +import json + from django.contrib.auth import get_user_model from django.test import RequestFactory, TestCase from django.urls import reverse +from jwcrypto import jwk, jwt + from oauth2_provider.models import get_application_model from oauth2_provider.settings import oauth2_settings from oauth2_provider.views import ProtectedResourceView @@ -34,8 +38,14 @@ def setUp(self): ) self.application.save() - oauth2_settings._SCOPES = ["read", "write"] + oauth2_settings._SCOPES = ["read", "write", "openid"] oauth2_settings._DEFAULT_SCOPES = ["read"] + oauth2_settings.SCOPES = { + "read": "Reading scope", + "write": "Writing scope", + "openid": "OpenID connect" + } + self.key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) def tearDown(self): self.application.delete() @@ -273,3 +283,197 @@ def test_resource_access_allowed(self): view = ResourceView.as_view() response = view(request) self.assertEqual(response, "This is a protected resource") + + +class TestOpenIDConnectImplicitFlow(BaseTest): + def test_id_token_post_auth_allow(self): + """ + Test authorization code is given for an allowed request with response_type: id_token + """ + self.client.login(username="test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "nonce": "random_nonce_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "id_token", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org#", response["Location"]) + self.assertNotIn("access_token=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + + uri_query = urlparse(response["Location"]).fragment + uri_query_params = dict(parse_qs(uri_query, keep_blank_values=True, strict_parsing=True)) + id_token = uri_query_params["id_token"][0] + jwt_token = jwt.JWT(key=self.key, jwt=id_token) + claims = json.loads(jwt_token.claims) + self.assertIn("nonce", claims) + self.assertNotIn("at_hash", claims) + + def test_id_token_skip_authorization_completely(self): + """ + If application.skip_authorization = True, should skip the authorization page. + """ + self.client.login(username="test_user", password="123456") + self.application.skip_authorization = True + self.application.save() + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "id_token", + "state": "random_state_string", + "nonce": "random_nonce_string", + "scope": "openid", + "redirect_uri": "http://example.org", + }) + + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org#", response["Location"]) + self.assertNotIn("access_token=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + + uri_query = urlparse(response["Location"]).fragment + uri_query_params = dict(parse_qs(uri_query, keep_blank_values=True, strict_parsing=True)) + id_token = uri_query_params["id_token"][0] + jwt_token = jwt.JWT(key=self.key, jwt=id_token) + claims = json.loads(jwt_token.claims) + self.assertIn("nonce", claims) + self.assertNotIn("at_hash", claims) + + def test_id_token_skip_authorization_completely_missing_nonce(self): + """ + If application.skip_authorization = True, should skip the authorization page. + """ + self.client.login(username="test_user", password="123456") + self.application.skip_authorization = True + self.application.save() + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "id_token", + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + }) + + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + self.assertIn("error=invalid_request", response["Location"]) + self.assertIn("error_description=Request+is+missing+mandatory+nonce+paramete", response["Location"]) + + def test_id_token_post_auth_deny(self): + """ + Test error when resource owner deny access + """ + self.client.login(username="test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "nonce": "random_nonce_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "id_token", + "allow": False, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("error=access_denied", response["Location"]) + + def test_access_token_and_id_token_post_auth_allow(self): + """ + Test authorization code is given for an allowed request with response_type: token + """ + self.client.login(username="test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "nonce": "random_nonce_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "id_token token", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org#", response["Location"]) + self.assertIn("access_token=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + + uri_query = urlparse(response["Location"]).fragment + uri_query_params = dict(parse_qs(uri_query, keep_blank_values=True, strict_parsing=True)) + id_token = uri_query_params["id_token"][0] + jwt_token = jwt.JWT(key=self.key, jwt=id_token) + claims = json.loads(jwt_token.claims) + self.assertIn("nonce", claims) + self.assertIn("at_hash", claims) + + def test_access_token_and_id_token_skip_authorization_completely(self): + """ + If application.skip_authorization = True, should skip the authorization page. + """ + self.client.login(username="test_user", password="123456") + self.application.skip_authorization = True + self.application.save() + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "id_token token", + "state": "random_state_string", + "nonce": "random_nonce_string", + "scope": "openid", + "redirect_uri": "http://example.org", + }) + + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org#", response["Location"]) + self.assertIn("access_token=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + + uri_query = urlparse(response["Location"]).fragment + uri_query_params = dict(parse_qs(uri_query, keep_blank_values=True, strict_parsing=True)) + id_token = uri_query_params["id_token"][0] + jwt_token = jwt.JWT(key=self.key, jwt=id_token) + claims = json.loads(jwt_token.claims) + self.assertIn("nonce", claims) + self.assertIn("at_hash", claims) + + def test_access_token_and_id_token_post_auth_deny(self): + """ + Test error when resource owner deny access + """ + self.client.login(username="test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "id_token token", + "allow": False, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("error=access_denied", response["Location"]) diff --git a/tests/test_models.py b/tests/test_models.py index 8adc3b099..ec8e2f9ef 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -9,16 +9,18 @@ from oauth2_provider.models import ( clear_expired, get_access_token_model, get_application_model, - get_grant_model, get_refresh_token_model + get_grant_model, get_refresh_token_model, get_id_token_model ) from oauth2_provider.settings import oauth2_settings +from .models import SampleRefreshToken Application = get_application_model() Grant = get_grant_model() AccessToken = get_access_token_model() RefreshToken = get_refresh_token_model() UserModel = get_user_model() +IDToken = get_id_token_model() class TestModels(TestCase): @@ -301,29 +303,77 @@ class TestClearExpired(TestCase): def setUp(self): self.user = UserModel.objects.create_user("test_user", "test@example.com", "123456") + app1 = Application.objects.create( + name="Test Application", + redirect_uris=( + "http://localhost http://example.com http://example.org custom-scheme://example.com" + ), + user=self.user, + client_type=Application.CLIENT_CONFIDENTIAL, + authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, + ) + app2 = Application.objects.create( + name="Test Application", + redirect_uris=( + "http://localhost http://example.com http://example.org custom-scheme://example.com" + ), + user=self.user, + client_type=Application.CLIENT_CONFIDENTIAL, + authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, + ) + id1 = IDToken.objects.create( + token="666", + expires=dt.now(), + scope=2, + application=app1, + user=self.user, + created=dt.now(), + updated=dt.now(), + ) + id2 = IDToken.objects.create( + token="999", + expires=dt.now(), + scope=2, + application=app2, + user=self.user, + created=dt.now(), + updated=dt.now(), + ) + refresh_token1 = SampleRefreshToken.objects.create( + token="test_token", + application=app1, + user=self.user, + ) + refresh_token2 = SampleRefreshToken.objects.create( + token="test_token2", + application=app2, + user=self.user, + ) # Insert two tokens on database. AccessToken.objects.create( id=1, token="555", expires=dt.now(), scope=2, - application_id=3, - user_id=1, + application=app1, + id_token=id1, + user=self.user, created=dt.now(), updated=dt.now(), - source_refresh_token_id="0", - ) + refresh_token=refresh_token1, + ) AccessToken.objects.create( id=2, token="666", expires=dt.now(), scope=2, - application_id=3, - user_id=1, + application=app2, + user=self.user, + id_token=id2, created=dt.now(), updated=dt.now(), - source_refresh_token_id="1", - ) + refresh_token=refresh_token2, + ) def test_clear_expired_tokens(self): oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = 60 diff --git a/tests/test_oauth2_backends.py b/tests/test_oauth2_backends.py index d844da5f4..2381e9cdc 100644 --- a/tests/test_oauth2_backends.py +++ b/tests/test_oauth2_backends.py @@ -65,7 +65,7 @@ def test_create_token_response_gets_extra_credentials(self): payload = "grant_type=password&username=john&password=123456" request = self.factory.post("/o/token/", payload, content_type="application/x-www-form-urlencoded") - with mock.patch("oauthlib.oauth2.Server.create_token_response") as create_token_response: + with mock.patch("oauthlib.openid.connect.core.endpoints.pre_configured.Server.create_token_response") as create_token_response: mocked = mock.MagicMock() create_token_response.return_value = mocked, mocked, mocked core = self.MyOAuthLibCore() diff --git a/tests/test_oidc_views.py b/tests/test_oidc_views.py new file mode 100644 index 000000000..2105b4fd0 --- /dev/null +++ b/tests/test_oidc_views.py @@ -0,0 +1,47 @@ +from __future__ import unicode_literals + +from django.test import TestCase +from django.urls import reverse + + +class TestConnectDiscoveryInfoView(TestCase): + def test_get_connect_discovery_info(self): + expected_response = { + "issuer": "http://localhost", + "authorization_endpoint": "http://localhost/o/authorize/", + "token_endpoint": "http://localhost/o/token/", + "userinfo_endpoint": "http://localhost/userinfo/", + "jwks_uri": "http://localhost/o/jwks/", + "response_types_supported": [ + "code", + "token", + "id_token", + "id_token token", + "code token", + "code id_token", + "code id_token token" + ], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["RS256", "HS256"], + "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"] + } + response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info")) + self.assertEqual(response.status_code, 200) + assert response.json() == expected_response + + +class TestJwksInfoView(TestCase): + def test_get_jwks_info(self): + expected_response = { + "keys": [{ + "alg": "RS256", + "use": "sig", + "kid": "s4a1o8mFEd1tATAIH96caMlu4hOxzBUaI2QTqbYNBHs", + "e": "AQAB", + "kty": "RSA", + "n": "mwmIeYdjZkLgalTuhvvwjvnB5vVQc7G9DHgOm20Hw524bLVTk49IXJ2Scw42HOmowWWX-oMVT_ca3ZvVIeffVSN1-TxVy2zB65s0wDMwhiMoPv35z9IKHGMZgl9vlyso_2b7daVF_FQDdgIayUn8TQylBxEU1RFfW0QSYOBdAt8" + }] + } + response = self.client.get(reverse("oauth2_provider:jwks-info")) + self.assertEqual(response.status_code, 200) + assert response.json() == expected_response diff --git a/tox.ini b/tox.ini index a492aeaf0..f1156d581 100644 --- a/tox.ini +++ b/tox.ini @@ -11,7 +11,8 @@ envlist = django_find_project = false [testenv] -commands = pytest --cov=oauth2_provider --cov-report= --cov-append {posargs} +commands = + pytest --cov=oauth2_provider --cov-report= --cov-append {posargs} -s setenv = DJANGO_SETTINGS_MODULE = tests.settings PYTHONPATH = {toxinidir} @@ -28,6 +29,7 @@ deps = pytest-django pytest-xdist py27: mock + jwcrypto [testenv:py36-docs] basepython = python