From db04d6080239f2e0ac96c65335d05ab4e04f1ed0 Mon Sep 17 00:00:00 2001 From: Stephen Rosen Date: Wed, 29 Jun 2022 14:48:25 +0000 Subject: [PATCH] Expose get_algorithm_by_name as new method Looking up an algorithm by name is used internally for signature generation. This encapsulates that functionality in a dedicated method and adds it to the public API. No new tests are needed to exercise the functionality. Rationale: 1. Inside of PyJWS, this improves the code. The KeyError handler is better scoped and the signing code reads more directly. 2. This is part of the path to supporting OIDC at_hash validation as a use-case (see: #295, #296, #314). This is arguably sufficient to consider that use-case supported and close it. However, it is an improvement and step in the right direction in either case. --- CHANGELOG.rst | 2 ++ jwt/__init__.py | 2 ++ jwt/api_jws.py | 32 +++++++++++++++++++++----------- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index c11ce06ac..81d421e23 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -19,6 +19,8 @@ Fixed Added ~~~~~ - Add to_jwk static method to ECAlgorithm by @leonsmith in https://github.com/jpadilla/pyjwt/pull/732 +- Add ``get_algorithm_by_name`` as a method of ``PyJWS`` objects, and expose + the global PyJWS method as part of the public API `v2.4.0 `__ ----------------------------------------------------------------------- diff --git a/jwt/__init__.py b/jwt/__init__.py index 6b3f8ab16..a96cc6eee 100644 --- a/jwt/__init__.py +++ b/jwt/__init__.py @@ -1,6 +1,7 @@ from .api_jwk import PyJWK, PyJWKSet from .api_jws import ( PyJWS, + get_algorithm_by_name, get_unverified_header, register_algorithm, unregister_algorithm, @@ -51,6 +52,7 @@ "get_unverified_header", "register_algorithm", "unregister_algorithm", + "get_algorithm_by_name", # Exceptions "DecodeError", "ExpiredSignatureError", diff --git a/jwt/api_jws.py b/jwt/api_jws.py index 16ec846cb..e56968698 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -73,6 +73,23 @@ def get_algorithms(self): """ return list(self._valid_algs) + def get_algorithm_by_name(self, alg_name: str) -> Algorithm: + """ + For a given string name, return the matching Algorithm object. + + Example usage: + + >>> jws_obj.get_algorithm_by_name("RS256") + """ + try: + return self._algorithms[alg_name] + except KeyError as e: + if not has_crypto and alg_name in requires_cryptography: + raise NotImplementedError( + f"Algorithm '{alg_name}' could not be found. Do you have cryptography installed?" + ) from e + raise NotImplementedError("Algorithm not supported") from e + def encode( self, payload: bytes, @@ -128,17 +145,9 @@ def encode( # Segments signing_input = b".".join(segments) - try: - alg_obj = self._algorithms[algorithm] - key = alg_obj.prepare_key(key) - signature = alg_obj.sign(signing_input, key) - - except KeyError as e: - if not has_crypto and algorithm in requires_cryptography: - raise NotImplementedError( - f"Algorithm '{algorithm}' could not be found. Do you have cryptography installed?" - ) from e - raise NotImplementedError("Algorithm not supported") from e + alg_obj = self.get_algorithm_by_name(algorithm) + key = alg_obj.prepare_key(key) + signature = alg_obj.sign(signing_input, key) segments.append(base64url_encode(signature)) @@ -286,4 +295,5 @@ def _validate_kid(self, kid): decode = _jws_global_obj.decode register_algorithm = _jws_global_obj.register_algorithm unregister_algorithm = _jws_global_obj.unregister_algorithm +get_algorithm_by_name = _jws_global_obj.get_algorithm_by_name get_unverified_header = _jws_global_obj.get_unverified_header