diff --git a/src/saml2/response.py b/src/saml2/response.py index 26963a04e..968639434 100644 --- a/src/saml2/response.py +++ b/src/saml2/response.py @@ -383,7 +383,7 @@ def status_ok(self): logger.info(msg) raise err_cls(msg) - def issue_instant_ok(self): + def issue_instant_ok(self, issue_instant): """ Check that the response was issued at a reasonable time """ upper = time_util.shift_time(time_util.time_in_a_while(days=1), self.timeslack).timetuple() @@ -391,9 +391,65 @@ def issue_instant_ok(self): -self.timeslack).timetuple() # print("issue_instant: %s" % self.response.issue_instant) # print("%s < x < %s" % (lower, upper)) - issued_at = str_to_time(self.response.issue_instant) + issued_at = str_to_time(issue_instant) return lower < issued_at < upper + def issuer_ok(self): + """ Check if the issuer have a valid Format, additional check may be implemented""" + if self.response.issuer and self.response.issuer.format: + if self.response.issuer.format != saml.NAMEID_FORMAT_ENTITY: + return False + return True + + def assertion_ok(self): + """ Check assertions attributes, additional checks may be implemented + + Commented check may be too pedantic + """ + valid = True + if hasattr(self.response, 'assertion'): + # breakpoint() + for ass in self.response.assertion: + # Version + if ass.version and ass.version != '2.0': + raise VersionMismatch(f'{ass.version}') + # IssueInstant + # if hasattr(ass, 'issue_instant') and not self.issue_instant_ok(ass.issue_instant): + # breakpoint() + # raise Exception('Invalid Issue Instant') + # NameQualifier + # if not hasattr(ass.subject.name_id, 'name_qualifier') or \ + # not ass.subject.name_id.name_qualifier: + # raise Exception('Not a valid subject.name_id.name_qualifier') + if hasattr(ass.subject.name_id, 'format'): + if not ass.subject.name_id.format: + raise Exception('Not a valid subject.name_id.format') + + if ass.subject.name_id.format not in dict(saml.NAMEID_FORMATS_SAML2).values(): + msg = 'Not a valid subject.name_id.format: {}' + raise Exception(msg.format(ass.subject.name_id.format)) + + # subject confirmation + for subject_confirmation in ass.subject.subject_confirmation: + if not hasattr(subject_confirmation, 'subject_confirmation_data') or \ + not getattr(subject_confirmation, 'subject_confirmation_data', None): + msg = 'subject_confirmation_data not present' + raise Exception(msg) + + if not subject_confirmation.subject_confirmation_data.in_response_to: + raise Exception('subject.subject_confirmation_data in response -> null data') + + # TODO: match to the recipient + # if self.recipient != subject_confirmation.subject_confirmation_data.recipient: + # msg = 'subject_confirmation_data.recipient not valid: {}' + # raise Exception(msg.format(subject_confirmation_data.recipient)) + + if not hasattr(subject_confirmation.subject_confirmation_data, 'not_on_or_after') or \ + not getattr(subject_confirmation.subject_confirmation_data, 'not_on_or_after', None): + raise Exception('subject.subject_confirmation_data not_on_or_after not valid') + + return valid + def _verify(self): if self.request_id and self.in_response_to and \ self.in_response_to != self.request_id: @@ -416,7 +472,14 @@ def _verify(self): logger.error("%s not in %s", self.response.destination, self.return_addrs) return None - valid = self.issue_instant_ok() and self.status_ok() + valid = all( + ( + self.issue_instant_ok(self.response.issue_instant), + self.issuer_ok(), + self.status_ok(), + self.assertion_ok() + ) + ) return valid def loads(self, xmldata, decode=True, origxml=None): @@ -1116,7 +1179,7 @@ def session_info(self): raise StatusInvalidAuthnResponseStatement( "The Authn Response Statement is not valid" ) - + def __str__(self): return self.xmlstr