Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JWTPayload(dict) for extended verification #322

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions jwt/api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,17 @@ def encode(

# Segments
signing_input = b".".join(segments)
try:
alg_obj = self._algorithms[algorithm]
key = alg_obj.prepare_key(key)
signature = alg_obj.sign(signing_input, key)
alg_obj = self.get_algo_by_name(algorithm)
key = alg_obj.prepare_key(key)
signature = alg_obj.sign(signing_input, key)

segments.append(base64url_encode(signature))

return b".".join(segments)

def get_algo_by_name(self, algorithm):
try:
return self._algorithms[algorithm]
except KeyError:
if not has_crypto and algorithm in requires_cryptography:
raise NotImplementedError(
Expand All @@ -124,10 +130,6 @@ def encode(
else:
raise NotImplementedError("Algorithm not supported")

segments.append(base64url_encode(signature))

return b".".join(segments)

def decode(
self,
jwt, # type: str
Expand Down
5 changes: 3 additions & 2 deletions jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
InvalidIssuerError,
MissingRequiredClaimError,
)
from .jwt_payload import JWTPayload
from .utils import merge_dict

try:
Expand Down Expand Up @@ -91,7 +92,7 @@ def decode(
DeprecationWarning,
)

payload, _, _, _ = self._load(jwt)
payload, signing_input, header, signature = self._load(jwt)

if options is None:
options = {"verify_signature": verify}
Expand All @@ -113,7 +114,7 @@ def decode(
merged_options = merge_dict(self.options, options)
self._validate_claims(payload, merged_options, **kwargs)

return payload
return JWTPayload(self, payload, signing_input, header, signature)

def _validate_claims(
self, payload, options, audience=None, issuer=None, leeway=0, **kwargs
Expand Down
89 changes: 89 additions & 0 deletions jwt/jwt_payload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
= JWTPayload

A JWTPayload is the result of PyJWT.decode()

It
- is a dict (namely, the decoded payload)
- has signing_input, header, and signature as attributes
- exposes JWTPayload.compute_hash_digest(<string>)
which selects a hash algo (and implementation) based on the header and uses
it to compute a message digest


== Design Decision: Why JWTPayload?

This implementation path was chosen to handle a desire to support additional
verification of JWTs without changing the API of pyjwt to v2.0

Because JWTPayload inherits from dict, it behaves the same as the raw dict
objects that PyJWT.decode() used to return (prior to this addition). Unless you
check `type(PyJWT.decode()) is dict`, you likely won't see any change.

It exposes the information previously hidden by PyJWT.decode to allow complex
verification methods to be added to pyjwt client code (rather than baked into
pyjwt itself).

It also allows carefully selected methods (like compute_hash_digest) to be
exposed which are derived from these data.
"""
try:
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.backends import default_backend

has_crypto = True
except ImportError:
has_crypto = False


class JWTPayload(dict):
"""
A decoded JWT payload.
When treated directly as a dict, represents the JWT Payload (which is
typically what clients want).

:ivar signing_input: The signing input as a bytestring
:ivar header: The JWT header as a **dict**
:ivar signature: The JWT signature as a string
"""

def __init__(
self,
jwt_api,
payload,
signing_input,
header,
signature,
*args,
**kwargs
):
super(JWTPayload, self).__init__(payload, *args, **kwargs)
self.signing_input = signing_input
self.header = header
self.signature = signature

self._jwt_api = jwt_api

def compute_hash_digest(self, bytestr):
"""
Given a bytestring, compute a hash digest of the bytestring and
return it, using the algorithm specified by the JWT header.

When `cryptography` is present, it will be used.

This method is necessary in order to support computation of the OIDC
at_hash claim.
"""
algorithm = self.header.get("alg")
alg_obj = self._jwt_api.get_algo_by_name(algorithm)
hash_alg = alg_obj.hash_alg

if has_crypto and (
isinstance(hash_alg, type)
and issubclass(hash_alg, hashes.HashAlgorithm)
):
digest = hashes.Hash(hash_alg(), backend=default_backend())
digest.update(bytestr)
return digest.finalize()
else:
return hash_alg(bytestr).digest()
49 changes: 48 additions & 1 deletion tests/test_api_jwt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import json
import time
from calendar import timegm
Expand All @@ -16,10 +17,21 @@
InvalidIssuerError,
MissingRequiredClaimError,
)
from jwt.utils import force_bytes

from .test_api_jws import has_crypto
from .utils import utc_timestamp

try:
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import (
load_pem_private_key,
load_ssh_public_key,
)

has_crypto = True
except ImportError:
has_crypto = False


@pytest.fixture
def jwt():
Expand Down Expand Up @@ -525,3 +537,38 @@ def test_decode_no_algorithms_verify_false(self, jwt, payload):
pass
else:
assert False, "Unexpected DeprecationWarning raised."

def test_decoded_payload_can_compute_hash(self, jwt, payload):
secret = "secret"
jwt_message = jwt.encode(payload, secret, algorithm="HS256")
decoded_payload = jwt.decode(jwt_message, secret)

assert (
decoded_payload.compute_hash_digest(b"abc")
== hashlib.sha256(b"abc").digest()
)

@pytest.mark.skipif(
not has_crypto, reason="Can't run without cryptography library"
)
def test_decoded_payload_can_compute_hash_rsa(self, jwt, payload):
with open("tests/keys/testkey_rsa", "r") as rsa_priv_file:
priv_rsakey = load_pem_private_key(
force_bytes(rsa_priv_file.read()),
password=None,
backend=default_backend(),
)

with open("tests/keys/testkey_rsa.pub", "r") as rsa_pub_file:
pub_rsakey = load_ssh_public_key(
force_bytes(rsa_pub_file.read()), backend=default_backend()
)
jwt_message = jwt.encode(payload, priv_rsakey, algorithm="RS256")
decoded_payload = jwt.decode(jwt_message, pub_rsakey)

# RSA-256 still means sha256 hashing, but using the cryptography
# provided value
assert (
decoded_payload.compute_hash_digest(b"abc")
== hashlib.sha256(b"abc").digest()
)