diff --git a/jose/backends/cryptography_backend.py b/jose/backends/cryptography_backend.py index abd24260..f392b098 100644 --- a/jose/backends/cryptography_backend.py +++ b/jose/backends/cryptography_backend.py @@ -1,3 +1,4 @@ +import re import math import warnings @@ -22,6 +23,68 @@ _binding = None +# Based on https://github.com/jpadilla/pyjwt/commit/9c528670c455b8d948aff95ed50e22940d1ad3fc +# Based on https://github.com/hynek/pem/blob/7ad94db26b0bc21d10953f5dbad3acfdfacf57aa/src/pem/_core.py#L224-L252 +_PEMS = { + b"CERTIFICATE", + b"TRUSTED CERTIFICATE", + b"PRIVATE KEY", + b"PUBLIC KEY", + b"ENCRYPTED PRIVATE KEY", + b"OPENSSH PRIVATE KEY", + b"DSA PRIVATE KEY", + b"RSA PRIVATE KEY", + b"RSA PUBLIC KEY", + b"EC PRIVATE KEY", + b"DH PARAMETERS", + b"NEW CERTIFICATE REQUEST", + b"CERTIFICATE REQUEST", + b"SSH2 PUBLIC KEY", + b"SSH2 ENCRYPTED PRIVATE KEY", + b"X509 CRL", +} + + +_PEM_RE = re.compile( + b"----[- ]BEGIN (" + + b"|".join(_PEMS) + + b""")[- ]----\r? +.+?\r? +----[- ]END \\1[- ]----\r?\n?""", + re.DOTALL, +) + + +def is_pem_format(key): + return bool(_PEM_RE.search(key)) + + +# Based on https://github.com/pyca/cryptography/blob/bcb70852d577b3f490f015378c75cba74986297b/src/cryptography/hazmat/primitives/serialization/ssh.py#L40-L46 +_CERT_SUFFIX = b"-cert-v01@openssh.com" +_SSH_PUBKEY_RC = re.compile(br"\A(\S+)[ \t]+(\S+)") +_SSH_KEY_FORMATS = [ + b"ssh-ed25519", + b"ssh-rsa", + b"ssh-dss", + b"ecdsa-sha2-nistp256", + b"ecdsa-sha2-nistp384", + b"ecdsa-sha2-nistp521", +] + + +def is_ssh_key(key): + if any(string_value in key for string_value in _SSH_KEY_FORMATS): + return True + + ssh_pubkey_match = _SSH_PUBKEY_RC.match(key) + if ssh_pubkey_match: + key_type = ssh_pubkey_match.group(1) + if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]: + return True + + return False + + def get_random_bytes(num_bytes): """ Get random bytes @@ -552,14 +615,7 @@ def __init__(self, key, algorithm): if isinstance(key, str): key = key.encode("utf-8") - invalid_strings = [ - b"-----BEGIN PUBLIC KEY-----", - b"-----BEGIN RSA PUBLIC KEY-----", - b"-----BEGIN CERTIFICATE-----", - b"ssh-rsa", - ] - - if any(string_value in key for string_value in invalid_strings): + if is_pem_format(key) or is_ssh_key(key): raise JWKError( "The specified key is an asymmetric key or x509 certificate and" " should not be used as an HMAC secret." diff --git a/tests/test_jwt.py b/tests/test_jwt.py index 8c2e262f..378504f8 100644 --- a/tests/test_jwt.py +++ b/tests/test_jwt.py @@ -5,7 +5,7 @@ import pytest from jose import jws, jwt -from jose.exceptions import JWTError +from jose.exceptions import JWTError, JWKError @pytest.fixture @@ -522,3 +522,36 @@ def test_require(self, claims, key, claim, value): new_claims[claim] = value token = jwt.encode(new_claims, key) jwt.decode(token, key, options=options, audience=str(value)) + + def test_CVE_2024_33663(self): + """Test based on https://github.com/mpdavis/python-jose/issues/346""" + from Crypto.PublicKey import ECC + from Crypto.Hash import HMAC, SHA256 + + # ----- SETUP ----- + # generate an asymmetric ECC keypair + # !! signing should only be possible with the private key !! + KEY = ECC.generate(curve='P-256') + + # PUBLIC KEY, AVAILABLE TO USER + # CAN BE RECOVERED THROUGH E.G. PUBKEY RECOVERY WITH TWO SIGNATURES: + # https://en.wikipedia.org/wiki/Elliptic_Curve_Digital_Signature_Algorithm#Public_key_recovery + # https://github.com/FlorianPicca/JWT-Key-Recovery + PUBKEY = KEY.public_key().export_key(format='OpenSSH').encode() + + # ---- CLIENT SIDE ----- + # without knowing the private key, a valid token can be constructed + # YIKES!! + + b64 = lambda x:base64.urlsafe_b64encode(x).replace(b'=',b'') + payload = b64(b'{"alg":"HS256"}') + b'.' + b64(b'{"pwned":true}') + hasher = HMAC.new(PUBKEY, digestmod=SHA256) + hasher.update(payload) + evil_token = payload + b'.' + b64(hasher.digest()) + + # ---- SERVER SIDE ----- + # verify and decode the token using the public key, as is custom + # algorithm field is left unspecified + # but the library will happily still verify without warning, trusting the user-controlled alg field of the token header + with pytest.raises(JWKError): + data = jwt.decode(evil_token, PUBKEY)