1
1
import hashlib
2
2
import hmac
3
3
import json
4
+ from typing import Any , Dict , Union
4
5
5
6
from .exceptions import InvalidKeyError
7
+ from .types import JWKDict
6
8
from .utils import (
7
9
base64url_decode ,
8
10
base64url_encode ,
20
22
from cryptography .exceptions import InvalidSignature
21
23
from cryptography .hazmat .backends import default_backend
22
24
from cryptography .hazmat .primitives import hashes
23
- from cryptography .hazmat .primitives .asymmetric import ec , padding
25
+ from cryptography .hazmat .primitives .asymmetric import padding
24
26
from cryptography .hazmat .primitives .asymmetric .ec import (
27
+ ECDSA ,
28
+ SECP256K1 ,
29
+ SECP256R1 ,
30
+ SECP384R1 ,
31
+ SECP521R1 ,
32
+ EllipticCurve ,
25
33
EllipticCurvePrivateKey ,
34
+ EllipticCurvePrivateNumbers ,
26
35
EllipticCurvePublicKey ,
36
+ EllipticCurvePublicNumbers ,
27
37
)
28
38
from cryptography .hazmat .primitives .asymmetric .ed448 import (
29
39
Ed448PrivateKey ,
73
83
}
74
84
75
85
76
- def get_default_algorithms ():
86
+ def get_default_algorithms () -> Dict [ str , "Algorithm" ] :
77
87
"""
78
88
Returns the algorithms that are implemented by the library.
79
89
"""
@@ -130,40 +140,44 @@ def compute_hash_digest(self, bytestr: bytes) -> bytes:
130
140
):
131
141
digest = hashes .Hash (hash_alg (), backend = default_backend ())
132
142
digest .update (bytestr )
133
- return digest .finalize ()
143
+ return bytes ( digest .finalize () )
134
144
else :
135
- return hash_alg (bytestr ).digest ()
145
+ return bytes ( hash_alg (bytestr ).digest () )
136
146
137
- def prepare_key (self , key ):
147
+ # TODO: all key-related `Any`s in this class should optimally be made
148
+ # variadic (TypeVar) but as discussed in https://github.com/jpadilla/pyjwt/pull/605
149
+ # that may still be poorly supported.
150
+
151
+ def prepare_key (self , key : Any ) -> Any :
138
152
"""
139
153
Performs necessary validation and conversions on the key and returns
140
154
the key value in the proper format for sign() and verify().
141
155
"""
142
156
raise NotImplementedError
143
157
144
- def sign (self , msg , key ) :
158
+ def sign (self , msg : bytes , key : Any ) -> bytes :
145
159
"""
146
160
Returns a digital signature for the specified message
147
161
using the specified key value.
148
162
"""
149
163
raise NotImplementedError
150
164
151
- def verify (self , msg , key , sig ) :
165
+ def verify (self , msg : bytes , key : Any , sig : bytes ) -> bool :
152
166
"""
153
167
Verifies that the specified digital signature is valid
154
168
for the specified message and key values.
155
169
"""
156
170
raise NotImplementedError
157
171
158
172
@staticmethod
159
- def to_jwk (key_obj ):
173
+ def to_jwk (key_obj ) -> JWKDict :
160
174
"""
161
175
Serializes a given RSA key into a JWK
162
176
"""
163
177
raise NotImplementedError
164
178
165
179
@staticmethod
166
- def from_jwk (jwk ):
180
+ def from_jwk (jwk : JWKDict ):
167
181
"""
168
182
Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
169
183
"""
@@ -202,7 +216,7 @@ class HMACAlgorithm(Algorithm):
202
216
SHA384 = hashlib .sha384
203
217
SHA512 = hashlib .sha512
204
218
205
- def __init__ (self , hash_alg ):
219
+ def __init__ (self , hash_alg ) -> None :
206
220
self .hash_alg = hash_alg
207
221
208
222
def prepare_key (self , key ):
@@ -242,7 +256,7 @@ def from_jwk(jwk):
242
256
243
257
return base64url_decode (obj ["k" ])
244
258
245
- def sign (self , msg , key ) :
259
+ def sign (self , msg : bytes , key : bytes ) -> bytes :
246
260
return hmac .new (key , msg , self .hash_alg ).digest ()
247
261
248
262
def verify (self , msg , key , sig ):
@@ -261,7 +275,7 @@ class RSAAlgorithm(Algorithm):
261
275
SHA384 = hashes .SHA384
262
276
SHA512 = hashes .SHA512
263
277
264
- def __init__ (self , hash_alg ):
278
+ def __init__ (self , hash_alg ) -> None :
265
279
self .hash_alg = hash_alg
266
280
267
281
def prepare_key (self , key ):
@@ -271,16 +285,15 @@ def prepare_key(self, key):
271
285
if not isinstance (key , (bytes , str )):
272
286
raise TypeError ("Expecting a PEM-formatted key." )
273
287
274
- key = force_bytes (key )
288
+ key_bytes = force_bytes (key )
275
289
276
290
try :
277
- if key .startswith (b"ssh-rsa" ):
278
- key = load_ssh_public_key (key )
291
+ if key_bytes .startswith (b"ssh-rsa" ):
292
+ return load_ssh_public_key (key_bytes )
279
293
else :
280
- key = load_pem_private_key (key , password = None )
294
+ return load_pem_private_key (key_bytes , password = None )
281
295
except ValueError :
282
- key = load_pem_public_key (key )
283
- return key
296
+ return load_pem_public_key (key_bytes )
284
297
285
298
@staticmethod
286
299
def to_jwk (key_obj ):
@@ -383,12 +396,10 @@ def from_jwk(jwk):
383
396
return numbers .private_key ()
384
397
elif "n" in obj and "e" in obj :
385
398
# Public key
386
- numbers = RSAPublicNumbers (
399
+ return RSAPublicNumbers (
387
400
from_base64url_uint (obj ["e" ]),
388
401
from_base64url_uint (obj ["n" ]),
389
- )
390
-
391
- return numbers .public_key ()
402
+ ).public_key ()
392
403
else :
393
404
raise InvalidKeyError ("Not a public or private key" )
394
405
@@ -412,7 +423,7 @@ class ECAlgorithm(Algorithm):
412
423
SHA384 = hashes .SHA384
413
424
SHA512 = hashes .SHA512
414
425
415
- def __init__ (self , hash_alg ):
426
+ def __init__ (self , hash_alg ) -> None :
416
427
self .hash_alg = hash_alg
417
428
418
429
def prepare_key (self , key ):
@@ -422,18 +433,18 @@ def prepare_key(self, key):
422
433
if not isinstance (key , (bytes , str )):
423
434
raise TypeError ("Expecting a PEM-formatted key." )
424
435
425
- key = force_bytes (key )
436
+ key_bytes = force_bytes (key )
426
437
427
438
# Attempt to load key. We don't know if it's
428
439
# a Signing Key or a Verifying Key, so we try
429
440
# the Verifying Key first.
430
441
try :
431
- if key .startswith (b"ecdsa-sha2-" ):
432
- key = load_ssh_public_key (key )
442
+ if key_bytes .startswith (b"ecdsa-sha2-" ):
443
+ key = load_ssh_public_key (key_bytes )
433
444
else :
434
- key = load_pem_public_key (key )
445
+ key = load_pem_public_key (key_bytes )
435
446
except ValueError :
436
- key = load_pem_private_key (key , password = None )
447
+ key = load_pem_private_key (key_bytes , password = None )
437
448
438
449
# Explicit check the key to prevent confusing errors from cryptography
439
450
if not isinstance (key , (EllipticCurvePrivateKey , EllipticCurvePublicKey )):
@@ -444,7 +455,7 @@ def prepare_key(self, key):
444
455
return key
445
456
446
457
def sign (self , msg , key ):
447
- der_sig = key .sign (msg , ec . ECDSA (self .hash_alg ()))
458
+ der_sig = key .sign (msg , ECDSA (self .hash_alg ()))
448
459
449
460
return der_to_raw_signature (der_sig , key .curve )
450
461
@@ -457,7 +468,7 @@ def verify(self, msg, key, sig):
457
468
try :
458
469
if isinstance (key , EllipticCurvePrivateKey ):
459
470
key = key .public_key ()
460
- key .verify (der_sig , msg , ec . ECDSA (self .hash_alg ()))
471
+ key .verify (der_sig , msg , ECDSA (self .hash_alg ()))
461
472
return True
462
473
except InvalidSignature :
463
474
return False
@@ -472,13 +483,13 @@ def to_jwk(key_obj):
472
483
else :
473
484
raise InvalidKeyError ("Not a public or private key" )
474
485
475
- if isinstance (key_obj .curve , ec . SECP256R1 ):
486
+ if isinstance (key_obj .curve , SECP256R1 ):
476
487
crv = "P-256"
477
- elif isinstance (key_obj .curve , ec . SECP384R1 ):
488
+ elif isinstance (key_obj .curve , SECP384R1 ):
478
489
crv = "P-384"
479
- elif isinstance (key_obj .curve , ec . SECP521R1 ):
490
+ elif isinstance (key_obj .curve , SECP521R1 ):
480
491
crv = "P-521"
481
- elif isinstance (key_obj .curve , ec . SECP256K1 ):
492
+ elif isinstance (key_obj .curve , SECP256K1 ):
482
493
crv = "secp256k1"
483
494
else :
484
495
raise InvalidKeyError (f"Invalid curve: { key_obj .curve } " )
@@ -498,7 +509,9 @@ def to_jwk(key_obj):
498
509
return json .dumps (obj )
499
510
500
511
@staticmethod
501
- def from_jwk (jwk ):
512
+ def from_jwk (
513
+ jwk : Any ,
514
+ ) -> Union [EllipticCurvePublicKey , EllipticCurvePrivateKey ]:
502
515
try :
503
516
if isinstance (jwk , str ):
504
517
obj = json .loads (jwk )
@@ -519,32 +532,34 @@ def from_jwk(jwk):
519
532
y = base64url_decode (obj .get ("y" ))
520
533
521
534
curve = obj .get ("crv" )
535
+ curve_obj : EllipticCurve
536
+
522
537
if curve == "P-256" :
523
538
if len (x ) == len (y ) == 32 :
524
- curve_obj = ec . SECP256R1 ()
539
+ curve_obj = SECP256R1 ()
525
540
else :
526
541
raise InvalidKeyError ("Coords should be 32 bytes for curve P-256" )
527
542
elif curve == "P-384" :
528
543
if len (x ) == len (y ) == 48 :
529
- curve_obj = ec . SECP384R1 ()
544
+ curve_obj = SECP384R1 ()
530
545
else :
531
546
raise InvalidKeyError ("Coords should be 48 bytes for curve P-384" )
532
547
elif curve == "P-521" :
533
548
if len (x ) == len (y ) == 66 :
534
- curve_obj = ec . SECP521R1 ()
549
+ curve_obj = SECP521R1 ()
535
550
else :
536
551
raise InvalidKeyError ("Coords should be 66 bytes for curve P-521" )
537
552
elif curve == "secp256k1" :
538
553
if len (x ) == len (y ) == 32 :
539
- curve_obj = ec . SECP256K1 ()
554
+ curve_obj = SECP256K1 ()
540
555
else :
541
556
raise InvalidKeyError (
542
557
"Coords should be 32 bytes for curve secp256k1"
543
558
)
544
559
else :
545
560
raise InvalidKeyError (f"Invalid curve: { curve } " )
546
561
547
- public_numbers = ec . EllipticCurvePublicNumbers (
562
+ public_numbers = EllipticCurvePublicNumbers (
548
563
x = int .from_bytes (x , byteorder = "big" ),
549
564
y = int .from_bytes (y , byteorder = "big" ),
550
565
curve = curve_obj ,
@@ -559,7 +574,7 @@ def from_jwk(jwk):
559
574
"D should be {} bytes for curve {}" , len (x ), curve
560
575
)
561
576
562
- return ec . EllipticCurvePrivateNumbers (
577
+ return EllipticCurvePrivateNumbers (
563
578
int .from_bytes (d , byteorder = "big" ), public_numbers
564
579
).private_key ()
565
580
@@ -600,7 +615,7 @@ class OKPAlgorithm(Algorithm):
600
615
This class requires ``cryptography>=2.6`` to be installed.
601
616
"""
602
617
603
- def __init__ (self , ** kwargs ):
618
+ def __init__ (self , ** kwargs ) -> None :
604
619
pass
605
620
606
621
def prepare_key (self , key ):
0 commit comments