From 2b8405bf7adaafa5af5f1f81b887044f3096efa5 Mon Sep 17 00:00:00 2001 From: Dima Kryukov Date: Sun, 17 Sep 2023 22:20:57 +0300 Subject: [PATCH] add more contrib modules --- cashews/contrib/__init__.py | 4 + cashews/contrib/fastapi.py | 136 +++++++++++++++++++++++++++ cashews/contrib/prometheus.py | 28 ++++++ cashews/decorators/cache/defaults.py | 10 +- cashews/ttl.py | 6 +- cashews/wrapper/decorators.py | 8 +- cashews/wrapper/wrapper.py | 11 ++- examples/fastapi_app.py | 29 ++---- tests/conftest.py | 16 +++- tests/test_cache.py | 18 ++-- tests/test_wrapper.py | 11 +++ 11 files changed, 232 insertions(+), 45 deletions(-) create mode 100644 cashews/contrib/fastapi.py create mode 100644 cashews/contrib/prometheus.py diff --git a/cashews/contrib/__init__.py b/cashews/contrib/__init__.py index b958f44..a3a0192 100644 --- a/cashews/contrib/__init__.py +++ b/cashews/contrib/__init__.py @@ -1,3 +1,7 @@ +""" +Here modules with auto setup +""" + try: from . import _starlette except ImportError: diff --git a/cashews/contrib/fastapi.py b/cashews/contrib/fastapi.py new file mode 100644 index 0000000..6f7df68 --- /dev/null +++ b/cashews/contrib/fastapi.py @@ -0,0 +1,136 @@ +import contextlib +from contextlib import nullcontext +from contextvars import ContextVar +from hashlib import blake2s +from typing import List, Optional + +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import ASGIApp + +from cashews import Cache, Command, cache, invalidate_further + +_CACHE_MAX_AGE = ContextVar("cache_control_max_age") + +_CACHE_CONTROL_HEADER = "Cache-Control" +_AGE_HEADER = "Age" +_ETAG_HEADER = "ETag" +_IF_NOT_MATCH_HEADER = "If-None-Match" +_CLEAR_CACHE_HEADER = "Clear-Site-Data" + +_NO_CACHE = "no-cache" # disable GET +_NO_STORE = "no-store" # disable GET AND SET +_MAX_AGE = "max-age=" +_ONLY_IF_CACHED = "only-if-cached" +_PRIVATE = "private" +_PUBLIC = "public" + +_CLEAR_CACHE_HEADER_VALUE = "cache" + + +def cache_control_condition(): + pass + + +def cache_control_ttl(): + pass + + +class CacheRequestControlMiddleware(BaseHTTPMiddleware): + def __init__(self, app: ASGIApp, cache_instance: Cache = cache, methods: List[str] = ("get",), private=True): + self._private = private + self._cache = cache_instance + self._methods = methods + super().__init__(app) + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + context = nullcontext() + cache_control_value = request.headers.get(_CACHE_CONTROL_HEADER) + if request.method.lower() not in self._methods: + return await call_next(request) + to_disable = self._to_disable(cache_control_value) + if to_disable: + context = self._cache.disabling(*to_disable) + with context, self.max_age(cache_control_value), self._cache.detect as detector: + response = await call_next(request) + calls = detector.calls_list + if calls: + key, _ = calls[0] + expire = await self._cache.get_expire(key) + if expire: + response.headers[ + _CACHE_CONTROL_HEADER + ] = f"{_PRIVATE if self._private else _PUBLIC}, {_MAX_AGE}{expire}" + response.headers[_AGE_HEADER] = f"{expire}" + else: + response.headers[_CACHE_CONTROL_HEADER] = _NO_CACHE + return response + + @contextlib.contextmanager + def max_age(self, cache_control_value: Optional[str]): + if not cache_control_value: + yield + return + _max_age = self._get_max_age(cache_control_value) + reset_token = None + if _max_age: + reset_token = _CACHE_MAX_AGE.set(_max_age) + try: + yield + finally: + if reset_token: + _CACHE_MAX_AGE.reset(reset_token) + + @staticmethod + def _to_disable(cache_control_value: Optional[str]) -> tuple[Command]: + if cache_control_value == _NO_CACHE: + return (Command.GET,) + if cache_control_value == _NO_STORE: + return Command.GET, Command.SET + return tuple() + + @staticmethod + def _get_max_age(cache_control_value: str) -> int: + if not cache_control_value.startswith(_MAX_AGE): + return 0 + try: + return int(cache_control_value.replace(_MAX_AGE, "")) + except (ValueError, TypeError): + return 0 + + +class CacheEtagMiddleware(BaseHTTPMiddleware): + def __init__(self, app: ASGIApp, cache_instance: Cache = cache): + self._cache = cache_instance + super().__init__(app) + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): + etag = request.headers.get(_IF_NOT_MATCH_HEADER) + if etag and await self._cache.exists(etag): + return Response(status_code=304) + + with self._cache.detect as detector: + response = await call_next(request) + calls = detector.calls_list + if not calls: + return response + + key, _ = calls[0] + response.headers[_ETAG_HEADER] = await self._get_etag(key) + return response + + async def _get_etag(self, key: str) -> str: + data = await self._cache.get_raw(key) + expire = await self._cache.get_expire(key) + etag = blake2s(data).hexdigest() + await self._cache.set(etag, True, expire=expire) + return etag + + +class CacheDeleteMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): + if request.headers.get(_CLEAR_CACHE_HEADER) == _CLEAR_CACHE_HEADER_VALUE: + with invalidate_further(): + return await call_next(request) + return await call_next(request) diff --git a/cashews/contrib/prometheus.py b/cashews/contrib/prometheus.py new file mode 100644 index 0000000..6970537 --- /dev/null +++ b/cashews/contrib/prometheus.py @@ -0,0 +1,28 @@ +from typing import Optional + +from prometheus_client import Histogram + +from cashews import cache +from cashews._typing import Middleware +from cashews.backends.interface import Backend +from cashews.commands import Command + + +def create_metrics_middleware(latency_metric: Optional[Histogram] = None, with_tag: bool = False) -> Middleware: + _DEFAULT_METRIC = Histogram( + "cashews_operations_latency_seconds", + "Latency of different operations with a cache", + labels=["operation", "backend_class"] if not with_tag else ["operation", "backend_class", "tag"], + ) + _latency_metric = latency_metric or _DEFAULT_METRIC + + async def metrics_middleware(call, cmd: Command, backend: Backend, *args, **kwargs): + with _latency_metric as metric: + metric.labels(operation=cmd.value, backend_class=backend.__name__) + if with_tag and "key" in kwargs: + tags = cache.get_key_tags(kwargs["key"]) + if tags: + metric.labels(tag=tags[0]) + return await call(*args, **kwargs) + + return metrics_middleware diff --git a/cashews/decorators/cache/defaults.py b/cashews/decorators/cache/defaults.py index d4d2169..a93f329 100644 --- a/cashews/decorators/cache/defaults.py +++ b/cashews/decorators/cache/defaults.py @@ -11,21 +11,23 @@ class CacheDetect: __slots__ = ("_value", "_unset_token", "_previous_level") def __init__(self, previous_level=0, unset_token=None): - self._value = {} + self._value = [] self._unset_token = unset_token self._previous_level = previous_level def _set(self, key: Key, **kwargs: Any) -> None: - self._value.setdefault(key, []).append(kwargs) + self._value.append((key, [kwargs])) @property def calls(self): return dict(self._value) - keys = calls # backward compatibility + @property + def calls_list(self): + return self._value.copy() def clear(self): - self._value = {} + self._value = [] _level = ContextVar("level", default=0) diff --git a/cashews/ttl.py b/cashews/ttl.py index 8807647..d2a8928 100644 --- a/cashews/ttl.py +++ b/cashews/ttl.py @@ -9,13 +9,13 @@ def ttl_to_seconds( ) -> Union[int, None, float]: if ttl is None: return None - _type = type(ttl) + _type = type(ttl) # isinstance is slow + if _type == str: + return _ttl_from_str(ttl) if _type == int: return ttl if _type == timedelta: return ttl.total_seconds() - if _type == str: - return _ttl_from_str(ttl) if callable(ttl) and with_callable: try: ttl = ttl(*args, result=result, **kwargs) diff --git a/cashews/wrapper/decorators.py b/cashews/wrapper/decorators.py index 0a9a767..4a85cc1 100644 --- a/cashews/wrapper/decorators.py +++ b/cashews/wrapper/decorators.py @@ -10,7 +10,7 @@ from .wrapper import Wrapper -def _thunder_protection(func: Callable) -> Callable: +def _skip_thunder_protection(func: Callable) -> Callable: return func @@ -36,12 +36,13 @@ def _decorator(func: AsyncCallable_T) -> AsyncCallable_T: decor_kwargs["condition"] = condition decorator = decorator_fabric(self, **decor_kwargs)(func) - thunder_protection = _thunder_protection + thunder_protection = _skip_thunder_protection if protected: thunder_protection = decorators.thunder_protection(key=decor_kwargs.get("key")) @wraps(func) async def _call(*args, **kwargs): + self._check_setup() if self.is_full_disable: return await func(*args, **kwargs) if lock: @@ -69,12 +70,13 @@ def _decorator(func: AsyncCallable_T) -> AsyncCallable_T: @wraps(func) async def _call(*args, **kwargs): + self._check_setup() if self.is_full_disable: return await func(*args, **kwargs) with decorators.context_cache_detect as detect: def new_condition(result, _args, _kwargs, key): - if detect.keys: + if detect.calls: return False return _condition(result, _args, _kwargs, key=key) if _condition else result is not None diff --git a/cashews/wrapper/wrapper.py b/cashews/wrapper/wrapper.py index 6e361ad..ed238bf 100644 --- a/cashews/wrapper/wrapper.py +++ b/cashews/wrapper/wrapper.py @@ -30,8 +30,8 @@ def _get_backend_and_config(self, key: Key) -> Tuple[Backend, Tuple[Middleware, for prefix in sorted(self._backends.keys(), reverse=True): if key.startswith(prefix): return self._backends[prefix] - if self.default_prefix not in self._backends: - raise NotConfiguredError("run `cache.setup(...)` before using cache") + self._check_setup() + raise NotConfiguredError("Backend for given key not configured") def _get_backend(self, key: Key) -> Backend: backend, _ = self._get_backend_and_config(key) @@ -62,6 +62,13 @@ def setup(self, settings_url: str, middlewares: Tuple = (), prefix: str = defaul self._add_backend(backend, middlewares, prefix) return backend + def is_setup(self) -> bool: + return bool(self._backends) + + def _check_setup(self): + if not self._backends: + raise NotConfiguredError("run `cache.setup(...)` before using cache") + def _add_backend(self, backend: Backend, middlewares=(), prefix: str = default_prefix): self._backends[prefix] = ( backend, diff --git a/examples/fastapi_app.py b/examples/fastapi_app.py index 68d6b10..5e36a2c 100644 --- a/examples/fastapi_app.py +++ b/examples/fastapi_app.py @@ -7,16 +7,20 @@ from fastapi import FastAPI, Header, Query from fastapi.responses import StreamingResponse -from cashews import Command, cache +from cashews import cache +from cashews.contrib.fastapi import CacheDeleteMiddleware, CacheEtagMiddleware, CacheRequestControlMiddleware app = FastAPI() -cache.setup(os.environ.get("CACHE_URI", "redis://?client_side=True")) +app.add_middleware(CacheDeleteMiddleware) +app.add_middleware(CacheEtagMiddleware) +app.add_middleware(CacheRequestControlMiddleware) +cache.setup(os.environ.get("CACHE_URI", "redis://")) KB = 1024 @app.get("/") @cache.failover(ttl="1h") -@cache.slice_rate_limit(10, "1m") +@cache.slice_rate_limit(10, "3m") @cache(ttl="10m", key="simple:{user_agent}", time_condition="1s") async def simple(user_agent: str = Header("No")): await asyncio.sleep(1.1) @@ -56,25 +60,6 @@ async def add_process_time_header(request, call_next): return response -@app.middleware("http") -async def add_from_cache_headers(request, call_next): - with cache.detect as detector: - response = await call_next(request) - if request.method.lower() != "get": - return response - if detector.calls: - response.headers["X-From-Cache-keys"] = ";".join(detector.calls.keys()) - return response - - -@app.middleware("http") -async def disable_middleware(request, call_next): - if request.headers.get("X-No-Cache"): - with cache.disabling(Command.GET): - return await call_next(request) - return await call_next(request) - - if __name__ == "__main__": import uvicorn diff --git a/tests/conftest.py b/tests/conftest.py index 219d56b..20f80d7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -74,13 +74,25 @@ async def _backend(request, redis_dsn, backend_factory): from cashews.backends.redis import Redis backend = backend_factory( - Redis, redis_dsn, hash_key=None, max_connections=20, suppress=False, socket_timeout=10 + Redis, + redis_dsn, + hash_key=None, + max_connections=20, + suppress=False, + socket_timeout=10, + wait_for_connection_timeout=1, ) elif request.param == "redis_hash": from cashews.backends.redis import Redis backend = backend_factory( - Redis, redis_dsn, hash_key=uuid4().hex, max_connections=20, suppress=False, socket_timeout=10 + Redis, + redis_dsn, + hash_key=uuid4().hex, + max_connections=20, + suppress=False, + socket_timeout=10, + wait_for_connection_timeout=1, ) elif request.param == "redis_cs": from cashews.backends.redis.client_side import BcastClientSide diff --git a/tests/test_cache.py b/tests/test_cache.py index e36f920..e439d08 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -344,17 +344,17 @@ async def func(resp=b"ok"): with decorators.context_cache_detect as detector: assert await func() == b"ok" - assert detector.keys == {} + assert detector.calls == {} await asyncio.sleep(0) assert await func(b"notok") == b"ok" - assert list(detector.keys.keys()) == [ + assert list(detector.calls.keys()) == [ "key", ] await asyncio.sleep(EXPIRE * 1.1) assert await func(b"notok") == b"notok" - assert len(detector.keys) == 1 + assert len(detector.calls) == 1 assert decorators.context_cache_detect._levels == {} @@ -373,14 +373,14 @@ async def func(): with decorators.context_cache_detect as detector: await func() - assert detector.keys == {} + assert detector.calls == {} await asyncio.sleep(0) await func() - assert len(detector.keys) == 2 - assert "key1" in detector.keys - assert "key2" in detector.keys + assert len(detector.calls) == 2 + assert "key1" in detector.calls + assert "key2" in detector.calls assert decorators.context_cache_detect._levels == {} @@ -396,7 +396,7 @@ async def func2(): async def func(*funcs): with decorators.context_cache_detect as detector: await asyncio.gather(*funcs) - return len(detector.keys) + return len(detector.calls) await cache.set("key1", "test") await cache.set("key2", "test") @@ -409,6 +409,6 @@ async def func(*funcs): assert await asyncio.create_task(func(func2())) == 1 assert await asyncio.create_task(func(func2(), func2())) == 1 - assert len(detector.keys) == 2 + assert len(detector.calls) == 2 assert decorators.context_cache_detect._levels == {} diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py index 958dca5..3d0a45d 100644 --- a/tests/test_wrapper.py +++ b/tests/test_wrapper.py @@ -279,3 +279,14 @@ async def test_no_setup(): with pytest.raises(NotConfiguredError): await cache.get("test") + + +async def test_no_setup_decor(): + cache = Cache() + + @cache(ttl=0.1, key="key") + async def func(): + return None + + with pytest.raises(NotConfiguredError): + await func()