diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 9b5d0d8eb9..8e75e3e07f 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -284,6 +284,9 @@ def set_parser(self, parser_class: Type[BaseParser]) -> None: async def connect(self): """Connects to the Redis server if not already connected""" + await self.connect_check_health(check_health=True) + + async def connect_check_health(self, check_health: bool = True): if self.is_connected: return try: @@ -302,7 +305,7 @@ async def connect(self): try: if not self.redis_connect_func: # Use the default on_connect function - await self.on_connect() + await self.on_connect_check_health(check_health=check_health) else: # Use the passed function redis_connect_func ( @@ -341,6 +344,9 @@ def get_protocol(self): async def on_connect(self) -> None: """Initialize the connection, authenticate and select a database""" + await self.on_connect_check_health(check_health=True) + + async def on_connect_check_health(self, check_health: bool = True) -> None: self._parser.on_connect(self) parser = self._parser @@ -398,7 +404,7 @@ async def on_connect(self) -> None: # update cluster exception classes self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES self._parser.on_connect(self) - await self.send_command("HELLO", self.protocol) + await self.send_command("HELLO", self.protocol, check_health=check_health) response = await self.read_response() # if response.get(b"proto") != self.protocol and response.get( # "proto" @@ -407,18 +413,35 @@ async def on_connect(self) -> None: # if a client_name is given, set it if self.client_name: - await self.send_command("CLIENT", "SETNAME", self.client_name) + await self.send_command( + "CLIENT", + "SETNAME", + self.client_name, + check_health=check_health, + ) if str_if_bytes(await self.read_response()) != "OK": raise ConnectionError("Error setting client name") # set the library name and version, pipeline for lower startup latency if self.lib_name: - await self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name) + await self.send_command( + "CLIENT", + "SETINFO", + "LIB-NAME", + self.lib_name, + check_health=check_health, + ) if self.lib_version: - await self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version) + await self.send_command( + "CLIENT", + "SETINFO", + "LIB-VER", + self.lib_version, + check_health=check_health, + ) # if a database is specified, switch to it. Also pipeline this if self.db: - await self.send_command("SELECT", self.db) + await self.send_command("SELECT", self.db, check_health=check_health) # read responses from pipeline for _ in (sent for sent in (self.lib_name, self.lib_version) if sent): @@ -480,8 +503,8 @@ async def send_packed_command( self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True ) -> None: if not self.is_connected: - await self.connect() - elif check_health: + await self.connect_check_health(check_health=False) + if check_health: await self.check_health() try: diff --git a/redis/connection.py b/redis/connection.py index a298542c03..1e7d6ba2ec 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -372,6 +372,9 @@ def set_parser(self, parser_class): def connect(self): "Connects to the Redis server if not already connected" + self.connect_check_health(check_health=True) + + def connect_check_health(self, check_health: bool = True): if self._sock: return try: @@ -387,7 +390,7 @@ def connect(self): try: if self.redis_connect_func is None: # Use the default on_connect function - self.on_connect() + self.on_connect_check_health(check_health=check_health) else: # Use the passed function redis_connect_func self.redis_connect_func(self) @@ -417,6 +420,9 @@ def _error_message(self, exception): return format_error_message(self._host_error(), exception) def on_connect(self): + self.on_connect_check_health(check_health=True) + + def on_connect_check_health(self, check_health: bool = True): "Initialize the connection, authenticate and select a database" self._parser.on_connect(self) parser = self._parser @@ -475,7 +481,7 @@ def on_connect(self): # update cluster exception classes self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES self._parser.on_connect(self) - self.send_command("HELLO", self.protocol) + self.send_command("HELLO", self.protocol, check_health=check_health) self.handshake_metadata = self.read_response() if ( self.handshake_metadata.get(b"proto") != self.protocol @@ -485,28 +491,45 @@ def on_connect(self): # if a client_name is given, set it if self.client_name: - self.send_command("CLIENT", "SETNAME", self.client_name) + self.send_command( + "CLIENT", + "SETNAME", + self.client_name, + check_health=check_health, + ) if str_if_bytes(self.read_response()) != "OK": raise ConnectionError("Error setting client name") try: # set the library name and version if self.lib_name: - self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name) + self.send_command( + "CLIENT", + "SETINFO", + "LIB-NAME", + self.lib_name, + check_health=check_health, + ) self.read_response() except ResponseError: pass try: if self.lib_version: - self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version) + self.send_command( + "CLIENT", + "SETINFO", + "LIB-VER", + self.lib_version, + check_health=check_health, + ) self.read_response() except ResponseError: pass # if a database is specified, switch to it if self.db: - self.send_command("SELECT", self.db) + self.send_command("SELECT", self.db, check_health=check_health) if str_if_bytes(self.read_response()) != "OK": raise ConnectionError("Invalid Database") @@ -548,7 +571,7 @@ def check_health(self): def send_packed_command(self, command, check_health=True): """Send an already packed command to the Redis server""" if not self._sock: - self.connect() + self.connect_check_health(check_health=False) # guard against health check recursion if check_health: self.check_health()