Skip to content

Fixed infinitely recursive health checks #3557

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,9 @@ def set_parser(self, parser_class: Type[BaseParser]) -> None:

async def connect(self):
"""Connects to the Redis server if not already connected"""
await self.connect_check_health(check_health=True)

async def connect_check_health(self, check_health: bool = True):
if self.is_connected:
return
try:
Expand All @@ -302,7 +305,7 @@ async def connect(self):
try:
if not self.redis_connect_func:
# Use the default on_connect function
await self.on_connect()
await self.on_connect_check_health(check_health=check_health)
else:
# Use the passed function redis_connect_func
(
Expand Down Expand Up @@ -341,6 +344,9 @@ def get_protocol(self):

async def on_connect(self) -> None:
"""Initialize the connection, authenticate and select a database"""
await self.on_connect_check_health(check_health=True)

async def on_connect_check_health(self, check_health: bool = True) -> None:
self._parser.on_connect(self)
parser = self._parser

Expand Down Expand Up @@ -398,7 +404,7 @@ async def on_connect(self) -> None:
# update cluster exception classes
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
self._parser.on_connect(self)
await self.send_command("HELLO", self.protocol)
await self.send_command("HELLO", self.protocol, check_health=check_health)
response = await self.read_response()
# if response.get(b"proto") != self.protocol and response.get(
# "proto"
Expand All @@ -407,18 +413,35 @@ async def on_connect(self) -> None:

# if a client_name is given, set it
if self.client_name:
await self.send_command("CLIENT", "SETNAME", self.client_name)
await self.send_command(
"CLIENT",
"SETNAME",
self.client_name,
check_health=check_health,
)
if str_if_bytes(await self.read_response()) != "OK":
raise ConnectionError("Error setting client name")

# set the library name and version, pipeline for lower startup latency
if self.lib_name:
await self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name)
await self.send_command(
"CLIENT",
"SETINFO",
"LIB-NAME",
self.lib_name,
check_health=check_health,
)
if self.lib_version:
await self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version)
await self.send_command(
"CLIENT",
"SETINFO",
"LIB-VER",
self.lib_version,
check_health=check_health,
)
# if a database is specified, switch to it. Also pipeline this
if self.db:
await self.send_command("SELECT", self.db)
await self.send_command("SELECT", self.db, check_health=check_health)

# read responses from pipeline
for _ in (sent for sent in (self.lib_name, self.lib_version) if sent):
Expand Down Expand Up @@ -480,8 +503,8 @@ async def send_packed_command(
self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True
) -> None:
if not self.is_connected:
await self.connect()
elif check_health:
await self.connect_check_health(check_health=False)
if check_health:
await self.check_health()

try:
Expand Down
37 changes: 30 additions & 7 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,9 @@ def set_parser(self, parser_class):

def connect(self):
"Connects to the Redis server if not already connected"
self.connect_check_health(check_health=True)

def connect_check_health(self, check_health: bool = True):
if self._sock:
return
try:
Expand All @@ -387,7 +390,7 @@ def connect(self):
try:
if self.redis_connect_func is None:
# Use the default on_connect function
self.on_connect()
self.on_connect_check_health(check_health=check_health)
else:
# Use the passed function redis_connect_func
self.redis_connect_func(self)
Expand Down Expand Up @@ -417,6 +420,9 @@ def _error_message(self, exception):
return format_error_message(self._host_error(), exception)

def on_connect(self):
self.on_connect_check_health(check_health=True)

def on_connect_check_health(self, check_health: bool = True):
"Initialize the connection, authenticate and select a database"
self._parser.on_connect(self)
parser = self._parser
Expand Down Expand Up @@ -475,7 +481,7 @@ def on_connect(self):
# update cluster exception classes
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
self._parser.on_connect(self)
self.send_command("HELLO", self.protocol)
self.send_command("HELLO", self.protocol, check_health=check_health)
self.handshake_metadata = self.read_response()
if (
self.handshake_metadata.get(b"proto") != self.protocol
Expand All @@ -485,28 +491,45 @@ def on_connect(self):

# if a client_name is given, set it
if self.client_name:
self.send_command("CLIENT", "SETNAME", self.client_name)
self.send_command(
"CLIENT",
"SETNAME",
self.client_name,
check_health=check_health,
)
if str_if_bytes(self.read_response()) != "OK":
raise ConnectionError("Error setting client name")

try:
# set the library name and version
if self.lib_name:
self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name)
self.send_command(
"CLIENT",
"SETINFO",
"LIB-NAME",
self.lib_name,
check_health=check_health,
)
self.read_response()
except ResponseError:
pass

try:
if self.lib_version:
self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version)
self.send_command(
"CLIENT",
"SETINFO",
"LIB-VER",
self.lib_version,
check_health=check_health,
)
self.read_response()
except ResponseError:
pass

# if a database is specified, switch to it
if self.db:
self.send_command("SELECT", self.db)
self.send_command("SELECT", self.db, check_health=check_health)
if str_if_bytes(self.read_response()) != "OK":
raise ConnectionError("Invalid Database")

Expand Down Expand Up @@ -548,7 +571,7 @@ def check_health(self):
def send_packed_command(self, command, check_health=True):
"""Send an already packed command to the Redis server"""
if not self._sock:
self.connect()
self.connect_check_health(check_health=False)
# guard against health check recursion
if check_health:
self.check_health()
Expand Down
Loading