Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optionally disable disconnects in read_response #2695

Merged
merged 6 commits into from
May 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* Revert #2104, #2673, add `disconnect_on_error` option to `read_response()` (issues #2506, #2624)
* Add `address_remap` parameter to `RedisCluster`
* Fix incorrect usage of once flag in async Sentinel
* asyncio: Fix memory leak caused by hiredis (#2693)
Expand Down
93 changes: 27 additions & 66 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,23 +500,6 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
):
raise error

async def _try_send_command_parse_response(self, conn, *args, **options):
try:
return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, args[0], *args, **options
),
lambda error: self._disconnect_raise(conn, error),
)
except asyncio.CancelledError:
await conn.disconnect(nowait=True)
raise
finally:
if self.single_connection_client:
self._single_conn_lock.release()
if not self.connection:
await self.connection_pool.release(conn)

# COMMAND EXECUTION AND PROTOCOL PARSING
async def execute_command(self, *args, **options):
"""Execute a command and return a parsed response"""
Expand All @@ -527,10 +510,18 @@ async def execute_command(self, *args, **options):

if self.single_connection_client:
await self._single_conn_lock.acquire()

return await asyncio.shield(
self._try_send_command_parse_response(conn, *args, **options)
)
try:
return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_raise(conn, error),
)
finally:
if self.single_connection_client:
self._single_conn_lock.release()
if not self.connection:
await pool.release(conn)

async def parse_response(
self, connection: Connection, command_name: Union[str, bytes], **options
Expand Down Expand Up @@ -774,18 +765,10 @@ async def _disconnect_raise_connect(self, conn, error):
is not a TimeoutError. Otherwise, try to reconnect
"""
await conn.disconnect()

if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
raise error
await conn.connect()

async def _try_execute(self, conn, command, *arg, **kwargs):
try:
return await command(*arg, **kwargs)
except asyncio.CancelledError:
await conn.disconnect()
raise

async def _execute(self, conn, command, *args, **kwargs):
"""
Connect manually upon disconnection. If the Redis server is down,
Expand All @@ -794,11 +777,9 @@ async def _execute(self, conn, command, *args, **kwargs):
called by the # connection to resubscribe us to any channels and
patterns we were previously listening to
"""
return await asyncio.shield(
conn.retry.call_with_retry(
lambda: self._try_execute(conn, command, *args, **kwargs),
lambda error: self._disconnect_raise_connect(conn, error),
)
return await conn.retry.call_with_retry(
lambda: command(*args, **kwargs),
lambda error: self._disconnect_raise_connect(conn, error),
)

async def parse_response(self, block: bool = True, timeout: float = 0):
Expand All @@ -816,7 +797,9 @@ async def parse_response(self, block: bool = True, timeout: float = 0):
await conn.connect()

read_timeout = None if block else timeout
response = await self._execute(conn, conn.read_response, timeout=read_timeout)
response = await self._execute(
conn, conn.read_response, timeout=read_timeout, disconnect_on_error=False
)

if conn.health_check_interval and response == self.health_check_response:
# ignore the health check message as user might not expect it
Expand Down Expand Up @@ -1200,18 +1183,6 @@ async def _disconnect_reset_raise(self, conn, error):
await self.reset()
raise

async def _try_send_command_parse_response(self, conn, *args, **options):
try:
return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, args[0], *args, **options
),
lambda error: self._disconnect_reset_raise(conn, error),
)
except asyncio.CancelledError:
await conn.disconnect()
raise

async def immediate_execute_command(self, *args, **options):
"""
Execute a command immediately, but don't auto-retry on a
Expand All @@ -1227,8 +1198,12 @@ async def immediate_execute_command(self, *args, **options):
command_name, self.shard_hint
)
self.connection = conn
return await asyncio.shield(
self._try_send_command_parse_response(conn, *args, **options)

return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_reset_raise(conn, error),
)

