Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Algorithm an abstract base class #845

Merged
merged 3 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions jwt/algorithms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import hashlib
import hmac
import json
from abc import ABC, abstractmethod
from typing import Any, ClassVar, Dict, Type, Union

from .exceptions import InvalidKeyError
Expand Down Expand Up @@ -117,7 +118,7 @@ def get_default_algorithms() -> Dict[str, "Algorithm"]:
return default_algorithms


class Algorithm:
class Algorithm(ABC):
"""
The interface for an algorithm used to sign and verify tokens.
"""
Expand Down Expand Up @@ -148,40 +149,40 @@ def compute_hash_digest(self, bytestr: bytes) -> bytes:
# variadic (TypeVar) but as discussed in https://github.com/jpadilla/pyjwt/pull/605
# that may still be poorly supported.

@abstractmethod
def prepare_key(self, key: Any) -> Any:
"""
Performs necessary validation and conversions on the key and returns
the key value in the proper format for sign() and verify().
"""
raise NotImplementedError

@abstractmethod
def sign(self, msg: bytes, key: Any) -> bytes:
"""
Returns a digital signature for the specified message
using the specified key value.
"""
raise NotImplementedError

@abstractmethod
def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
"""
Verifies that the specified digital signature is valid
for the specified message and key values.
"""
raise NotImplementedError

@staticmethod
@abstractmethod
def to_jwk(key_obj) -> JWKDict:
"""
Serializes a given RSA key into a JWK
"""
raise NotImplementedError

@staticmethod
@abstractmethod
def from_jwk(jwk: JWKDict):
"""
Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
"""
raise NotImplementedError


class NoneAlgorithm(Algorithm):
Expand All @@ -205,6 +206,14 @@ def sign(self, msg, key):
def verify(self, msg, key, sig):
return False

@staticmethod
def to_jwk(key_obj) -> JWKDict:
raise NotImplementedError()

@staticmethod
def from_jwk(jwk: JWKDict):
raise NotImplementedError()


class HMACAlgorithm(Algorithm):
"""
Expand Down Expand Up @@ -299,7 +308,7 @@ def prepare_key(self, key):
def to_jwk(key_obj):
obj = None

if getattr(key_obj, "private_numbers", None):
if hasattr(key_obj, "private_numbers"):
# Private key
numbers = key_obj.private_numbers()

Expand All @@ -316,7 +325,7 @@ def to_jwk(key_obj):
"qi": to_base64url_uint(numbers.iqmp).decode(),
}

elif getattr(key_obj, "verify", None):
elif hasattr(key_obj, "verify"):
# Public key
numbers = key_obj.public_numbers()

Expand Down Expand Up @@ -588,7 +597,7 @@ def sign(self, msg, key):
msg,
padding.PSS(
mgf=padding.MGF1(self.hash_alg()),
salt_length=self.hash_alg.digest_size,
salt_length=self.hash_alg().digest_size,
),
self.hash_alg(),
)
Expand All @@ -600,7 +609,7 @@ def verify(self, msg, key, sig):
msg,
padding.PSS(
mgf=padding.MGF1(self.hash_alg()),
salt_length=self.hash_alg.digest_size,
salt_length=self.hash_alg().digest_size,
),
self.hash_alg(),
)
Expand Down
8 changes: 4 additions & 4 deletions jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import warnings
from calendar import timegm
from collections.abc import Iterable, Mapping
from collections.abc import Iterable
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Type, Union

Expand Down Expand Up @@ -47,10 +47,10 @@ def encode(
json_encoder: Optional[Type[json.JSONEncoder]] = None,
sort_headers: bool = True,
) -> str:
# Check that we get a mapping
if not isinstance(payload, Mapping):
# Check that we get a dict
if not isinstance(payload, dict):
raise TypeError(
"Expecting a mapping object, as JWT only supports "
"Expecting a dict object, as JWT only supports "
"JSON objects as payloads."
)

Expand Down
46 changes: 11 additions & 35 deletions tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from jwt.algorithms import Algorithm, HMACAlgorithm, NoneAlgorithm, has_crypto
from jwt.algorithms import HMACAlgorithm, NoneAlgorithm, has_crypto
from jwt.exceptions import InvalidKeyError
from jwt.utils import base64url_decode

Expand All @@ -15,47 +15,23 @@


class TestAlgorithms:
def test_algorithm_should_throw_exception_if_prepare_key_not_impl(self):
algo = Algorithm()

with pytest.raises(NotImplementedError):
algo.prepare_key("test")

def test_algorithm_should_throw_exception_if_sign_not_impl(self):
algo = Algorithm()

with pytest.raises(NotImplementedError):
algo.sign(b"message", "key")

def test_algorithm_should_throw_exception_if_verify_not_impl(self):
algo = Algorithm()

with pytest.raises(NotImplementedError):
algo.verify(b"message", "key", b"signature")

def test_algorithm_should_throw_exception_if_to_jwk_not_impl(self):
algo = Algorithm()

with pytest.raises(NotImplementedError):
algo.from_jwk({"val": "ue"})

def test_algorithm_should_throw_exception_if_from_jwk_not_impl(self):
algo = Algorithm()
def test_none_algorithm_should_throw_exception_if_key_is_not_none(self):
algo = NoneAlgorithm()

with pytest.raises(NotImplementedError):
algo.to_jwk("value")
with pytest.raises(InvalidKeyError):
algo.prepare_key("123")

def test_algorithm_should_throw_exception_if_compute_hash_digest_not_impl(self):
algo = Algorithm()
def test_none_algorithm_should_throw_exception_on_to_jwk(self):
algo = NoneAlgorithm()

with pytest.raises(NotImplementedError):
algo.compute_hash_digest(b"value")
algo.to_jwk("dummy") # Using a dummy argument as is it not relevant

def test_none_algorithm_should_throw_exception_if_key_is_not_none(self):
def test_none_algorithm_should_throw_exception_on_from_jwk(self):
algo = NoneAlgorithm()

with pytest.raises(InvalidKeyError):
algo.prepare_key("123")
with pytest.raises(NotImplementedError):
algo.from_jwk({}) # Using a dummy argument as is it not relevant

def test_hmac_should_reject_nonstring_key(self):
algo = HMACAlgorithm(HMACAlgorithm.SHA256)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from jwt.algorithms import Algorithm, has_crypto
from jwt.algorithms import NoneAlgorithm, has_crypto
from jwt.api_jws import PyJWS
from jwt.exceptions import (
DecodeError,
Expand Down Expand Up @@ -39,10 +39,10 @@ def payload():

class TestJWS:
def test_register_algo_does_not_allow_duplicate_registration(self, jws):
jws.register_algorithm("AAA", Algorithm())
jws.register_algorithm("AAA", NoneAlgorithm())

with pytest.raises(ValueError):
jws.register_algorithm("AAA", Algorithm())
jws.register_algorithm("AAA", NoneAlgorithm())

def test_register_algo_rejects_non_algorithm_obj(self, jws):
with pytest.raises(TypeError):
Expand Down