Skip to content

Commit

Permalink
Client side caching invalidations (standalone) (#3089)
Browse files Browse the repository at this point in the history
* cache invalidations

* isort

* deamon thread

* remove threads

* delete comment

* tests

* skip if hiredis available

* async

* review comments

* docstring

* decode test

* fix test

* fix decode response test
  • Loading branch information
dvora-h authored and vladvildanov committed Sep 27, 2024
1 parent d3b854d commit a9a9f70
Show file tree
Hide file tree
Showing 10 changed files with 457 additions and 57 deletions.
2 changes: 0 additions & 2 deletions redis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from redis import asyncio # noqa
from redis.backoff import default_backoff
from redis.cache import _LocalChace
from redis.client import Redis, StrictRedis
from redis.cluster import RedisCluster
from redis.connection import (
Expand Down Expand Up @@ -62,7 +61,6 @@ def int_or_str(value):
VERSION = tuple([99, 99, 99])

__all__ = [
"_LocalChace",
"AuthenticationError",
"AuthenticationWrongNumberOfArgsError",
"BlockingConnectionPool",
Expand Down
66 changes: 44 additions & 22 deletions redis/_parsers/resp3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@
from .base import _AsyncRESPBase, _RESPBase
from .socket import SERVER_CLOSED_CONNECTION_ERROR

_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"]


class _RESP3Parser(_RESPBase):
"""RESP3 protocol implementation"""

def __init__(self, socket_read_size):
super().__init__(socket_read_size)
self.push_handler_func = self.handle_push_response
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.invalidations_push_handler_func = None

def handle_push_response(self, response):
def handle_pubsub_push_response(self, response):
logger = getLogger("push_response")
logger.info("Push response: " + str(response))
return response
Expand Down Expand Up @@ -114,30 +117,40 @@ def _read_response(self, disable_decoding=False, push_request=False):
)
for _ in range(int(response))
]
res = self.push_handler_func(response)
if not push_request:
return self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return res
self.handle_push_response(response, disable_decoding, push_request)
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")

if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)
return response

def set_push_handler(self, push_handler_func):
self.push_handler_func = push_handler_func
def handle_push_response(self, response, disable_decoding, push_request):
if response[0] in _INVALIDATION_MESSAGE:
res = self.invalidation_push_handler_func(response)
else:
res = self.pubsub_push_handler_func(response)
if not push_request:
return self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return res

def set_pubsub_push_handler(self, pubsub_push_handler_func):
self.pubsub_push_handler_func = pubsub_push_handler_func

def set_invalidation_push_handler(self, invalidations_push_handler_func):
self.invalidation_push_handler_func = invalidations_push_handler_func


class _AsyncRESP3Parser(_AsyncRESPBase):
def __init__(self, socket_read_size):
super().__init__(socket_read_size)
self.push_handler_func = self.handle_push_response
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.invalidations_push_handler_func = None

def handle_push_response(self, response):
def handle_pubsub_push_response(self, response):
logger = getLogger("push_response")
logger.info("Push response: " + str(response))
return response
Expand Down Expand Up @@ -246,19 +259,28 @@ async def _read_response(
)
for _ in range(int(response))
]
res = self.push_handler_func(response)
if not push_request:
return await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return res
await self.handle_push_response(response, disable_decoding, push_request)
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")

if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)
return response

def set_push_handler(self, push_handler_func):
self.push_handler_func = push_handler_func
async def handle_push_response(self, response, disable_decoding, push_request):
if response[0] in _INVALIDATION_MESSAGE:
res = self.invalidation_push_handler_func(response)
else:
res = self.pubsub_push_handler_func(response)
if not push_request:
return await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return res

def set_pubsub_push_handler(self, pubsub_push_handler_func):
self.pubsub_push_handler_func = pubsub_push_handler_func

