From c573bc4ab61d0d57726f872fdfca31962d44b534 Mon Sep 17 00:00:00 2001 From: dvora-h <67596500+dvora-h@users.noreply.github.com> Date: Wed, 28 Feb 2024 14:31:59 +0200 Subject: [PATCH] Fix bug: client side caching causes unexpected disconnections (#3160) * fix disconnects * skip test in cluster --------- Co-authored-by: Chayim --- redis/_parsers/resp3.py | 4 +++- redis/client.py | 14 +++++++------- redis/commands/core.py | 2 +- redis/connection.py | 17 +++++++--------- tests/test_cache.py | 43 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 61 insertions(+), 19 deletions(-) diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index 13aa1ffccb..88c8d5e52b 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -117,7 +117,9 @@ def _read_response(self, disable_decoding=False, push_request=False): ) for _ in range(int(response)) ] - self.handle_push_response(response, disable_decoding, push_request) + response = self.handle_push_response( + response, disable_decoding, push_request + ) else: raise InvalidResponse(f"Protocol Error: {raw!r}") diff --git a/redis/client.py b/redis/client.py index 85ed7380a8..79f52cc989 100755 --- a/redis/client.py +++ b/redis/client.py @@ -563,10 +563,10 @@ def execute_command(self, *args, **options): pool = self.connection_pool conn = self.connection or pool.get_connection(command_name, **options) response_from_cache = conn._get_from_local_cache(args) - if response_from_cache is not None: - return response_from_cache - else: - try: + try: + if response_from_cache is not None: + return response_from_cache + else: response = conn.retry.call_with_retry( lambda: self._send_command_parse_response( conn, command_name, *args, **options @@ -575,9 +575,9 @@ def execute_command(self, *args, **options): ) conn._add_to_local_cache(args, response, keys) return response - finally: - if not self.connection: - pool.release(conn) + finally: + if not self.connection: + pool.release(conn) def parse_response(self, connection, command_name, **options): """Parses a response from the Redis server""" diff --git a/redis/commands/core.py b/redis/commands/core.py index 6d81d76035..464e8d8c85 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -2011,7 +2011,7 @@ def mget(self, keys: KeysT, *args: EncodableT) -> ResponseT: options = {} if not args: options[EMPTY_RESPONSE] = [] - options["keys"] = keys + options["keys"] = args return self.execute_command("MGET", *args, **options) def mset(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: diff --git a/redis/connection.py b/redis/connection.py index 617d04af5c..b89ce0e94b 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,6 +1,5 @@ import copy import os -import select import socket import ssl import sys @@ -609,11 +608,6 @@ def pack_commands(self, commands): output.append(SYM_EMPTY.join(pieces)) return output - def _socket_is_empty(self): - """Check if the socket is empty""" - r, _, _ = select.select([self._sock], [], [], 0) - return not bool(r) - def _cache_invalidation_process( self, data: List[Union[str, Optional[List[str]]]] ) -> None: @@ -639,7 +633,7 @@ def _get_from_local_cache(self, command: str): or command[0] not in self.cache_whitelist ): return None - while not self._socket_is_empty(): + while self.can_read(): self.read_response(push_request=True) return self.client_cache.get(command) @@ -1187,12 +1181,15 @@ def get_connection(self, command_name: str, *keys, **options) -> "Connection": try: # ensure this connection is connected to Redis connection.connect() - # connections that the pool provides should be ready to send - # a command. if not, the connection was either returned to the + # if client caching is not enabled connections that the pool + # provides should be ready to send a command. + # if not, the connection was either returned to the # pool before all data has been read or the socket has been # closed. either way, reconnect and verify everything is good. + # (if caching enabled the connection will not always be ready + # to send a command because it may contain invalidation messages) try: - if connection.can_read(): + if connection.can_read() and connection.client_cache is None: raise ConnectionError("Connection has data") except (ConnectionError, OSError): connection.disconnect() diff --git a/tests/test_cache.py b/tests/test_cache.py index 4eb5160ecc..dd33afd23e 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -146,6 +146,49 @@ def test_cache_return_copy(self, r): check = cache.get(("LRANGE", "mylist", 0, -1)) assert check == [b"baz", b"bar", b"foo"] + @pytest.mark.onlynoncluster + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + def test_csc_not_cause_disconnects(self, r): + r, cache = r + id1 = r.client_id() + r.mset({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1, "f": 1}) + assert r.mget("a", "b", "c", "d", "e", "f") == ["1", "1", "1", "1", "1", "1"] + id2 = r.client_id() + + # client should get value from client cache + assert r.mget("a", "b", "c", "d", "e", "f") == ["1", "1", "1", "1", "1", "1"] + assert cache.get(("MGET", "a", "b", "c", "d", "e", "f")) == [ + "1", + "1", + "1", + "1", + "1", + "1", + ] + + r.mset({"a": 2, "b": 2, "c": 2, "d": 2, "e": 2, "f": 2}) + id3 = r.client_id() + # client should get value from redis server post invalidate messages + assert r.mget("a", "b", "c", "d", "e", "f") == ["2", "2", "2", "2", "2", "2"] + + r.mset({"a": 3, "b": 3, "c": 3, "d": 3, "e": 3, "f": 3}) + # need to check that we get correct value 3 and not 2 + assert r.mget("a", "b", "c", "d", "e", "f") == ["3", "3", "3", "3", "3", "3"] + # client should get value from client cache + assert r.mget("a", "b", "c", "d", "e", "f") == ["3", "3", "3", "3", "3", "3"] + + r.mset({"a": 4, "b": 4, "c": 4, "d": 4, "e": 4, "f": 4}) + # need to check that we get correct value 4 and not 3 + assert r.mget("a", "b", "c", "d", "e", "f") == ["4", "4", "4", "4", "4", "4"] + # client should get value from client cache + assert r.mget("a", "b", "c", "d", "e", "f") == ["4", "4", "4", "4", "4", "4"] + id4 = r.client_id() + assert id1 == id2 == id3 == id4 + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlycluster