Skip to content

Commit

Permalink
Type annotate public API
Browse files Browse the repository at this point in the history
  • Loading branch information
unmade committed Sep 20, 2023
1 parent d562605 commit 31349d8
Show file tree
Hide file tree
Showing 12 changed files with 120 additions and 87 deletions.
1 change: 1 addition & 0 deletions cashews/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __call__(
KeyTemplate = str
KeyOrTemplate = Union[KeyTemplate, Key]
Value = Any
Default = TypeVar("Default")
Tag = str
Tags = Iterable[Tag]
Exceptions = Union[Type[Exception], Iterable[Type[Exception]], None]
Expand Down
34 changes: 22 additions & 12 deletions cashews/backends/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from abc import ABCMeta, abstractmethod
from contextlib import asynccontextmanager
from contextvars import ContextVar
from typing import Any, AsyncIterator, Iterable, Mapping
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Iterable, Mapping, overload

from cashews._typing import Callback, Key, Value
from cashews.commands import ALL, Command
from cashews.exceptions import CacheBackendInteractionError, LockedError

if TYPE_CHECKING:
from cashews._typing import Callback, Default, Key, Value

NOT_EXIST = -2
UNLIMITED = -1

Expand Down Expand Up @@ -40,15 +42,23 @@ async def set(
...

@abstractmethod
async def set_many(self, pairs: Mapping[Key, Value], expire: float | None = None):
async def set_many(self, pairs: Mapping[Key, Value], expire: float | None = None) -> None:
...

@abstractmethod
async def set_raw(self, key: Key, value: Value, **kwargs: Any):
async def set_raw(self, key: Key, value: Value, **kwargs: Any) -> None:
...

@overload
async def get(self, key: Key, default: Default) -> Value | Default:
...

@overload
async def get(self, key: Key, default: None = None) -> Value | None:
...

@abstractmethod
async def get(self, key: Key, default: Value | None = None) -> Value:
async def get(self, key: Key, default: Default | None = None) -> Value | Default | None:
...

@abstractmethod
Expand Down Expand Up @@ -80,11 +90,11 @@ async def delete(self, key: Key) -> bool:
...

@abstractmethod
async def delete_many(self, *keys: Key):
async def delete_many(self, *keys: Key) -> None:
...

@abstractmethod
async def delete_match(self, pattern: str):
async def delete_match(self, pattern: str) -> None:
...

@abstractmethod
Expand All @@ -110,11 +120,11 @@ async def slice_incr(
...

@abstractmethod
async def set_add(self, key: Key, *values: str, expire: float | None = None):
async def set_add(self, key: Key, *values: str, expire: float | None = None) -> None:
...

@abstractmethod
async def set_remove(self, key: Key, *values: str):
async def set_remove(self, key: Key, *values: str) -> None:
...

@abstractmethod
Expand All @@ -140,7 +150,7 @@ async def ping(self, message: bytes | None = None) -> bytes:
...

@abstractmethod
async def clear(self):
async def clear(self) -> None:
...

async def set_lock(self, key: Key, value: Value, expire: float) -> bool:
Expand All @@ -160,7 +170,7 @@ async def unlock(self, key: Key, value: Value) -> bool:
...

@asynccontextmanager
async def lock(self, key: Key, expire: float, wait=True):
async def lock(self, key: Key, expire: float, wait: bool = True) -> AsyncGenerator[None, None]:
identifier = str(uuid.uuid4())
while True:
lock = await self.set_lock(key, identifier, expire=expire)
Expand Down Expand Up @@ -212,7 +222,7 @@ def is_enable(self, *cmds: Command) -> bool:
return not self.is_disable(*cmds)

@property
def is_full_disable(self):
def is_full_disable(self) -> bool:
return self._disable == ALL

def disable(self, *cmds: Command) -> None:
Expand Down
11 changes: 6 additions & 5 deletions cashews/backends/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,21 @@
import time
from collections import OrderedDict
from contextlib import suppress
from typing import Any, AsyncIterator, Iterable, Mapping, TypeVar, overload
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Mapping, overload

from cashews._typing import Key, Value
from cashews.serialize import SerializerMixin
from cashews.utils import Bitarray, get_obj_size

from .interface import NOT_EXIST, UNLIMITED, Backend

if TYPE_CHECKING:
from cashews._typing import Default, Key, Value


__all__ = ["Memory"]

_missed = object()

Default = TypeVar("Default")


class _Memory(Backend):
"""
Expand Down Expand Up @@ -69,7 +70,7 @@ async def set(
self._set(key, value, expire)
return True

async def set_raw(self, key: Key, value: Value, **kwargs: Any):
async def set_raw(self, key: Key, value: Value, **kwargs: Any) -> None:
self.store[key] = value

async def get(self, key: Key, default: Value | None = None) -> Value:
Expand Down
22 changes: 13 additions & 9 deletions cashews/wrapper/backend_settings.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from typing import Any, Callable, Dict, Tuple, Type, Union
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Type, Union
from urllib.parse import parse_qsl, urlparse

from cashews.backends.interface import Backend
from cashews.backends.memory import Memory
from cashews.exceptions import BackendNotAvailableError

if TYPE_CHECKING:
BackendOrFabric = Union[Type[Backend], Callable[..., Backend]]

_NO_REDIS_ERROR = "Redis backend requires `redis` to be installed."
_CUSTOM_ERRORS = {
"redis": _NO_REDIS_ERROR,
"rediss": _NO_REDIS_ERROR,
"disk": "Disk backend requires `diskcache` to be installed.",
}
BackendOrFabric = Union[Type[Backend], Callable[..., Backend]]
_BACKENDS: Dict[str, Tuple[BackendOrFabric, bool]] = {}
_BACKENDS: dict[str, tuple[BackendOrFabric, bool]] = {}


def register_backend(alias: str, backend_class: BackendOrFabric, pass_uri: bool = False) -> None:
Expand All @@ -30,7 +34,7 @@ def register_backend(alias: str, backend_class: BackendOrFabric, pass_uri: bool
from cashews.backends.redis import Redis
from cashews.backends.redis.client_side import BcastClientSide

def _redis_fabric(**params: Any) -> Union[Redis, BcastClientSide]:
def _redis_fabric(**params) -> Redis | BcastClientSide:
if params.pop("client_side", None):
return BcastClientSide(**params)
return Redis(**params)
Expand All @@ -49,10 +53,10 @@ def _redis_fabric(**params: Any) -> Union[Redis, BcastClientSide]:
register_backend("disk", DiskCache)


def settings_url_parse(url: str) -> Tuple[BackendOrFabric, Dict[str, Any]]:
def settings_url_parse(url: str) -> tuple[BackendOrFabric, dict[str, Any]]:
parse_result = urlparse(url)
params: Dict[str, Any] = dict(parse_qsl(parse_result.query))
params = serialize_params(params)
params: dict[str, Any] = dict(parse_qsl(parse_result.query))
params = _serialize_params(params)

alias = parse_result.scheme
if alias == "":
Expand All @@ -67,15 +71,15 @@ def settings_url_parse(url: str) -> Tuple[BackendOrFabric, Dict[str, Any]]:
return backend_class, params


def serialize_params(params: Dict[str, str]) -> Dict[str, Union[str, int, bool, float]]:
def _serialize_params(params: dict[str, str]) -> dict[str, str | int | bool | float]:
new_params = {}
bool_keys = ("safe", "suppress", "enable", "disable", "client_side")
true_values = (
"1",
"true",
)
for key, value in params.items():
_value: Union[str, int, bool, float]
_value: str | int | bool | float
if key.lower() in bool_keys:
_value = value.lower() in true_values
elif value.isdigit():
Expand Down
55 changes: 30 additions & 25 deletions cashews/wrapper/commands.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from functools import partial
from typing import AsyncIterator, Dict, Iterable, List, Mapping, Optional, Tuple, Union
from typing import TYPE_CHECKING, AsyncIterator, Iterable, Mapping

from cashews._typing import TTL, Key, Value
from cashews.backends.interface import Backend
Expand All @@ -8,14 +10,17 @@

from .wrapper import Wrapper

if TYPE_CHECKING:
from cashews._typing import Default


class CommandWrapper(Wrapper):
async def set(
self,
key: Key,
value: Value,
expire: Union[float, TTL, None] = None,
exist: Optional[bool] = None,
expire: TTL = None,
exist: bool | None = None,
) -> bool:
return await self._with_middlewares(Command.SET, key)(
key=key,
Expand All @@ -24,10 +29,10 @@ async def set(
exist=exist,
)

async def set_raw(self, key: Key, value: Value, **kwargs):
async def set_raw(self, key: Key, value: Value, **kwargs) -> None:
return await self._with_middlewares(Command.SET_RAW, key)(key=key, value=value, **kwargs)

async def get(self, key: Key, default: Optional[Value] = None) -> Value:
async def get(self, key: Key, default: Default | None = None) -> Value | Default | None:
return await self._with_middlewares(Command.GET, key)(key=key, default=default)

async def get_raw(self, key: Key) -> Value:
Expand All @@ -48,7 +53,7 @@ async def get_match(
self,
pattern: str,
batch_size: int = 100,
) -> AsyncIterator[Tuple[Key, Value]]:
) -> AsyncIterator[tuple[Key, Value]]:
backend, middlewares = self._get_backend_and_config(pattern)

async def call(pattern, batch_size):
Expand All @@ -59,19 +64,19 @@ async def call(pattern, batch_size):
async for key, value in (await call(pattern=pattern, batch_size=batch_size)):
yield key, value

async def get_many(self, *keys: Key, default: Optional[Value] = None) -> Tuple[Optional[Value], ...]:
backends: Dict[Backend, List[str]] = {}
async def get_many(self, *keys: Key, default: Value | None = None) -> tuple[Value | None, ...]:
backends: dict[Backend, list[str]] = {}
for key in keys:
backend = self._get_backend(key)
backends.setdefault(backend, []).append(key)
result: Dict[Key, Value] = {}
result: dict[Key, Value] = {}
for _keys in backends.values():
_values = await self._with_middlewares(Command.GET_MANY, _keys[0])(*_keys, default=default)
result.update(dict(zip(_keys, _values)))
return tuple(result.get(key) for key in keys)

async def set_many(self, pairs: Mapping[Key, Value], expire: Union[float, TTL, None] = None):
backends: Dict[Backend, List[Key]] = {}
async def set_many(self, pairs: Mapping[Key, Value], expire: TTL = None):
backends: dict[Backend, list[Key]] = {}
for key in pairs:
backend = self._get_backend(key)
backends.setdefault(backend, []).append(key)
Expand All @@ -82,39 +87,39 @@ async def set_many(self, pairs: Mapping[Key, Value], expire: Union[float, TTL, N
expire=ttl_to_seconds(expire),
)

async def get_bits(self, key: Key, *indexes: int, size: int = 1) -> Tuple[int, ...]:
async def get_bits(self, key: Key, *indexes: int, size: int = 1) -> tuple[int, ...]:
return await self._with_middlewares(Command.GET_BITS, key)(key, *indexes, size=size)

async def incr_bits(self, key: Key, *indexes: int, size: int = 1, by: int = 1) -> Tuple[int, ...]:
async def incr_bits(self, key: Key, *indexes: int, size: int = 1, by: int = 1) -> tuple[int, ...]:
return await self._with_middlewares(Command.INCR_BITS, key)(key, *indexes, size=size, by=by)

async def slice_incr(
self,
key: Key,
start: Union[int, float],
end: Union[int, float],
start: int | float,
end: int | float,
maxvalue: int,
expire: Union[float, TTL, None] = None,
expire: TTL = None,
) -> int:
return await self._with_middlewares(Command.SLICE_INCR, key)(
key=key, start=start, end=end, maxvalue=maxvalue, expire=ttl_to_seconds(expire)
)

async def incr(self, key: Key, value: int = 1, expire: Optional[float] = None) -> int:
async def incr(self, key: Key, value: int = 1, expire: float | None = None) -> int:
return await self._with_middlewares(Command.INCR, key)(key=key, value=value, expire=expire)

async def delete(self, key: Key) -> bool:
return await self._with_middlewares(Command.DELETE, key)(key=key)

async def delete_many(self, *keys: Key) -> None:
backends: Dict[Backend, List[Key]] = {}
backends: dict[Backend, list[Key]] = {}
for key in keys:
backend = self._get_backend(key)
backends.setdefault(backend, []).append(key)
for _keys in backends.values():
await self._with_middlewares(Command.DELETE_MANY, _keys[0])(*_keys)

async def delete_match(self, pattern: str):
async def delete_match(self, pattern: str) -> None:
return await self._with_middlewares(Command.DELETE_MATCH, pattern)(pattern=pattern)

async def expire(self, key: Key, timeout: TTL):
Expand All @@ -135,7 +140,7 @@ async def unlock(self, key: Key, value: Value) -> bool:
async def get_size(self, key: Key) -> int:
return await self._with_middlewares(Command.GET_SIZE, key)(key=key)

async def ping(self, message: Optional[bytes] = None) -> bytes:
async def ping(self, message: bytes | None = None) -> bytes:
message = b"PING" if message is None else message
return await self._with_middlewares(Command.PING, message.decode())(message=message)

Expand All @@ -148,22 +153,22 @@ async def get_keys_count(self) -> int:
result += count
return result

async def clear(self):
async def clear(self) -> None:
for backend, _ in self._backends.values():
await self._with_middlewares_for_backend(Command.CLEAR, backend, self._default_middlewares)()

async def is_locked(
self,
key: Key,
wait: Union[float, None, TTL] = None,
step: Union[int, float] = 0.1,
wait: TTL = None,
step: int | float = 0.1,
) -> bool:
return await self._with_middlewares(Command.IS_LOCKED, key)(key=key, wait=ttl_to_seconds(wait), step=step)

async def set_add(self, key: Key, *values: str, expire: TTL = None):
async def set_add(self, key: Key, *values: str, expire: TTL = None) -> None:
return await self._with_middlewares(Command.SET_ADD, key)(key, *values, expire=ttl_to_seconds(expire))

async def set_remove(self, key: Key, *values: str):
async def set_remove(self, key: Key, *values: str) -> None:
return await self._with_middlewares(Command.SET_REMOVE, key)(key, *values)

async def set_pop(self, key: Key, count: int = 100) -> Iterable[str]:
Expand Down
2 changes: 1 addition & 1 deletion cashews/wrapper/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from cashews.cache_condition import get_cache_condition
from cashews.ttl import ttl_to_seconds

from .._typing import DecoratedFunc, Decorator
from .time_condition import create_time_condition
from .wrapper import Wrapper

if TYPE_CHECKING:
from cashews._typing import DecoratedFunc, Decorator
from cashews.decorators.bloom import IntOrPair


Expand Down
Loading

0 comments on commit 31349d8

Please sign in to comment.