Skip to content

Commit 2306f68

Browse files
authored
Make mypy configuration stricter and improve typing (#830)
* PyJWS._verify_signature: raise early KeyError if header is missing alg * Make Mypy configuration stricter * Improve typing in jwt.utils * Improve typing in jwt.help * Improve typing in jwt.exceptions * Improve typing in jwt.api_jwk * Improve typing in jwt.api_jws * Improve typing & clean up imports in jwt.algorithms * Correct JWS.decode rettype to any (payload could be something else) * Update typing in api_jwt * Improve typing in jwks_client * Improve typing in docs/conf.py * Fix (benign) mistyping in test_advisory * Fix misc type complaints in tests
1 parent fb9b311 commit 2306f68

14 files changed

+187
-116
lines changed

docs/conf.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import sphinx_rtd_theme
55

66

7-
def read(*parts):
7+
def read(*parts) -> str:
88
"""
99
Build an absolute path from *parts* and and return the contents of the
1010
resulting file. Assume UTF-8 encoding.
@@ -14,7 +14,7 @@ def read(*parts):
1414
return f.read()
1515

1616

17-
def find_version(*file_paths):
17+
def find_version(*file_paths) -> str:
1818
"""
1919
Build a path from *file_paths* and search for a ``__version__``
2020
string inside.

jwt/algorithms.py

+57-42
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import hashlib
22
import hmac
33
import json
4+
from typing import Any, Dict, Union
45

56
from .exceptions import InvalidKeyError
7+
from .types import JWKDict
68
from .utils import (
79
base64url_decode,
810
base64url_encode,
@@ -20,10 +22,18 @@
2022
from cryptography.exceptions import InvalidSignature
2123
from cryptography.hazmat.backends import default_backend
2224
from cryptography.hazmat.primitives import hashes
23-
from cryptography.hazmat.primitives.asymmetric import ec, padding
25+
from cryptography.hazmat.primitives.asymmetric import padding
2426
from cryptography.hazmat.primitives.asymmetric.ec import (
27+
ECDSA,
28+
SECP256K1,
29+
SECP256R1,
30+
SECP384R1,
31+
SECP521R1,
32+
EllipticCurve,
2533
EllipticCurvePrivateKey,
34+
EllipticCurvePrivateNumbers,
2635
EllipticCurvePublicKey,
36+
EllipticCurvePublicNumbers,
2737
)
2838
from cryptography.hazmat.primitives.asymmetric.ed448 import (
2939
Ed448PrivateKey,
@@ -73,7 +83,7 @@
7383
}
7484

7585

76-
def get_default_algorithms():
86+
def get_default_algorithms() -> Dict[str, "Algorithm"]:
7787
"""
7888
Returns the algorithms that are implemented by the library.
7989
"""
@@ -130,40 +140,44 @@ def compute_hash_digest(self, bytestr: bytes) -> bytes:
130140
):
131141
digest = hashes.Hash(hash_alg(), backend=default_backend())
132142
digest.update(bytestr)
133-
return digest.finalize()
143+
return bytes(digest.finalize())
134144
else:
135-
return hash_alg(bytestr).digest()
145+
return bytes(hash_alg(bytestr).digest())
136146

137-
def prepare_key(self, key):
147+
# TODO: all key-related `Any`s in this class should optimally be made
148+
# variadic (TypeVar) but as discussed in https://github.com/jpadilla/pyjwt/pull/605
149+
# that may still be poorly supported.
150+
151+
def prepare_key(self, key: Any) -> Any:
138152
"""
139153
Performs necessary validation and conversions on the key and returns
140154
the key value in the proper format for sign() and verify().
141155
"""
142156
raise NotImplementedError
143157

144-
def sign(self, msg, key):
158+
def sign(self, msg: bytes, key: Any) -> bytes:
145159
"""
146160
Returns a digital signature for the specified message
147161
using the specified key value.
148162
"""
149163
raise NotImplementedError
150164

151-
def verify(self, msg, key, sig):
165+
def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
152166
"""
153167
Verifies that the specified digital signature is valid
154168
for the specified message and key values.
155169
"""
156170
raise NotImplementedError
157171

158172
@staticmethod
159-
def to_jwk(key_obj):
173+
def to_jwk(key_obj) -> JWKDict:
160174
"""
161175
Serializes a given RSA key into a JWK
162176
"""
163177
raise NotImplementedError
164178

165179
@staticmethod
166-
def from_jwk(jwk):
180+
def from_jwk(jwk: JWKDict):
167181
"""
168182
Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
169183
"""
@@ -202,7 +216,7 @@ class HMACAlgorithm(Algorithm):
202216
SHA384 = hashlib.sha384
203217
SHA512 = hashlib.sha512
204218

205-
def __init__(self, hash_alg):
219+
def __init__(self, hash_alg) -> None:
206220
self.hash_alg = hash_alg
207221

208222
def prepare_key(self, key):
@@ -242,7 +256,7 @@ def from_jwk(jwk):
242256

243257
return base64url_decode(obj["k"])
244258

245-
def sign(self, msg, key):
259+
def sign(self, msg: bytes, key: bytes) -> bytes:
246260
return hmac.new(key, msg, self.hash_alg).digest()
247261

248262
def verify(self, msg, key, sig):
@@ -261,7 +275,7 @@ class RSAAlgorithm(Algorithm):
261275
SHA384 = hashes.SHA384
262276
SHA512 = hashes.SHA512
263277

264-
def __init__(self, hash_alg):
278+
def __init__(self, hash_alg) -> None:
265279
self.hash_alg = hash_alg
266280

267281
def prepare_key(self, key):
@@ -271,16 +285,15 @@ def prepare_key(self, key):
271285
if not isinstance(key, (bytes, str)):
272286
raise TypeError("Expecting a PEM-formatted key.")
273287

274-
key = force_bytes(key)
288+
key_bytes = force_bytes(key)
275289

276290
try:
277-
if key.startswith(b"ssh-rsa"):
278-
key = load_ssh_public_key(key)
291+
if key_bytes.startswith(b"ssh-rsa"):
292+
return load_ssh_public_key(key_bytes)
279293
else:
280-
key = load_pem_private_key(key, password=None)
294+
return load_pem_private_key(key_bytes, password=None)
281295
except ValueError:
282-
key = load_pem_public_key(key)
283-
return key
296+
return load_pem_public_key(key_bytes)
284297

285298
@staticmethod
286299
def to_jwk(key_obj):
@@ -383,12 +396,10 @@ def from_jwk(jwk):
383396
return numbers.private_key()
384397
elif "n" in obj and "e" in obj:
385398
# Public key
386-
numbers = RSAPublicNumbers(
399+
return RSAPublicNumbers(
387400
from_base64url_uint(obj["e"]),
388401
from_base64url_uint(obj["n"]),
389-
)
390-
391-
return numbers.public_key()
402+
).public_key()
392403
else:
393404
raise InvalidKeyError("Not a public or private key")
394405

@@ -412,7 +423,7 @@ class ECAlgorithm(Algorithm):
412423
SHA384 = hashes.SHA384
413424
SHA512 = hashes.SHA512
414425

415-
def __init__(self, hash_alg):
426+
def __init__(self, hash_alg) -> None:
416427
self.hash_alg = hash_alg
417428

418429
def prepare_key(self, key):
@@ -422,18 +433,18 @@ def prepare_key(self, key):
422433
if not isinstance(key, (bytes, str)):
423434
raise TypeError("Expecting a PEM-formatted key.")
424435

425-
key = force_bytes(key)
436+
key_bytes = force_bytes(key)
426437

427438
# Attempt to load key. We don't know if it's
428439
# a Signing Key or a Verifying Key, so we try
429440
# the Verifying Key first.
430441
try:
431-
if key.startswith(b"ecdsa-sha2-"):
432-
key = load_ssh_public_key(key)
442+
if key_bytes.startswith(b"ecdsa-sha2-"):
443+
key = load_ssh_public_key(key_bytes)
433444
else:
434-
key = load_pem_public_key(key)
445+
key = load_pem_public_key(key_bytes)
435446
except ValueError:
436-
key = load_pem_private_key(key, password=None)
447+
key = load_pem_private_key(key_bytes, password=None)
437448

438449
# Explicit check the key to prevent confusing errors from cryptography
439450
if not isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
@@ -444,7 +455,7 @@ def prepare_key(self, key):
444455
return key
445456

446457
def sign(self, msg, key):
447-
der_sig = key.sign(msg, ec.ECDSA(self.hash_alg()))
458+
der_sig = key.sign(msg, ECDSA(self.hash_alg()))
448459

449460
return der_to_raw_signature(der_sig, key.curve)
450461

@@ -457,7 +468,7 @@ def verify(self, msg, key, sig):
457468
try:
458469
if isinstance(key, EllipticCurvePrivateKey):
459470
key = key.public_key()
460-
key.verify(der_sig, msg, ec.ECDSA(self.hash_alg()))
471+
key.verify(der_sig, msg, ECDSA(self.hash_alg()))
461472
return True
462473
except InvalidSignature:
463474
return False
@@ -472,13 +483,13 @@ def to_jwk(key_obj):
472483
else:
473484
raise InvalidKeyError("Not a public or private key")
474485

475-
if isinstance(key_obj.curve, ec.SECP256R1):
486+
if isinstance(key_obj.curve, SECP256R1):
476487
crv = "P-256"
477-
elif isinstance(key_obj.curve, ec.SECP384R1):
488+
elif isinstance(key_obj.curve, SECP384R1):
478489
crv = "P-384"
479-
elif isinstance(key_obj.curve, ec.SECP521R1):
490+
elif isinstance(key_obj.curve, SECP521R1):
480491
crv = "P-521"
481-
elif isinstance(key_obj.curve, ec.SECP256K1):
492+
elif isinstance(key_obj.curve, SECP256K1):
482493
crv = "secp256k1"
483494
else:
484495
raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
@@ -498,7 +509,9 @@ def to_jwk(key_obj):
498509
return json.dumps(obj)
499510

500511
@staticmethod
501-
def from_jwk(jwk):
512+
def from_jwk(
513+
jwk: Any,
514+
) -> Union[EllipticCurvePublicKey, EllipticCurvePrivateKey]:
502515
try:
503516
if isinstance(jwk, str):
504517
obj = json.loads(jwk)
@@ -519,32 +532,34 @@ def from_jwk(jwk):
519532
y = base64url_decode(obj.get("y"))
520533

521534
curve = obj.get("crv")
535+
curve_obj: EllipticCurve
536+
522537
if curve == "P-256":
523538
if len(x) == len(y) == 32:
524-
curve_obj = ec.SECP256R1()
539+
curve_obj = SECP256R1()
525540
else:
526541
raise InvalidKeyError("Coords should be 32 bytes for curve P-256")
527542
elif curve == "P-384":
528543
if len(x) == len(y) == 48:
529-
curve_obj = ec.SECP384R1()
544+
curve_obj = SECP384R1()
530545
else:
531546
raise InvalidKeyError("Coords should be 48 bytes for curve P-384")
532547
elif curve == "P-521":
533548
if len(x) == len(y) == 66:
534-
curve_obj = ec.SECP521R1()
549+
curve_obj = SECP521R1()
535550
else:
536551
raise InvalidKeyError("Coords should be 66 bytes for curve P-521")
537552
elif curve == "secp256k1":
538553
if len(x) == len(y) == 32:
539-
curve_obj = ec.SECP256K1()
554+
curve_obj = SECP256K1()
540555
else:
541556
raise InvalidKeyError(
542557
"Coords should be 32 bytes for curve secp256k1"
543558
)
544559
else:
545560
raise InvalidKeyError(f"Invalid curve: {curve}")
546561

547-
public_numbers = ec.EllipticCurvePublicNumbers(
562+
public_numbers = EllipticCurvePublicNumbers(
548563
x=int.from_bytes(x, byteorder="big"),
549564
y=int.from_bytes(y, byteorder="big"),
550565
curve=curve_obj,
@@ -559,7 +574,7 @@ def from_jwk(jwk):
559574
"D should be {} bytes for curve {}", len(x), curve
560575
)
561576

562-
return ec.EllipticCurvePrivateNumbers(
577+
return EllipticCurvePrivateNumbers(
563578
int.from_bytes(d, byteorder="big"), public_numbers
564579
).private_key()
565580

@@ -600,7 +615,7 @@ class OKPAlgorithm(Algorithm):
600615
This class requires ``cryptography>=2.6`` to be installed.
601616
"""
602617

603-
def __init__(self, **kwargs):
618+
def __init__(self, **kwargs) -> None:
604619
pass
605620

606621
def prepare_key(self, key):

jwt/api_jwk.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
import json
44
import time
5+
from typing import Any, Optional
56

67
from .algorithms import get_default_algorithms
78
from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError
9+
from .types import JWKDict
810

911

1012
class PyJWK:
11-
def __init__(self, jwk_data, algorithm=None):
13+
def __init__(self, jwk_data: JWKDict, algorithm: Optional[str] = None) -> None:
1214
self._algorithms = get_default_algorithms()
1315
self._jwk_data = jwk_data
1416

@@ -55,29 +57,29 @@ def __init__(self, jwk_data, algorithm=None):
5557
self.key = self.Algorithm.from_jwk(self._jwk_data)
5658

5759
@staticmethod
58-
def from_dict(obj, algorithm=None):
60+
def from_dict(obj: JWKDict, algorithm: Optional[str] = None) -> "PyJWK":
5961
return PyJWK(obj, algorithm)
6062

6163
@staticmethod
62-
def from_json(data, algorithm=None):
64+
def from_json(data: str, algorithm: None = None) -> "PyJWK":
6365
obj = json.loads(data)
6466
return PyJWK.from_dict(obj, algorithm)
6567

6668
@property
67-
def key_type(self):
69+
def key_type(self) -> str:
6870
return self._jwk_data.get("kty", None)
6971

7072
@property
71-
def key_id(self):
73+
def key_id(self) -> str:
7274
return self._jwk_data.get("kid", None)
7375

7476
@property
75-
def public_key_use(self):
77+
def public_key_use(self) -> Optional[str]:
7678
return self._jwk_data.get("use", None)
7779

7880

7981
class PyJWKSet:
80-
def __init__(self, keys: list[dict]) -> None:
82+
def __init__(self, keys: list[JWKDict]) -> None:
8183
self.keys = []
8284

8385
if not keys:
@@ -97,16 +99,16 @@ def __init__(self, keys: list[dict]) -> None:
9799
raise PyJWKSetError("The JWK Set did not contain any usable keys")
98100

99101
@staticmethod
100-
def from_dict(obj):
102+
def from_dict(obj: dict[str, Any]) -> "PyJWKSet":
101103
keys = obj.get("keys", [])
102104
return PyJWKSet(keys)
103105

104106
@staticmethod
105-
def from_json(data):
107+
def from_json(data: str) -> "PyJWKSet":
106108
obj = json.loads(data)
107109
return PyJWKSet.from_dict(obj)
108110

109-
def __getitem__(self, kid):
111+
def __getitem__(self, kid: str) -> "PyJWK":
110112
for key in self.keys:
111113
if key.key_id == kid:
112114
return key
@@ -118,8 +120,8 @@ def __init__(self, jwk_set: PyJWKSet):
118120
self.jwk_set = jwk_set
119121
self.timestamp = time.monotonic()
120122

121-
def get_jwk_set(self):
123+
def get_jwk_set(self) -> PyJWKSet:
122124
return self.jwk_set
123125

124-
def get_timestamp(self):
126+
def get_timestamp(self) -> float:
125127
return self.timestamp

0 commit comments

Comments
 (0)