Skip to content

Commit

Permalink
Add support for the OIDC at_hash claim
Browse files Browse the repository at this point in the history
Use PyJWT to compute the at_hash value for OpenID Connect:
http://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken

This makes more sense in PyJWT than its client code because of the tight
coupling between the chosen signing algorithm and the computation of the
at_hash. Any client code would have to jump through hoops to get this to
work nicely based on the algorithm being fed to PyJWT.

Closes #295

Primary changes:

Add support for access_token=... as a param to PyJWT.encode and
PyJWT.decode . On encode, the at_hash claim is computed and added to the
payload. On decode, unpacks the at_hash value, raising a missing claim
error if its missing, and compares it to a freshly computed at_hash.
Raises a new error type if they don't match.
Does not use the verification options dict, as it's redundant with the
caller supplying access_token in this case.

Supporting changes:
- Add tests for the above
- Let PyJWT and PyJWS get an algorithm object from a string as a method
- Add a method, compute_at_hash, to PyJWT objects
- PyJWT._validate_claims now takes the header as an arg (needed to get
  algo)
  • Loading branch information
sirosen committed Oct 2, 2017
1 parent 72bb76c commit 3c08972
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 17 deletions.
27 changes: 15 additions & 12 deletions jwt/api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ def get_algorithms(self):
"""
return list(self._valid_algs)

def get_algo_by_name(self, algorithm):
try:
return self._algorithms[algorithm]
except KeyError:
if not has_crypto and algorithm in requires_cryptography:
raise NotImplementedError(
"Algorithm '%s' could not be found. Do you have cryptography "
"installed?" % algorithm
)
else:
raise NotImplementedError('Algorithm not supported')

def encode(self, payload, key, algorithm='HS256', headers=None,
json_encoder=None):
segments = []
Expand Down Expand Up @@ -97,19 +109,10 @@ def encode(self, payload, key, algorithm='HS256', headers=None,

# Segments
signing_input = b'.'.join(segments)
try:
alg_obj = self._algorithms[algorithm]
key = alg_obj.prepare_key(key)
signature = alg_obj.sign(signing_input, key)

except KeyError:
if not has_crypto and algorithm in requires_cryptography:
raise NotImplementedError(
"Algorithm '%s' could not be found. Do you have cryptography "
"installed?" % algorithm
)
else:
raise NotImplementedError('Algorithm not supported')
alg_obj = self.get_algo_by_name(algorithm)
key = alg_obj.prepare_key(key)
signature = alg_obj.sign(signing_input, key)

segments.append(base64url_encode(signature))

Expand Down
76 changes: 71 additions & 5 deletions jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,18 @@
from .compat import string_types, timedelta_total_seconds
from .exceptions import (
DecodeError, ExpiredSignatureError, ImmatureSignatureError,
InvalidAccessTokenHashError,
InvalidAudienceError, InvalidIssuedAtError,
InvalidIssuerError, MissingRequiredClaimError
)
from .utils import merge_dict
from .utils import merge_dict, base64url_encode

try:
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.backends import default_backend
has_crypto = True
except ImportError:
has_crypto = False


class PyJWT(PyJWS):
Expand All @@ -34,7 +42,7 @@ def _get_default_options():
}

def encode(self, payload, key, algorithm='HS256', headers=None,
json_encoder=None):
json_encoder=None, access_token=None):
# Check that we get a mapping
if not isinstance(payload, Mapping):
raise TypeError('Expecting a mapping object, as JWT only supports '
Expand All @@ -46,6 +54,10 @@ def encode(self, payload, key, algorithm='HS256', headers=None,
if isinstance(payload.get(time_claim), datetime):
payload[time_claim] = timegm(payload[time_claim].utctimetuple())

# OIDC ID Token may have at_hash additional claim
if access_token is not None:
payload['at_hash'] = self.compute_at_hash(access_token, algorithm)

json_payload = json.dumps(
payload,
separators=(',', ':'),
Expand Down Expand Up @@ -87,12 +99,52 @@ def decode(self, jwt, key='', verify=True, algorithms=None, options=None,

if verify:
merged_options = merge_dict(self.options, options)
self._validate_claims(payload, merged_options, **kwargs)
self._validate_claims(payload, header, merged_options, **kwargs)

return payload

def _validate_claims(self, payload, options, audience=None, issuer=None,
leeway=0, **kwargs):
def compute_at_hash(self, access_token, algorithm='HS256'):
"""
Computes the at_hash claim for JWTs used in OpenID Connect.
The at_hash is based on the hashing algorithm used to sign the JWT, and
is specified here:
http://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken
:return: at_hash of the access_token as a string
**Parameters**
``access_token`` (string)
The access token to hash
``algorithm``
An algorithm object from jwt.algorithms. Its expected behavior is
determined by has_crypto.
If has_crypto=False, it must have a callable hash_alg member which
provides digest(), like the hashlib variants
If has_crypto=True, it *may* be a hashlib style hashing function,
or it may be a cryptography hashing algorithm
"""
alg_obj = self.get_algo_by_name(algorithm)
hash_alg = alg_obj.hash_alg

def get_digest(bytestr):
if has_crypto and (
isinstance(hash_alg, type) and
issubclass(hash_alg, hashes.HashAlgorithm)):
digest = hashes.Hash(hash_alg(), backend=default_backend())
digest.update(bytestr)
return digest.finalize()
else:
return hash_alg(bytestr).digest()

digest = get_digest(access_token.encode('utf-8'))
truncated = digest[:(len(digest) / 2)]
return base64url_encode(truncated).decode('utf-8')

def _validate_claims(self, payload, header, options, audience=None,
issuer=None, leeway=0, access_token=None, **kwargs):

if 'verify_expiration' in kwargs:
options['verify_exp'] = kwargs.get('verify_expiration', True)
Expand All @@ -119,6 +171,9 @@ def _validate_claims(self, payload, options, audience=None, issuer=None,
if 'exp' in payload and options.get('verify_exp'):
self._validate_exp(payload, now, leeway)

if access_token:
self._validate_at_hash(payload, header, access_token)

if options.get('verify_iss'):
self._validate_iss(payload, issuer)

Expand Down Expand Up @@ -190,6 +245,17 @@ def _validate_iss(self, payload, issuer):
if payload['iss'] != issuer:
raise InvalidIssuerError('Invalid issuer')

def _validate_at_hash(self, payload, header, access_token):
try:
at_hash = payload['at_hash']
except KeyError:
raise MissingRequiredClaimError('at_hash')

alg = header.get('alg')

if at_hash != self.compute_at_hash(access_token, alg):
raise InvalidAccessTokenHashError("at_hash doesn't match")


_jwt_global_obj = PyJWT()
encode = _jwt_global_obj.encode
Expand Down
4 changes: 4 additions & 0 deletions jwt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ class InvalidIssuedAtError(InvalidTokenError):
pass


class InvalidAccessTokenHashError(InvalidTokenError):
pass


class ImmatureSignatureError(InvalidTokenError):
pass

Expand Down
45 changes: 45 additions & 0 deletions tests/test_api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from jwt.exceptions import (
DecodeError, ExpiredSignatureError, ImmatureSignatureError,
InvalidAudienceError, InvalidIssuedAtError, InvalidIssuerError,
InvalidAccessTokenHashError,
MissingRequiredClaimError
)

Expand Down Expand Up @@ -58,6 +59,14 @@ def test_load_verify_valid_jwt(self, jwt):

assert decoded_payload == example_payload

def test_verify_fails_missing_athash(self, jwt, payload):
secret = 'secret'
jwt_message = jwt.encode(payload, secret)

with pytest.raises(MissingRequiredClaimError) as exc:
jwt.decode(jwt_message, key=secret, access_token='foobar')
assert 'at_hash' in str(exc.value)

def test_decode_invalid_payload_string(self, jwt):
example_jwt = (
'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.aGVsb'
Expand Down Expand Up @@ -118,6 +127,19 @@ def test_decode_with_invalid_aud_list_member_throws_exception(self, jwt):
exception = context.value
assert str(exception) == 'Invalid claim format in token'

def test_decode_with_wrong_access_token_throws_exception(self, jwt, payload):
secret = 'secret'
jwt_message = jwt.encode(payload, secret, access_token='foobar')

with pytest.raises(InvalidAccessTokenHashError):
jwt.decode(jwt_message, key=secret, access_token='foobar2')

def test_decode_with_no_access_token_skips_at_hash(self, jwt, payload):
secret = 'secret'
jwt_message = jwt.encode(payload, secret, access_token='foobar')

jwt.decode(jwt_message, key=secret)

def test_encode_bad_type(self, jwt):

types = ['string', tuple(), list(), 42, set()]
Expand Down Expand Up @@ -495,3 +517,26 @@ def test_decode_no_algorithms_verify_false(self, jwt, payload):
pass
else:
assert False, "Unexpected DeprecationWarning raised."

@pytest.mark.skipif(not has_crypto,
reason="Can't run without cryptography library")
def test_at_hashes_match(self, jwt):
"""
Check that HS256 and RS256 at_hash values match.
These are different implementations of the at_hash computation, one in
terms of hashlib.sha256 and one in terms of cryptography..hashes.SHA256
They should produce identical values, and both should evaluate
successfully.
Checks
- Evaluation works for both methods (doesn't crash)
- Evaluation of cryptography hash digests is accurate (use of
finalize()). Assumes that the simpler hashlib evaluation is
"obviously correct"
"""
# this is just garbage to feed in
token = "abc123" * 20

assert (jwt.compute_at_hash(token, 'HS256') ==
jwt.compute_at_hash(token, 'RS256'))

0 comments on commit 3c08972

Please sign in to comment.