From c426082f3f0ac0574ab9a041d7d7b03d537c5d1f Mon Sep 17 00:00:00 2001 From: Tomas Pazderka Date: Sat, 2 Mar 2019 21:13:14 +0100 Subject: [PATCH 1/2] Synced implementation of token_endpoint All three providers (oauth2, oic and extension) now share common code. --- CHANGELOG.md | 2 + src/oic/extension/provider.py | 31 ---- src/oic/oauth2/provider.py | 117 +++++++++++---- src/oic/oic/provider.py | 119 +++++---------- tests/test_oauth2_provider.py | 264 ++++++++++++++++++++++++++++++++- tests/test_oic_provider.py | 272 ++++++++++++---------------------- 6 files changed, 478 insertions(+), 327 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 46e417ba1..a4723ceb1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ The format is based on the [KeepAChangeLog] project. - [#605] Message.c_param dictionary values have to be a ParamDefinition namedtuple type - [#56] Updated README, CLI help texts, pip requirements.txt and such for OP2, making it into a stand-alone example easy for beginners to take on +- [#624] token_endpoint implementation and kwargs have been changed ### Added - [#441] CookieDealer now accepts secure and httponly params @@ -36,6 +37,7 @@ The format is based on the [KeepAChangeLog] project. [#612]: https://github.com/OpenIDC/pyoidc/pull/612 [#618]: https://github.com/OpenIDC/pyoidc/pull/618 [#56]: https://github.com/OpenIDC/pyoidc/issues/56 +[#624]: https://github.com/OpenIDC/pyoidc/pull/624 ## 0.15.1 [2019-01-31] diff --git a/src/oic/extension/provider.py b/src/oic/extension/provider.py index 972a384b3..26113e793 100644 --- a/src/oic/extension/provider.py +++ b/src/oic/extension/provider.py @@ -25,7 +25,6 @@ from oic.extension.message import TokenIntrospectionRequest from oic.extension.message import TokenIntrospectionResponse from oic.extension.message import TokenRevocationRequest -from oic.oauth2 import AccessTokenRequest from oic.oauth2 import AccessTokenResponse from oic.oauth2 import TokenErrorResponse from oic.oauth2 import compact @@ -662,36 +661,6 @@ def refresh_token_grant_type(self, areq): atr = AccessTokenResponse(**by_schema(AccessTokenResponse, **at)) return Response(atr.to_json(), content="application/json") - def token_endpoint(self, authn="", **kwargs): - """Provide clients their access tokens.""" - logger.debug("- token -") - body = kwargs["request"] - logger.debug("body: %s" % body) - - areq = AccessTokenRequest().deserialize(body, "urlencoded") - - try: - self.client_authn(self, areq, authn) - except FailedAuthentication as err: - logger.error(err) - err = TokenErrorResponse(error="unauthorized_client", - error_description="%s" % err) - return Response(err.to_json(), content="application/json", status_code=401) - - logger.debug("AccessTokenRequest: %s" % areq) - - _grant_type = areq["grant_type"] - if _grant_type == "authorization_code": - return self.code_grant_type(areq) - elif _grant_type == 'client_credentials': - return self.client_credentials_grant_type(areq) - elif _grant_type == 'password': - return self.password_grant_type(areq) - elif _grant_type == 'refresh_token': - return self.refresh_token_grant_type(areq) - else: - raise UnSupported('grant_type: {}'.format(_grant_type)) - @staticmethod def token_access(endpoint, client_id, token_info): # simple rules: if client_id in azp or aud it's allow to introspect diff --git a/src/oic/oauth2/provider.py b/src/oic/oauth2/provider.py index bb4c941ab..37dfcd208 100644 --- a/src/oic/oauth2/provider.py +++ b/src/oic/oauth2/provider.py @@ -36,6 +36,7 @@ from oic.oauth2.message import TokenErrorResponse from oic.oauth2.message import add_non_standard from oic.oauth2.message import by_schema +from oic.utils.authn.client import AuthnFailure from oic.utils.authn.user import NoSuchAuthentication from oic.utils.authn.user import TamperAllert from oic.utils.authn.user import ToOld @@ -153,6 +154,8 @@ def re_authenticate(areq, authn): class Provider(object): endp = [AuthorizationEndpoint, TokenEndpoint] + # Define the message class that in token_enpdoint + atr_class = AccessTokenRequest def __init__(self, name, sdb, cdb, authn_broker, authz, client_authn, symkey=None, urlmap=None, iv=0, default_scope="", @@ -768,51 +771,77 @@ def token_scope_check(self, areq, info): """Not implemented here.""" return None - def token_endpoint(self, authn="", **kwargs): - """Provide clients with access tokens.""" - _sdb = self.sdb + def token_endpoint(self, request='', authn='', dtype='urlencoded', **kwargs): + """ + Provide clients with access tokens. + :param authn: Auhentication info, comes from HTTP header. + :param request: The request. + :param dtype: deserialization method for the request. + """ logger.debug("- token -") - body = kwargs["request"] - logger.debug("body: %s" % sanitize(body)) + logger.debug("token_request: %s" % sanitize(request)) - areq = AccessTokenRequest().deserialize(body, "urlencoded") + areq = self.atr_class().deserialize(request, dtype) + # Verify client authentication try: - self.client_authn(self, areq, authn) - except FailedAuthentication as err: + client_id = self.client_authn(self, areq, authn) + except (FailedAuthentication, AuthnFailure) as err: logger.error(err) - err = TokenErrorResponse(error="unauthorized_client", - error_description="%s" % err) - return Response(err.to_json(), content="application/json", status_code=401) + err = TokenErrorResponse(error="unauthorized_client", error_description="%s" % err) + return Unauthorized(err.to_json(), content="application/json") logger.debug("AccessTokenRequest: %s" % sanitize(areq)) - if areq["grant_type"] != "authorization_code": - error = TokenErrorResponse(error="invalid_request", error_description="Wrong grant type") - return Response(error.to_json(), content="application/json", status="401 Unauthorized") - - # assert that the code is valid - _info = _sdb[areq["code"]] - - resp = self.token_scope_check(areq, _info) - if resp: - return resp + # `code` is not mandatory for all requests + if 'code' in areq: + try: + _info = self.sdb[areq["code"]] + except KeyError: + logger.error('Code not present in SessionDB') + error = TokenErrorResponse(error="unauthorized_client") + return Unauthorized(error.to_json(), content="application/json") + + resp = self.token_scope_check(areq, _info) + if resp: + return resp + # If redirect_uri was in the initial authorization request verify that they match + if "redirect_uri" in _info and areq["redirect_uri"] != _info["redirect_uri"]: + logger.error('Redirect_uri mismatch') + error = TokenErrorResponse(error="unauthorized_client") + return Unauthorized(error.to_json(), content="application/json") + if 'state' in areq: + if _info['state'] != areq['state']: + logger.error('State value mismatch') + error = TokenErrorResponse(error="unauthorized_client") + return Unauthorized(error.to_json(), content="application/json") + + # Propagate the client_id further + areq.setdefault('client_id', client_id) + grant_type = areq["grant_type"] + if grant_type == "authorization_code": + return self.code_grant_type(areq) + elif grant_type == "refresh_token": + return self.refresh_token_grant_type(areq) + elif grant_type == 'client_credentials': + return self.client_credentials_grant_type(areq) + elif grant_type == 'password': + return self.password_grant_type(areq) + else: + raise UnSupported('grant_type: {}'.format(grant_type)) - # If redirect_uri was in the initial authorization request - # verify that the one given here is the correct one. - if "redirect_uri" in _info and areq["redirect_uri"] != _info["redirect_uri"]: - logger.error('Redirect_uri mismatch') - error = TokenErrorResponse(error="unauthorized_client") - return Unauthorized(error.to_json(), content="application/json") + def code_grant_type(self, areq): + """ + Token authorization using Code Grant. + RFC6749 section 4.1 + """ try: - _tinfo = _sdb.upgrade_to_token(areq["code"], issue_refresh=True) + _tinfo = self.sdb.upgrade_to_token(areq["code"], issue_refresh=True) except AccessCodeUsed: - error = TokenErrorResponse(error="invalid_grant", - error_description="Access grant used") - return Response(error.to_json(), content="application/json", - status="401 Unauthorized") + error = TokenErrorResponse(error="invalid_grant", error_description="Access grant used") + return Unauthorized(error.to_json(), content="application/json") logger.debug("_tinfo: %s" % sanitize(_tinfo)) @@ -822,6 +851,30 @@ def token_endpoint(self, authn="", **kwargs): return Response(atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS) + def refresh_token_grant_type(self, areq): + """ + Token refresh. + + RFC6749 section 6 + """ + raise NotImplementedError('See oic.extension.provider.') + + def client_credentials_grant_type(self, areq): + """ + Token authorization using client credentials. + + RFC6749 section 4.4 + """ + raise NotImplementedError('See oic.extension.provider.') + + def password_grant_type(self, areq): + """ + Token authorization using Resource owner password credentials. + + RFC6749 section 4.3 + """ + raise NotImplementedError('See oic.extension.provider.') + def verify_endpoint(self, request="", cookie=None, **kwargs): _req = parse_qs(request) try: diff --git a/src/oic/oic/provider.py b/src/oic/oic/provider.py index d045a5dff..aa83aaaba 100644 --- a/src/oic/oic/provider.py +++ b/src/oic/oic/provider.py @@ -67,10 +67,8 @@ from oic.oic.message import OpenIDRequest from oic.oic.message import OpenIDSchema from oic.oic.message import ProviderConfigurationResponse -from oic.oic.message import RefreshAccessTokenRequest from oic.oic.message import RegistrationRequest from oic.oic.message import RegistrationResponse -from oic.oic.message import TokenErrorResponse from oic.utils import sort_sign_alg from oic.utils.http_util import OAUTH2_NOCACHE_HEADERS from oic.utils.http_util import BadRequest @@ -86,6 +84,7 @@ from oic.utils.sdb import AccessCodeUsed from oic.utils.sdb import AuthnEvent from oic.utils.sdb import ExpiredToken +from oic.utils.sdb import WrongTokenType from oic.utils.template_render import render_template from oic.utils.time_util import utc_time_sans_frac @@ -206,6 +205,8 @@ class EndSessionEndpoint(Endpoint): class Provider(AProvider): + atr_class = AccessTokenRequest + def __init__(self, name, sdb, cdb, authn_broker, userinfo, authz, client_authn, symkey=None, urlmap=None, keyjar=None, hostname="", template_lookup=None, template=None, @@ -961,17 +962,19 @@ def sign_encrypt_id_token(self, sinfo, client_info, areq, code=None, return id_token - def _access_token_endpoint(self, req, **kwargs): + def code_grant_type(self, areq): + """ + Token authorization using Code Grant. + RFC6749 section 4.1 + """ _sdb = self.sdb _log_debug = logger.debug - client_info = self.cdb[str(req["client_id"])] - - assert req["grant_type"] == "authorization_code" + client_info = self.cdb[str(areq["client_id"])] try: - _access_code = req["code"].replace(' ', '+') + _access_code = areq["code"].replace(' ', '+') except KeyError: # Missing code parameter - absolutely fatal return error_response('invalid_request', descr='Missing code') @@ -985,27 +988,20 @@ def _access_token_endpoint(self, req, **kwargs): except KeyError: return error_response("invalid_request", descr="Code is invalid") - # If redirect_uri was in the initial authorization request - # verify that the one given here is the correct one. - if "redirect_uri" in _info: - if 'redirect_uri' not in req: - return error_response('invalid_request', descr='Missing redirect_uri') - if req["redirect_uri"] != _info["redirect_uri"]: - return error_response("invalid_request", descr="redirect_uri mismatch") + # If redirect_uri was in the initial authorization request verify that it is here as well + # Mismatch would raise in oic.oauth2.provider.Provider.token_endpoint + if "redirect_uri" in _info and 'redirect_uri' not in areq: + return error_response('invalid_request', descr='Missing redirect_uri') _log_debug("All checks OK") issue_refresh = False - if "issue_refresh" in kwargs: - issue_refresh = kwargs["issue_refresh"] - permissions = _info.get('permission', ['offline_access']) or ['offline_access'] if 'offline_access' in _info['scope'] and 'offline_access' in permissions: issue_refresh = True try: - _tinfo = _sdb.upgrade_to_token(_access_code, - issue_refresh=issue_refresh) + _tinfo = _sdb.upgrade_to_token(_access_code, issue_refresh=issue_refresh) except AccessCodeUsed as err: logger.error("%s" % err) # Should revoke the token issued to this access code @@ -1015,8 +1011,7 @@ def _access_token_endpoint(self, req, **kwargs): if "openid" in _info["scope"]: userinfo = self.userinfo_in_id_token_claims(_info) try: - _idtoken = self.sign_encrypt_id_token( - _info, client_info, req, user_info=userinfo) + _idtoken = self.sign_encrypt_id_token(_info, client_info, areq, user_info=userinfo) except (JWEException, NoSuitableSigningKeys) as err: logger.warning(str(err)) return error_response("invalid_request", descr="Could not sign/encrypt id_token") @@ -1034,25 +1029,30 @@ def _access_token_endpoint(self, req, **kwargs): return Response(atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS) - def _refresh_access_token_endpoint(self, req, **kwargs): + def refresh_token_grant_type(self, areq): + """ + Token refresh. + + RFC6749 section 6 + """ _sdb = self.sdb _log_debug = logger.debug - client_id = str(req['client_id']) + client_id = str(areq['client_id']) client_info = self.cdb[client_id] - assert req["grant_type"] == "refresh_token" - rtoken = req["refresh_token"] + rtoken = areq["refresh_token"] try: _info = _sdb.refresh_token(rtoken, client_id=client_id) except ExpiredToken: return error_response("invalid_request", descr="Refresh token is expired") + except WrongTokenType: + return error_response("invalid_request", descr="Not a refresh token") if "openid" in _info["scope"] and "authn_event" in _info: userinfo = self.userinfo_in_id_token_claims(_info) try: - _idtoken = self.sign_encrypt_id_token( - _info, client_info, req, user_info=userinfo) + _idtoken = self.sign_encrypt_id_token(_info, client_info, areq, user_info=userinfo) except (JWEException, NoSuitableSigningKeys) as err: logger.warning(str(err)) return error_response("invalid_request", descr="Could not sign/encrypt id_token") @@ -1068,64 +1068,23 @@ def _refresh_access_token_endpoint(self, req, **kwargs): return Response(atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS) - def token_endpoint(self, request="", authn=None, dtype='urlencoded', - **kwargs): + def client_credentials_grant_type(self, areq): """ - Give clients their access tokens. + Token authorization using client credentials. - :param request: The request - :param authn: Authentication info, comes from HTTP header - :returns: + RFC6749 section 4.4 """ - logger.debug("- token -") - logger.info("token_request: %s" % sanitize(request)) - - req = AccessTokenRequest().deserialize(request, dtype) - - if 'state' in req: - try: - state = self.sdb[req['code']]['state'] - except KeyError: - logger.error('Code not present in SessionDB') - err = TokenErrorResponse(error="unauthorized_client") - return Unauthorized(err.to_json(), content="application/json") - - if state != req['state']: - logger.error('State value mismatch') - err = TokenErrorResponse(error="unauthorized_client") - return Unauthorized(err.to_json(), content="application/json") + # Not supported in OpenID Connect + return error_response('invalid_request', descr='Unsupported grant_type') - if "refresh_token" in req: - req = RefreshAccessTokenRequest().deserialize(request, dtype) - - logger.debug("%s: %s" % (req.__class__.__name__, sanitize(req))) - - try: - client_id = self.client_authn(self, req, authn) - msg = '' - except Exception as err: - msg = "Failed to verify client due to: {}".format(err) - logger.error(msg) - client_id = "" - - if not client_id: - logger.error('No client_id, authentication failed') - error = TokenErrorResponse(error="unauthorized_client", - error_description=msg) - return Unauthorized(error.to_json(), content="application/json") - - if "client_id" not in req: # Optional for access token request - req["client_id"] = client_id - - if isinstance(req, AccessTokenRequest): - try: - return self._access_token_endpoint(req, **kwargs) - except JWEException as err: - return error_response("invalid_request", - descr="%s" % err) + def password_grant_type(self, areq): + """ + Token authorization using Resource owner password credentials. - else: - return self._refresh_access_token_endpoint(req, **kwargs) + RFC6749 section 4.3 + """ + # Not supported in OpenID Connect + return error_response('invalid_request', descr='Unsupported grant_type') def _collect_user_info(self, session, userinfo_claims=None): """ diff --git a/tests/test_oauth2_provider.py b/tests/test_oauth2_provider.py index 9cb959e63..a31167f31 100644 --- a/tests/test_oauth2_provider.py +++ b/tests/test_oauth2_provider.py @@ -7,11 +7,15 @@ import pytest from testfixtures import LogCapture +from oic.exception import UnSupported from oic.oauth2.consumer import Consumer from oic.oauth2.message import AccessTokenRequest from oic.oauth2.message import AccessTokenResponse from oic.oauth2.message import AuthorizationRequest from oic.oauth2.message import AuthorizationResponse +from oic.oauth2.message import CCAccessTokenRequest +from oic.oauth2.message import Message +from oic.oauth2.message import ROPCAccessTokenRequest from oic.oauth2.message import TokenErrorResponse from oic.oauth2.provider import Provider from oic.utils.authn.authn_context import AuthnBroker @@ -19,6 +23,7 @@ from oic.utils.authn.user import UserAuthnMethod from oic.utils.authz import Implicit from oic.utils.http_util import Response +from oic.utils.sdb import AuthnEvent CLIENT_CONFIG = { "client_id": "client1", @@ -53,6 +58,12 @@ "redirect_uris": [("http://localhost:8087/authz", None)], 'token_endpoint_auth_method': 'client_secret_post', 'response_types': ['code', 'token'] + }, + "client2": { + "client_secret": "verysecret", + "redirect_uris": [("http://localhost:8087/authz", None)], + 'token_endpoint_auth_method': 'client_secret_basic', + 'response_types': ['code', 'token'] } } @@ -254,11 +265,9 @@ def test_token_endpoint(self): assert _eq(atr.keys(), ['access_token', 'token_type', 'refresh_token']) expected = ( - 'body: code=&client_secret=&grant_type' - '=authorization_code' - ' &client_id=client1&redirect_uri=http%3A%2F%2Fexample.com' - '%2Fauthz') - assert _eq(parse_qs(logcap.records[1].msg[6:]), parse_qs(expected[6:])) + 'token_request: code=&client_secret=&grant_type=authorization_code' + '&client_id=client1&redirect_uri=http%3A%2F%2Fexample.com%2Fauthz') + assert _eq(parse_qs(logcap.records[1].msg[15:]), parse_qs(expected[15:])) expected = {u'code': '', u'client_secret': '', u'redirect_uri': u'http://example.com/authz', u'client_id': 'client1', @@ -342,6 +351,251 @@ def test_token_endpoint_unauth(self): atr = TokenErrorResponse().deserialize(resp.message, "json") assert _eq(atr.keys(), ['error_description', 'error']) + def test_token_endpoint_malformed_code(self): + authreq = AuthorizationRequest(state="state", + redirect_uri="http://example.com/authz", + client_id='client1', + response_type="code", + scope=["openid"]) + + _sdb = self.provider.sdb + sid = _sdb.access_token.key(user="sub", areq=authreq) + access_grant = _sdb.access_token(sid=sid) + _sdb[sid] = { + "oauth_state": "authz", + "authn_event": '', + "authzreq": '', + "client_id": 'client1', + "code": access_grant, + "code_used": False, + "scope": ["openid"], + "redirect_uri": "http://example.com/authz", + } + + # Construct Access token request + areq = AccessTokenRequest(code=access_grant[0:len(access_grant) - 1], + client_id='client1', + redirect_uri="http://example.com/authz", + client_secret='hemlighet', + grant_type='authorization_code') + + txt = areq.to_urlencoded() + + resp = self.provider.token_endpoint(request=txt) + atr = TokenErrorResponse().deserialize(resp.message, "json") + assert atr['error'] == "unauthorized_client" + + def test_token_endpoint_bad_redirect_uri(self): + authreq = AuthorizationRequest(state="state", + redirect_uri="http://example.com/authz", + client_id='client1', + response_type="code", + scope=["openid"]) + + _sdb = self.provider.sdb + sid = _sdb.access_token.key(user="sub", areq=authreq) + access_grant = _sdb.access_token(sid=sid) + _sdb[sid] = { + "oauth_state": "authz", + "authn_event": '', + "authzreq": '', + "client_id": 'client1', + "code": access_grant, + "code_used": False, + "scope": ["openid"], + "redirect_uri": "http://example.com/authz", + } + + # Construct Access token request + areq = AccessTokenRequest(code=access_grant, + client_id='client1', + redirect_uri="http://example.com/authz2", + client_secret='hemlighet', + grant_type='authorization_code') + + txt = areq.to_urlencoded() + + resp = self.provider.token_endpoint(request=txt) + atr = TokenErrorResponse().deserialize(resp.message, "json") + assert atr['error'] == "unauthorized_client" + + def test_token_endpoint_ok_state(self): + authreq = AuthorizationRequest(state="state", + redirect_uri="http://example.com/authz", + client_id='client1', + response_type="code", + scope=["openid"]) + + _sdb = self.provider.sdb + sid = _sdb.access_token.key(user="sub", areq=authreq) + access_grant = _sdb.access_token(sid=sid) + ae = AuthnEvent("user", "salt") + _sdb[sid] = { + "oauth_state": "authz", + "authn_event": ae.to_json(), + "authzreq": '', + "client_id": 'client1', + "code": access_grant, + 'state': 'state', + "code_used": False, + "scope": ["openid"], + "redirect_uri": "http://example.com/authz", + } + _sdb.do_sub(sid, "client_salt") + + # Construct Access token request + areq = AccessTokenRequest(code=access_grant, + client_id='client1', + redirect_uri="http://example.com/authz", + client_secret='hemlighet', + grant_type='authorization_code', + state='state') + + txt = areq.to_urlencoded() + + resp = self.provider.token_endpoint(request=txt) + atr = AccessTokenResponse().deserialize(resp.message, "json") + assert atr['token_type'] == "Bearer" + + def test_token_endpoint_bad_state(self): + authreq = AuthorizationRequest(state="state", + redirect_uri="http://example.com/authz", + client_id='client1', + response_type="code", + scope=["openid"]) + + _sdb = self.provider.sdb + sid = _sdb.access_token.key(user="sub", areq=authreq) + access_grant = _sdb.access_token(sid=sid) + _sdb[sid] = { + "oauth_state": "authz", + "authn_event": '', + "authzreq": '', + "client_id": 'client1', + "code": access_grant, + 'state': 'state', + "code_used": False, + "scope": ["openid"], + "redirect_uri": "http://example.com/authz", + } + + # Construct Access token request + areq = AccessTokenRequest(code=access_grant, + client_id='client1', + redirect_uri="http://example.com/authz", + client_secret='hemlighet', + grant_type='authorization_code', + state='other_state') + + txt = areq.to_urlencoded() + + resp = self.provider.token_endpoint(request=txt) + atr = TokenErrorResponse().deserialize(resp.message, "json") + assert atr['error'] == "unauthorized_client" + + def test_token_endpoint_client_credentials(self): + authreq = AuthorizationRequest(state="state", + redirect_uri="http://example.com/authz", + client_id="client1") + + _sdb = self.provider.sdb + sid = _sdb.access_token.key(user="sub", areq=authreq) + access_grant = _sdb.access_token(sid=sid) + _sdb[sid] = { + "oauth_state": "authz", + "sub": "sub", + "authzreq": "", + "client_id": "client1", + "code": access_grant, + "code_used": False, + "redirect_uri": "http://example.com/authz", + 'token_endpoint_auth_method': 'client_secret_basic', + } + areq = CCAccessTokenRequest(grant_type='client_credentials') + authn = 'Basic Y2xpZW50Mjp2ZXJ5c2VjcmV0=' + with pytest.raises(NotImplementedError): + self.provider.token_endpoint(request=areq.to_urlencoded(), authn=authn) + + def test_token_endpoint_password(self): + authreq = AuthorizationRequest(state="state", + redirect_uri="http://example.com/authz", + client_id="client1") + + _sdb = self.provider.sdb + sid = _sdb.access_token.key(user="sub", areq=authreq) + access_grant = _sdb.access_token(sid=sid) + _sdb[sid] = { + "oauth_state": "authz", + "sub": "sub", + "authzreq": "", + "client_id": "client1", + "code": access_grant, + "code_used": False, + "redirect_uri": "http://example.com/authz", + 'token_endpoint_auth_method': 'client_secret_basic', + } + areq = ROPCAccessTokenRequest(grant_type='password', username='client1', password='password') + authn = 'Basic Y2xpZW50Mjp2ZXJ5c2VjcmV0=' + with pytest.raises(NotImplementedError): + self.provider.token_endpoint(request=areq.to_urlencoded(), authn=authn) + + def test_token_endpoint_other(self): + authreq = AuthorizationRequest(state="state", + redirect_uri="http://example.com/authz", + client_id="client1") + + _sdb = self.provider.sdb + sid = _sdb.access_token.key(user="sub", areq=authreq) + access_grant = _sdb.access_token(sid=sid) + _sdb[sid] = { + "oauth_state": "authz", + "sub": "sub", + "authzreq": "", + "client_id": "client1", + "code": access_grant, + "code_used": False, + "redirect_uri": "http://example.com/authz", + 'token_endpoint_auth_method': 'client_secret_basic', + } + areq = Message(grant_type='some_other') + authn = 'Basic Y2xpZW50Mjp2ZXJ5c2VjcmV0=' + with pytest.raises(UnSupported): + self.provider.token_endpoint(request=areq.to_urlencoded(), authn=authn) + + def test_code_grant_type_used(self): + authreq = AuthorizationRequest(state="state", + redirect_uri="http://example.com/authz", + client_id='client1', + response_type="code", + scope=["openid"]) + + _sdb = self.provider.sdb + sid = _sdb.access_token.key(user="sub", areq=authreq) + access_grant = _sdb.access_token(sid=sid) + _sdb[sid] = { + "oauth_state": "authz", + "authn_event": '', + "authzreq": '', + "client_id": 'client1', + "code": access_grant, + "code_used": True, + "scope": ["openid"], + "redirect_uri": "http://example.com/authz", + } + + # Construct Access token request + areq = AccessTokenRequest(code=access_grant, + client_id='client1', + redirect_uri="http://example.com/authz", + client_secret='hemlighet', + grant_type='authorization_code') + + txt = areq.to_urlencoded() + + resp = self.provider.token_endpoint(request=txt) + atr = TokenErrorResponse().deserialize(resp.message, "json") + assert atr['error'] == "invalid_grant" + @pytest.mark.parametrize("response_types", [ ['token id_token', 'id_token'], ['id_token token'] diff --git a/tests/test_oic_provider.py b/tests/test_oic_provider.py index 80162dc16..4d996e2aa 100644 --- a/tests/test_oic_provider.py +++ b/tests/test_oic_provider.py @@ -33,6 +33,7 @@ from oic.oic.message import AuthorizationResponse from oic.oic.message import CheckSessionRequest from oic.oic.message import IdToken +from oic.oic.message import Message from oic.oic.message import OpenIDSchema from oic.oic.message import RefreshAccessTokenRequest from oic.oic.message import RegistrationRequest @@ -43,7 +44,6 @@ from oic.oic.provider import InvalidSectorIdentifier from oic.oic.provider import Provider from oic.utils.authn.authn_context import AuthnBroker -from oic.utils.authn.client import ClientSecretBasic from oic.utils.authn.client import verify_client from oic.utils.authn.user import UserAuthnMethod from oic.utils.authz import AuthzHandling @@ -418,7 +418,7 @@ def test_authenticated_none(self): parsed.path) == "http://localhost:8087/authz" assert "state" in parse_qs(parsed.query) - def test_token_endpoint(self): + def test_code_grant_type_ok(self): authreq = AuthorizationRequest(state="state", redirect_uri="http://example.com/authz", client_id=CLIENT_ID, @@ -439,22 +439,29 @@ def test_token_endpoint(self): "scope": ["openid"], "redirect_uri": "http://example.com/authz", } - _sdb.do_sub(sid, "client_salt") + _sdb.do_sub(sid, 'client_salt') # Construct Access token request areq = AccessTokenRequest(code=access_grant, client_id=CLIENT_ID, - redirect_uri="http://example.com/authz", + redirect_uri='http://example.com/authz', client_secret=CLIENT_SECRET, grant_type='authorization_code') + resp = self.provider.code_grant_type(areq) + atr = AccessTokenResponse().deserialize(resp.message, 'json') + assert _eq(atr.keys(), ['token_type', 'id_token', 'access_token', 'scope']) - txt = areq.to_urlencoded() - - resp = self.provider.token_endpoint(request=txt) - atr = AccessTokenResponse().deserialize(resp.message, "json") - assert _eq(atr.keys(), - ['token_type', 'id_token', 'access_token', 'scope']) + def test_code_grant_type_missing_code(self): + # Construct Access token request + areq = AccessTokenRequest(client_id=CLIENT_ID, + redirect_uri='http://example.com/authz', + client_secret=CLIENT_SECRET, + grant_type='authorization_code') + resp = self.provider.code_grant_type(areq) + atr = TokenErrorResponse().deserialize(resp.message, 'json') + assert atr['error'] == 'invalid_request' + assert atr['error_description'] == 'Missing code' - def test_token_endpoint_no_cache(self): + def test_code_grant_type_revoked(self): authreq = AuthorizationRequest(state="state", redirect_uri="http://example.com/authz", client_id=CLIENT_ID, @@ -471,31 +478,39 @@ def test_token_endpoint_no_cache(self): "authzreq": authreq.to_json(), "client_id": CLIENT_ID, "code": access_grant, - "code_used": False, + "revoked": True, "scope": ["openid"], "redirect_uri": "http://example.com/authz", } - _sdb.do_sub(sid, "client_salt") + _sdb.do_sub(sid, 'client_salt') # Construct Access token request areq = AccessTokenRequest(code=access_grant, client_id=CLIENT_ID, - redirect_uri="http://example.com/authz", + redirect_uri='http://example.com/authz', client_secret=CLIENT_SECRET, grant_type='authorization_code') + resp = self.provider.code_grant_type(areq) + atr = TokenErrorResponse().deserialize(resp.message, 'json') + assert atr['error'] == 'invalid_request' + assert atr['error_description'] == 'Token is revoked' - txt = areq.to_urlencoded() - - resp = self.provider.token_endpoint(request=txt) - assert resp.headers == [('Pragma', 'no-cache'), ('Cache-Control', 'no-store'), - ('Content-type', 'application/json')] + def test_code_grant_type_no_session(self): + # Construct Access token request + areq = AccessTokenRequest(code='some grant', client_id=CLIENT_ID, + redirect_uri='http://example.com/authz', + client_secret=CLIENT_SECRET, + grant_type='authorization_code') + resp = self.provider.code_grant_type(areq) + atr = TokenErrorResponse().deserialize(resp.message, 'json') + assert atr['error'] == 'invalid_request' + assert atr['error_description'] == 'Code is invalid' - def test_token_endpoint_refresh(self): + def test_code_grant_type_missing_redirect_uri(self): authreq = AuthorizationRequest(state="state", redirect_uri="http://example.com/authz", client_id=CLIENT_ID, response_type="code", - scope=["openid offline_access"], - prompt="consent") + scope=["openid"]) _sdb = self.provider.sdb sid = _sdb.access_token.key(user="sub", areq=authreq) @@ -508,26 +523,21 @@ def test_token_endpoint_refresh(self): "client_id": CLIENT_ID, "code": access_grant, "code_used": False, - "scope": ["openid", "offline_access"], + "scope": ["openid"], "redirect_uri": "http://example.com/authz", } - _sdb.do_sub(sid, "client_salt") + _sdb.do_sub(sid, 'client_salt') # Construct Access token request areq = AccessTokenRequest(code=access_grant, client_id=CLIENT_ID, - redirect_uri="http://example.com/authz", client_secret=CLIENT_SECRET, grant_type='authorization_code') + resp = self.provider.code_grant_type(areq) + atr = TokenErrorResponse().deserialize(resp.message, 'json') + assert atr['error'] == 'invalid_request' + assert atr['error_description'] == 'Missing redirect_uri' - txt = areq.to_urlencoded() - - resp = self.provider.token_endpoint(request=txt) - atr = AccessTokenResponse().deserialize(resp.message, "json") - assert _eq(atr.keys(), - ['token_type', 'id_token', 'access_token', 'scope', - 'refresh_token']) - - def test_token_endpoint_malformed(self): + def test_code_grant_type_used(self): authreq = AuthorizationRequest(state="state", redirect_uri="http://example.com/authz", client_id=CLIENT_ID, @@ -544,31 +554,29 @@ def test_token_endpoint_malformed(self): "authzreq": authreq.to_json(), "client_id": CLIENT_ID, "code": access_grant, - "code_used": False, + "code_used": True, "scope": ["openid"], "redirect_uri": "http://example.com/authz", } - _sdb.do_sub(sid, "client_salt") + _sdb.do_sub(sid, 'client_salt') # Construct Access token request - areq = AccessTokenRequest(code=access_grant[0:len(access_grant) - 1], - client_id=CLIENT_ID, - redirect_uri="http://example.com/authz", + areq = AccessTokenRequest(code=access_grant, client_id=CLIENT_ID, + redirect_uri='http://example.com/authz', client_secret=CLIENT_SECRET, grant_type='authorization_code') + resp = self.provider.code_grant_type(areq) + atr = TokenErrorResponse().deserialize(resp.message, 'json') + assert atr['error'] == 'access_denied' + assert atr['error_description'] == 'Access Code already used' - txt = areq.to_urlencoded() - - resp = self.provider.token_endpoint(request=txt) - atr = TokenErrorResponse().deserialize(resp.message, "json") - assert atr['error'] == "invalid_request" - - def test_token_endpoint_bad_code(self): + def test_code_grant_type_refresh(self): authreq = AuthorizationRequest(state="state", redirect_uri="http://example.com/authz", client_id=CLIENT_ID, response_type="code", - scope=["openid"]) + scope=["openid offline_access"], + prompt="consent") _sdb = self.provider.sdb sid = _sdb.access_token.key(user="sub", areq=authreq) @@ -581,94 +589,31 @@ def test_token_endpoint_bad_code(self): "client_id": CLIENT_ID, "code": access_grant, "code_used": False, - "scope": ["openid"], - "state": "state", + "scope": ["openid", "offline_access"], "redirect_uri": "http://example.com/authz", } _sdb.do_sub(sid, "client_salt") # Construct Access token request - areq = AccessTokenRequest(code='bad_code', - client_id=CLIENT_ID, + areq = AccessTokenRequest(code=access_grant, client_id=CLIENT_ID, redirect_uri="http://example.com/authz", client_secret=CLIENT_SECRET, - grant_type='authorization_code', - state="state") - - txt = areq.to_urlencoded() - - resp = self.provider.token_endpoint(request=txt) - atr = TokenErrorResponse().deserialize(resp.message, "json") - assert atr['error'] == "unauthorized_client" - - def test_token_endpoint_unauth(self): - state = 'state' - authreq = AuthorizationRequest(state=state, - redirect_uri="http://example.com/authz", - client_id="client_1") - - _sdb = self.provider.sdb - sid = _sdb.access_token.key(user="sub", areq=authreq) - access_grant = _sdb.access_token(sid=sid) - ae = AuthnEvent("user", "salt") - _sdb[sid] = { - "authn_event": ae.to_json(), - "oauth_state": "authz", - "authzreq": "", - "client_id": "client_1", - "code": access_grant, - "code_used": False, - "scope": ["openid"], - "redirect_uri": "http://example.com/authz", - 'state': state - } - _sdb.do_sub(sid, "client_salt") - - # Construct Access token request - areq = AccessTokenRequest(code=access_grant, - redirect_uri="http://example.com/authz", - client_id="client_1", - client_secret="secret", - state=state, grant_type='authorization_code') + resp = self.provider.code_grant_type(areq) + atr = AccessTokenResponse().deserialize(resp.message, "json") + assert _eq(atr.keys(), ['token_type', 'id_token', 'access_token', 'scope', 'refresh_token']) - txt = areq.to_urlencoded() - - resp = self.provider.token_endpoint(request=txt, remote_user="client2", - request_method="POST") - atr = TokenErrorResponse().deserialize(resp.message, "json") - assert atr["error"] == "unauthorized_client" - - def test_token_endpoint_auth(self): - state, location = self.cons.begin("openid", "code", - path="http://localhost:8087") - - resp = self.provider.authorization_endpoint( - request=urlparse(location).query) - - self.cons.parse_response(AuthorizationResponse, resp.message, - sformat="urlencoded") - - # Construct Access token request - areq = self.cons.construct_AccessTokenRequest( - redirect_uri="http://example.com/authz", - client_id="client_1", - client_secret='abcdefghijklmnop', - state=state) - - txt = areq.to_urlencoded() - self.cons.client_secret = 'drickyoughurt' - - csb = ClientSecretBasic(self.cons) - http_args = csb.construct(areq) - - resp = self.provider.token_endpoint(request=txt, remote_user="client2", - request_method="POST", - authn=http_args['headers'][ - 'Authorization']) + def test_client_credentials_grant_type(self): + resp = self.provider.client_credentials_grant_type(Message()) + parsed = ErrorResponse().from_json(resp.message) + assert parsed['error'] == 'invalid_request' + assert parsed['error_description'] == 'Unsupported grant_type' - atr = TokenErrorResponse().deserialize(resp.message, "json") - assert atr["token_type"] == 'Bearer' + def test_password_grant_type(self): + resp = self.provider.password_grant_type(Message()) + parsed = ErrorResponse().from_json(resp.message) + assert parsed['error'] == 'invalid_request' + assert parsed['error_description'] == 'Unsupported grant_type' def test_authz_endpoint(self): _state, location = self.cons.begin("openid", @@ -1662,7 +1607,7 @@ def test_id_token_RS512_sign(self): id_token = self._auth_with_id_token() assert id_token.jws_header['alg'] == "RS512" - def test_refresh_access_token_request(self): + def test_refresh_token_grant_type_ok(self): authreq = AuthorizationRequest(state="state", redirect_uri="http://example.com/authz", client_id=CLIENT_ID, @@ -1685,74 +1630,43 @@ def test_refresh_access_token_request(self): "redirect_uri": "http://example.com/authz", } _sdb.do_sub(sid, "client_salt") + info = _sdb.upgrade_to_token(access_grant, issue_refresh=True) - # Construct Access token request - areq = AccessTokenRequest(code=access_grant, client_id=CLIENT_ID, - redirect_uri="http://example.com/authz", - client_secret=CLIENT_SECRET, - grant_type='authorization_code') - - txt = areq.to_urlencoded() + rareq = RefreshAccessTokenRequest(grant_type="refresh_token", + refresh_token=info['refresh_token'], + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + scope=['openid']) - resp = self.provider.token_endpoint(request=txt) + resp = self.provider.refresh_token_grant_type(rareq) atr = AccessTokenResponse().deserialize(resp.message, "json") + assert atr['refresh_token'] is not None + assert atr['token_type'] == 'Bearer' + def test_refresh_token_grant_type_wrong_token(self): rareq = RefreshAccessTokenRequest(grant_type="refresh_token", - refresh_token=atr['refresh_token'], + refresh_token='some_other_refresh_token', client_id=CLIENT_ID, client_secret=CLIENT_SECRET, scope=['openid']) - resp = self.provider.token_endpoint(request=rareq.to_urlencoded()) - atr2 = AccessTokenResponse().deserialize(resp.message, "json") - assert atr2['access_token'] != atr['access_token'] - assert atr2['refresh_token'] == atr['refresh_token'] - assert atr2['token_type'] == 'Bearer' - - def test_refresh_access_token_no_cache(self): - authreq = AuthorizationRequest(state="state", - redirect_uri="http://example.com/authz", - client_id=CLIENT_ID, - response_type="code", - scope=["openid", 'offline_access'], - prompt='consent') - - _sdb = self.provider.sdb - sid = _sdb.access_token.key(user="sub", areq=authreq) - access_grant = _sdb.access_token(sid=sid) - ae = AuthnEvent("user", "salt") - _sdb[sid] = { - "oauth_state": "authz", - "authn_event": ae.to_json(), - "authzreq": authreq.to_json(), - "client_id": CLIENT_ID, - "code": access_grant, - "code_used": False, - "scope": ["openid", 'offline_access'], - "redirect_uri": "http://example.com/authz", - } - _sdb.do_sub(sid, "client_salt") - - # Construct Access token request - areq = AccessTokenRequest(code=access_grant, client_id=CLIENT_ID, - redirect_uri="http://example.com/authz", - client_secret=CLIENT_SECRET, - grant_type='authorization_code') - - txt = areq.to_urlencoded() - - resp = self.provider.token_endpoint(request=txt) - atr = AccessTokenResponse().deserialize(resp.message, "json") + resp = self.provider.refresh_token_grant_type(rareq) + atr = TokenErrorResponse().deserialize(resp.message, "json") + assert atr['error'] == 'invalid_request' + assert atr['error_description'] == 'Not a refresh token' + def test_refresh_token_grant_type_expired(self): + # Missing refresh_token also raises Expired rareq = RefreshAccessTokenRequest(grant_type="refresh_token", - refresh_token=atr['refresh_token'], + refresh_token='Refresh_some_other_refresh_token', client_id=CLIENT_ID, client_secret=CLIENT_SECRET, scope=['openid']) - resp = self.provider.token_endpoint(request=rareq.to_urlencoded()) - assert resp.headers == [('Pragma', 'no-cache'), ('Cache-Control', 'no-store'), - ('Content-type', 'application/json')] + resp = self.provider.refresh_token_grant_type(rareq) + atr = TokenErrorResponse().deserialize(resp.message, "json") + assert atr['error'] == 'invalid_request' + assert atr['error_description'] == 'Refresh token is expired' def test_authorization_endpoint_faulty_request_uri(self): bib = {"scope": ["openid"], From e5ff129db2e12453dad40431418ade68e035360d Mon Sep 17 00:00:00 2001 From: Tomas Pazderka Date: Wed, 6 Mar 2019 21:59:56 +0100 Subject: [PATCH 2/2] fixup! Synced implementation of token_endpoint --- src/oic/oauth2/provider.py | 20 +++++++++++++------- src/oic/oic/provider.py | 4 ++-- tests/test_oauth2_provider.py | 10 ++++++---- tests/test_oic_provider.py | 4 ++-- 4 files changed, 23 insertions(+), 15 deletions(-) diff --git a/src/oic/oauth2/provider.py b/src/oic/oauth2/provider.py index 37dfcd208..f6a703a0c 100644 --- a/src/oic/oauth2/provider.py +++ b/src/oic/oauth2/provider.py @@ -154,7 +154,7 @@ def re_authenticate(areq, authn): class Provider(object): endp = [AuthorizationEndpoint, TokenEndpoint] - # Define the message class that in token_enpdoint + # Define the message class that in token_endpoint atr_class = AccessTokenRequest def __init__(self, name, sdb, cdb, authn_broker, authz, client_authn, @@ -800,7 +800,8 @@ def token_endpoint(self, request='', authn='', dtype='urlencoded', **kwargs): _info = self.sdb[areq["code"]] except KeyError: logger.error('Code not present in SessionDB') - error = TokenErrorResponse(error="unauthorized_client") + error = TokenErrorResponse(error="unauthorized_client", + error_description='Invalid code.') return Unauthorized(error.to_json(), content="application/json") resp = self.token_scope_check(areq, _info) @@ -809,12 +810,14 @@ def token_endpoint(self, request='', authn='', dtype='urlencoded', **kwargs): # If redirect_uri was in the initial authorization request verify that they match if "redirect_uri" in _info and areq["redirect_uri"] != _info["redirect_uri"]: logger.error('Redirect_uri mismatch') - error = TokenErrorResponse(error="unauthorized_client") + error = TokenErrorResponse(error="unauthorized_client", + error_description='Redirect_uris do not match.') return Unauthorized(error.to_json(), content="application/json") if 'state' in areq: if _info['state'] != areq['state']: logger.error('State value mismatch') - error = TokenErrorResponse(error="unauthorized_client") + error = TokenErrorResponse(error="unauthorized_client", + error_description='State values do not match.') return Unauthorized(error.to_json(), content="application/json") # Propagate the client_id further @@ -857,7 +860,8 @@ def refresh_token_grant_type(self, areq): RFC6749 section 6 """ - raise NotImplementedError('See oic.extension.provider.') + # This is not implemented here, please see oic.extension.provider. + return error_response('unsupported_grant_type', descr='Unsupported grant_type') def client_credentials_grant_type(self, areq): """ @@ -865,7 +869,8 @@ def client_credentials_grant_type(self, areq): RFC6749 section 4.4 """ - raise NotImplementedError('See oic.extension.provider.') + # This is not implemented here, please see oic.extension.provider. + return error_response('unsupported_grant_type', descr='Unsupported grant_type') def password_grant_type(self, areq): """ @@ -873,7 +878,8 @@ def password_grant_type(self, areq): RFC6749 section 4.3 """ - raise NotImplementedError('See oic.extension.provider.') + # This is not implemented here, please see oic.extension.provider. + return error_response('unsupported_grant_type', descr='Unsupported grant_type') def verify_endpoint(self, request="", cookie=None, **kwargs): _req = parse_qs(request) diff --git a/src/oic/oic/provider.py b/src/oic/oic/provider.py index aa83aaaba..ceabcdb37 100644 --- a/src/oic/oic/provider.py +++ b/src/oic/oic/provider.py @@ -1075,7 +1075,7 @@ def client_credentials_grant_type(self, areq): RFC6749 section 4.4 """ # Not supported in OpenID Connect - return error_response('invalid_request', descr='Unsupported grant_type') + return error_response('unsupported_grant_type', descr='Unsupported grant_type') def password_grant_type(self, areq): """ @@ -1084,7 +1084,7 @@ def password_grant_type(self, areq): RFC6749 section 4.3 """ # Not supported in OpenID Connect - return error_response('invalid_request', descr='Unsupported grant_type') + return error_response('unsupported_grant_type', descr='Unsupported grant_type') def _collect_user_info(self, session, userinfo_claims=None): """ diff --git a/tests/test_oauth2_provider.py b/tests/test_oauth2_provider.py index a31167f31..2808d8a08 100644 --- a/tests/test_oauth2_provider.py +++ b/tests/test_oauth2_provider.py @@ -513,8 +513,9 @@ def test_token_endpoint_client_credentials(self): } areq = CCAccessTokenRequest(grant_type='client_credentials') authn = 'Basic Y2xpZW50Mjp2ZXJ5c2VjcmV0=' - with pytest.raises(NotImplementedError): - self.provider.token_endpoint(request=areq.to_urlencoded(), authn=authn) + resp = self.provider.token_endpoint(request=areq.to_urlencoded(), authn=authn) + parsed = TokenErrorResponse().from_json(resp.message) + assert parsed['error'] == "unsupported_grant_type" def test_token_endpoint_password(self): authreq = AuthorizationRequest(state="state", @@ -536,8 +537,9 @@ def test_token_endpoint_password(self): } areq = ROPCAccessTokenRequest(grant_type='password', username='client1', password='password') authn = 'Basic Y2xpZW50Mjp2ZXJ5c2VjcmV0=' - with pytest.raises(NotImplementedError): - self.provider.token_endpoint(request=areq.to_urlencoded(), authn=authn) + resp = self.provider.token_endpoint(request=areq.to_urlencoded(), authn=authn) + parsed = TokenErrorResponse().from_json(resp.message) + assert parsed['error'] == "unsupported_grant_type" def test_token_endpoint_other(self): authreq = AuthorizationRequest(state="state", diff --git a/tests/test_oic_provider.py b/tests/test_oic_provider.py index 4d996e2aa..e2fbb8b98 100644 --- a/tests/test_oic_provider.py +++ b/tests/test_oic_provider.py @@ -606,13 +606,13 @@ def test_code_grant_type_refresh(self): def test_client_credentials_grant_type(self): resp = self.provider.client_credentials_grant_type(Message()) parsed = ErrorResponse().from_json(resp.message) - assert parsed['error'] == 'invalid_request' + assert parsed['error'] == 'unsupported_grant_type' assert parsed['error_description'] == 'Unsupported grant_type' def test_password_grant_type(self): resp = self.provider.password_grant_type(Message()) parsed = ErrorResponse().from_json(resp.message) - assert parsed['error'] == 'invalid_request' + assert parsed['error'] == 'unsupported_grant_type' assert parsed['error_description'] == 'Unsupported grant_type' def test_authz_endpoint(self):