Skip to content

Commit

Permalink
more tests for fastapi contrib
Browse files Browse the repository at this point in the history
  • Loading branch information
Dima Kryukov committed Dec 3, 2023
1 parent 409d60d commit d68e242
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 40 deletions.
4 changes: 3 additions & 1 deletion cashews/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import ContextManager

from .cache_condition import NOT_NONE, only_exceptions, with_exceptions
from .commands import Command
from .contrib import * # noqa
Expand All @@ -18,7 +20,7 @@
hit = cache.hit
transaction = cache.transaction
setup = cache.setup
cache_detect = cache.detect
cache_detect: ContextManager = cache.detect

circuit_breaker = cache.circuit_breaker
dynamic = cache.dynamic
Expand Down
47 changes: 31 additions & 16 deletions cashews/contrib/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from starlette.types import ASGIApp

from cashews import Cache, Command, cache, invalidate_further
from cashews._typing import TTL

_CACHE_MAX_AGE: ContextVar[int] = ContextVar("cache_control_max_age")

Expand All @@ -29,14 +30,14 @@
_PUBLIC = "public"

_CLEAR_CACHE_HEADER_VALUE = "cache"
__all__ = ["cache_control_ttl", "CacheRequestControlMiddleware", "CacheEtagMiddleware", "CacheDeleteMiddleware"]


def cache_control_condition():
pass
def cache_control_ttl(default: TTL):
def _ttl(*args, **kwargs):
return _CACHE_MAX_AGE.get(default)


def cache_control_ttl():
pass
return _ttl


class CacheRequestControlMiddleware(BaseHTTPMiddleware):
Expand All @@ -60,7 +61,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -
if calls:
key, _ = calls[0]
expire = await self._cache.get_expire(key)
if expire:
if expire > 0:
response.headers[
_CACHE_CONTROL_HEADER
] = f"{_PRIVATE if self._private else _PUBLIC}, {_MAX_AGE}{expire}"
Expand All @@ -84,22 +85,28 @@ def max_age(self, cache_control_value: str | None):
if reset_token:
_CACHE_MAX_AGE.reset(reset_token)

@staticmethod
def _to_disable(cache_control_value: str | None) -> tuple[Command, ...]:
def _to_disable(self, cache_control_value: str | None) -> tuple[Command, ...]:
if cache_control_value == _NO_CACHE:
return (Command.GET,)
if cache_control_value == _NO_STORE:
return Command.GET, Command.SET
if cache_control_value and self._get_max_age(cache_control_value) == 0:
return Command.GET, Command.SET
return tuple()

@staticmethod
def _get_max_age(cache_control_value: str) -> int:
if not cache_control_value.startswith(_MAX_AGE):
return 0
try:
return int(cache_control_value.replace(_MAX_AGE, ""))
except (ValueError, TypeError):
return 0
def _get_max_age(cache_control_value: str) -> int | None:
if _MAX_AGE not in cache_control_value:
return None
for cc_value_item in cache_control_value.split(","):
cc_value_item = cc_value_item.strip()
try:
key, value = cc_value_item.split("=")
if key == _MAX_AGE[:-1]:
return int(value)
except (ValueError, TypeError):
continue
return None


class CacheEtagMiddleware(BaseHTTPMiddleware):
Expand All @@ -112,10 +119,18 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
if etag and await self._cache.exists(etag):
return Response(status_code=304)

with self._cache.detect as detector:
set_key = None

def set_callback(key, result):
nonlocal set_key
set_key = key

with self._cache.detect as detector, self._cache.callback(Command.SET, set_callback):
response = await call_next(request)
calls = detector.calls_list
if not calls:
if set_key:
response.headers[_ETAG_HEADER] = await self._get_etag(set_key)
return response

key, _ = calls[0]
Expand Down
4 changes: 3 additions & 1 deletion cashews/wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from cashews.backends.interface import _BackendInterface
from cashews.decorators import context_cache_detect

from .backend_settings import register_backend
from .backend_settings import register_backend # noqa
from .callback import CallbackWrapper
from .decorators import DecoratorsWrapper
from .disable_control import ControlWrapper
from .tags import CommandsTagsWrapper
Expand All @@ -17,6 +18,7 @@
class Cache(
TransactionWrapper,
ControlWrapper,
CallbackWrapper,
CommandsTagsWrapper,
DecoratorsWrapper,
_BackendInterface,
Expand Down
59 changes: 59 additions & 0 deletions cashews/wrapper/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import contextlib
import uuid
from typing import TYPE_CHECKING

from cashews._typing import AsyncCallable_T
from cashews.commands import PATTERN_CMDS, Command
from cashews.key import get_call_values

from .wrapper import Wrapper

if TYPE_CHECKING: # pragma: no cover
from cashews.backends.interface import Backend


class CallbackMiddleware:
def __init__(self):
self._callbacks = {}

async def __call__(self, call: AsyncCallable_T, cmd: Command, backend: "Backend", *args, **kwargs):
result = await call(*args, **kwargs)
as_key = "pattern" if cmd in PATTERN_CMDS else "key"
call_values = get_call_values(call, args, kwargs)
key = call_values.get(as_key)
for callback in self._callbacks.values():
callback(cmd, key=key, result=result, backend=backend)
return result

def add_callback(self, callback, name):
self._callbacks[name] = callback

def remove_callback(self, name):
del self._callbacks[name]

@contextlib.contextmanager
def callback(self, callback):
name = uuid.uuid4().hex
self.add_callback(callback, name)
try:
yield
finally:
self.remove_callback(name)


class CallbackWrapper(Wrapper):
def __init__(self, name: str = ""):
super().__init__(name)
self._callbacks = CallbackMiddleware()
self.add_middleware(self._callbacks)

@contextlib.contextmanager
def callback(self, cmd: Command, callback):
t_cmd = cmd

def _wrapped_callback(cmd, key, result, backend):
if cmd == t_cmd:
callback(key, result)

with self._callbacks.callback(_wrapped_callback):
yield
9 changes: 7 additions & 2 deletions examples/fastapi_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from fastapi.responses import StreamingResponse

from cashews import cache
from cashews.contrib.fastapi import CacheDeleteMiddleware, CacheEtagMiddleware, CacheRequestControlMiddleware
from cashews.contrib.fastapi import (
CacheDeleteMiddleware,
CacheEtagMiddleware,
CacheRequestControlMiddleware,
cache_control_ttl,
)

app = FastAPI()
app.add_middleware(CacheDeleteMiddleware)
Expand All @@ -21,7 +26,7 @@
@app.get("/")
@cache.failover(ttl="1h")
@cache.slice_rate_limit(10, "3m")
@cache(ttl="10m", key="simple:{user_agent}", time_condition="1s")
@cache(ttl=cache_control_ttl(default="4m"), key="simple:{user_agent}", time_condition="1s")
async def simple(user_agent: str = Header("No")):
await asyncio.sleep(1.1)
return "".join([random.choice(string.ascii_letters) for _ in range(10)])
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ lint =
types-redis
tests =
pytest
pytest-asyncio
pytest-asyncio==0.21.1
hypothesis

[flake8]
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ async def _backend(request, redis_dsn, backend_factory):
hash_key=None,
max_connections=20,
suppress=False,
socket_timeout=10,
socket_timeout=1,
wait_for_connection_timeout=1,
)
elif request.param == "redis_hash":
Expand Down
Loading

0 comments on commit d68e242

Please sign in to comment.