Skip to content

Commit

Permalink
add more contrib modules
Browse files Browse the repository at this point in the history
  • Loading branch information
Dima Kryukov committed Sep 17, 2023
1 parent 7e72221 commit 2b8405b
Show file tree
Hide file tree
Showing 11 changed files with 232 additions and 45 deletions.
4 changes: 4 additions & 0 deletions cashews/contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Here modules with auto setup
"""

try:
from . import _starlette
except ImportError:
Expand Down
136 changes: 136 additions & 0 deletions cashews/contrib/fastapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import contextlib
from contextlib import nullcontext
from contextvars import ContextVar
from hashlib import blake2s
from typing import List, Optional

from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp

from cashews import Cache, Command, cache, invalidate_further

_CACHE_MAX_AGE = ContextVar("cache_control_max_age")

_CACHE_CONTROL_HEADER = "Cache-Control"
_AGE_HEADER = "Age"
_ETAG_HEADER = "ETag"
_IF_NOT_MATCH_HEADER = "If-None-Match"
_CLEAR_CACHE_HEADER = "Clear-Site-Data"

_NO_CACHE = "no-cache" # disable GET
_NO_STORE = "no-store" # disable GET AND SET
_MAX_AGE = "max-age="
_ONLY_IF_CACHED = "only-if-cached"
_PRIVATE = "private"
_PUBLIC = "public"

_CLEAR_CACHE_HEADER_VALUE = "cache"


def cache_control_condition():
pass


def cache_control_ttl():
pass


class CacheRequestControlMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp, cache_instance: Cache = cache, methods: List[str] = ("get",), private=True):
self._private = private
self._cache = cache_instance
self._methods = methods
super().__init__(app)

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
context = nullcontext()
cache_control_value = request.headers.get(_CACHE_CONTROL_HEADER)
if request.method.lower() not in self._methods:
return await call_next(request)
to_disable = self._to_disable(cache_control_value)
if to_disable:
context = self._cache.disabling(*to_disable)
with context, self.max_age(cache_control_value), self._cache.detect as detector:
response = await call_next(request)
calls = detector.calls_list
if calls:
key, _ = calls[0]
expire = await self._cache.get_expire(key)
if expire:
response.headers[
_CACHE_CONTROL_HEADER
] = f"{_PRIVATE if self._private else _PUBLIC}, {_MAX_AGE}{expire}"
response.headers[_AGE_HEADER] = f"{expire}"
else:
response.headers[_CACHE_CONTROL_HEADER] = _NO_CACHE
return response

@contextlib.contextmanager
def max_age(self, cache_control_value: Optional[str]):
if not cache_control_value:
yield
return
_max_age = self._get_max_age(cache_control_value)
reset_token = None
if _max_age:
reset_token = _CACHE_MAX_AGE.set(_max_age)
try:
yield
finally:
if reset_token:
_CACHE_MAX_AGE.reset(reset_token)

@staticmethod
def _to_disable(cache_control_value: Optional[str]) -> tuple[Command]:
if cache_control_value == _NO_CACHE:
return (Command.GET,)
if cache_control_value == _NO_STORE:
return Command.GET, Command.SET
return tuple()

@staticmethod
def _get_max_age(cache_control_value: str) -> int:
if not cache_control_value.startswith(_MAX_AGE):
return 0
try:
return int(cache_control_value.replace(_MAX_AGE, ""))
except (ValueError, TypeError):
return 0


class CacheEtagMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp, cache_instance: Cache = cache):
self._cache = cache_instance
super().__init__(app)

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
etag = request.headers.get(_IF_NOT_MATCH_HEADER)
if etag and await self._cache.exists(etag):
return Response(status_code=304)

with self._cache.detect as detector:
response = await call_next(request)
calls = detector.calls_list
if not calls:
return response

key, _ = calls[0]
response.headers[_ETAG_HEADER] = await self._get_etag(key)
return response

async def _get_etag(self, key: str) -> str:
data = await self._cache.get_raw(key)
expire = await self._cache.get_expire(key)
etag = blake2s(data).hexdigest()
await self._cache.set(etag, True, expire=expire)
return etag


class CacheDeleteMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
if request.headers.get(_CLEAR_CACHE_HEADER) == _CLEAR_CACHE_HEADER_VALUE:
with invalidate_further():
return await call_next(request)
return await call_next(request)
28 changes: 28 additions & 0 deletions cashews/contrib/prometheus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Optional

from prometheus_client import Histogram

from cashews import cache
from cashews._typing import Middleware
from cashews.backends.interface import Backend
from cashews.commands import Command


def create_metrics_middleware(latency_metric: Optional[Histogram] = None, with_tag: bool = False) -> Middleware:
_DEFAULT_METRIC = Histogram(
"cashews_operations_latency_seconds",
"Latency of different operations with a cache",
labels=["operation", "backend_class"] if not with_tag else ["operation", "backend_class", "tag"],
)
_latency_metric = latency_metric or _DEFAULT_METRIC

async def metrics_middleware(call, cmd: Command, backend: Backend, *args, **kwargs):
with _latency_metric as metric:
metric.labels(operation=cmd.value, backend_class=backend.__name__)
if with_tag and "key" in kwargs:
tags = cache.get_key_tags(kwargs["key"])
if tags:
metric.labels(tag=tags[0])
return await call(*args, **kwargs)

return metrics_middleware
10 changes: 6 additions & 4 deletions cashews/decorators/cache/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,23 @@ class CacheDetect:
__slots__ = ("_value", "_unset_token", "_previous_level")

def __init__(self, previous_level=0, unset_token=None):
self._value = {}
self._value = []
self._unset_token = unset_token
self._previous_level = previous_level

def _set(self, key: Key, **kwargs: Any) -> None:
self._value.setdefault(key, []).append(kwargs)
self._value.append((key, [kwargs]))

@property
def calls(self):
return dict(self._value)

keys = calls # backward compatibility
@property
def calls_list(self):
return self._value.copy()

def clear(self):
self._value = {}
self._value = []


_level = ContextVar("level", default=0)
Expand Down
6 changes: 3 additions & 3 deletions cashews/ttl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ def ttl_to_seconds(
) -> Union[int, None, float]:
if ttl is None:
return None
_type = type(ttl)
_type = type(ttl) # isinstance is slow
if _type == str:
return _ttl_from_str(ttl)
if _type == int:
return ttl
if _type == timedelta:
return ttl.total_seconds()
if _type == str:
return _ttl_from_str(ttl)
if callable(ttl) and with_callable:
try:
ttl = ttl(*args, result=result, **kwargs)
Expand Down
8 changes: 5 additions & 3 deletions cashews/wrapper/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .wrapper import Wrapper


def _thunder_protection(func: Callable) -> Callable:
def _skip_thunder_protection(func: Callable) -> Callable:
return func


Expand All @@ -36,12 +36,13 @@ def _decorator(func: AsyncCallable_T) -> AsyncCallable_T:
decor_kwargs["condition"] = condition

decorator = decorator_fabric(self, **decor_kwargs)(func)
thunder_protection = _thunder_protection
thunder_protection = _skip_thunder_protection
if protected:
thunder_protection = decorators.thunder_protection(key=decor_kwargs.get("key"))

@wraps(func)
async def _call(*args, **kwargs):
self._check_setup()
if self.is_full_disable:
return await func(*args, **kwargs)
if lock:
Expand Down Expand Up @@ -69,12 +70,13 @@ def _decorator(func: AsyncCallable_T) -> AsyncCallable_T:

@wraps(func)
async def _call(*args, **kwargs):
self._check_setup()
if self.is_full_disable:
return await func(*args, **kwargs)
with decorators.context_cache_detect as detect:

def new_condition(result, _args, _kwargs, key):
if detect.keys:
if detect.calls:
return False
return _condition(result, _args, _kwargs, key=key) if _condition else result is not None

Expand Down
11 changes: 9 additions & 2 deletions cashews/wrapper/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def _get_backend_and_config(self, key: Key) -> Tuple[Backend, Tuple[Middleware,
for prefix in sorted(self._backends.keys(), reverse=True):
if key.startswith(prefix):
return self._backends[prefix]
if self.default_prefix not in self._backends:
raise NotConfiguredError("run `cache.setup(...)` before using cache")
self._check_setup()
raise NotConfiguredError("Backend for given key not configured")

def _get_backend(self, key: Key) -> Backend:
backend, _ = self._get_backend_and_config(key)
Expand Down Expand Up @@ -62,6 +62,13 @@ def setup(self, settings_url: str, middlewares: Tuple = (), prefix: str = defaul
self._add_backend(backend, middlewares, prefix)
return backend

def is_setup(self) -> bool:
return bool(self._backends)

def _check_setup(self):
if not self._backends:
raise NotConfiguredError("run `cache.setup(...)` before using cache")

def _add_backend(self, backend: Backend, middlewares=(), prefix: str = default_prefix):
self._backends[prefix] = (
backend,
Expand Down
29 changes: 7 additions & 22 deletions examples/fastapi_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,20 @@
from fastapi import FastAPI, Header, Query
from fastapi.responses import StreamingResponse

from cashews import Command, cache
from cashews import cache
from cashews.contrib.fastapi import CacheDeleteMiddleware, CacheEtagMiddleware, CacheRequestControlMiddleware

app = FastAPI()
cache.setup(os.environ.get("CACHE_URI", "redis://?client_side=True"))
app.add_middleware(CacheDeleteMiddleware)
app.add_middleware(CacheEtagMiddleware)
app.add_middleware(CacheRequestControlMiddleware)
cache.setup(os.environ.get("CACHE_URI", "redis://"))
KB = 1024


@app.get("/")
@cache.failover(ttl="1h")
@cache.slice_rate_limit(10, "1m")
@cache.slice_rate_limit(10, "3m")
@cache(ttl="10m", key="simple:{user_agent}", time_condition="1s")
async def simple(user_agent: str = Header("No")):
await asyncio.sleep(1.1)
Expand Down Expand Up @@ -56,25 +60,6 @@ async def add_process_time_header(request, call_next):
return response


@app.middleware("http")
async def add_from_cache_headers(request, call_next):
with cache.detect as detector:
response = await call_next(request)
if request.method.lower() != "get":
return response
if detector.calls:
response.headers["X-From-Cache-keys"] = ";".join(detector.calls.keys())
return response


@app.middleware("http")
async def disable_middleware(request, call_next):
if request.headers.get("X-No-Cache"):
with cache.disabling(Command.GET):
return await call_next(request)
return await call_next(request)


if __name__ == "__main__":
import uvicorn

Expand Down
16 changes: 14 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,25 @@ async def _backend(request, redis_dsn, backend_factory):
from cashews.backends.redis import Redis

backend = backend_factory(
Redis, redis_dsn, hash_key=None, max_connections=20, suppress=False, socket_timeout=10
Redis,
redis_dsn,
hash_key=None,
max_connections=20,
suppress=False,
socket_timeout=10,
wait_for_connection_timeout=1,
)
elif request.param == "redis_hash":
from cashews.backends.redis import Redis

backend = backend_factory(
Redis, redis_dsn, hash_key=uuid4().hex, max_connections=20, suppress=False, socket_timeout=10
Redis,
redis_dsn,
hash_key=uuid4().hex,
max_connections=20,
suppress=False,
socket_timeout=10,
wait_for_connection_timeout=1,
)
elif request.param == "redis_cs":
from cashews.backends.redis.client_side import BcastClientSide
Expand Down
Loading

0 comments on commit 2b8405b

Please sign in to comment.