diff --git a/src/idpyoidc/server/oidc/userinfo.py b/src/idpyoidc/server/oidc/userinfo.py index 32d77506..281b669d 100755 --- a/src/idpyoidc/server/oidc/userinfo.py +++ b/src/idpyoidc/server/oidc/userinfo.py @@ -64,7 +64,6 @@ def do_response( client_id: Optional[str] = "", **kwargs, ) -> dict: - if "error" in kwargs and kwargs["error"]: return Endpoint.do_response(self, response_args, request, **kwargs) @@ -126,44 +125,35 @@ def process_request(self, request=None, **kwargs): return self.error_cls(error="invalid_token", error_description="Invalid Token") _grant = _session_info["grant"] - token = _grant.get_token(request["access_token"]) - # should be an access token - if token and token.token_class != "access_token": - return self.error_cls(error="invalid_token", error_description="Wrong type of token") + access_token = _grant.get_token(request["access_token"]) - # And it should be valid - if token.is_active() is False: + # there must be a token + if not access_token: return self.error_cls(error="invalid_token", error_description="Invalid Token") - allowed = True - _auth_event = _grant.authentication_event - # if the authentication is still active or offline_access is granted. - if not _auth_event["valid_until"] >= utc_time_sans_frac(): - logger.debug( - "authentication not valid: {} > {}".format( - datetime.fromtimestamp(_auth_event["valid_until"]), - datetime.fromtimestamp(utc_time_sans_frac()), - ) - ) - allowed = False + # the token must be an access_token + if access_token.token_class != "access_token": + return self.error_cls(error="invalid_token", error_description="Wrong type of token") - # This has to be made more finegrained. - # if "offline_access" in session["authn_req"]["scope"]: - # pass + # the access_token must be valid + if access_token.is_active() is False: + return self.error_cls(error="invalid_token", error_description="Invalid Token") + + # the access_token must contain the openid scope + if "openid" not in access_token.scope: + return self.error_cls(error="invalid_token", error_description="Invalid Token") _cntxt = self.upstream_get("context") - if allowed: - _claims_restriction = _cntxt.claims_interface.get_claims( - _session_info["branch_id"], scopes=token.scope, claims_release_point="userinfo" - ) - info = _cntxt.claims_interface.get_user_claims( - _session_info["user_id"], - claims_restriction=_claims_restriction, - client_id=_session_info["client_id"] - ) - info["sub"] = _grant.sub - if _grant.add_acr_value("userinfo"): - info["acr"] = _grant.authentication_event["authn_info"] + _claims_restriction = _cntxt.claims_interface.get_claims( + _session_info["branch_id"], scopes=access_token.scope, claims_release_point="userinfo" + ) + info = _cntxt.claims_interface.get_user_claims( + _session_info["user_id"], claims_restriction=_claims_restriction, + client_id=_session_info["client_id"] + ) + info["sub"] = _grant.sub + if _grant.add_acr_value("userinfo"): + info["acr"] = _grant.authentication_event["authn_info"] extra_claims = kwargs.get("extra_claims") if extra_claims: @@ -173,7 +163,7 @@ def process_request(self, request=None, **kwargs): self.config["policy"] = _cntxt.cdb[request["client_id"]]["userinfo"]["policy"] if "policy" in self.config: - info = self._enforce_policy(request, info, token, self.config) + info = self._enforce_policy(request, info, access_token, self.config) return {"response_args": info, "client_id": _session_info["client_id"]} @@ -213,7 +203,7 @@ def parse_request(self, request, http_info=None, **kwargs): def _enforce_policy(self, request, response_info, token, config): policy = config["policy"] callable = policy["function"] - kwargs = policy.get("kwargs", {}) + kwargs = policy.get("kwargs") or {} if isinstance(callable, str): try: