Skip to content

Commit

Permalink
Add JWTPayload(dict) for extended verification
Browse files Browse the repository at this point in the history
The JWTPayload class allows PyJWT.decode() to expose header, signature,
signing_input, and compute_hash_digest() (based on header) without
changing the pyjwt API in a breaking way.
Merely making this info accessible to the client (without specify an
additional verification callback scheme) is simpler for everyone.

Include doc on why JWTPayload is a good idea (in a docstring), since
it's a little unusual to subclass `dict`. The intent is to make the JWT
payload change as little as possible while still making it easy to add
more verification after the fact.

Add a simple test for `JWTPayload.compute_hash_digest()`

Closes #314, #295
  • Loading branch information
sirosen committed Jan 10, 2018
1 parent 0c80a71 commit 87370b4
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 13 deletions.
23 changes: 12 additions & 11 deletions jwt/api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,17 @@ def encode(self, payload, key, algorithm='HS256', headers=None,

# Segments
signing_input = b'.'.join(segments)
try:
alg_obj = self._algorithms[algorithm]
key = alg_obj.prepare_key(key)
signature = alg_obj.sign(signing_input, key)
alg_obj = self.get_algo_by_name(algorithm)
key = alg_obj.prepare_key(key)
signature = alg_obj.sign(signing_input, key)

segments.append(base64url_encode(signature))

return b'.'.join(segments)

def get_algo_by_name(self, algorithm):
try:
return self._algorithms[algorithm]
except KeyError:
if not has_crypto and algorithm in requires_cryptography:
raise NotImplementedError(
Expand All @@ -113,10 +119,6 @@ def encode(self, payload, key, algorithm='HS256', headers=None,
else:
raise NotImplementedError('Algorithm not supported')

segments.append(base64url_encode(signature))

return b'.'.join(segments)

def decode(self, jws, key='', verify=True, algorithms=None, options=None,
**kwargs):

Expand All @@ -138,8 +140,7 @@ def decode(self, jws, key='', verify=True, algorithms=None, options=None,
'Please use verify_signature in options instead.',
DeprecationWarning, stacklevel=2)
elif verify_signature:
self._verify_signature(payload, signing_input, header, signature,
key, algorithms)
self._verify_signature(payload, signing_input, header, signature, key, algorithms)

return payload

Expand Down Expand Up @@ -191,7 +192,7 @@ def _load(self, jwt):
except (TypeError, binascii.Error):
raise DecodeError('Invalid crypto padding')

return (payload, signing_input, header, signature)
return payload, signing_input, header, signature

def _verify_signature(self, payload, signing_input, header, signature,
key='', algorithms=None):
Expand Down
3 changes: 2 additions & 1 deletion jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime, timedelta

from .api_jws import PyJWS
from .jwt_payload import JWTPayload
from .algorithms import Algorithm, get_default_algorithms # NOQA
from .compat import string_types
from .exceptions import (
Expand Down Expand Up @@ -88,7 +89,7 @@ def decode(self, jwt, key='', verify=True, algorithms=None, options=None,
merged_options = merge_dict(self.options, options)
self._validate_claims(payload, merged_options, **kwargs)

return payload
return JWTPayload(self, payload, signing_input, header, signature)

def _validate_claims(self, payload, options, audience=None, issuer=None,
leeway=0, **kwargs):
Expand Down
78 changes: 78 additions & 0 deletions jwt/jwt_payload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""
= JWTPayload
A JWTPayload is the result of PyJWT.decode()
It
- is a dict (namely, the decoded payload)
- has signing_input, header, and signature as attributes
- exposes JWTPayload.compute_hash_digest(<string>)
which selects a hash algo (and implementation) based on the header and uses
it to compute a message digest
== Design Decision: Why JWTPayload?
This implementation path was chosen to handle a desire to support additional
verification of JWTs without changing the API of pyjwt to v2.0
Because JWTPayload inherits from dict, it behaves the same as the raw dict
objects that PyJWT.decode() used to return (prior to this addition). Unless you
check `type(PyJWT.decode()) is dict`, you likely won't see any change.
It exposes the information previously hidden by PyJWT.decode to allow complex
verification methods to be added to pyjwt client code (rather than baked into
pyjwt itself).
It also allows carefully selected methods (like compute_hash_digest) to be
exposed which are derived from these data.
"""
try:
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.backends import default_backend
has_crypto = True
except ImportError:
has_crypto = False


class JWTPayload(dict):
"""
A decoded JWT payload.
When treated directly as a dict, represents the JWT Payload (which is
typically what clients want).
:ivar signing_input: The signing input as a bytestring
:ivar header: The JWT header as a **dict**
:ivar signature: The JWT signature as a string
"""
def __init__(self, jwt_api, payload, signing_input, header, signature,
*args, **kwargs):
super(JWTPayload, self).__init__(payload, *args, **kwargs)
self.signing_input = signing_input
self.header = header
self.signature = signature

self._jwt_api = jwt_api

def compute_hash_digest(self, bytestr):
"""
Given a bytestring, compute a hash digest of the bytestring and
return it, using the algorithm specified by the JWT header.
When `cryptography` is present, it will be used.
This method is necessary in order to support computation of the OIDC
at_hash claim.
"""
algorithm = self.header.get('alg')
alg_obj = self._jwt_api.get_algo_by_name(algorithm)
hash_alg = alg_obj.hash_alg

if has_crypto and (
isinstance(hash_alg, type) and
issubclass(hash_alg, hashes.HashAlgorithm)):
digest = hashes.Hash(hash_alg(), backend=default_backend())
digest.update(bytestr)
return digest.finalize()
else:
return hash_alg(bytestr).digest()
10 changes: 9 additions & 1 deletion tests/test_api_jwt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

import hashlib
import json
import time
from calendar import timegm
Expand Down Expand Up @@ -511,3 +511,11 @@ def test_decode_no_algorithms_verify_false(self, jwt, payload):
pass
else:
assert False, "Unexpected DeprecationWarning raised."

def test_decoded_payload_can_compute_hash(self, jwt, payload):
secret = 'secret'
jwt_message = jwt.encode(payload, secret, algorithm='HS256')
decoded_payload = jwt.decode(jwt_message, secret)

assert (decoded_payload.compute_hash_digest(b'abc') ==
hashlib.sha256(b'abc').digest())

0 comments on commit 87370b4

Please sign in to comment.