Skip to content

Commit 26ab964

Browse files
dvora-hchayim
andauthored
Fix bug: client side caching causes unexpected disconnections (async version) (#3165)
* fix disconnects * skip test in cluster * add test * save return value from handle_push_response (without it 'read_response' return the push message) * insert return response from cache to the try block to prevent connection leak * enable to get connection with data avaliable to read in csc mode and change can_read_destructive to not read data * fix check if socket is empty (at_eof() can return False but this doesn't mean there's definitely more data to read) --------- Co-authored-by: Chayim <chayim@users.noreply.github.com>
1 parent c573bc4 commit 26ab964

File tree

5 files changed

+77
-25
lines changed

5 files changed

+77
-25
lines changed

redis/_parsers/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ async def can_read_destructive(self) -> bool:
182182
return True
183183
try:
184184
async with async_timeout(0):
185-
return await self._stream.read(1)
185+
return self._stream.at_eof()
186186
except TimeoutError:
187187
return False
188188

redis/_parsers/resp3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,9 @@ async def _read_response(
261261
)
262262
for _ in range(int(response))
263263
]
264-
await self.handle_push_response(response, disable_decoding, push_request)
264+
response = await self.handle_push_response(
265+
response, disable_decoding, push_request
266+
)
265267
else:
266268
raise InvalidResponse(f"Protocol Error: {raw!r}")
267269

redis/asyncio/client.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -629,25 +629,27 @@ async def execute_command(self, *args, **options):
629629
pool = self.connection_pool
630630
conn = self.connection or await pool.get_connection(command_name, **options)
631631
response_from_cache = await conn._get_from_local_cache(args)
632-
if response_from_cache is not None:
633-
return response_from_cache
634-
else:
635-
if self.single_connection_client:
636-
await self._single_conn_lock.acquire()
637-
try:
638-
response = await conn.retry.call_with_retry(
639-
lambda: self._send_command_parse_response(
640-
conn, command_name, *args, **options
641-
),
642-
lambda error: self._disconnect_raise(conn, error),
643-
)
644-
conn._add_to_local_cache(args, response, keys)
645-
return response
646-
finally:
647-
if self.single_connection_client:
648-
self._single_conn_lock.release()
649-
if not self.connection:
650-
await pool.release(conn)
632+
try:
633+
if response_from_cache is not None:
634+
return response_from_cache
635+
else:
636+
try:
637+
if self.single_connection_client:
638+
await self._single_conn_lock.acquire()
639+
response = await conn.retry.call_with_retry(
640+
lambda: self._send_command_parse_response(
641+
conn, command_name, *args, **options
642+
),
643+
lambda error: self._disconnect_raise(conn, error),
644+
)
645+
conn._add_to_local_cache(args, response, keys)
646+
return response
647+
finally:
648+
if self.single_connection_client:
649+
self._single_conn_lock.release()
650+
finally:
651+
if not self.connection:
652+
await pool.release(conn)
651653

652654
async def parse_response(
653655
self, connection: Connection, command_name: Union[str, bytes], **options

redis/asyncio/connection.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]
685685

686686
def _socket_is_empty(self):
687687
"""Check if the socket is empty"""
688-
return not self._reader.at_eof()
688+
return len(self._reader._buffer) == 0
689689

690690
def _cache_invalidation_process(
691691
self, data: List[Union[str, Optional[List[str]]]]
@@ -1192,12 +1192,18 @@ def make_connection(self):
11921192
async def ensure_connection(self, connection: AbstractConnection):
11931193
"""Ensure that the connection object is connected and valid"""
11941194
await connection.connect()
1195-
# connections that the pool provides should be ready to send
1196-
# a command. if not, the connection was either returned to the
1195+
# if client caching is not enabled connections that the pool
1196+
# provides should be ready to send a command.
1197+
# if not, the connection was either returned to the
11971198
# pool before all data has been read or the socket has been
11981199
# closed. either way, reconnect and verify everything is good.
1200+
# (if caching enabled the connection will not always be ready
1201+
# to send a command because it may contain invalidation messages)
11991202
try:
1200-
if await connection.can_read_destructive():
1203+
if (
1204+
await connection.can_read_destructive()
1205+
and connection.client_cache is None
1206+
):
12011207
raise ConnectionError("Connection has data") from None
12021208
except (ConnectionError, OSError):
12031209
await connection.disconnect()

tests/test_asyncio/test_cache.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,48 @@ async def test_cache_return_copy(self, r):
142142
check = cache.get(("LRANGE", "mylist", 0, -1))
143143
assert check == [b"baz", b"bar", b"foo"]
144144

145+
@pytest.mark.onlynoncluster
146+
@pytest.mark.parametrize(
147+
"r",
148+
[{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
149+
indirect=True,
150+
)
151+
async def test_csc_not_cause_disconnects(self, r):
152+
r, cache = r
153+
id1 = await r.client_id()
154+
await r.mset({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1})
155+
assert await r.mget("a", "b", "c", "d", "e") == ["1", "1", "1", "1", "1"]
156+
id2 = await r.client_id()
157+
158+
# client should get value from client cache
159+
assert await r.mget("a", "b", "c", "d", "e") == ["1", "1", "1", "1", "1"]
160+
assert cache.get(("MGET", "a", "b", "c", "d", "e")) == [
161+
"1",
162+
"1",
163+
"1",
164+
"1",
165+
"1",
166+
]
167+
168+
await r.mset({"a": 2, "b": 2, "c": 2, "d": 2, "e": 2})
169+
id3 = await r.client_id()
170+
# client should get value from redis server post invalidate messages
171+
assert await r.mget("a", "b", "c", "d", "e") == ["2", "2", "2", "2", "2"]
172+
173+
await r.mset({"a": 3, "b": 3, "c": 3, "d": 3, "e": 3})
174+
# need to check that we get correct value 3 and not 2
175+
assert await r.mget("a", "b", "c", "d", "e") == ["3", "3", "3", "3", "3"]
176+
# client should get value from client cache
177+
assert await r.mget("a", "b", "c", "d", "e") == ["3", "3", "3", "3", "3"]
178+
179+
await r.mset({"a": 4, "b": 4, "c": 4, "d": 4, "e": 4})
180+
# need to check that we get correct value 4 and not 3
181+
assert await r.mget("a", "b", "c", "d", "e") == ["4", "4", "4", "4", "4"]
182+
# client should get value from client cache
183+
assert await r.mget("a", "b", "c", "d", "e") == ["4", "4", "4", "4", "4"]
184+
id4 = await r.client_id()
185+
assert id1 == id2 == id3 == id4
186+
145187

146188
@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
147189
@pytest.mark.onlycluster

0 commit comments

Comments
 (0)