diff --git a/jwt/api_jws.py b/jwt/api_jws.py index 5822ebf6..92c90063 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -18,6 +18,7 @@ InvalidSignatureError, InvalidTokenError, ) +from .types import JwsOptions from .utils import base64url_decode, base64url_encode from .warnings import RemovedInPyjwt3Warning @@ -31,7 +32,7 @@ class PyJWS: def __init__( self, algorithms: list[str] | None = None, - options: dict[str, Any] | None = None, + options: JwsOptions | None = None, ) -> None: self._algorithms = get_default_algorithms() self._valid_algs = ( @@ -44,11 +45,12 @@ def __init__( del self._algorithms[key] if options is None: - options = {} - self.options = {**self._get_default_options(), **options} + self.options = self._get_default_options() + else: + self.options = options @staticmethod - def _get_default_options() -> dict[str, bool]: + def _get_default_options() -> JwsOptions: return {"verify_signature": True} def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None: @@ -175,7 +177,7 @@ def decode_complete( jwt: str | bytes, key: AllowedPublicKeys | PyJWK | str | bytes = "", algorithms: list[str] | None = None, - options: dict[str, Any] | None = None, + options: JwsOptions | None = None, detached_payload: bytes | None = None, **kwargs, ) -> dict[str, Any]: @@ -187,9 +189,9 @@ def decode_complete( RemovedInPyjwt3Warning, ) if options is None: - options = {} - merged_options = {**self.options, **options} - verify_signature = merged_options["verify_signature"] + options = self.options + + verify_signature = options["verify_signature"] if verify_signature and not algorithms and not isinstance(key, PyJWK): raise DecodeError( @@ -220,7 +222,7 @@ def decode( jwt: str | bytes, key: AllowedPublicKeys | PyJWK | str | bytes = "", algorithms: list[str] | None = None, - options: dict[str, Any] | None = None, + options: JwsOptions | None = None, detached_payload: bytes | None = None, **kwargs, ) -> Any: diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index 7a07c336..56b7b54c 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -17,6 +17,7 @@ InvalidIssuerError, MissingRequiredClaimError, ) +from .types import JwtOptions from .warnings import RemovedInPyjwt3Warning if TYPE_CHECKING: @@ -25,13 +26,14 @@ class PyJWT: - def __init__(self, options: dict[str, Any] | None = None) -> None: + def __init__(self, options: JwtOptions | None = None) -> None: if options is None: - options = {} - self.options: dict[str, Any] = {**self._get_default_options(), **options} + self.options = self._get_default_options() + else: + self.options = options @staticmethod - def _get_default_options() -> dict[str, bool | list[str]]: + def _get_default_options() -> JwtOptions: return { "verify_signature": True, "verify_exp": True, @@ -103,7 +105,7 @@ def decode_complete( jwt: str | bytes, key: AllowedPublicKeys | PyJWK | str | bytes = "", algorithms: list[str] | None = None, - options: dict[str, Any] | None = None, + options: JwtOptions | None = None, # deprecated arg, remove in pyjwt3 verify: bool | None = None, # could be used as passthrough to api_jws, consider removal in pyjwt3 @@ -115,7 +117,7 @@ def decode_complete( leeway: float | timedelta = 0, # kwargs **kwargs: Any, - ) -> dict[str, Any]: + ) -> Any: if kwargs: warnings.warn( "passing additional kwargs to decode_complete() is deprecated " @@ -123,8 +125,14 @@ def decode_complete( f"Unsupported kwargs: {tuple(kwargs.keys())}", RemovedInPyjwt3Warning, ) - options = dict(options or {}) # shallow-copy or initialize an empty dict - options.setdefault("verify_signature", True) + if options is None: + options = self.options + if options["verify_signature"] is False: + options.setdefault("verify_exp", False) + options.setdefault("verify_nbf", False) + options.setdefault("verify_iat", False) + options.setdefault("verify_aud", False) + options.setdefault("verify_iss", False) # If the user has set the legacy `verify` argument, and it doesn't match # what the relevant `options` entry for the argument is, inform the user @@ -137,13 +145,6 @@ def decode_complete( category=DeprecationWarning, ) - if not options["verify_signature"]: - options.setdefault("verify_exp", False) - options.setdefault("verify_nbf", False) - options.setdefault("verify_iat", False) - options.setdefault("verify_aud", False) - options.setdefault("verify_iss", False) - if options["verify_signature"] and not algorithms: raise DecodeError( 'It is required that you pass in a value for the "algorithms" argument when calling decode().' @@ -188,7 +189,7 @@ def decode( jwt: str | bytes, key: AllowedPublicKeys | PyJWK | str | bytes = "", algorithms: list[str] | None = None, - options: dict[str, Any] | None = None, + options: JwtOptions | None = None, # deprecated arg, remove in pyjwt3 verify: bool | None = None, # could be used as passthrough to api_jws, consider removal in pyjwt3 diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index f19b10ac..10ce7271 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -9,6 +9,7 @@ from .api_jwt import decode_complete as decode_token from .exceptions import PyJWKClientConnectionError, PyJWKClientError from .jwk_set_cache import JWKSetCache +from .types import JwtOptions class PyJWKClient: diff --git a/jwt/types.py b/jwt/types.py index 7d993520..58e78922 100644 --- a/jwt/types.py +++ b/jwt/types.py @@ -1,5 +1,21 @@ -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, List, NotRequired, TypedDict JWKDict = Dict[str, Any] HashlibHash = Callable[..., Any] + + +# TODO: Make fields mandatory in PyJWT3 +# See: https://peps.python.org/pep-0589/#inheritance +class JwtOptionsEncode(TypedDict): + verify_signature: NotRequired[bool] + verify_exp: NotRequired[bool] + verify_nbf: NotRequired[bool] + verify_iat: NotRequired[bool] + verify_aud: NotRequired[bool] + verify_iss: NotRequired[bool] + require: NotRequired[List[str]] + + +class JwsOptions(TypedDict): + verify_signature: NotRequired[bool] diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index c764f09f..35e9b2ef 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -83,10 +83,6 @@ def test_non_object_options_dont_persist(self, jws, payload): assert jws.options["verify_signature"] - def test_options_must_be_dict(self): - pytest.raises(TypeError, PyJWS, options=object()) - pytest.raises((TypeError, ValueError), PyJWS, options=("something")) - def test_encode_decode(self, jws, payload): secret = "secret" jws_message = jws.encode(payload, secret, algorithm="HS256")