diff --git a/docs/conf.py b/docs/conf.py index 54663e1d..3a85f595 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -18,9 +18,10 @@ def find_version(*file_paths) -> str: string inside. """ version_file = read(*file_paths) - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) - if version_match: - return version_match.group(1) + if version_match := re.search( + r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M + ): + return version_match[1] raise RuntimeError("Unable to find version string.") diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 9be50b20..4f42cd00 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -1,5 +1,6 @@ from __future__ import annotations +import base64 import hashlib import hmac import json @@ -24,7 +25,7 @@ from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes - from cryptography.hazmat.primitives.asymmetric import padding + from cryptography.hazmat.primitives.asymmetric import ec, padding from cryptography.hazmat.primitives.asymmetric.ec import ( ECDSA, SECP256K1, @@ -114,24 +115,20 @@ def get_default_algorithms() -> dict[str, Algorithm]: } if has_crypto: - default_algorithms.update( - { - "RS256": RSAAlgorithm(RSAAlgorithm.SHA256), - "RS384": RSAAlgorithm(RSAAlgorithm.SHA384), - "RS512": RSAAlgorithm(RSAAlgorithm.SHA512), - "ES256": ECAlgorithm(ECAlgorithm.SHA256), - "ES256K": ECAlgorithm(ECAlgorithm.SHA256), - "ES384": ECAlgorithm(ECAlgorithm.SHA384), - "ES521": ECAlgorithm(ECAlgorithm.SHA512), - "ES512": ECAlgorithm( - ECAlgorithm.SHA512 - ), # Backward compat for #219 fix - "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256), - "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384), - "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512), - "EdDSA": OKPAlgorithm(), - } - ) + default_algorithms |= { + "RS256": RSAAlgorithm(RSAAlgorithm.SHA256), + "RS384": RSAAlgorithm(RSAAlgorithm.SHA384), + "RS512": RSAAlgorithm(RSAAlgorithm.SHA512), + "ES256": ECAlgorithm(ECAlgorithm.SHA256), + "ES256K": ECAlgorithm(ECAlgorithm.SHA256), + "ES384": ECAlgorithm(ECAlgorithm.SHA384), + "ES521": ECAlgorithm(ECAlgorithm.SHA512), + "ES512": ECAlgorithm(ECAlgorithm.SHA512), # Backward compat for #219 fix + "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256), + "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384), + "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512), + "EdDSA": OKPAlgorithm(), + } return default_algorithms @@ -153,15 +150,14 @@ def compute_hash_digest(self, bytestr: bytes) -> bytes: raise NotImplementedError if ( - has_crypto - and isinstance(hash_alg, type) - and issubclass(hash_alg, hashes.HashAlgorithm) + not has_crypto + or not isinstance(hash_alg, type) + or not issubclass(hash_alg, hashes.HashAlgorithm) ): - digest = hashes.Hash(hash_alg(), backend=default_backend()) - digest.update(bytestr) - return bytes(digest.finalize()) - else: return bytes(hash_alg(bytestr).digest()) + digest = hashes.Hash(hash_alg(), backend=default_backend()) + digest.update(bytestr) + return bytes(digest.finalize()) @abstractmethod def prepare_key(self, key: Any) -> Any: @@ -282,10 +278,7 @@ def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str: "kty": "oct", } - if as_dict: - return jwk - else: - return json.dumps(jwk) + return jwk if as_dict else json.dumps(jwk) @staticmethod def from_jwk(jwk: str | JWKDict) -> bytes: @@ -296,8 +289,8 @@ def from_jwk(jwk: str | JWKDict) -> bytes: obj = jwk else: raise ValueError - except ValueError: - raise InvalidKeyError("Key is not valid JSON") + except ValueError as e: + raise InvalidKeyError("Key is not valid JSON") from e if obj.get("kty") != "oct": raise InvalidKeyError("Not an HMAC key") @@ -345,8 +338,10 @@ def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys: except ValueError: try: return cast(RSAPublicKey, load_pem_public_key(key_bytes)) - except (ValueError, UnsupportedAlgorithm): - raise InvalidKeyError("Could not parse the provided public key.") + except (ValueError, UnsupportedAlgorithm) as e: + raise InvalidKeyError( + "Could not parse the provided public key." + ) from e @overload @staticmethod @@ -394,10 +389,7 @@ def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str: else: raise InvalidKeyError("Not a public or private key") - if as_dict: - return obj - else: - return json.dumps(obj) + return obj if as_dict else json.dumps(obj) @staticmethod def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys: @@ -408,8 +400,8 @@ def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys: obj = jwk else: raise ValueError - except ValueError: - raise InvalidKeyError("Key is not valid JSON") + except ValueError as e: + raise InvalidKeyError("Key is not valid JSON") from e if obj.get("kty") != "RSA": raise InvalidKeyError("Not an RSA key") @@ -494,12 +486,15 @@ class ECAlgorithm(Algorithm): def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None: self.hash_alg = hash_alg - def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys: + def prepare_key(self, key: AllowedECKeys | str | bytes | dict) -> AllowedECKeys: if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)): return key + if isinstance(key, dict): + return self._load_jwk(key) + if not isinstance(key, (bytes, str)): - raise TypeError("Expecting a PEM-formatted key.") + raise TypeError("Expecting a PEM-formatted key or JWK.") key_bytes = force_bytes(key) @@ -524,6 +519,38 @@ def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys: return crypto_key + def _load_jwk(self, jwk: dict) -> EllipticCurvePublicKey: + if jwk.get("kty") != "EC": + raise InvalidKeyError("Not an EC key") + + curve = self._get_curve(jwk["crv"]) + x = self._base64url_decode(jwk["x"]) + y = self._base64url_decode(jwk["y"]) + + public_numbers = ec.EllipticCurvePublicNumbers( + x=int.from_bytes(x, byteorder="big"), + y=int.from_bytes(y, byteorder="big"), + curve=curve, + ) + + return public_numbers.public_key() + + def _get_curve(self, crv: str) -> ec.EllipticCurve: + if crv == "P-256": + return ec.SECP256R1() + elif crv == "P-384": + return ec.SECP384R1() + elif crv == "P-521": + return ec.SECP521R1() + elif crv == "secp256k1": + return ec.SECP256K1() + else: + raise InvalidKeyError(f"Invalid curve: {crv}") + + def _base64url_decode(self, input: str) -> bytes: + input += "=" * (4 - len(input) % 4) + return base64.urlsafe_b64decode(input) + def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes: der_sig = key.sign(msg, ECDSA(self.hash_alg())) @@ -590,10 +617,7 @@ def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str: key_obj.private_numbers().private_value ).decode() - if as_dict: - return obj - else: - return json.dumps(obj) + return obj if as_dict else json.dumps(obj) @staticmethod def from_jwk(jwk: str | JWKDict) -> AllowedECKeys: @@ -604,8 +628,8 @@ def from_jwk(jwk: str | JWKDict) -> AllowedECKeys: obj = jwk else: raise ValueError - except ValueError: - raise InvalidKeyError("Key is not valid JSON") + except ValueError as e: + raise InvalidKeyError("Key is not valid JSON") from e if obj.get("kty") != "EC": raise InvalidKeyError("Not an Elliptic curve key") @@ -712,7 +736,7 @@ def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys: key = load_pem_public_key(key_bytes) # type: ignore[assignment] elif "-----BEGIN PRIVATE" in key_str: key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment] - elif key_str[0:4] == "ssh-": + elif key_str[:4] == "ssh-": key = load_ssh_public_key(key_bytes) # type: ignore[assignment] # Explicit check the key to prevent confusing errors from cryptography @@ -792,10 +816,7 @@ def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str: "crv": crv, } - if as_dict: - return obj - else: - return json.dumps(obj) + return obj if as_dict else json.dumps(obj) if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)): d = key.private_bytes( @@ -817,11 +838,7 @@ def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str: "crv": crv, } - if as_dict: - return obj - else: - return json.dumps(obj) - + return obj if as_dict else json.dumps(obj) raise InvalidKeyError("Not a public or private key") @staticmethod @@ -833,14 +850,14 @@ def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys: obj = jwk else: raise ValueError - except ValueError: - raise InvalidKeyError("Key is not valid JSON") + except ValueError as e: + raise InvalidKeyError("Key is not valid JSON") from e if obj.get("kty") != "OKP": raise InvalidKeyError("Not an Octet Key Pair") curve = obj.get("crv") - if curve != "Ed25519" and curve != "Ed448": + if curve not in ["Ed25519", "Ed448"]: raise InvalidKeyError(f"Invalid curve: {curve}") if "x" not in obj: diff --git a/jwt/api_jwk.py b/jwt/api_jwk.py index 02f4679c..531dd25c 100644 --- a/jwt/api_jwk.py +++ b/jwt/api_jwk.py @@ -110,7 +110,7 @@ def __init__(self, keys: list[JWKDict]) -> None: # skip unusable keys continue - if len(self.keys) == 0: + if not self.keys: raise PyJWKSetError( "The JWK Set did not contain any usable keys. Perhaps 'cryptography' is not installed?" ) diff --git a/jwt/api_jws.py b/jwt/api_jws.py index 7089d9d9..ff799516 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -112,8 +112,6 @@ def encode( is_payload_detached: bool = False, sort_headers: bool = True, ) -> str: - segments = [] - # declare a new var to narrow the type for type checkers if algorithm is None: if isinstance(key, PyJWK): @@ -125,8 +123,7 @@ def encode( # Prefer headers values if present to function parameters. if headers: - headers_alg = headers.get("alg") - if headers_alg: + if headers_alg := headers.get("alg"): algorithm_ = headers["alg"] headers_b64 = headers.get("b64") @@ -138,7 +135,7 @@ def encode( if headers: self._validate_headers(headers) - header.update(headers) + header |= headers if not header["typ"]: del header["typ"] @@ -153,12 +150,8 @@ def encode( header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers ).encode() - segments.append(base64url_encode(json_header)) - - if is_payload_detached: - msg_payload = payload - else: - msg_payload = base64url_encode(payload) + segments = [base64url_encode(json_header)] + msg_payload = payload if is_payload_detached else base64url_encode(payload) segments.append(msg_payload) # Segments diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index 91ad266f..9bd354a4 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -173,7 +173,7 @@ def _decode_payload(self, decoded: dict[str, Any]) -> Any: try: payload = json.loads(decoded["payload"]) except ValueError as e: - raise DecodeError(f"Invalid payload string: {e}") + raise DecodeError(f"Invalid payload string: {e}") from e if not isinstance(payload, dict): raise DecodeError("Invalid payload string: must be a json object") return payload @@ -268,8 +268,10 @@ def _validate_iat( ) -> None: try: iat = int(payload["iat"]) - except ValueError: - raise InvalidIssuedAtError("Issued At claim (iat) must be an integer.") + except ValueError as e: + raise InvalidIssuedAtError( + "Issued At claim (iat) must be an integer." + ) from e if iat > (now + leeway): raise ImmatureSignatureError("The token is not yet valid (iat)") @@ -281,8 +283,8 @@ def _validate_nbf( ) -> None: try: nbf = int(payload["nbf"]) - except ValueError: - raise DecodeError("Not Before claim (nbf) must be an integer.") + except ValueError as e: + raise DecodeError("Not Before claim (nbf) must be an integer.") from e if nbf > (now + leeway): raise ImmatureSignatureError("The token is not yet valid (nbf)") @@ -295,8 +297,8 @@ def _validate_exp( ) -> None: try: exp = int(payload["exp"]) - except ValueError: - raise DecodeError("Expiration Time claim (exp) must be an integer.") + except ValueError as e: + raise DecodeError("Expiration Time claim (exp) must be an integer.") from e if exp <= (now - leeway): raise ExpiredSignatureError("Signature has expired") @@ -358,12 +360,13 @@ def _validate_iss(self, payload: dict[str, Any], issuer: Any) -> None: if "iss" not in payload: raise MissingRequiredClaimError("iss") - if isinstance(issuer, Sequence): - if payload["iss"] not in issuer: - raise InvalidIssuerError("Invalid issuer") - else: - if payload["iss"] != issuer: - raise InvalidIssuerError("Invalid issuer") + if ( + isinstance(issuer, Sequence) + and payload["iss"] not in issuer + or not isinstance(issuer, Sequence) + and payload["iss"] != issuer + ): + raise InvalidIssuerError("Invalid issuer") _jwt_global_obj = PyJWT() diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index f19b10ac..25bab95d 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -58,7 +58,7 @@ def fetch_data(self) -> Any: except (URLError, TimeoutError) as e: raise PyJWKClientConnectionError( f'Fail to fetch data from the url, err: "{e}"' - ) + ) from e else: return jwk_set finally: @@ -80,17 +80,15 @@ def get_jwk_set(self, refresh: bool = False) -> PyJWKSet: def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]: jwk_set = self.get_jwk_set(refresh) - signing_keys = [ + if signing_keys := [ jwk_set_key for jwk_set_key in jwk_set.keys if jwk_set_key.public_key_use in ["sig", None] and jwk_set_key.key_id - ] - - if not signing_keys: + ]: + return signing_keys + else: raise PyJWKClientError("The JWKS endpoint did not contain any signing keys") - return signing_keys - def get_signing_key(self, kid: str) -> PyJWK: signing_keys = self.get_signing_keys() signing_key = self.match_kid(signing_keys, kid) @@ -100,10 +98,10 @@ def get_signing_key(self, kid: str) -> PyJWK: signing_keys = self.get_signing_keys(refresh=True) signing_key = self.match_kid(signing_keys, kid) - if not signing_key: - raise PyJWKClientError( - f'Unable to find a signing key that matches: "{kid}"' - ) + if not signing_key: + raise PyJWKClientError( + f'Unable to find a signing key that matches: "{kid}"' + ) return signing_key @@ -114,11 +112,4 @@ def get_signing_key_from_jwt(self, token: str) -> PyJWK: @staticmethod def match_kid(signing_keys: List[PyJWK], kid: str) -> Optional[PyJWK]: - signing_key = None - - for key in signing_keys: - if key.key_id == kid: - signing_key = key - break - - return signing_key + return next((key for key in signing_keys if key.key_id == kid), None) diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 337de96a..ee54bf3d 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -91,6 +91,19 @@ def test_hmac_jwk_should_parse_and_verify(self): signature = algo.sign(b"Hello World!", key) assert algo.verify(b"Hello World!", key, signature) + def test_ec_should_accept_jwk(self): + algo = ECAlgorithm(ECAlgorithm.SHA256) + + jwk = { + "crv": "P-256", + "kty": "EC", + "x": "PY5pUvmWTEz5mCVir-Tyfi1M0q07_qaZSU_UAN3HBSI", + "y": "aH9ZAGpTidZjxNu2zKXeX9koNQX_BAtIBCa-h7YC_B0", + } + + pub_key = algo.prepare_key(jwk) + assert isinstance(pub_key, EllipticCurvePublicKey) + @pytest.mark.parametrize("as_dict", (False, True)) def test_hmac_to_jwk_returns_correct_values(self, as_dict): algo = HMACAlgorithm(HMACAlgorithm.SHA256)