def pipeline_execute_command(self, *args, **options):
Expand Down Expand Up @@ -1396,19 +1371,6 @@ async def _disconnect_raise_reset(self, conn: Connection, error: Exception):
await self.reset()
raise

async def _try_execute(self, conn, execute, stack, raise_on_error):
try:
return await conn.retry.call_with_retry(
lambda: execute(conn, stack, raise_on_error),
lambda error: self._disconnect_raise_reset(conn, error),
)
except asyncio.CancelledError:
# not supposed to be possible, yet here we are
await conn.disconnect(nowait=True)
raise
finally:
await self.reset()

async def execute(self, raise_on_error: bool = True):
"""Execute all the commands in the current pipeline"""
stack = self.command_stack
Expand All @@ -1430,11 +1392,10 @@ async def execute(self, raise_on_error: bool = True):
conn = cast(Connection, conn)

try:
return await asyncio.shield(
self._try_execute(conn, execute, stack, raise_on_error)
return await conn.retry.call_with_retry(
lambda: execute(conn, stack, raise_on_error),
lambda error: self._disconnect_raise_reset(conn, error),
)
except RuntimeError:
await self.reset()
finally:
await self.reset()

Expand Down
33 changes: 9 additions & 24 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,33 +1016,12 @@ async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
await connection.send_packed_command(connection.pack_command(*args), False)

# Read response
return await asyncio.shield(
self._parse_and_release(connection, args[0], **kwargs)
)

async def _parse_and_release(self, connection, *args, **kwargs):
try:
return await self.parse_response(connection, *args, **kwargs)
except asyncio.CancelledError:
# should not be possible
await connection.disconnect(nowait=True)
raise
return await self.parse_response(connection, args[0], **kwargs)
finally:
# Release connection
self._free.append(connection)

async def _try_parse_response(self, cmd, connection, ret):
try:
cmd.result = await asyncio.shield(
self.parse_response(connection, cmd.args[0], **cmd.kwargs)
)
except asyncio.CancelledError:
await connection.disconnect(nowait=True)
raise
except Exception as e:
cmd.result = e
ret = True
return ret

async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
# Acquire connection
connection = self.acquire_connection()
Expand All @@ -1055,7 +1034,13 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
# Read responses
ret = False
for cmd in commands:
ret = await asyncio.shield(self._try_parse_response(cmd, connection, ret))
try:
cmd.result = await self.parse_response(
connection, cmd.args[0], **cmd.kwargs
)
except Exception as e:
cmd.result = e
ret = True

# Release connection
self._free.append(connection)
Expand Down
28 changes: 18 additions & 10 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,11 @@ async def send_packed_command(
raise ConnectionError(
f"Error {err_no} while writing to socket. {errmsg}."
) from e
except Exception:
except BaseException:
# BaseExceptions can be raised when a socket send operation is not
# finished, e.g. due to a timeout. Ideally, a caller could then re-try
# to send un-sent data. However, the send_packed_command() API
# does not support it so there is no point in keeping the connection open.
await self.disconnect(nowait=True)
raise

Expand All @@ -828,6 +832,8 @@ async def read_response(
self,
disable_decoding: bool = False,
timeout: Optional[float] = None,
*,
disconnect_on_error: bool = True,
):
"""Read the response from a previously sent command"""
read_timeout = timeout if timeout is not None else self.socket_timeout
Expand All @@ -843,22 +849,24 @@ async def read_response(
)
except asyncio.TimeoutError:
if timeout is not None:
# user requested timeout, return None
# user requested timeout, return None. Operation can be retried
return None
# it was a self.socket_timeout error.
await self.disconnect(nowait=True)
if disconnect_on_error:
await self.disconnect(nowait=True)
raise TimeoutError(f"Timeout reading from {self.host}:{self.port}")
except OSError as e:
await self.disconnect(nowait=True)
if disconnect_on_error:
await self.disconnect(nowait=True)
raise ConnectionError(
f"Error while reading from {self.host}:{self.port} : {e.args}"
)
except asyncio.CancelledError:
# need this check for 3.7, where CancelledError
# is subclass of Exception, not BaseException
raise
except Exception:
await self.disconnect(nowait=True)
except BaseException:
# Also by default close in case of BaseException. A lot of code
# relies on this behaviour when doing Command/Response pairs.
# See #1128.
if disconnect_on_error:
await self.disconnect(nowait=True)
raise

if self.health_check_interval:
Expand Down
2 changes: 1 addition & 1 deletion redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1529,7 +1529,7 @@ def try_read():
return None
else:
conn.connect()
return conn.read_response()
return conn.read_response(disconnect_on_error=False)

response = self._execute(conn, try_read)

Expand Down
24 changes: 18 additions & 6 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,11 @@ def send_packed_command(self, command, check_health=True):
errno = e.args[0]
errmsg = e.args[1]
raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
except Exception:
except BaseException:
# BaseExceptions can be raised when a socket send operation is not
# finished, e.g. due to a timeout. Ideally, a caller could then re-try
# to send un-sent data. However, the send_packed_command() API
# does not support it so there is no point in keeping the connection open.
self.disconnect()
raise

Expand All @@ -859,23 +863,31 @@ def can_read(self, timeout=0):
self.disconnect()
raise ConnectionError(f"Error while reading from {host_error}: {e.args}")

def read_response(self, disable_decoding=False):
def read_response(
self, disable_decoding=False, *, disconnect_on_error: bool = True
):
"""Read the response from a previously sent command"""

host_error = self._host_error()

try:
response = self._parser.read_response(disable_decoding=disable_decoding)
except socket.timeout:
self.disconnect()
if disconnect_on_error:
self.disconnect()
raise TimeoutError(f"Timeout reading from {host_error}")
except OSError as e:
self.disconnect()
if disconnect_on_error:
self.disconnect()
raise ConnectionError(
f"Error while reading from {host_error}" f" : {e.args}"
)
except Exception:
self.disconnect()
except BaseException:
# Also by default close in case of BaseException. A lot of code
# relies on this behaviour when doing Command/Response pairs.
# See #1128.
if disconnect_on_error:
self.disconnect()
raise

if self.health_check_interval:
Expand Down
38 changes: 38 additions & 0 deletions tests/test_asyncio/test_commands.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""
Tests async overrides of commands from their mixins
"""
import asyncio
import binascii
import datetime
import re
import sys
from string import ascii_letters

import pytest
Expand All @@ -18,6 +20,11 @@
skip_unless_arch_bits,
)

if sys.version_info >= (3, 11, 3):
from asyncio import timeout as async_timeout
else:
from async_timeout import timeout as async_timeout

REDIS_6_VERSION = "5.9.0"


Expand Down Expand Up @@ -3008,6 +3015,37 @@ async def test_module_list(self, r: redis.Redis):
for x in await r.module_list():
assert isinstance(x, dict)

@pytest.mark.onlynoncluster
async def test_interrupted_command(self, r: redis.Redis):
"""
Regression test for issue #1128: An Un-handled BaseException
will leave the socket with un-read response to a previous
command.
"""
ready = asyncio.Event()

async def helper():
with pytest.raises(asyncio.CancelledError):
# blocking pop
ready.set()
await r.brpop(["nonexist"])
# If the following is not done, further Timout operations will fail,
# because the timeout won't catch its Cancelled Error if the task
# has a pending cancel. Python documentation probably should reflect this.
if sys.version_info >= (3, 11):
asyncio.current_task().uncancel()
# if all is well, we can continue. The following should not hang.
await r.set("status", "down")

task = asyncio.create_task(helper())
await ready.wait()
await asyncio.sleep(0.01)
# the task is now sleeping, lets send it an exception
task.cancel()
# If all is well, the task should finish right away, otherwise fail with Timeout
async with async_timeout(0.1):
await task


@pytest.mark.onlynoncluster
class TestBinarySave:
Expand Down
Loading