def set_invalidation_push_handler(self, invalidations_push_handler_func):
self.invalidation_push_handler_func = invalidations_push_handler_func
131 changes: 113 additions & 18 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@
)
from redis.asyncio.lock import Lock
from redis.asyncio.retry import Retry
from redis.cache import (
DEFAULT_BLACKLIST,
DEFAULT_EVICTION_POLICY,
DEFAULT_WHITELIST,
_LocalCache,
)
from redis.client import (
EMPTY_RESPONSE,
NEVER_DECODE,
Expand All @@ -60,7 +66,7 @@
TimeoutError,
WatchError,
)
from redis.typing import ChannelT, EncodableT, KeyT
from redis.typing import ChannelT, EncodableT, KeysT, KeyT, ResponseT
from redis.utils import (
HIREDIS_AVAILABLE,
_set_info_logger,
Expand Down Expand Up @@ -231,6 +237,13 @@ def __init__(
redis_connect_func=None,
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
cache_enable: bool = False,
client_cache: Optional[_LocalCache] = None,
cache_max_size: int = 100,
cache_ttl: int = 0,
cache_eviction_policy: str = DEFAULT_EVICTION_POLICY,
cache_blacklist: List[str] = DEFAULT_BLACKLIST,
cache_whitelist: List[str] = DEFAULT_WHITELIST,
):
"""
Initialize a new Redis client.
Expand Down Expand Up @@ -336,6 +349,16 @@ def __init__(
# on a set of redis commands
self._single_conn_lock = asyncio.Lock()

self.client_cache = client_cache
if cache_enable:
self.client_cache = _LocalCache(
cache_max_size, cache_ttl, cache_eviction_policy
)
if self.client_cache is not None:
self.cache_blacklist = cache_blacklist
self.cache_whitelist = cache_whitelist
self.client_cache_initialized = False

def __repr__(self):
return f"{self.__class__.__name__}<{self.connection_pool!r}>"

Expand All @@ -347,6 +370,10 @@ async def initialize(self: _RedisT) -> _RedisT:
async with self._single_conn_lock:
if self.connection is None:
self.connection = await self.connection_pool.get_connection("_")
if self.client_cache is not None:
self.connection._parser.set_invalidation_push_handler(
self._cache_invalidation_process
)
return self

def set_response_callback(self, command: str, callback: ResponseCallbackT):
Expand Down Expand Up @@ -565,6 +592,8 @@ async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
close_connection_pool is None and self.auto_close_connection_pool
):
await self.connection_pool.disconnect()
if self.client_cache:
self.client_cache.flush()

@deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
async def close(self, close_connection_pool: Optional[bool] = None) -> None:
Expand Down Expand Up @@ -593,29 +622,95 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
):
raise error

def _cache_invalidation_process(
self, data: List[Union[str, Optional[List[str]]]]
) -> None:
"""
Invalidate (delete) all redis commands associated with a specific key.
`data` is a list of strings, where the first string is the invalidation message
and the second string is the list of keys to invalidate.
(if the list of keys is None, then all keys are invalidated)
"""
if data[1] is not None:
for key in data[1]:
self.client_cache.invalidate(str_if_bytes(key))
else:
self.client_cache.flush()

async def _get_from_local_cache(self, command: str):
"""
If the command is in the local cache, return the response
"""
if (
self.client_cache is None
or command[0] in self.cache_blacklist
or command[0] not in self.cache_whitelist
):
return None
while not self.connection._is_socket_empty():
await self.connection.read_response(push_request=True)
return self.client_cache.get(command)

def _add_to_local_cache(
self, command: Tuple[str], response: ResponseT, keys: List[KeysT]
):
"""
Add the command and response to the local cache if the command
is allowed to be cached
"""
if (
self.client_cache is not None
and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist)
and (self.cache_whitelist == [] or command[0] in self.cache_whitelist)
):
self.client_cache.set(command, response, keys)

def delete_from_local_cache(self, command: str):
"""
Delete the command from the local cache
"""
try:
self.client_cache.delete(command)
except AttributeError:
pass

# COMMAND EXECUTION AND PROTOCOL PARSING
async def execute_command(self, *args, **options):
"""Execute a command and return a parsed response"""
await self.initialize()
options.pop("keys", None) # the keys are used only for client side caching
pool = self.connection_pool
command_name = args[0]
conn = self.connection or await pool.get_connection(command_name, **options)
keys = options.pop("keys", None) # keys are used only for client side caching
response_from_cache = await self._get_from_local_cache(args)
if response_from_cache is not None:
return response_from_cache
else:
pool = self.connection_pool
conn = self.connection or await pool.get_connection(command_name, **options)

if self.single_connection_client:
await self._single_conn_lock.acquire()
try:
return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_raise(conn, error),
)
finally:
if self.single_connection_client:
self._single_conn_lock.release()
if not self.connection:
await pool.release(conn)
await self._single_conn_lock.acquire()
try:
if self.client_cache is not None and not self.client_cache_initialized:
await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, "CLIENT", *("CLIENT", "TRACKING", "ON")
),
lambda error: self._disconnect_raise(conn, error),
)
self.client_cache_initialized = True
response = await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_raise(conn, error),
)
self._add_to_local_cache(args, response, keys)
return response
finally:
if self.single_connection_client:
self._single_conn_lock.release()
if not self.connection:
await pool.release(conn)

async def parse_response(
self, connection: Connection, command_name: Union[str, bytes], **options
Expand Down Expand Up @@ -863,7 +958,7 @@ async def connect(self):
else:
await self.connection.connect()
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
self.connection._parser.set_push_handler(self.push_handler_func)
self.connection._parser.set_pubsub_push_handler(self.push_handler_func)

async def _disconnect_raise_connect(self, conn, error):
"""
Expand Down
4 changes: 4 additions & 0 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,10 @@ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]
output.append(SYM_EMPTY.join(pieces))
return output

