diff --git a/authlib/integrations/base_client/sync_openid.py b/authlib/integrations/base_client/sync_openid.py index ac51907a..1611e24d 100644 --- a/authlib/integrations/base_client/sync_openid.py +++ b/authlib/integrations/base_client/sync_openid.py @@ -33,15 +33,8 @@ def parse_id_token(self, token, nonce, claims_options=None, leeway=120): """Return an instance of UserInfo from token's ``id_token``.""" if 'id_token' not in token: return None - - def load_key(header, _): - jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set()) - try: - return jwk_set.find_by_kid(header.get('kid')) - except ValueError: - # re-try with new jwk set - jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set(force=True)) - return jwk_set.find_by_kid(header.get('kid')) + + load_key = self.create_load_key() claims_params = dict( nonce=nonce, @@ -75,3 +68,15 @@ def load_key(header, _): claims.validate(leeway=leeway) return UserInfo(claims) + + def create_load_key(self): + def load_key(header, _): + jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set()) + try: + return jwk_set.find_by_kid(header.get('kid')) + except ValueError: + # re-try with new jwk set + jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set(force=True)) + return jwk_set.find_by_kid(header.get('kid')) + + return load_key