Skip to content

Commit

Permalink
improve typing (#373)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidism authored Apr 13, 2024
2 parents 0f15cf1 + 69a3bca commit bc88e94
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 154 deletions.
8 changes: 5 additions & 3 deletions src/itsdangerous/_json.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from __future__ import annotations

import json as _json
import typing as _t
import typing as t


class _CompactJSON:
"""Wrapper around json module that strips whitespace."""

@staticmethod
def loads(payload: _t.Union[str, bytes]) -> _t.Any:
def loads(payload: str | bytes) -> t.Any:
return _json.loads(payload)

@staticmethod
def dumps(obj: _t.Any, **kwargs: _t.Any) -> str:
def dumps(obj: t.Any, **kwargs: t.Any) -> str:
kwargs.setdefault("ensure_ascii", False)
kwargs.setdefault("separators", (",", ":"))
return _json.dumps(obj, **kwargs)
14 changes: 7 additions & 7 deletions src/itsdangerous/encoding.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
from __future__ import annotations

import base64
import string
import struct
import typing as _t
import typing as t

from .exc import BadData

_t_str_bytes = _t.Union[str, bytes]


def want_bytes(
s: _t_str_bytes, encoding: str = "utf-8", errors: str = "strict"
s: str | bytes, encoding: str = "utf-8", errors: str = "strict"
) -> bytes:
if isinstance(s, str):
s = s.encode(encoding, errors)

return s


def base64_encode(string: _t_str_bytes) -> bytes:
def base64_encode(string: str | bytes) -> bytes:
"""Base64 encode a string of bytes or text. The resulting bytes are
safe to use in URLs.
"""
string = want_bytes(string)
return base64.urlsafe_b64encode(string).rstrip(b"=")


def base64_decode(string: _t_str_bytes) -> bytes:
def base64_decode(string: str | bytes) -> bytes:
"""Base64 decode a URL-safe string of bytes or text. The result is
bytes.
"""
Expand All @@ -43,7 +43,7 @@ def base64_decode(string: _t_str_bytes) -> bytes:

_int64_struct = struct.Struct(">Q")
_int_to_bytes = _int64_struct.pack
_bytes_to_int = _t.cast("_t.Callable[[bytes], _t.Tuple[int]]", _int64_struct.unpack)
_bytes_to_int = t.cast("t.Callable[[bytes], tuple[int]]", _int64_struct.unpack)


def int_to_bytes(num: int) -> bytes:
Expand Down
29 changes: 14 additions & 15 deletions src/itsdangerous/exc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import typing as _t
from datetime import datetime
from __future__ import annotations

_t_opt_any = _t.Optional[_t.Any]
_t_opt_exc = _t.Optional[Exception]
import typing as t
from datetime import datetime


class BadData(Exception):
Expand All @@ -23,15 +22,15 @@ def __str__(self) -> str:
class BadSignature(BadData):
"""Raised if a signature does not match."""

def __init__(self, message: str, payload: _t_opt_any = None):
def __init__(self, message: str, payload: t.Any | None = None):
super().__init__(message)

#: The payload that failed the signature test. In some
#: situations you might still want to inspect this, even if
#: you know it was tampered with.
#:
#: .. versionadded:: 0.14
self.payload: _t_opt_any = payload
self.payload: t.Any | None = payload


class BadTimeSignature(BadSignature):
Expand All @@ -42,8 +41,8 @@ class BadTimeSignature(BadSignature):
def __init__(
self,
message: str,
payload: _t_opt_any = None,
date_signed: _t.Optional[datetime] = None,
payload: t.Any | None = None,
date_signed: datetime | None = None,
):
super().__init__(message, payload)

Expand Down Expand Up @@ -75,19 +74,19 @@ class BadHeader(BadSignature):
def __init__(
self,
message: str,
payload: _t_opt_any = None,
header: _t_opt_any = None,
original_error: _t_opt_exc = None,
payload: t.Any | None = None,
header: t.Any | None = None,
original_error: Exception | None = None,
):
super().__init__(message, payload)

#: If the header is actually available but just malformed it
#: might be stored here.
self.header: _t_opt_any = header
self.header: t.Any | None = header

#: If available, the error that indicates why the payload was
#: not valid. This might be ``None``.
self.original_error: _t_opt_exc = original_error
self.original_error: Exception | None = original_error


class BadPayload(BadData):
Expand All @@ -99,9 +98,9 @@ class BadPayload(BadData):
.. versionadded:: 0.15
"""

def __init__(self, message: str, original_error: _t_opt_exc = None):
def __init__(self, message: str, original_error: Exception | None = None):
super().__init__(message)

#: If available, the error that indicates why the payload was
#: not valid. This might be ``None``.
self.original_error: _t_opt_exc = original_error
self.original_error: Exception | None = original_error
106 changes: 52 additions & 54 deletions src/itsdangerous/serializer.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,17 @@
from __future__ import annotations

import collections.abc as cabc
import json
import typing as _t
import typing as t

from .encoding import want_bytes
from .exc import BadPayload
from .exc import BadSignature
from .signer import _make_keys_list
from .signer import Signer

_t_str_bytes = _t.Union[str, bytes]
_t_opt_str_bytes = _t.Optional[_t_str_bytes]
_t_kwargs = _t.Dict[str, _t.Any]
_t_opt_kwargs = _t.Optional[_t_kwargs]
_t_signer = _t.Type[Signer]
_t_fallbacks = _t.List[_t.Union[_t_kwargs, _t.Tuple[_t_signer, _t_kwargs], _t_signer]]
_t_load_unsafe = _t.Tuple[bool, _t.Any]
_t_secret_key = _t.Union[_t.Iterable[_t_str_bytes], _t_str_bytes]


def is_text_serializer(serializer: _t.Any) -> bool:
def is_text_serializer(serializer: t.Any) -> bool:
"""Checks whether a serializer generates text or binary."""
return isinstance(serializer.dumps({}), str)

Expand Down Expand Up @@ -77,31 +71,36 @@ class Serializer:
#: The default serialization module to use to serialize data to a
#: string internally. The default is :mod:`json`, but can be changed
#: to any object that provides ``dumps`` and ``loads`` methods.
default_serializer: _t.Any = json
default_serializer: t.Any = json

#: The default ``Signer`` class to instantiate when signing data.
#: The default is :class:`itsdangerous.signer.Signer`.
default_signer: _t_signer = Signer
default_signer: type[Signer] = Signer

#: The default fallback signers to try when unsigning fails.
default_fallback_signers: _t_fallbacks = []
default_fallback_signers: list[
dict[str, t.Any] | tuple[type[Signer], dict[str, t.Any]] | type[Signer]
] = []

def __init__(
self,
secret_key: _t_secret_key,
salt: _t_opt_str_bytes = b"itsdangerous",
serializer: _t.Any = None,
serializer_kwargs: _t_opt_kwargs = None,
signer: _t.Optional[_t_signer] = None,
signer_kwargs: _t_opt_kwargs = None,
fallback_signers: _t.Optional[_t_fallbacks] = None,
secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes],
salt: str | bytes | None = b"itsdangerous",
serializer: t.Any = None,
serializer_kwargs: dict[str, t.Any] | None = None,
signer: type[Signer] | None = None,
signer_kwargs: dict[str, t.Any] | None = None,
fallback_signers: list[
dict[str, t.Any] | tuple[type[Signer], dict[str, t.Any]] | type[Signer]
]
| None = None,
):
#: The list of secret keys to try for verifying signatures, from
#: oldest to newest. The newest (last) key is used for signing.
#:
#: This allows a key rotation system to keep a list of allowed
#: keys and remove expired ones.
self.secret_keys: _t.List[bytes] = _make_keys_list(secret_key)
self.secret_keys: list[bytes] = _make_keys_list(secret_key)

if salt is not None:
salt = want_bytes(salt)
Expand All @@ -112,20 +111,22 @@ def __init__(
if serializer is None:
serializer = self.default_serializer

self.serializer: _t.Any = serializer
self.serializer: t.Any = serializer
self.is_text_serializer: bool = is_text_serializer(serializer)

if signer is None:
signer = self.default_signer

self.signer: _t_signer = signer
self.signer_kwargs: _t_kwargs = signer_kwargs or {}
self.signer: type[Signer] = signer
self.signer_kwargs: dict[str, t.Any] = signer_kwargs or {}

if fallback_signers is None:
fallback_signers = list(self.default_fallback_signers or ())
fallback_signers = list(self.default_fallback_signers)

self.fallback_signers: _t_fallbacks = fallback_signers
self.serializer_kwargs: _t_kwargs = serializer_kwargs or {}
self.fallback_signers: list[
dict[str, t.Any] | tuple[type[Signer], dict[str, t.Any]] | type[Signer]
] = fallback_signers
self.serializer_kwargs: dict[str, t.Any] = serializer_kwargs or {}

@property
def secret_key(self) -> bytes:
Expand All @@ -134,41 +135,40 @@ def secret_key(self) -> bytes:
"""
return self.secret_keys[-1]

def load_payload(
self, payload: bytes, serializer: _t.Optional[_t.Any] = None
) -> _t.Any:
def load_payload(self, payload: bytes, serializer: t.Any | None = None) -> t.Any:
"""Loads the encoded object. This function raises
:class:`.BadPayload` if the payload is not valid. The
``serializer`` parameter can be used to override the serializer
stored on the class. The encoded ``payload`` should always be
bytes.
"""
if serializer is None:
serializer = self.serializer
use_serializer = self.serializer
is_text = self.is_text_serializer
else:
use_serializer = serializer
is_text = is_text_serializer(serializer)

try:
if is_text:
return serializer.loads(payload.decode("utf-8"))
return use_serializer.loads(payload.decode("utf-8"))

return serializer.loads(payload)
return use_serializer.loads(payload)
except Exception as e:
raise BadPayload(
"Could not load the payload because an exception"
" occurred on unserializing the data.",
original_error=e,
) from e

def dump_payload(self, obj: _t.Any) -> bytes:
def dump_payload(self, obj: t.Any) -> bytes:
"""Dumps the encoded object. The return value is always bytes.
If the internal serializer returns text, the value will be
encoded as UTF-8.
"""
return want_bytes(self.serializer.dumps(obj, **self.serializer_kwargs))

def make_signer(self, salt: _t_opt_str_bytes = None) -> Signer:
def make_signer(self, salt: str | bytes | None = None) -> Signer:
"""Creates a new instance of the signer to be used. The default
implementation uses the :class:`.Signer` base class.
"""
Expand All @@ -177,7 +177,7 @@ def make_signer(self, salt: _t_opt_str_bytes = None) -> Signer:

return self.signer(self.secret_keys, salt=salt, **self.signer_kwargs)

def iter_unsigners(self, salt: _t_opt_str_bytes = None) -> _t.Iterator[Signer]:
def iter_unsigners(self, salt: str | bytes | None = None) -> cabc.Iterator[Signer]:
"""Iterates over all signers to be tried for unsigning. Starts
with the configured signer, then constructs each signer
specified in ``fallback_signers``.
Expand All @@ -199,7 +199,7 @@ def iter_unsigners(self, salt: _t_opt_str_bytes = None) -> _t.Iterator[Signer]:
for secret_key in self.secret_keys:
yield fallback(secret_key, salt=salt, **kwargs)

def dumps(self, obj: _t.Any, salt: _t_opt_str_bytes = None) -> _t_str_bytes:
def dumps(self, obj: t.Any, salt: str | bytes | None = None) -> str | bytes:
"""Returns a signed string serialized with the internal
serializer. The return value can be either a byte or unicode
string depending on the format of the internal serializer.
Expand All @@ -212,17 +212,15 @@ def dumps(self, obj: _t.Any, salt: _t_opt_str_bytes = None) -> _t_str_bytes:

return rv

def dump(
self, obj: _t.Any, f: _t.IO[_t.Any], salt: _t_opt_str_bytes = None
) -> None:
def dump(self, obj: t.Any, f: t.IO[t.Any], salt: str | bytes | None = None) -> None:
"""Like :meth:`dumps` but dumps into a file. The file handle has
to be compatible with what the internal serializer expects.
"""
f.write(self.dumps(obj, salt))

def loads(
self, s: _t_str_bytes, salt: _t_opt_str_bytes = None, **kwargs: _t.Any
) -> _t.Any:
self, s: str | bytes, salt: str | bytes | None = None, **kwargs: t.Any
) -> t.Any:
"""Reverse of :meth:`dumps`. Raises :exc:`.BadSignature` if the
signature validation fails.
"""
Expand All @@ -235,15 +233,15 @@ def loads(
except BadSignature as err:
last_exception = err

raise _t.cast(BadSignature, last_exception)
raise t.cast(BadSignature, last_exception)

def load(self, f: _t.IO[_t.Any], salt: _t_opt_str_bytes = None) -> _t.Any:
def load(self, f: t.IO[t.Any], salt: str | bytes | None = None) -> t.Any:
"""Like :meth:`loads` but loads from a file."""
return self.loads(f.read(), salt)

def loads_unsafe(
self, s: _t_str_bytes, salt: _t_opt_str_bytes = None
) -> _t_load_unsafe:
self, s: str | bytes, salt: str | bytes | None = None
) -> tuple[bool, t.Any]:
"""Like :meth:`loads` but without verifying the signature. This
is potentially very dangerous to use depending on how your
serializer works. The return value is ``(signature_valid,
Expand All @@ -261,11 +259,11 @@ def loads_unsafe(

def _loads_unsafe_impl(
self,
s: _t_str_bytes,
salt: _t_opt_str_bytes,
load_kwargs: _t_opt_kwargs = None,
load_payload_kwargs: _t_opt_kwargs = None,
) -> _t_load_unsafe:
s: str | bytes,
salt: str | bytes | None,
load_kwargs: dict[str, t.Any] | None = None,
load_payload_kwargs: dict[str, t.Any] | None = None,
) -> tuple[bool, t.Any]:
"""Low level helper function to implement :meth:`loads_unsafe`
in serializer subclasses.
"""
Expand All @@ -290,8 +288,8 @@ def _loads_unsafe_impl(
return False, None

def load_unsafe(
self, f: _t.IO[_t.Any], salt: _t_opt_str_bytes = None
) -> _t_load_unsafe:
self, f: t.IO[t.Any], salt: str | bytes | None = None
) -> tuple[bool, t.Any]:
"""Like :meth:`loads_unsafe` but loads from a file.
.. versionadded:: 0.15
Expand Down
Loading

0 comments on commit bc88e94

Please sign in to comment.