Skip to content
This repository was archived by the owner on Jun 1, 2023. It is now read-only.

Needed to complete certification. #42

Merged
merged 9 commits into from
Oct 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/oidcmsg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__author__ = "Roland Hedberg"
__version__ = "1.4.0"
__version__ = "1.4.1"

import os
from typing import Dict
Expand Down
3 changes: 2 additions & 1 deletion src/oidcmsg/impexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List
from typing import Optional

from cryptojwt import as_unicode
from cryptojwt.utils import as_bytes
from cryptojwt.utils import importer
from cryptojwt.utils import qualified_name
Expand All @@ -25,7 +26,7 @@ def __init__(self):
def dump_attr(self, cls, item, exclude_attributes: Optional[List[str]] = None) -> dict:
if cls in [None, 0, "", [], {}, bool, b'']:
if cls == b'':
val = as_bytes(item)
val = as_unicode(item)
else:
val = item
elif cls == "DICT_TYPE":
Expand Down
68 changes: 39 additions & 29 deletions src/oidcmsg/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,10 @@ def from_dict(self, dictionary, **kwargs):
self._dict[key] = val
continue

self._add_value(skey, vtyp, key, val, _deser, null_allowed)
self._add_value(skey, vtyp, key, val, _deser, null_allowed, sformat="dict")
return self

def _add_value(self, skey, vtyp, key, val, _deser, null_allowed):
def _add_value(self, skey, vtyp, key, val, _deser, null_allowed, sformat="urlencoded"):
"""
Main method for adding a value to the instance. Does all the
checking on type of value and if among allowed values.
Expand Down Expand Up @@ -350,7 +350,7 @@ def _add_value(self, skey, vtyp, key, val, _deser, null_allowed):
self._dict[skey] = [val]
elif _deser:
try:
self._dict[skey] = _deser(val, sformat="urlencoded")
self._dict[skey] = _deser(val, sformat=sformat)
except Exception as exc:
raise DecodeError(ERRTXT % (key, exc))
else:
Expand Down Expand Up @@ -402,16 +402,6 @@ def _add_value(self, skey, vtyp, key, val, _deser, null_allowed):
except Exception as exc:
raise DecodeError(ERRTXT % (key, exc))
else:
# if isinstance(val, str):
# self._dict[skey] = val
# elif isinstance(val, list):
# if len(val) == 1:
# self._dict[skey] = val[0]
# elif not len(val):
# pass
# else:
# raise TooManyValues(key)
# else:
self._dict[skey] = val
elif vtyp is int:
try:
Expand Down Expand Up @@ -468,6 +458,28 @@ def to_jwt(self, key=None, algorithm="", lev=0, lifetime=0):
_jws = JWS(self.to_json(lev), alg=algorithm)
return _jws.sign_compact(key)

def _gather_keys(self, keyjar, jwt, header, **kwargs):
key = []

if keyjar:
_keys = keyjar.get_jwt_verify_keys(jwt, **kwargs)
if not _keys:
keyjar.update()
_keys = keyjar.get_jwt_verify_keys(jwt, **kwargs)
key.extend(_keys)

if "alg" in header and header["alg"] != "none":
if not key:
if keyjar:
keyjar.update()
key = keyjar.get_jwt_verify_keys(jwt, **kwargs)
if not key:
raise MissingSigningKey("alg=%s" % header["alg"])
else:
raise MissingSigningKey("alg=%s" % header["alg"])

return key

def from_jwt(self, txt, keyjar, verify=True, **kwargs):
"""
Given a signed and/or encrypted JWT, verify its correctness and then
Expand Down Expand Up @@ -515,7 +527,6 @@ def from_jwt(self, txt, keyjar, verify=True, **kwargs):
jso = _jwt.payload()
_header = _jwt.headers

key = []
# if "sender" in kwargs:
# key.extend(keyjar.get_verify_key(owner=kwargs["sender"]))

Expand All @@ -524,21 +535,13 @@ def from_jwt(self, txt, keyjar, verify=True, **kwargs):
if _header["alg"] == "none":
pass
elif verify:
if keyjar:
key.extend(keyjar.get_jwt_verify_keys(_jwt, **kwargs))
key = self._gather_keys(keyjar, _jwt, _header, **kwargs)

if "alg" in _header and _header["alg"] != "none":
if not key:
raise MissingSigningKey("alg=%s" % _header["alg"])
if not key:
raise MissingSigningKey("alg=%s" % _header["alg"])

logger.debug("Found signing key.")
try:
_verifier.verify_compact(txt, key)
except NoSuitableSigningKeys:
if keyjar:
keyjar.update()
key = keyjar.get_jwt_verify_keys(_jwt, **kwargs)
_verifier.verify_compact(txt, key)
_verifier.verify_compact(txt, key)

self.jws_header = _jwt.headers
else:
Expand Down Expand Up @@ -850,8 +853,12 @@ def add_non_standard(msg1, msg2):


def list_serializer(vals, sformat="urlencoded", lev=0):
if isinstance(vals, str) or not isinstance(vals, list):
if isinstance(vals, str) and sformat == "dict":
return [vals]

if not isinstance(vals, list):
raise ValueError("Expected list: %s" % vals)

if sformat == "urlencoded":
return " ".join(vals)
else:
Expand All @@ -864,8 +871,11 @@ def list_deserializer(val, sformat="urlencoded"):
return val.split(" ")
elif isinstance(val, list) and len(val) == 1:
return val[0].split(" ")
else:
return val
elif sformat == "dict":
if isinstance(val, str):
val = [val]

