From a9fe345ba1bb4ef516673b1acbd01db80ccdf658 Mon Sep 17 00:00:00 2001 From: Ivan Kanakarakis Date: Tue, 14 Feb 2023 15:01:56 +0200 Subject: [PATCH] Update behaviour of subject-id requirements entity attribute When the subject-id requiment is "any", both the subject-id and pairwise-id should be processsed. Signed-off-by: Ivan Kanakarakis --- src/saml2/assertion.py | 11 ++++--- src/saml2/mdstore.py | 70 +++++++++++++++++++++++++++------------- tests/test_30_mdstore.py | 23 ++++++++++--- 3 files changed, 71 insertions(+), 33 deletions(-) diff --git a/src/saml2/assertion.py b/src/saml2/assertion.py index 46733f934..54d96172c 100644 --- a/src/saml2/assertion.py +++ b/src/saml2/assertion.py @@ -556,11 +556,12 @@ def restrict(self, ava, sp_entity_id, metadata=None): metadata_store = metadata or self.metadata_store spec = metadata_store.attribute_requirement(sp_entity_id) or {} if metadata_store else {} - required_attributes = spec.get("required", []) - optional_attributes = spec.get("optional", []) - required_subject_id = metadata_store.subject_id_requirement(sp_entity_id) if metadata_store else None - if required_subject_id and required_subject_id not in required_attributes: - required_attributes.append(required_subject_id) + required_attributes = spec.get("required") or [] + optional_attributes = spec.get("optional") or [] + requirements_subject_id = metadata_store.subject_id_requirement(sp_entity_id) if metadata_store else [] + for r in requirements_subject_id: + if r not in required_attributes: + required_attributes.extend(r) return self.filter( ava, sp_entity_id, diff --git a/src/saml2/mdstore.py b/src/saml2/mdstore.py index 7519a20e4..4126e49e0 100644 --- a/src/saml2/mdstore.py +++ b/src/saml2/mdstore.py @@ -200,14 +200,15 @@ def all_locations(srvs): return values -def attribute_requirement(entity, index=None): +def attribute_requirement(entity_descriptor, index=None): res = {"required": [], "optional": []} - for acs in entity["attribute_consuming_service"]: + acss = entity_descriptor.get("attribute_consuming_service") or [] + for acs in acss: if index is not None and acs["index"] != index: continue for attr in acs["requested_attribute"]: - if "is_required" in attr and attr["is_required"] == "true": + if attr.get("is_required") == "true": res["required"].append(attr) else: res["optional"].append(attr) @@ -676,24 +677,26 @@ def service(self, entity_id, typ, service, binding=None): return res def attribute_requirement(self, entity_id, index=None): - """Returns what attributes the SP requires and which are optional + """ + Returns what attributes the SP requires and which are optional if any such demands are registered in the Metadata. + In case the metadata have multiple SPSSODescriptor elements, + the sum of the required and optional attributes is returned. + :param entity_id: The entity id of the SP :param index: which of the attribute consumer services its all about if index=None then return all attributes expected by all attribute_consuming_services. - :return: 2-tuple, list of required and list of optional attributes + :return: dict of required and optional list of attributes """ res = {"required": [], "optional": []} - try: - for sp in self[entity_id]["spsso_descriptor"]: - _res = attribute_requirement(sp, index) - res["required"].extend(_res["required"]) - res["optional"].extend(_res["optional"]) - except KeyError: - return None + sp_descriptors = self[entity_id].get("spsso_descriptor") or [] + for sp_desc in sp_descriptors: + _res = attribute_requirement(sp_desc, index) + res["required"].extend(_res.get("required") or []) + res["optional"].extend(_res.get("optional") or []) return res @@ -1297,35 +1300,56 @@ def discovery_response(self, entity_id, binding=None, _="spsso"): ) def attribute_requirement(self, entity_id, index=None): - for _md in self.metadata.values(): - if entity_id in _md: - return _md.attribute_requirement(entity_id, index) + for md_source in self.metadata.values(): + if entity_id in md_source: + return md_source.attribute_requirement(entity_id, index) def subject_id_requirement(self, entity_id): try: entity_attributes = self.entity_attributes(entity_id) except KeyError: - return None + return [] - if "urn:oasis:names:tc:SAML:profiles:subject-id:req" in entity_attributes: - subject_id_req = entity_attributes["urn:oasis:names:tc:SAML:profiles:subject-id:req"][0] - if subject_id_req == "any" or subject_id_req == "pairwise-id": - return { + subject_id_reqs = entity_attributes.get("urn:oasis:names:tc:SAML:profiles:subject-id:req") or [] + subject_id_req = next(iter(subject_id_reqs), None) + if subject_id_req == "any": + return [ + { + "__class__": "urn:oasis:names:tc:SAML:2.0:metadata&RequestedAttribute", + "name": "urn:oasis:names:tc:SAML:attribute:pairwise-id", + "name_format": "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + "friendly_name": "pairwise-id", + "is_required": "true", + }, + { + "__class__": "urn:oasis:names:tc:SAML:2.0:metadata&RequestedAttribute", + "name": "urn:oasis:names:tc:SAML:attribute:subject-id", + "name_format": "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + "friendly_name": "subject-id", + "is_required": "true", + } + ] + elif subject_id_req == "pairwise-id": + return [ + { "__class__": "urn:oasis:names:tc:SAML:2.0:metadata&RequestedAttribute", "name": "urn:oasis:names:tc:SAML:attribute:pairwise-id", "name_format": "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", "friendly_name": "pairwise-id", "is_required": "true", } - elif subject_id_req == "subject-id": - return { + ] + elif subject_id_req == "subject-id": + return [ + { "__class__": "urn:oasis:names:tc:SAML:2.0:metadata&RequestedAttribute", "name": "urn:oasis:names:tc:SAML:attribute:subject-id", "name_format": "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", "friendly_name": "subject-id", "is_required": "true", } - return None + ] + return [] def keys(self): res = [] diff --git a/tests/test_30_mdstore.py b/tests/test_30_mdstore.py index 013a6062f..6f6632014 100644 --- a/tests/test_30_mdstore.py +++ b/tests/test_30_mdstore.py @@ -664,11 +664,24 @@ def test_subject_id_requirement(): mds = MetadataStore(ATTRCONV, sec_config, disable_ssl_certificate_validation=True) mds.imp(METADATACONF["17"]) required_subject_id = mds.subject_id_requirement(entity_id="https://esi-coco.example.edu/saml2/metadata/") - assert required_subject_id["__class__"] == "urn:oasis:names:tc:SAML:2.0:metadata&RequestedAttribute" - assert required_subject_id["name"] == "urn:oasis:names:tc:SAML:attribute:pairwise-id" - assert required_subject_id["name_format"] == "urn:oasis:names:tc:SAML:2.0:attrname-format:uri" - assert required_subject_id["friendly_name"] == "pairwise-id" - assert required_subject_id["is_required"] == "true" + expected = [ + { + "__class__": "urn:oasis:names:tc:SAML:2.0:metadata&RequestedAttribute", + "name": "urn:oasis:names:tc:SAML:attribute:pairwise-id", + "name_format": "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + "friendly_name": "pairwise-id", + "is_required": "true", + }, + { + "__class__": "urn:oasis:names:tc:SAML:2.0:metadata&RequestedAttribute", + "name": "urn:oasis:names:tc:SAML:attribute:subject-id", + "name_format": "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + "friendly_name": "subject-id", + "is_required": "true", + }, + ] + assert required_subject_id + assert all(e in expected for e in required_subject_id) def test_extension():