Skip to content

Commit

Permalink
Add JWTPayload(dict) for extended verification
Browse files Browse the repository at this point in the history
The JWTPayload class allows PyJWT.decode() to expose header, signature,
signing_input, and compute_hash_digest() (based on header) without
changing the pyjwt API in a breaking way.
Merely making this info accessible to the client without specifying an
additional verification callback scheme is simpler for everyone.

Include doc on why JWTPayload is a good idea in a module docstring,
since it's a little unusual to subclass `dict`. The intent is to make
the JWT payload change as little as possible while still making it easy
to add more verification after the fact.

Add a simple test for `JWTPayload.compute_hash_digest()` and a test
for compute_hash_digest with cryptography (which is compared against a
manual hashlib usage).

Closes jpadilla#314, jpadilla#295
  • Loading branch information
sirosen committed May 14, 2020
1 parent 008490a commit e3222ad
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 11 deletions.
18 changes: 10 additions & 8 deletions jwt/api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,17 @@ def encode(

# Segments
signing_input = b".".join(segments)
try:
alg_obj = self._algorithms[algorithm]
key = alg_obj.prepare_key(key)
signature = alg_obj.sign(signing_input, key)
alg_obj = self.get_algo_by_name(algorithm)
key = alg_obj.prepare_key(key)
signature = alg_obj.sign(signing_input, key)

segments.append(base64url_encode(signature))

return b".".join(segments)

def get_algo_by_name(self, algorithm):
try:
return self._algorithms[algorithm]
except KeyError:
if not has_crypto and algorithm in requires_cryptography:
raise NotImplementedError(
Expand All @@ -124,10 +130,6 @@ def encode(
else:
raise NotImplementedError("Algorithm not supported")

segments.append(base64url_encode(signature))

return b".".join(segments)

def decode(
self,
jwt, # type: str
Expand Down
5 changes: 3 additions & 2 deletions jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
InvalidIssuerError,
MissingRequiredClaimError,
)
from .jwt_payload import JWTPayload
from .utils import merge_dict

try:
Expand Down Expand Up @@ -91,7 +92,7 @@ def decode(
DeprecationWarning,
)

payload, _, _, _ = self._load(jwt)
payload, signing_input, header, signature = self._load(jwt)

if options is None:
options = {"verify_signature": verify}
Expand All @@ -113,7 +114,7 @@ def decode(
merged_options = merge_dict(self.options, options)
self._validate_claims(payload, merged_options, **kwargs)

return payload
return JWTPayload(self, payload, signing_input, header, signature)

def _validate_claims(
self, payload, options, audience=None, issuer=None, leeway=0, **kwargs
Expand Down
89 changes: 89 additions & 0 deletions jwt/jwt_payload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
= JWTPayload
A JWTPayload is the result of PyJWT.decode()
It
- is a dict (namely, the decoded payload)
- has signing_input, header, and signature as attributes
- exposes JWTPayload.compute_hash_digest(<string>)
which selects a hash algo (and implementation) based on the header and uses
it to compute a message digest
== Design Decision: Why JWTPayload?
This implementation path was chosen to handle a desire to support additional
verification of JWTs without changing the API of pyjwt to v2.0
Because JWTPayload inherits from dict, it behaves the same as the raw dict
objects that PyJWT.decode() used to return (prior to this addition). Unless you
check `type(PyJWT.decode()) is dict`, you likely won't see any change.
It exposes the information previously hidden by PyJWT.decode to allow complex
verification methods to be added to pyjwt client code (rather than baked into
pyjwt itself).
It also allows carefully selected methods (like compute_hash_digest) to be
exposed which are derived from these data.
"""
try:
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.backends import default_backend

has_crypto = True
except ImportError:
has_crypto = False


class JWTPayload(dict):
"""
A decoded JWT payload.
When treated directly as a dict, represents the JWT Payload (which is
typically what clients want).
:ivar signing_input: The signing input as a bytestring
:ivar header: The JWT header as a **dict**
:ivar signature: The JWT signature as a string
"""

def __init__(
self,
jwt_api,
payload,
signing_input,
header,
signature,
*args,
**kwargs
):
super(JWTPayload, self).__init__(payload, *args, **kwargs)
self.signing_input = signing_input
self.header = header
self.signature = signature

self._jwt_api = jwt_api

def compute_hash_digest(self, bytestr):
"""
Given a bytestring, compute a hash digest of the bytestring and
return it, using the algorithm specified by the JWT header.
When `cryptography` is present, it will be used.
This method is necessary in order to support computation of the OIDC
at_hash claim.
"""
algorithm = self.header.get("alg")
alg_obj = self._jwt_api.get_algo_by_name(algorithm)
hash_alg = alg_obj.hash_alg

if has_crypto and (
isinstance(hash_alg, type)
and issubclass(hash_alg, hashes.HashAlgorithm)
):
digest = hashes.Hash(hash_alg(), backend=default_backend())
digest.update(bytestr)
return digest.finalize()
else:
return hash_alg(bytestr).digest()
49 changes: 48 additions & 1 deletion tests/test_api_jwt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import json
import time
from calendar import timegm
Expand All @@ -16,10 +17,21 @@
InvalidIssuerError,
MissingRequiredClaimError,
)
from jwt.utils import force_bytes

from .test_api_jws import has_crypto
from .utils import utc_timestamp

try:
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import (
load_pem_private_key,
load_ssh_public_key,
)

has_crypto = True
except ImportError:
has_crypto = False


@pytest.fixture
def jwt():
Expand Down Expand Up @@ -525,3 +537,38 @@ def test_decode_no_algorithms_verify_false(self, jwt, payload):
pass
else:
assert False, "Unexpected DeprecationWarning raised."

def test_decoded_payload_can_compute_hash(self, jwt, payload):
secret = "secret"
jwt_message = jwt.encode(payload, secret, algorithm="HS256")
decoded_payload = jwt.decode(jwt_message, secret)

assert (
decoded_payload.compute_hash_digest(b"abc")
== hashlib.sha256(b"abc").digest()
)

@pytest.mark.skipif(
not has_crypto, reason="Can't run without cryptography library"
)
def test_decoded_payload_can_compute_hash_rsa(self, jwt, payload):
with open("tests/keys/testkey_rsa", "r") as rsa_priv_file:
priv_rsakey = load_pem_private_key(
force_bytes(rsa_priv_file.read()),
password=None,
backend=default_backend(),
)

with open("tests/keys/testkey_rsa.pub", "r") as rsa_pub_file:
pub_rsakey = load_ssh_public_key(
force_bytes(rsa_pub_file.read()), backend=default_backend()
)
jwt_message = jwt.encode(payload, priv_rsakey, algorithm="RS256")
decoded_payload = jwt.decode(jwt_message, pub_rsakey)

# RSA-256 still means sha256 hashing, but using the cryptography
# provided value
assert (
decoded_payload.compute_hash_digest(b"abc")
== hashlib.sha256(b"abc").digest()
)

0 comments on commit e3222ad

Please sign in to comment.