diff --git a/jwt/api_jws.py b/jwt/api_jws.py index 9504c9f6..fc0fa04f 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -110,11 +110,17 @@ def encode( # Segments signing_input = b".".join(segments) - try: - alg_obj = self._algorithms[algorithm] - key = alg_obj.prepare_key(key) - signature = alg_obj.sign(signing_input, key) + alg_obj = self.get_algo_by_name(algorithm) + key = alg_obj.prepare_key(key) + signature = alg_obj.sign(signing_input, key) + + segments.append(base64url_encode(signature)) + + return b".".join(segments) + def get_algo_by_name(self, algorithm): + try: + return self._algorithms[algorithm] except KeyError: if not has_crypto and algorithm in requires_cryptography: raise NotImplementedError( @@ -124,10 +130,6 @@ def encode( else: raise NotImplementedError("Algorithm not supported") - segments.append(base64url_encode(signature)) - - return b".".join(segments) - def decode( self, jwt, # type: str diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index 125b57a5..f447b607 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -15,6 +15,7 @@ InvalidIssuerError, MissingRequiredClaimError, ) +from .jwt_payload import JWTPayload from .utils import merge_dict try: @@ -91,7 +92,7 @@ def decode( DeprecationWarning, ) - payload, _, _, _ = self._load(jwt) + payload, signing_input, header, signature = self._load(jwt) if options is None: options = {"verify_signature": verify} @@ -113,7 +114,7 @@ def decode( merged_options = merge_dict(self.options, options) self._validate_claims(payload, merged_options, **kwargs) - return payload + return JWTPayload(self, payload, signing_input, header, signature) def _validate_claims( self, payload, options, audience=None, issuer=None, leeway=0, **kwargs diff --git a/jwt/jwt_payload.py b/jwt/jwt_payload.py new file mode 100644 index 00000000..5135994d --- /dev/null +++ b/jwt/jwt_payload.py @@ -0,0 +1,89 @@ +""" += JWTPayload + +A JWTPayload is the result of PyJWT.decode() + +It +- is a dict (namely, the decoded payload) +- has signing_input, header, and signature as attributes +- exposes JWTPayload.compute_hash_digest() + which selects a hash algo (and implementation) based on the header and uses + it to compute a message digest + + +== Design Decision: Why JWTPayload? + +This implementation path was chosen to handle a desire to support additional +verification of JWTs without changing the API of pyjwt to v2.0 + +Because JWTPayload inherits from dict, it behaves the same as the raw dict +objects that PyJWT.decode() used to return (prior to this addition). Unless you +check `type(PyJWT.decode()) is dict`, you likely won't see any change. + +It exposes the information previously hidden by PyJWT.decode to allow complex +verification methods to be added to pyjwt client code (rather than baked into +pyjwt itself). + +It also allows carefully selected methods (like compute_hash_digest) to be +exposed which are derived from these data. +""" +try: + from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.backends import default_backend + + has_crypto = True +except ImportError: + has_crypto = False + + +class JWTPayload(dict): + """ + A decoded JWT payload. + When treated directly as a dict, represents the JWT Payload (which is + typically what clients want). + + :ivar signing_input: The signing input as a bytestring + :ivar header: The JWT header as a **dict** + :ivar signature: The JWT signature as a string + """ + + def __init__( + self, + jwt_api, + payload, + signing_input, + header, + signature, + *args, + **kwargs + ): + super(JWTPayload, self).__init__(payload, *args, **kwargs) + self.signing_input = signing_input + self.header = header + self.signature = signature + + self._jwt_api = jwt_api + + def compute_hash_digest(self, bytestr): + """ + Given a bytestring, compute a hash digest of the bytestring and + return it, using the algorithm specified by the JWT header. + + When `cryptography` is present, it will be used. + + This method is necessary in order to support computation of the OIDC + at_hash claim. + """ + algorithm = self.header.get("alg") + alg_obj = self._jwt_api.get_algo_by_name(algorithm) + hash_alg = alg_obj.hash_alg + + if has_crypto and ( + isinstance(hash_alg, type) + and issubclass(hash_alg, hashes.HashAlgorithm) + ): + digest = hashes.Hash(hash_alg(), backend=default_backend()) + digest.update(bytestr) + return digest.finalize() + else: + return hash_alg(bytestr).digest() diff --git a/tests/test_api_jwt.py b/tests/test_api_jwt.py index c1759e35..de1737d8 100644 --- a/tests/test_api_jwt.py +++ b/tests/test_api_jwt.py @@ -1,3 +1,4 @@ +import hashlib import json import time from calendar import timegm @@ -16,10 +17,21 @@ InvalidIssuerError, MissingRequiredClaimError, ) +from jwt.utils import force_bytes -from .test_api_jws import has_crypto from .utils import utc_timestamp +try: + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives.serialization import ( + load_pem_private_key, + load_ssh_public_key, + ) + + has_crypto = True +except ImportError: + has_crypto = False + @pytest.fixture def jwt(): @@ -525,3 +537,38 @@ def test_decode_no_algorithms_verify_false(self, jwt, payload): pass else: assert False, "Unexpected DeprecationWarning raised." + + def test_decoded_payload_can_compute_hash(self, jwt, payload): + secret = "secret" + jwt_message = jwt.encode(payload, secret, algorithm="HS256") + decoded_payload = jwt.decode(jwt_message, secret) + + assert ( + decoded_payload.compute_hash_digest(b"abc") + == hashlib.sha256(b"abc").digest() + ) + + @pytest.mark.skipif( + not has_crypto, reason="Can't run without cryptography library" + ) + def test_decoded_payload_can_compute_hash_rsa(self, jwt, payload): + with open("tests/keys/testkey_rsa", "r") as rsa_priv_file: + priv_rsakey = load_pem_private_key( + force_bytes(rsa_priv_file.read()), + password=None, + backend=default_backend(), + ) + + with open("tests/keys/testkey_rsa.pub", "r") as rsa_pub_file: + pub_rsakey = load_ssh_public_key( + force_bytes(rsa_pub_file.read()), backend=default_backend() + ) + jwt_message = jwt.encode(payload, priv_rsakey, algorithm="RS256") + decoded_payload = jwt.decode(jwt_message, pub_rsakey) + + # RSA-256 still means sha256 hashing, but using the cryptography + # provided value + assert ( + decoded_payload.compute_hash_digest(b"abc") + == hashlib.sha256(b"abc").digest() + )