From 10d0f09b5ee1babeadbb7166359ba9698e109cba Mon Sep 17 00:00:00 2001 From: Pedro Aguiar Date: Thu, 1 Aug 2024 17:01:19 -0400 Subject: [PATCH] Improve type annotations for the option parameter in decode methods. The options parameter currently accepts any key-value pair so long as the key is of type string (`options: dict[str, Any] | None = None`) and overwrites the default values with the arguments, as opposed to being limited to a strict set of keys and values (see [TypedDict](https://typing.readthedocs.io/en/latest/spec/typeddict.html#class-based-syntax)). Fixes: #869 *** Here's a minimal reproducible example you can try at the [MyPy Playground](https://mypy-play.net/?mypy=latest&python=3.12). ```python from typing import TypedDict, List class JwtOptions(TypedDict): verify_signature: bool verify_exp: bool verify_nbf: bool verify_iat: bool verify_aud: bool verify_iss: bool require: List[str] class PyJWT: def __init__(self, options: JwtOptions | None = None) -> None: if options is None: self.options: JwtOptions = self._get_default_options() @staticmethod def _get_default_options() -> JwtOptions: return { "verify_signature": True, "verify_exp": True, "verify_nbf": True, "verify_iat": True, "verify_aud": True, "verify_iss": True, "require": [], } def decode(self, options: JwtOptions) -> None: pass pyJWT = PyJWT() pyJWT.decode(options={ 'verify_signature': True, 'verify_exp': True, 'verify_nbf': True, 'verify_iat': True, 'verify_aud': True, 'verify_iss': True, 'required': [] # misspelled key }) ``` Signed-off-by: Pedro Aguiar --- jwt/api_jws.py | 20 +++++++++++--------- jwt/api_jwt.py | 33 +++++++++++++++++---------------- jwt/jwks_client.py | 1 + jwt/types.py | 18 +++++++++++++++++- tests/test_api_jws.py | 4 ---- 5 files changed, 46 insertions(+), 30 deletions(-) diff --git a/jwt/api_jws.py b/jwt/api_jws.py index 5822ebf6..92c90063 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -18,6 +18,7 @@ InvalidSignatureError, InvalidTokenError, ) +from .types import JwsOptions from .utils import base64url_decode, base64url_encode from .warnings import RemovedInPyjwt3Warning @@ -31,7 +32,7 @@ class PyJWS: def __init__( self, algorithms: list[str] | None = None, - options: dict[str, Any] | None = None, + options: JwsOptions | None = None, ) -> None: self._algorithms = get_default_algorithms() self._valid_algs = ( @@ -44,11 +45,12 @@ def __init__( del self._algorithms[key] if options is None: - options = {} - self.options = {**self._get_default_options(), **options} + self.options = self._get_default_options() + else: + self.options = options @staticmethod - def _get_default_options() -> dict[str, bool]: + def _get_default_options() -> JwsOptions: return {"verify_signature": True} def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None: @@ -175,7 +177,7 @@ def decode_complete( jwt: str | bytes, key: AllowedPublicKeys | PyJWK | str | bytes = "", algorithms: list[str] | None = None, - options: dict[str, Any] | None = None, + options: JwsOptions | None = None, detached_payload: bytes | None = None, **kwargs, ) -> dict[str, Any]: @@ -187,9 +189,9 @@ def decode_complete( RemovedInPyjwt3Warning, ) if options is None: - options = {} - merged_options = {**self.options, **options} - verify_signature = merged_options["verify_signature"] + options = self.options + + verify_signature = options["verify_signature"] if verify_signature and not algorithms and not isinstance(key, PyJWK): raise DecodeError( @@ -220,7 +222,7 @@ def decode( jwt: str | bytes, key: AllowedPublicKeys | PyJWK | str | bytes = "", algorithms: list[str] | None = None, - options: dict[str, Any] | None = None, + options: JwsOptions | None = None, detached_payload: bytes | None = None, **kwargs, ) -> Any: diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index 7a07c336..56b7b54c 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -17,6 +17,7 @@ InvalidIssuerError, MissingRequiredClaimError, ) +from .types import JwtOptions from .warnings import RemovedInPyjwt3Warning if TYPE_CHECKING: @@ -25,13 +26,14 @@ class PyJWT: - def __init__(self, options: dict[str, Any] | None = None) -> None: + def __init__(self, options: JwtOptions | None = None) -> None: if options is None: - options = {} - self.options: dict[str, Any] = {**self._get_default_options(), **options} + self.options = self._get_default_options() + else: + self.options = options @staticmethod - def _get_default_options() -> dict[str, bool | list[str]]: + def _get_default_options() -> JwtOptions: return { "verify_signature": True, "verify_exp": True, @@ -103,7 +105,7 @@ def decode_complete( jwt: str | bytes, key: AllowedPublicKeys | PyJWK | str | bytes = "", algorithms: list[str] | None = None, - options: dict[str, Any] | None = None, + options: JwtOptions | None = None, # deprecated arg, remove in pyjwt3 verify: bool | None = None, # could be used as passthrough to api_jws, consider removal in pyjwt3 @@ -115,7 +117,7 @@ def decode_complete( leeway: float | timedelta = 0, # kwargs **kwargs: Any, - ) -> dict[str, Any]: + ) -> Any: if kwargs: warnings.warn( "passing additional kwargs to decode_complete() is deprecated " @@ -123,8 +125,14 @@ def decode_complete( f"Unsupported kwargs: {tuple(kwargs.keys())}", RemovedInPyjwt3Warning, ) - options = dict(options or {}) # shallow-copy or initialize an empty dict - options.setdefault("verify_signature", True) + if options is None: + options = self.options + if options["verify_signature"] is False: + options.setdefault("verify_exp", False) + options.setdefault("verify_nbf", False) + options.setdefault("verify_iat", False) + options.setdefault("verify_aud", False) + options.setdefault("verify_iss", False) # If the user has set the legacy `verify` argument, and it doesn't match # what the relevant `options` entry for the argument is, inform the user @@ -137,13 +145,6 @@ def decode_complete( category=DeprecationWarning, ) - if not options["verify_signature"]: - options.setdefault("verify_exp", False) - options.setdefault("verify_nbf", False) - options.setdefault("verify_iat", False) - options.setdefault("verify_aud", False) - options.setdefault("verify_iss", False) - if options["verify_signature"] and not algorithms: raise DecodeError( 'It is required that you pass in a value for the "algorithms" argument when calling decode().' @@ -188,7 +189,7 @@ def decode( jwt: str | bytes, key: AllowedPublicKeys | PyJWK | str | bytes = "", algorithms: list[str] | None = None, - options: dict[str, Any] | None = None, + options: JwtOptions | None = None, # deprecated arg, remove in pyjwt3 verify: bool | None = None, # could be used as passthrough to api_jws, consider removal in pyjwt3 diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index f19b10ac..10ce7271 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -9,6 +9,7 @@ from .api_jwt import decode_complete as decode_token from .exceptions import PyJWKClientConnectionError, PyJWKClientError from .jwk_set_cache import JWKSetCache +from .types import JwtOptions class PyJWKClient: diff --git a/jwt/types.py b/jwt/types.py index 7d993520..58e78922 100644 --- a/jwt/types.py +++ b/jwt/types.py @@ -1,5 +1,21 @@ -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, List, NotRequired, TypedDict JWKDict = Dict[str, Any] HashlibHash = Callable[..., Any] + + +# TODO: Make fields mandatory in PyJWT3 +# See: https://peps.python.org/pep-0589/#inheritance +class JwtOptionsEncode(TypedDict): + verify_signature: NotRequired[bool] + verify_exp: NotRequired[bool] + verify_nbf: NotRequired[bool] + verify_iat: NotRequired[bool] + verify_aud: NotRequired[bool] + verify_iss: NotRequired[bool] + require: NotRequired[List[str]] + + +class JwsOptions(TypedDict): + verify_signature: NotRequired[bool] diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index c764f09f..35e9b2ef 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -83,10 +83,6 @@ def test_non_object_options_dont_persist(self, jws, payload): assert jws.options["verify_signature"] - def test_options_must_be_dict(self): - pytest.raises(TypeError, PyJWS, options=object()) - pytest.raises((TypeError, ValueError), PyJWS, options=("something")) - def test_encode_decode(self, jws, payload): secret = "secret" jws_message = jws.encode(payload, secret, algorithm="HS256")