Skip to content

Commit

Permalink
Add compute_hash_digest to Algorithm objects
Browse files Browse the repository at this point in the history
`Algorithm.compute_hash_digest` is defined as a method which inspects
the object to see that it has the requisite attributes, `hash_alg`.

If `hash_alg` is not set, then the method raises a
NotImplementedError. This applies to classes like NoneAlgorithm.

If `hash_alg` is set, then it is checked for
```
has_crypto  # is cryptography available?
and isinstance(hash_alg, type)
and issubclass(hash_alg, hashes.HashAlgorithm)
```
to see which API for computing a digest is appropriate --
`hashlib` vs `cryptography.hazmat.primitives.hashes`.

These checks could be avoided at runtime if it were necessary to
optimize further (e.g. attach compute_hash_digest methods to classes
with a class decorator) but this is not clearly a worthwhile
optimization. Such perf tuning is intentionally omitted for now.
  • Loading branch information
sirosen committed Jul 3, 2022
1 parent 2dab210 commit f493dab
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ Added
- Add to_jwk static method to ECAlgorithm by @leonsmith in https://github.com/jpadilla/pyjwt/pull/732
- Add ``get_algorithm_by_name`` as a method of ``PyJWS`` objects, and expose
the global PyJWS method as part of the public API
- Add ``compute_hash_digest`` as a method of ``Algorithm`` objects, which uses
the underlying hash algorithm to compute a digest. If there is no appropriate
hash algorithm, a ``NotImplementedError`` will be raised

`v2.4.0 <https://github.com/jpadilla/pyjwt/compare/2.3.0...2.4.0>`__
-----------------------------------------------------------------------
Expand Down
23 changes: 23 additions & 0 deletions jwt/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
try:
import cryptography.exceptions
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec, padding
from cryptography.hazmat.primitives.asymmetric.ec import (
Expand Down Expand Up @@ -111,6 +112,28 @@ class Algorithm:
The interface for an algorithm used to sign and verify tokens.
"""

def compute_hash_digest(self, bytestr: bytes) -> bytes:
"""
Compute a hash digest using the specified algorithm's hash algorithm.
If there is no hash algorithm, raises a NotImplementedError.
"""
# lookup self.hash_alg if defined in a way that mypy can understand
hash_alg = getattr(self, "hash_alg", None)
if hash_alg is None:
raise NotImplementedError

if (
has_crypto
and isinstance(hash_alg, type)
and issubclass(hash_alg, hashes.HashAlgorithm)
):
digest = hashes.Hash(hash_alg(), backend=default_backend())
digest.update(bytestr)
return digest.finalize()
else:
return hash_alg(bytestr).digest()

def prepare_key(self, key):
"""
Performs necessary validation and conversions on the key and returns
Expand Down
23 changes: 23 additions & 0 deletions tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ def test_algorithm_should_throw_exception_if_from_jwk_not_impl(self):
with pytest.raises(NotImplementedError):
algo.to_jwk("value")

def test_algorithm_should_throw_exception_if_compute_hash_digest_not_impl(self):
algo = Algorithm()

with pytest.raises(NotImplementedError):
algo.compute_hash_digest(b"value")

def test_none_algorithm_should_throw_exception_if_key_is_not_none(self):
algo = NoneAlgorithm()

Expand Down Expand Up @@ -1054,3 +1060,20 @@ def test_okp_ed448_to_jwk_works_with_from_jwk(self):
signature_2 = algo.sign(b"Hello World!", priv_key_2)
assert algo.verify(b"Hello World!", pub_key_2, signature_1)
assert algo.verify(b"Hello World!", pub_key_2, signature_2)

@crypto_required
def test_rsa_can_compute_digest(self):
# this is the well-known sha256 hash of "foo"
foo_hash = base64.b64decode(b"LCa0a2j/xo/5m0U8HTBBNBNCLXBkg7+g+YpeiGJm564=")

algo = RSAAlgorithm(RSAAlgorithm.SHA256)
computed_hash = algo.compute_hash_digest(b"foo")
assert computed_hash == foo_hash

def test_hmac_can_compute_digest(self):
# this is the well-known sha256 hash of "foo"
foo_hash = base64.b64decode(b"LCa0a2j/xo/5m0U8HTBBNBNCLXBkg7+g+YpeiGJm564=")

algo = HMACAlgorithm(HMACAlgorithm.SHA256)
computed_hash = algo.compute_hash_digest(b"foo")
assert computed_hash == foo_hash

0 comments on commit f493dab

Please sign in to comment.