def _is_socket_empty(self):
"""Check if the socket is empty"""
return not self._reader.at_eof()


class Connection(AbstractConnection):
"Manages TCP communication to and from a Redis server"
Expand Down
18 changes: 10 additions & 8 deletions redis/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class EvictionPolicy(Enum):
RANDOM = "random"


class _LocalChace:
class _LocalCache:
"""
A caching mechanism for storing redis commands and their responses.
Expand Down Expand Up @@ -220,6 +220,7 @@ def get(self, command: str) -> ResponseT:
if command in self.cache:
if self._is_expired(command):
self.delete(command)
return
self._update_access(command)
return self.cache[command]["response"]

Expand Down Expand Up @@ -266,28 +267,28 @@ def _update_access(self, command: str):
Args:
command (str): The redis command.
"""
if self.eviction_policy == EvictionPolicy.LRU:
if self.eviction_policy == EvictionPolicy.LRU.value:
self.cache.move_to_end(command)
elif self.eviction_policy == EvictionPolicy.LFU:
elif self.eviction_policy == EvictionPolicy.LFU.value:
self.cache[command]["access_count"] = (
self.cache.get(command, {}).get("access_count", 0) + 1
)
self.cache.move_to_end(command)
elif self.eviction_policy == EvictionPolicy.RANDOM:
elif self.eviction_policy == EvictionPolicy.RANDOM.value:
pass # Random eviction doesn't require updates

def _evict(self):
"""Evict a redis command from the cache based on the eviction policy."""
if self._is_expired(self.commands_ttl_list[0]):
self.delete(self.commands_ttl_list[0])
elif self.eviction_policy == EvictionPolicy.LRU:
elif self.eviction_policy == EvictionPolicy.LRU.value:
self.cache.popitem(last=False)
elif self.eviction_policy == EvictionPolicy.LFU:
elif self.eviction_policy == EvictionPolicy.LFU.value:
min_access_command = min(
self.cache, key=lambda k: self.cache[k].get("access_count", 0)
)
self.cache.pop(min_access_command)
elif self.eviction_policy == EvictionPolicy.RANDOM:
elif self.eviction_policy == EvictionPolicy.RANDOM.value:
random_command = random.choice(list(self.cache.keys()))
self.cache.pop(random_command)

Expand Down Expand Up @@ -322,5 +323,6 @@ def invalidate(self, key: KeyT):
"""
if key not in self.key_commands_map:
return
for command in self.key_commands_map[key]:
commands = list(self.key_commands_map[key])
for command in commands:
self.delete(command)
Loading

0 comments on commit a9a9f70

Please sign in to comment.