return val


def sp_sep_list_serializer(vals, sformat="urlencoded", lev=0):
Expand Down
12 changes: 2 additions & 10 deletions src/oidcmsg/oidc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ class RegistrationRequest(Message):
# "client_id": SINGLE_OPTIONAL_STRING,
# "client_secret": SINGLE_OPTIONAL_STRING,
# "access_token": SINGLE_OPTIONAL_STRING,
"post_logout_redirect_uris": OPTIONAL_LIST_OF_STRINGS,
"post_logout_redirect_uri": SINGLE_OPTIONAL_STRING,
"frontchannel_logout_uri": SINGLE_OPTIONAL_STRING,
"frontchannel_logout_session_required": SINGLE_OPTIONAL_BOOLEAN,
"backchannel_logout_uri": SINGLE_OPTIONAL_STRING,
Expand Down Expand Up @@ -771,14 +771,6 @@ def pack(self, alg="", **kwargs):
else:
self.pack_init()

# if 'jti' in self.c_param:
# try:
# _jti = kwargs['jti']
# except KeyError:
# _jti = uuid.uuid4().hex
#
# self['jti'] = _jti

def to_jwt(self, key=None, algorithm="", lev=0, lifetime=0):
self.pack(alg=algorithm, lifetime=lifetime)
return Message.to_jwt(self, key=key, algorithm=algorithm, lev=lev)
Expand All @@ -797,7 +789,7 @@ def verify(self, **kwargs):
# check that I'm among the recipients
if kwargs["client_id"] not in self["aud"]:
raise NotForMe(
"{} not in aud:{}".format(kwargs["client_id"], self["aud"]), self
'"{}" not in {}'.format(kwargs["client_id"], self["aud"]), self
)

# Then azp has to be present and be one of the aud values
Expand Down
2 changes: 1 addition & 1 deletion src/oidcmsg/oidc/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,6 @@ def verify(self, **kwargs):
return False

self[verified_claim_name("logout_token")] = idt
logger.info("Verified ID Token: {}".format(idt.to_dict()))
logger.info("Verified Logout Token: {}".format(idt.to_dict()))

return True
29 changes: 25 additions & 4 deletions tests/test_06_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,23 @@
from urllib.parse import parse_qs
from urllib.parse import urlencode

import pytest
from cryptojwt.exception import BadSignature
from cryptojwt.exception import UnsupportedAlgorithm
from cryptojwt.jws.exception import SignerAlgError
from cryptojwt.jws.utils import left_hash
from cryptojwt.jwt import JWT
from cryptojwt.key_bundle import KeyBundle
from cryptojwt.key_jar import KeyJar
import pytest

from oidcmsg import proper_path
from oidcmsg import time_util
from oidcmsg.exception import MessageException
from oidcmsg.exception import MissingRequiredAttribute
from oidcmsg.exception import NotAllowedValue
from oidcmsg.exception import OidcMsgError
from oidcmsg.oauth2 import ResponseMessage
from oidcmsg.oauth2 import ROPCAccessTokenRequest
from oidcmsg.oidc import JRD
from oidcmsg.oauth2 import ResponseMessage
from oidcmsg.oidc import AccessTokenRequest
from oidcmsg.oidc import AccessTokenResponse
from oidcmsg.oidc import AddressClaim
Expand All @@ -38,6 +37,7 @@
from oidcmsg.oidc import EXPError
from oidcmsg.oidc import IATError
from oidcmsg.oidc import IdToken
from oidcmsg.oidc import JRD
from oidcmsg.oidc import Link
from oidcmsg.oidc import OpenIDSchema
from oidcmsg.oidc import ProviderConfigurationResponse
Expand Down Expand Up @@ -661,7 +661,7 @@ def test_deserialize(self):
"client_secret_expires_at": 1577858400,
"registration_access_token": "this.is.an.access.token.value.ffx83",
"registration_client_uri": "https://server.example.com/connect/register?client_id"
"=s6BhdRkqt3",
"=s6BhdRkqt3",
"token_endpoint_auth_method": "client_secret_basic",
"application_type": "web",
"redirect_uris": [
Expand Down Expand Up @@ -1601,3 +1601,24 @@ def test_correct_sign_alg():
client_id="554295ce3770612820620000",
allowed_sign_alg="HS256",
)


def test_ID_Token_space_in_id():
idt = IdToken(**{
"at_hash": "buCCujNN632UIV8-VbKhgw",
"sub": "user-subject-1234531",
"aud": "client_ifCttPphtLxtPWd20602 ^.+/",
"iss": "https://www.certification.openid.net/test/a/idpy/",
"exp": 1632495959,
"nonce": "B88En9UpdHkQZMQXK9U3KHzV",
"iat": 1632495659
})

assert idt["aud"] == ["client_ifCttPphtLxtPWd20602 ^.+/"]

idt = IdToken(**{'at_hash': 'rgMbiR-Dj11dQjxhCyLkOw', 'sub': 'user-subject-1234531',
'aud': 'client_dVCwIQuSKklinFP70742;#__$',
'iss': 'https://www.certification.openid.net/test/a/idpy/', 'exp': 1632639462,
'nonce': 'hUT3RhSooxC9CilrD8al6bGx', 'iat': 1632639162})

assert idt["aud"] == ["client_dVCwIQuSKklinFP70742;#__$"]