From 2e8c10e97794b5fe581831853854efe743bf8ef2 Mon Sep 17 00:00:00 2001 From: Stephen Rosen Date: Tue, 9 Jan 2018 21:18:57 -0500 Subject: [PATCH] Add JWTPayload(dict) for extended verification The JWTPayload class allows PyJWT.decode() to expose header, signature, signing_input, and compute_hash_digest() (based on header) without changing the pyjwt API in a breaking way. Merely making this info accessible to the client (without specify an additional verification callback scheme) is simpler for everyone. Include doc on why JWTPayload is a good idea (in a docstring), since it's a little unusual to subclass `dict`. The intent is to make the JWT payload change as little as possible while still making it easy to add more verification after the fact. Add a simple test for `JWTPayload.compute_hash_digest()` Closes #314, #295 --- jwt/api_jws.py | 23 +++++++------ jwt/api_jwt.py | 3 +- jwt/jwt_payload.py | 78 +++++++++++++++++++++++++++++++++++++++++++ tests/test_api_jwt.py | 9 +++++ 4 files changed, 101 insertions(+), 12 deletions(-) create mode 100644 jwt/jwt_payload.py diff --git a/jwt/api_jws.py b/jwt/api_jws.py index b796fa76b..67a0691a1 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -99,11 +99,17 @@ def encode(self, payload, key, algorithm='HS256', headers=None, # Segments signing_input = b'.'.join(segments) - try: - alg_obj = self._algorithms[algorithm] - key = alg_obj.prepare_key(key) - signature = alg_obj.sign(signing_input, key) + alg_obj = self.get_algo_by_name(algorithm) + key = alg_obj.prepare_key(key) + signature = alg_obj.sign(signing_input, key) + + segments.append(base64url_encode(signature)) + + return b'.'.join(segments) + def get_algo_by_name(self, algorithm): + try: + return self._algorithms[algorithm] except KeyError: if not has_crypto and algorithm in requires_cryptography: raise NotImplementedError( @@ -113,10 +119,6 @@ def encode(self, payload, key, algorithm='HS256', headers=None, else: raise NotImplementedError('Algorithm not supported') - segments.append(base64url_encode(signature)) - - return b'.'.join(segments) - def decode(self, jws, key='', verify=True, algorithms=None, options=None, **kwargs): @@ -138,8 +140,7 @@ def decode(self, jws, key='', verify=True, algorithms=None, options=None, 'Please use verify_signature in options instead.', DeprecationWarning, stacklevel=2) elif verify_signature: - self._verify_signature(payload, signing_input, header, signature, - key, algorithms) + self._verify_signature(payload, signing_input, header, signature, key, algorithms) return payload @@ -191,7 +192,7 @@ def _load(self, jwt): except (TypeError, binascii.Error): raise DecodeError('Invalid crypto padding') - return (payload, signing_input, header, signature) + return payload, signing_input, header, signature def _verify_signature(self, payload, signing_input, header, signature, key='', algorithms=None): diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index edef77018..5292e80c7 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -5,6 +5,7 @@ from datetime import datetime, timedelta from .api_jws import PyJWS +from .jwt_payload import JWTPayload from .algorithms import Algorithm, get_default_algorithms # NOQA from .compat import string_types from .exceptions import ( @@ -88,7 +89,7 @@ def decode(self, jwt, key='', verify=True, algorithms=None, options=None, merged_options = merge_dict(self.options, options) self._validate_claims(payload, merged_options, **kwargs) - return payload + return JWTPayload(self, payload, signing_input, header, signature) def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0, **kwargs): diff --git a/jwt/jwt_payload.py b/jwt/jwt_payload.py new file mode 100644 index 000000000..61d3b1dc0 --- /dev/null +++ b/jwt/jwt_payload.py @@ -0,0 +1,78 @@ +""" += JWTPayload + +A JWTPayload is the result of PyJWT.decode() + +It +- is a dict (namely, the decoded payload) +- has signing_input, header, and signature as attributes +- exposes JWTPayload.compute_hash_digest() + which selects a hash algo (and implementation) based on the header and uses + it to compute a message digest + + +== Design Decision: Why JWTPayload? + +This implementation path was chosen to handle a desire to support additional +verification of JWTs without changing the API of pyjwt to v2.0 + +Because JWTPayload inherits from dict, it behaves the same as the raw dict +objects that PyJWT.decode() used to return (prior to this addition). Unless you +check `type(PyJWT.decode()) is dict`, you likely won't see any change. + +It exposes the information previously hidden by PyJWT.decode to allow complex +verification methods to be added to pyjwt client code (rather than baked into +pyjwt itself). + +It also allows carefully selected methods (like compute_hash_digest) to be +exposed which are derived from these data. +""" +try: + from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.backends import default_backend + has_crypto = True +except ImportError: + has_crypto = False + + +class JWTPayload(dict): + """ + A decoded JWT payload. + When treated directly as a dict, represents the JWT Payload (which is + typically what clients want). + + :ivar signing_input: The signing input as a bytestring + :ivar header: The JWT header as a **dict** + :ivar signature: The JWT signature as a string + """ + def __init__(self, jwt_api, payload, signing_input, header, signature, + *args, **kwargs): + super(JWTPayload, self).__init__(payload, *args, **kwargs) + self.signing_input = signing_input + self.header = header + self.signature = signature + + self._jwt_api = jwt_api + + def compute_hash_digest(self, bytestr): + """ + Given a bytestring, compute a hash digest of the bytestring and + return it, using the algorithm specified by the JWT header. + + When `cryptography` is present, it will be used. + + This method is necessary in order to support computation of the OIDC + at_hash claim. + """ + algorithm = self.header.get('alg') + alg_obj = self._jwt_api.get_algo_by_name(algorithm) + hash_alg = alg_obj.hash_alg + + if has_crypto and ( + isinstance(hash_alg, type) and + issubclass(hash_alg, hashes.HashAlgorithm)): + digest = hashes.Hash(hash_alg(), backend=default_backend()) + digest.update(bytestr) + return digest.finalize() + else: + return hash_alg(bytestr).digest() diff --git a/tests/test_api_jwt.py b/tests/test_api_jwt.py index 58b47f2c9..9c26c96cb 100644 --- a/tests/test_api_jwt.py +++ b/tests/test_api_jwt.py @@ -1,5 +1,6 @@ import json +import hashlib import time from calendar import timegm from datetime import datetime, timedelta @@ -511,3 +512,11 @@ def test_decode_no_algorithms_verify_false(self, jwt, payload): pass else: assert False, "Unexpected DeprecationWarning raised." + + def test_decoded_payload_can_compute_hash(self, jwt, payload): + secret = 'secret' + jwt_message = jwt.encode(payload, secret, algorithm='HS256') + decoded_payload = jwt.decode(jwt_message, secret) + + assert (decoded_payload.compute_hash_digest(b'abc') == + hashlib.sha256(b'abc').digest())