From 999ce7af1f0502a224d33338e6ded8bc1e81cb2f Mon Sep 17 00:00:00 2001 From: Daverball Date: Sun, 14 Apr 2024 09:38:57 +0200 Subject: [PATCH] Improve generic typing further --- src/itsdangerous/_json.py | 6 +- src/itsdangerous/serializer.py | 104 +++++++++++++++++++++++++++------ src/itsdangerous/timed.py | 3 +- 3 files changed, 92 insertions(+), 21 deletions(-) diff --git a/src/itsdangerous/_json.py b/src/itsdangerous/_json.py index ba4a8e4..fc23fea 100644 --- a/src/itsdangerous/_json.py +++ b/src/itsdangerous/_json.py @@ -8,11 +8,11 @@ class _CompactJSON: """Wrapper around json module that strips whitespace.""" @staticmethod - def loads(s: str | bytes) -> t.Any: - return _json.loads(s) + def loads(payload: str | bytes) -> t.Any: + return _json.loads(payload) @staticmethod - def dumps(obj: t.Any, *args: t.Any, **kwargs: t.Any) -> str: + def dumps(obj: t.Any, **kwargs: t.Any) -> str: kwargs.setdefault("ensure_ascii", False) kwargs.setdefault("separators", (",", ":")) return _json.dumps(obj, **kwargs) diff --git a/src/itsdangerous/serializer.py b/src/itsdangerous/serializer.py index 99bef69..5ddf387 100644 --- a/src/itsdangerous/serializer.py +++ b/src/itsdangerous/serializer.py @@ -10,18 +10,36 @@ from .signer import _make_keys_list from .signer import Signer - -class _PDataSerializer(t.Protocol[t.AnyStr]): - def loads(self, s: t.AnyStr) -> t.Any: ... - def dumps(self, obj: t.Any, *args: t.Any, **kwargs: t.Any) -> t.AnyStr: ... - - -def is_text_serializer(serializer: _PDataSerializer[t.Any]) -> bool: +if t.TYPE_CHECKING: + import typing_extensions as te + + # This should be either be str or bytes. To avoid having to specify the + # bound type, it falls back to a union if structural matching fails. + _TSerialized = te.TypeVar( + "_TSerialized", bound=t.Union[str, bytes], default=t.Union[str, bytes] + ) +else: + # Still available at runtime on Python < 3.13, but without the default. + _TSerialized = t.TypeVar("_TSerialized", bound=t.Union[str, bytes]) + + +class _PDataSerializer(t.Protocol[_TSerialized]): + def loads(self, payload: _TSerialized, /) -> t.Any: ... + # A signature with additional arguments is not handled correctly by type + # checkers right now, so an overload is used below for serializers that + # don't match this strict protocol. + def dumps(self, obj: t.Any, /) -> _TSerialized: ... + + +# Use TypeIs once it's available in typing_extensions or 3.13. +def is_text_serializer( + serializer: _PDataSerializer[t.Any], +) -> te.TypeGuard[_PDataSerializer[str]]: """Checks whether a serializer generates text or binary.""" return isinstance(serializer.dumps({}), str) -class Serializer(t.Generic[t.AnyStr]): +class Serializer(t.Generic[_TSerialized]): """A serializer wraps a :class:`~itsdangerous.signer.Signer` to enable serializing and securely signing data other than bytes. It can unsign to verify that the data hasn't been changed. @@ -76,7 +94,7 @@ class Serializer(t.Generic[t.AnyStr]): #: The default serialization module to use to serialize data to a #: string internally. The default is :mod:`json`, but can be changed #: to any object that provides ``dumps`` and ``loads`` methods. - default_serializer: _PDataSerializer[t.Any] = json # pyright: ignore + default_serializer: _PDataSerializer[t.Any] = json #: The default ``Signer`` class to instantiate when signing data. #: The default is :class:`itsdangerous.signer.Signer`. @@ -87,14 +105,64 @@ class Serializer(t.Generic[t.AnyStr]): dict[str, t.Any] | tuple[type[Signer], dict[str, t.Any]] | type[Signer] ] = [] - # Tell type checkers that the default type is Serializer[str] if no - # data serializer is provided. + # Serializer[str] if no data serializer is provided, or if it returns str. @t.overload def __init__( self: Serializer[str], secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes], salt: str | bytes | None = b"itsdangerous", - serializer: None = None, + serializer: None | _PDataSerializer[str] = None, + serializer_kwargs: dict[str, t.Any] | None = None, + signer: type[Signer] | None = None, + signer_kwargs: dict[str, t.Any] | None = None, + fallback_signers: list[ + dict[str, t.Any] | tuple[type[Signer], dict[str, t.Any]] | type[Signer] + ] + | None = None, + ): ... + + # Serializer[bytes] with a bytes data serializer positional argument. + @t.overload + def __init__( + self: Serializer[bytes], + secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes], + salt: str | bytes | None, + serializer: _PDataSerializer[bytes], + serializer_kwargs: dict[str, t.Any] | None = None, + signer: type[Signer] | None = None, + signer_kwargs: dict[str, t.Any] | None = None, + fallback_signers: list[ + dict[str, t.Any] | tuple[type[Signer], dict[str, t.Any]] | type[Signer] + ] + | None = None, + ): ... + + # Serializer[bytes] with a bytes data serializer keyword argument. + @t.overload + def __init__( + self: Serializer[bytes], + secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes], + salt: str | bytes | None = b"itsdangerous", + *, + serializer: _PDataSerializer[bytes], + serializer_kwargs: dict[str, t.Any] | None = None, + signer: type[Signer] | None = None, + signer_kwargs: dict[str, t.Any] | None = None, + fallback_signers: list[ + dict[str, t.Any] | tuple[type[Signer], dict[str, t.Any]] | type[Signer] + ] + | None = None, + ): ... + + # Fall back with a positional argument. If the strict signature of + # _PDataSerializer doesn't match, fall back to a union, requiring the user + # to specify the type. + @t.overload + def __init__( + self, + secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes], + salt: str | bytes | None, + serializer: t.Any, serializer_kwargs: dict[str, t.Any] | None = None, signer: type[Signer] | None = None, signer_kwargs: dict[str, t.Any] | None = None, @@ -104,12 +172,14 @@ def __init__( | None = None, ): ... + # Fall back with a keyword argument. @t.overload def __init__( - self: Serializer[t.AnyStr], + self, secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes], salt: str | bytes | None = b"itsdangerous", - serializer: _PDataSerializer[t.AnyStr] = ..., + *, + serializer: t.Any, serializer_kwargs: dict[str, t.Any] | None = None, signer: type[Signer] | None = None, signer_kwargs: dict[str, t.Any] | None = None, @@ -123,7 +193,7 @@ def __init__( self, secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes], salt: str | bytes | None = b"itsdangerous", - serializer: _PDataSerializer[t.AnyStr] | None = None, + serializer: t.Any | None = None, serializer_kwargs: dict[str, t.Any] | None = None, signer: type[Signer] | None = None, signer_kwargs: dict[str, t.Any] | None = None, @@ -148,7 +218,7 @@ def __init__( if serializer is None: serializer = self.default_serializer - self.serializer: _PDataSerializer[t.AnyStr] = serializer + self.serializer: _PDataSerializer[_TSerialized] = serializer self.is_text_serializer: bool = is_text_serializer(serializer) if signer is None: @@ -238,7 +308,7 @@ def iter_unsigners(self, salt: str | bytes | None = None) -> cabc.Iterator[Signe for secret_key in self.secret_keys: yield fallback(secret_key, salt=salt, **kwargs) - def dumps(self, obj: t.Any, salt: str | bytes | None = None) -> t.AnyStr: + def dumps(self, obj: t.Any, salt: str | bytes | None = None) -> _TSerialized: """Returns a signed string serialized with the internal serializer. The return value can be either a byte or unicode string depending on the format of the internal serializer. diff --git a/src/itsdangerous/timed.py b/src/itsdangerous/timed.py index 1bf2fd3..7384375 100644 --- a/src/itsdangerous/timed.py +++ b/src/itsdangerous/timed.py @@ -14,6 +14,7 @@ from .exc import BadSignature from .exc import BadTimeSignature from .exc import SignatureExpired +from .serializer import _TSerialized from .serializer import Serializer from .signer import Signer @@ -166,7 +167,7 @@ def validate(self, signed_value: str | bytes, max_age: int | None = None) -> boo return False -class TimedSerializer(Serializer[t.AnyStr]): +class TimedSerializer(Serializer[_TSerialized]): """Uses :class:`TimestampSigner` instead of the default :class:`.Signer`. """