From 59e22cc9c49740bcc71901400d7ad082ea7aadc7 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Thu, 2 Oct 2025 21:01:51 +0300 Subject: [PATCH 1/8] Extracting maintenance push notifications logic in separate abstract classes. Usingle multiple inheritance for Connection and ConnectionPool classes to expose the maintenance push notifications for the existing connections. --- redis/connection.py | 1679 +++++++++++--------- redis/maint_notifications.py | 10 +- tests/test_maint_notifications.py | 8 +- tests/test_maint_notifications_handling.py | 77 +- 4 files changed, 1028 insertions(+), 746 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index 837fccd40e..26b696116e 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -5,7 +5,7 @@ import threading import time import weakref -from abc import abstractmethod +from abc import ABC, abstractmethod from itertools import chain from queue import Empty, Full, LifoQueue from typing import ( @@ -178,10 +178,6 @@ def deregister_connect_callback(self, callback): def set_parser(self, parser_class): pass - @abstractmethod - def set_maint_notifications_pool_handler(self, maint_notifications_pool_handler): - pass - @abstractmethod def get_protocol(self): pass @@ -245,82 +241,373 @@ def set_re_auth_token(self, token: TokenInterface): def re_auth(self): pass - @property - @abstractmethod - def maintenance_state(self) -> MaintenanceState: + +class MaintNotificationsAbstractConnection: + """ + Abstract class for handling maintenance notifications logic. + This class is expected to be used as base class together with ConnectionInterface. + + This class is intended to be used with multiple inheritance! + + All logic related to maintenance notifications is encapsulated in this class. + """ + + def __init__( + self, + maint_notifications_config: Optional[MaintNotificationsConfig], + maint_notifications_pool_handler: Optional[ + MaintNotificationsPoolHandler + ] = None, + maintenance_state: "MaintenanceState" = MaintenanceState.NONE, + maintenance_notification_hash: Optional[int] = None, + orig_host_address: Optional[str] = None, + orig_socket_timeout: Optional[float] = None, + orig_socket_connect_timeout: Optional[float] = None, + parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None, + ): """ - Returns the current maintenance state of the connection. + Initialize the maintenance notifications for the connection. + + Args: + maint_notifications_config (MaintNotificationsConfig): The configuration for maintenance notifications. + maint_notifications_pool_handler (Optional[MaintNotificationsPoolHandler]): The pool handler for maintenance notifications. + maintenance_state (MaintenanceState): The current maintenance state of the connection. + maintenance_notification_hash (Optional[int]): The current maintenance notification hash of the connection. + orig_host_address (Optional[str]): The original host address of the connection. + orig_socket_timeout (Optional[float]): The original socket timeout of the connection. + orig_socket_connect_timeout (Optional[float]): The original socket connect timeout of the connection. + parser (Optional[Union[_HiredisParser, _RESP3Parser]]): The parser to use for maintenance notifications. + If not provided, the parser from the connection is used. + This is useful when the parser is created after this object. """ + self.maint_notifications_config = maint_notifications_config + self.maint_notifications_pool_handler = maint_notifications_pool_handler + self.maintenance_state = maintenance_state + self.maintenance_notification_hash = maintenance_notification_hash + self._configure_maintenance_notifications( + self.maint_notifications_pool_handler, + orig_host_address, + orig_socket_timeout, + orig_socket_connect_timeout, + parser, + ) + self._should_reconnect = False + + @abstractmethod + def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]: pass - @maintenance_state.setter @abstractmethod - def maintenance_state(self, state: "MaintenanceState"): - """ - Sets the current maintenance state of the connection. - """ + def _get_socket(self) -> Optional[socket.socket]: pass @abstractmethod - def getpeername(self): + def get_protocol(self) -> Union[int, str]: """ - Returns the peer name of the connection. + Returns: + The RESP protocol version, or ``None`` if the protocol is not specified, + in which case the server default will be used. """ pass + @property @abstractmethod - def mark_for_reconnect(self): - """ - Mark the connection to be reconnected on the next command. - This is useful when a connection is moved to a different node. - """ + def host(self) -> str: pass + @host.setter @abstractmethod - def should_reconnect(self): - """ - Returns True if the connection should be reconnected. - """ + def host(self, value: str): + pass + + @property + @abstractmethod + def socket_timeout(self) -> Optional[Union[float, int]]: pass + @socket_timeout.setter @abstractmethod - def get_resolved_ip(self): - """ - Get resolved ip address for the connection. - """ + def socket_timeout(self, value: Optional[Union[float, int]]): pass + @property @abstractmethod - def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None): - """ - Update the timeout for the current socket. - """ + def socket_connect_timeout(self) -> Optional[Union[float, int]]: pass + @socket_connect_timeout.setter @abstractmethod + def socket_connect_timeout(self, value: Optional[Union[float, int]]): + pass + + @abstractmethod + def send_command(self, *args, **kwargs): + pass + + @abstractmethod + def read_response( + self, + disable_decoding=False, + *, + disconnect_on_error=True, + push_request=False, + ): + pass + + @abstractmethod + def disconnect(self, *args): + pass + + def _configure_maintenance_notifications( + self, + maint_notifications_pool_handler=None, + orig_host_address=None, + orig_socket_timeout=None, + orig_socket_connect_timeout=None, + parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None, + ): + """ + Enable maintenance notifications by setting up + handlers and storing original connection parameters. + + Should be used ONLY parsers that support push notifications. + """ + if ( + not self.maint_notifications_config + or not self.maint_notifications_config.enabled + ): + self._maint_notifications_connection_handler = None + return + + if not parser: + raise RedisError( + "To configure maintenance notifications, a parser must be provided." + ) + + # Set up pool handler if available + if maint_notifications_pool_handler: + parser.set_node_moving_push_handler( + maint_notifications_pool_handler.handle_notification + ) + + # Set up connection handler + self._maint_notifications_connection_handler = ( + MaintNotificationsConnectionHandler(self, self.maint_notifications_config) + ) + parser.set_maintenance_push_handler( + self._maint_notifications_connection_handler.handle_notification + ) + + # Store original connection parameters + self.orig_host_address = orig_host_address if orig_host_address else self.host + self.orig_socket_timeout = ( + orig_socket_timeout if orig_socket_timeout else self.socket_timeout + ) + self.orig_socket_connect_timeout = ( + orig_socket_connect_timeout + if orig_socket_connect_timeout + else self.socket_connect_timeout + ) + + def set_maint_notifications_pool_handler( + self, maint_notifications_pool_handler: MaintNotificationsPoolHandler + ): + maint_notifications_pool_handler.set_connection(self) + self._get_parser().set_node_moving_push_handler( + maint_notifications_pool_handler.handle_notification + ) + + # Update maintenance notification connection handler if it doesn't exist + if not self._maint_notifications_connection_handler: + self._maint_notifications_connection_handler = ( + MaintNotificationsConnectionHandler( + self, maint_notifications_pool_handler.config + ) + ) + self._get_parser().set_maintenance_push_handler( + self._maint_notifications_connection_handler.handle_notification + ) + else: + self._maint_notifications_connection_handler.config = ( + maint_notifications_pool_handler.config + ) + + def activate_maint_notifications_handling_if_enabled(self, check_health=True): + # Send maintenance notifications handshake if RESP3 is active + # and maintenance notifications are enabled + # and we have a host to determine the endpoint type from + # When the maint_notifications_config enabled mode is "auto", + # we just log a warning if the handshake fails + # When the mode is enabled=True, we raise an exception in case of failure + if ( + self.get_protocol() not in [2, "2"] + and self.maint_notifications_config + and self.maint_notifications_config.enabled + and self._maint_notifications_connection_handler + and hasattr(self, "host") + ): + self._enable_maintenance_notifications( + maint_notifications_config=self.maint_notifications_config, + check_health=check_health, + ) + + def _enable_maintenance_notifications( + self, maint_notifications_config: MaintNotificationsConfig, check_health=True + ): + try: + host = getattr(self, "host", None) + if host is None: + raise ValueError( + "Cannot enable maintenance notifications for connection" + " object that doesn't have a host attribute." + ) + else: + endpoint_type = maint_notifications_config.get_endpoint_type(host, self) + self.send_command( + "CLIENT", + "MAINT_NOTIFICATIONS", + "ON", + "moving-endpoint-type", + endpoint_type.value, + check_health=check_health, + ) + response = self.read_response() + if not response or str_if_bytes(response) != "OK": + raise ResponseError( + "The server doesn't support maintenance notifications" + ) + except Exception as e: + if ( + isinstance(e, ResponseError) + and maint_notifications_config.enabled == "auto" + ): + # Log warning but don't fail the connection + import logging + + logger = logging.getLogger(__name__) + logger.warning(f"Failed to enable maintenance notifications: {e}") + else: + raise + + def get_resolved_ip(self) -> Optional[str]: + """ + Extract the resolved IP address from an + established connection or resolve it from the host. + + First tries to get the actual IP from the socket (most accurate), + then falls back to DNS resolution if needed. + + Args: + connection: The connection object to extract the IP from + + Returns: + str: The resolved IP address, or None if it cannot be determined + """ + + # Method 1: Try to get the actual IP from the established socket connection + # This is most accurate as it shows the exact IP being used + try: + conn_socket = self._get_socket() + if conn_socket is not None: + peer_addr = conn_socket.getpeername() + if peer_addr and len(peer_addr) >= 1: + # For TCP sockets, peer_addr is typically (host, port) tuple + # Return just the host part + return peer_addr[0] + except (AttributeError, OSError): + # Socket might not be connected or getpeername() might fail + pass + + # Method 2: Fallback to DNS resolution of the host + # This is less accurate but works when socket is not available + try: + host = getattr(self, "host", "localhost") + port = getattr(self, "port", 6379) + if host: + # Use getaddrinfo to resolve the hostname to IP + # This mimics what the connection would do during _connect() + addr_info = socket.getaddrinfo( + host, port, socket.AF_UNSPEC, socket.SOCK_STREAM + ) + if addr_info: + # Return the IP from the first result + # addr_info[0] is (family, socktype, proto, canonname, sockaddr) + # sockaddr[0] is the IP address + return str(addr_info[0][4][0]) + except (AttributeError, OSError, socket.gaierror): + # DNS resolution might fail + pass + + return None + + @property + def maintenance_state(self) -> MaintenanceState: + return self._maintenance_state + + @maintenance_state.setter + def maintenance_state(self, state: "MaintenanceState"): + self._maintenance_state = state + + def getpeername(self): + """ + Returns the peer name of the connection. + """ + conn_socket = self._get_socket() + if conn_socket: + return conn_socket.getpeername()[0] + return None + + def mark_for_reconnect(self): + self._should_reconnect = True + + def should_reconnect(self): + return self._should_reconnect + + def reset_should_reconnect(self): + self._should_reconnect = False + + def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None): + conn_socket = self._get_socket() + if conn_socket: + timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout + conn_socket.settimeout(timeout) + self.update_parser_timeout(timeout) + + def update_parser_timeout(self, timeout: Optional[float] = None): + parser = self._get_parser() + if parser and parser._buffer: + if isinstance(parser, _RESP3Parser) and timeout: + parser._buffer.socket_timeout = timeout + elif isinstance(parser, _HiredisParser): + parser._socket_timeout = timeout + def set_tmp_settings( self, - tmp_host_address: Optional[str] = None, + tmp_host_address: Optional[Union[str, object]] = SENTINEL, tmp_relaxed_timeout: Optional[float] = None, ): """ - Updates temporary host address and timeout settings for the connection. + The value of SENTINEL is used to indicate that the property should not be updated. """ - pass + if tmp_host_address and tmp_host_address != SENTINEL: + self.host = str(tmp_host_address) + if tmp_relaxed_timeout != -1: + self.socket_timeout = tmp_relaxed_timeout + self.socket_connect_timeout = tmp_relaxed_timeout - @abstractmethod def reset_tmp_settings( self, reset_host_address: bool = False, reset_relaxed_timeout: bool = False, ): - """ - Resets temporary host address and timeout settings for the connection. - """ - pass + if reset_host_address: + self.host = self.orig_host_address + if reset_relaxed_timeout: + self.socket_timeout = self.orig_socket_timeout + self.socket_connect_timeout = self.orig_socket_connect_timeout -class AbstractConnection(ConnectionInterface): +class AbstractConnection(MaintNotificationsAbstractConnection, ConnectionInterface): "Manages communication to and from a Redis server" def __init__( @@ -347,10 +634,10 @@ def __init__( protocol: Optional[int] = 2, command_packer: Optional[Callable[[], None]] = None, event_dispatcher: Optional[EventDispatcher] = None, + maint_notifications_config: Optional[MaintNotificationsConfig] = None, maint_notifications_pool_handler: Optional[ MaintNotificationsPoolHandler ] = None, - maint_notifications_config: Optional[MaintNotificationsConfig] = None, maintenance_state: "MaintenanceState" = MaintenanceState.NONE, maintenance_notification_hash: Optional[int] = None, orig_host_address: Optional[str] = None, @@ -383,10 +670,10 @@ def __init__( self.credential_provider = credential_provider self.password = password self.username = username - self.socket_timeout = socket_timeout + self._socket_timeout = socket_timeout if socket_connect_timeout is None: socket_connect_timeout = socket_timeout - self.socket_connect_timeout = socket_connect_timeout + self._socket_connect_timeout = socket_connect_timeout self.retry_on_timeout = retry_on_timeout if retry_on_error is SENTINEL: retry_on_errors_list = [] @@ -432,21 +719,20 @@ def __init__( parser_class = _RESP3Parser self.set_parser(parser_class) - self.maint_notifications_config = maint_notifications_config + self._command_packer = self._construct_command_packer(command_packer) - # Set up maintenance notifications if enabled - self._configure_maintenance_notifications( + # Set up maintenance notifications + MaintNotificationsAbstractConnection.__init__( + self, + maint_notifications_config, maint_notifications_pool_handler, - orig_host_address, - orig_socket_timeout, - orig_socket_connect_timeout, - ) - - self._should_reconnect = False - self.maintenance_state = maintenance_state - self.maintenance_notification_hash = maintenance_notification_hash - - self._command_packer = self._construct_command_packer(command_packer) + maintenance_state, + maintenance_notification_hash, + orig_host_address, + orig_socket_timeout, + orig_socket_connect_timeout, + self._parser, + ) def __repr__(self): repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) @@ -501,68 +787,8 @@ def set_parser(self, parser_class): """ self._parser = parser_class(socket_read_size=self._socket_read_size) - def _configure_maintenance_notifications( - self, - maint_notifications_pool_handler=None, - orig_host_address=None, - orig_socket_timeout=None, - orig_socket_connect_timeout=None, - ): - """Enable maintenance notifications by setting up handlers and storing original connection parameters.""" - if ( - not self.maint_notifications_config - or not self.maint_notifications_config.enabled - ): - self._maint_notifications_connection_handler = None - return - - # Set up pool handler if available - if maint_notifications_pool_handler: - self._parser.set_node_moving_push_handler( - maint_notifications_pool_handler.handle_notification - ) - - # Set up connection handler - self._maint_notifications_connection_handler = ( - MaintNotificationsConnectionHandler(self, self.maint_notifications_config) - ) - self._parser.set_maintenance_push_handler( - self._maint_notifications_connection_handler.handle_notification - ) - - # Store original connection parameters - self.orig_host_address = orig_host_address if orig_host_address else self.host - self.orig_socket_timeout = ( - orig_socket_timeout if orig_socket_timeout else self.socket_timeout - ) - self.orig_socket_connect_timeout = ( - orig_socket_connect_timeout - if orig_socket_connect_timeout - else self.socket_connect_timeout - ) - - def set_maint_notifications_pool_handler( - self, maint_notifications_pool_handler: MaintNotificationsPoolHandler - ): - maint_notifications_pool_handler.set_connection(self) - self._parser.set_node_moving_push_handler( - maint_notifications_pool_handler.handle_notification - ) - - # Update maintenance notification connection handler if it doesn't exist - if not self._maint_notifications_connection_handler: - self._maint_notifications_connection_handler = ( - MaintNotificationsConnectionHandler( - self, maint_notifications_pool_handler.config - ) - ) - self._parser.set_maintenance_push_handler( - self._maint_notifications_connection_handler.handle_notification - ) - else: - self._maint_notifications_connection_handler.config = ( - maint_notifications_pool_handler.config - ) + def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser, _RESP2Parser]: + return self._parser def connect(self): "Connects to the Redis server if not already connected" @@ -688,48 +914,10 @@ def on_connect_check_health(self, check_health: bool = True): ): raise ConnectionError("Invalid RESP version") - # Send maintenance notifications handshake if RESP3 is active - # and maintenance notifications are enabled - # and we have a host to determine the endpoint type from - # When the maint_notifications_config enabled mode is "auto", - # we just log a warning if the handshake fails - # When the mode is enabled=True, we raise an exception in case of failure - if ( - self.protocol not in [2, "2"] - and self.maint_notifications_config - and self.maint_notifications_config.enabled - and self._maint_notifications_connection_handler - and hasattr(self, "host") - ): - try: - endpoint_type = self.maint_notifications_config.get_endpoint_type( - self.host, self - ) - self.send_command( - "CLIENT", - "MAINT_NOTIFICATIONS", - "ON", - "moving-endpoint-type", - endpoint_type.value, - check_health=check_health, - ) - response = self.read_response() - if str_if_bytes(response) != "OK": - raise ResponseError( - "The server doesn't support maintenance notifications" - ) - except Exception as e: - if ( - isinstance(e, ResponseError) - and self.maint_notifications_config.enabled == "auto" - ): - # Log warning but don't fail the connection - import logging - - logger = logging.getLogger(__name__) - logger.warning(f"Failed to enable maintenance notifications: {e}") - else: - raise + # Activate maintenance notifications for this connection + # if enabled in the configuration + # This is a no-op if maintenance notifications are not enabled + self.activate_maint_notifications_handling_if_enabled(check_health=check_health) # if a client_name is given, set it if self.client_name: @@ -782,7 +970,7 @@ def disconnect(self, *args): conn_sock = self._sock self._sock = None # reset the reconnect flag - self._should_reconnect = False + self.reset_should_reconnect() if conn_sock is None: return @@ -967,109 +1155,24 @@ def re_auth(self): self.read_response() self._re_auth_token = None - def get_resolved_ip(self) -> Optional[str]: - """ - Extract the resolved IP address from an - established connection or resolve it from the host. - - First tries to get the actual IP from the socket (most accurate), - then falls back to DNS resolution if needed. - - Args: - connection: The connection object to extract the IP from - - Returns: - str: The resolved IP address, or None if it cannot be determined - """ - - # Method 1: Try to get the actual IP from the established socket connection - # This is most accurate as it shows the exact IP being used - try: - if self._sock is not None: - peer_addr = self._sock.getpeername() - if peer_addr and len(peer_addr) >= 1: - # For TCP sockets, peer_addr is typically (host, port) tuple - # Return just the host part - return peer_addr[0] - except (AttributeError, OSError): - # Socket might not be connected or getpeername() might fail - pass - - # Method 2: Fallback to DNS resolution of the host - # This is less accurate but works when socket is not available - try: - host = getattr(self, "host", "localhost") - port = getattr(self, "port", 6379) - if host: - # Use getaddrinfo to resolve the hostname to IP - # This mimics what the connection would do during _connect() - addr_info = socket.getaddrinfo( - host, port, socket.AF_UNSPEC, socket.SOCK_STREAM - ) - if addr_info: - # Return the IP from the first result - # addr_info[0] is (family, socktype, proto, canonname, sockaddr) - # sockaddr[0] is the IP address - return addr_info[0][4][0] - except (AttributeError, OSError, socket.gaierror): - # DNS resolution might fail - pass - - return None + def _get_socket(self) -> Optional[socket.socket]: + return self._sock @property - def maintenance_state(self) -> MaintenanceState: - return self._maintenance_state - - @maintenance_state.setter - def maintenance_state(self, state: "MaintenanceState"): - self._maintenance_state = state + def socket_timeout(self) -> Optional[Union[float, int]]: + return self._socket_timeout - def getpeername(self): - if not self._sock: - return None - return self._sock.getpeername()[0] - - def mark_for_reconnect(self): - self._should_reconnect = True - - def should_reconnect(self): - return self._should_reconnect + @socket_timeout.setter + def socket_timeout(self, value: Optional[Union[float, int]]): + self._socket_timeout = value - def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None): - if self._sock: - timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout - self._sock.settimeout(timeout) - self.update_parser_buffer_timeout(timeout) - - def update_parser_buffer_timeout(self, timeout: Optional[float] = None): - if self._parser and self._parser._buffer: - self._parser._buffer.socket_timeout = timeout - - def set_tmp_settings( - self, - tmp_host_address: Optional[Union[str, object]] = SENTINEL, - tmp_relaxed_timeout: Optional[float] = None, - ): - """ - The value of SENTINEL is used to indicate that the property should not be updated. - """ - if tmp_host_address is not SENTINEL: - self.host = tmp_host_address - if tmp_relaxed_timeout != -1: - self.socket_timeout = tmp_relaxed_timeout - self.socket_connect_timeout = tmp_relaxed_timeout + @property + def socket_connect_timeout(self) -> Optional[Union[float, int]]: + return self._socket_connect_timeout - def reset_tmp_settings( - self, - reset_host_address: bool = False, - reset_relaxed_timeout: bool = False, - ): - if reset_host_address: - self.host = self.orig_host_address - if reset_relaxed_timeout: - self.socket_timeout = self.orig_socket_timeout - self.socket_connect_timeout = self.orig_socket_connect_timeout + @socket_connect_timeout.setter + def socket_connect_timeout(self, value: Optional[Union[float, int]]): + self._socket_connect_timeout = value class Connection(AbstractConnection): @@ -1084,7 +1187,7 @@ def __init__( socket_type=0, **kwargs, ): - self.host = host + self._host = host self.port = int(port) self.socket_keepalive = socket_keepalive self.socket_keepalive_options = socket_keepalive_options or {} @@ -1146,8 +1249,16 @@ def _connect(self): def _host_error(self): return f"{self.host}:{self.port}" + @property + def host(self) -> str: + return self._host + + @host.setter + def host(self, value: str): + self._host = value + -class CacheProxyConnection(ConnectionInterface): +class CacheProxyConnection(MaintNotificationsAbstractConnection, ConnectionInterface): DUMMY_CACHE_VALUE = b"foo" MIN_ALLOWED_VERSION = "7.4.0" DEFAULT_SERVER_NAME = "redis" @@ -1171,6 +1282,19 @@ def __init__( self._current_options = None self.register_connect_callback(self._enable_tracking_callback) + if isinstance(self._conn, MaintNotificationsAbstractConnection): + MaintNotificationsAbstractConnection.__init__( + self, + self._conn.maint_notifications_config, + self._conn.maint_notifications_pool_handler, + self._conn.maintenance_state, + self._conn.maintenance_notification_hash, + self._conn.host, + self._conn.socket_timeout, + self._conn.socket_connect_timeout, + self._conn._get_parser(), + ) + def repr_pieces(self): return self._conn.repr_pieces() @@ -1183,6 +1307,15 @@ def deregister_connect_callback(self, callback): def set_parser(self, parser_class): self._conn.set_parser(parser_class) + def set_maint_notifications_pool_handler(self, maint_notifications_pool_handler): + if isinstance(self._conn, MaintNotificationsAbstractConnection): + self._conn.set_maint_notifications_pool_handler( + maint_notifications_pool_handler + ) + + def get_protocol(self): + return self._conn.get_protocol() + def connect(self): self._conn.connect() @@ -1328,6 +1461,134 @@ def pack_commands(self, commands): def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: return self._conn.handshake_metadata + def set_re_auth_token(self, token: TokenInterface): + self._conn.set_re_auth_token(token) + + def re_auth(self): + self._conn.re_auth() + + @property + def host(self) -> str: + return self._conn.host + + @host.setter + def host(self, value: str): + self._conn.host = value + + @property + def socket_timeout(self) -> Optional[Union[float, int]]: + return self._conn.socket_timeout + + @socket_timeout.setter + def socket_timeout(self, value: Optional[Union[float, int]]): + self._conn.socket_timeout = value + + @property + def socket_connect_timeout(self) -> Optional[Union[float, int]]: + return self._conn.socket_connect_timeout + + @socket_connect_timeout.setter + def socket_connect_timeout(self, value: Optional[Union[float, int]]): + self._conn.socket_connect_timeout = value + + def _get_socket(self) -> Optional[socket.socket]: + if isinstance(self._conn, MaintNotificationsAbstractConnection): + return self._conn._get_socket() + else: + raise NotImplementedError( + "Maintenance notifications are not supported by this connection type" + ) + + @property + def maintenance_state(self) -> MaintenanceState: + if isinstance(self._conn, MaintNotificationsAbstractConnection): + return self._conn.maintenance_state + else: + raise NotImplementedError( + "Maintenance notifications are not supported by this connection type" + ) + + @maintenance_state.setter + def maintenance_state(self, state: MaintenanceState): + if isinstance(self._conn, MaintNotificationsAbstractConnection): + self._conn.maintenance_state = state + else: + raise NotImplementedError( + "Maintenance notifications are not supported by this connection type" + ) + + def getpeername(self): + if isinstance(self._conn, MaintNotificationsAbstractConnection): + return self._conn.getpeername() + else: + raise NotImplementedError( + "Maintenance notifications are not supported by this connection type" + ) + + def mark_for_reconnect(self): + if isinstance(self._conn, MaintNotificationsAbstractConnection): + self._conn.mark_for_reconnect() + else: + raise NotImplementedError( + "Maintenance notifications are not supported by this connection type" + ) + + def should_reconnect(self): + if isinstance(self._conn, MaintNotificationsAbstractConnection): + return self._conn.should_reconnect() + else: + raise NotImplementedError( + "Maintenance notifications are not supported by this connection type" + ) + + def reset_should_reconnect(self): + if isinstance(self._conn, MaintNotificationsAbstractConnection): + self._conn.reset_should_reconnect() + else: + raise NotImplementedError( + "Maintenance notifications are not supported by this connection type" + ) + + def get_resolved_ip(self): + if isinstance(self._conn, MaintNotificationsAbstractConnection): + return self._conn.get_resolved_ip() + else: + raise NotImplementedError( + "Maintenance notifications are not supported by this connection type" + ) + + def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None): + if isinstance(self._conn, MaintNotificationsAbstractConnection): + self._conn.update_current_socket_timeout(relaxed_timeout) + else: + raise NotImplementedError( + "Maintenance notifications are not supported by this connection type" + ) + + def set_tmp_settings( + self, + tmp_host_address: Optional[str] = None, + tmp_relaxed_timeout: Optional[float] = None, + ): + if isinstance(self._conn, MaintNotificationsAbstractConnection): + self._conn.set_tmp_settings(tmp_host_address, tmp_relaxed_timeout) + else: + raise NotImplementedError( + "Maintenance notifications are not supported by this connection type" + ) + + def reset_tmp_settings( + self, + reset_host_address: bool = False, + reset_relaxed_timeout: bool = False, + ): + if isinstance(self._conn, MaintNotificationsAbstractConnection): + self._conn.reset_tmp_settings(reset_host_address, reset_relaxed_timeout) + else: + raise NotImplementedError( + "Maintenance notifications are not supported by this connection type" + ) + def _connect(self): self._conn._connect() @@ -1351,15 +1612,6 @@ def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]] else: self._cache.delete_by_redis_keys(data[1]) - def get_protocol(self): - return self._conn.get_protocol() - - def set_re_auth_token(self, token: TokenInterface): - self._conn.set_re_auth_token(token) - - def re_auth(self): - self._conn.re_auth() - class SSLConnection(Connection): """Manages SSL connections to and from the Redis server(s). @@ -1605,83 +1857,425 @@ def parse_ssl_verify_flags(value): return verify_flags -URL_QUERY_ARGUMENT_PARSERS = { - "db": int, - "socket_timeout": float, - "socket_connect_timeout": float, - "socket_keepalive": to_bool, - "retry_on_timeout": to_bool, - "retry_on_error": list, - "max_connections": int, - "health_check_interval": int, - "ssl_check_hostname": to_bool, - "ssl_include_verify_flags": parse_ssl_verify_flags, - "ssl_exclude_verify_flags": parse_ssl_verify_flags, - "timeout": float, -} +URL_QUERY_ARGUMENT_PARSERS = { + "db": int, + "socket_timeout": float, + "socket_connect_timeout": float, + "socket_keepalive": to_bool, + "retry_on_timeout": to_bool, + "retry_on_error": list, + "max_connections": int, + "health_check_interval": int, + "ssl_check_hostname": to_bool, + "ssl_include_verify_flags": parse_ssl_verify_flags, + "ssl_exclude_verify_flags": parse_ssl_verify_flags, + "timeout": float, +} + + +def parse_url(url): + if not ( + url.startswith("redis://") + or url.startswith("rediss://") + or url.startswith("unix://") + ): + raise ValueError( + "Redis URL must specify one of the following " + "schemes (redis://, rediss://, unix://)" + ) + + url = urlparse(url) + kwargs = {} + + for name, value in parse_qs(url.query).items(): + if value and len(value) > 0: + value = unquote(value[0]) + parser = URL_QUERY_ARGUMENT_PARSERS.get(name) + if parser: + try: + kwargs[name] = parser(value) + except (TypeError, ValueError): + raise ValueError(f"Invalid value for '{name}' in connection URL.") + else: + kwargs[name] = value + + if url.username: + kwargs["username"] = unquote(url.username) + if url.password: + kwargs["password"] = unquote(url.password) + + # We only support redis://, rediss:// and unix:// schemes. + if url.scheme == "unix": + if url.path: + kwargs["path"] = unquote(url.path) + kwargs["connection_class"] = UnixDomainSocketConnection + + else: # implied: url.scheme in ("redis", "rediss"): + if url.hostname: + kwargs["host"] = unquote(url.hostname) + if url.port: + kwargs["port"] = int(url.port) + + # If there's a path argument, use it as the db argument if a + # querystring value wasn't specified + if url.path and "db" not in kwargs: + try: + kwargs["db"] = int(unquote(url.path).replace("/", "")) + except (AttributeError, ValueError): + pass + + if url.scheme == "rediss": + kwargs["connection_class"] = SSLConnection + + return kwargs + + +_CP = TypeVar("_CP", bound="ConnectionPool") + + +class ConnectionPoolInterface(ABC): + @abstractmethod + def get_protocol(self): + pass + + @abstractmethod + def reset(self): + pass + + @abstractmethod + @deprecated_args( + args_to_warn=["*"], + reason="Use get_connection() without args instead", + version="5.3.0", + ) + def get_connection( + self, command_name: Optional[str], *keys, **options + ) -> ConnectionInterface: + pass + + @abstractmethod + def get_encoder(self): + pass + + @abstractmethod + def release(self, connection: ConnectionInterface): + pass + + @abstractmethod + def disconnect(self, inuse_connections: bool = True): + pass + + @abstractmethod + def close(self): + pass + + @abstractmethod + def set_retry(self, retry: Retry): + pass + + @abstractmethod + def re_auth_callback(self, token: TokenInterface): + pass + + +class MaintNotificationsConnectionPoolBase: + """ + Mixin class for handling maintenance notifications logic. + This class is mixed into the ConnectionPool classes. + + This class is not intended to be used directly! + + All logic related to maintenance notifications and + connection pool handling is encapsulated in this class. + """ + + def __init__(self, **kwargs): + # Initialize maintenance notifications if enabled + if kwargs.get("maint_notifications_pool_handler") or kwargs.get( + "maint_notifications_config" + ): + if kwargs.get("protocol") not in [3, "3"]: + raise RedisError( + "Push handlers on connection are only supported with RESP version 3" + ) + + config = kwargs.get("maint_notifications_config", None) + handler = kwargs.get("maint_notifications_pool_handler", None) + + config = config or (handler.config if handler else None) + + if config and config.enabled: + self._update_connection_kwargs_for_maint_notifications() + + @property + @abstractmethod + def connection_kwargs(self) -> Dict[str, Any]: + pass + + @connection_kwargs.setter + @abstractmethod + def connection_kwargs(self, value: Dict[str, Any]): + pass + + @abstractmethod + def _get_pool_lock(self) -> threading.RLock: + pass + + @abstractmethod + def _get_free_connections(self) -> Iterable["MaintNotificationsAbstractConnection"]: + pass + + @abstractmethod + def _get_in_use_connections( + self, + ) -> Iterable["MaintNotificationsAbstractConnection"]: + pass + + def maint_notifications_pool_handler_enabled(self): + """ + Returns: + True if the maintenance notifications pool handler is enabled, False otherwise. + """ + maint_notifications_config = self.connection_kwargs.get( + "maint_notifications_config", None + ) + + return maint_notifications_config and maint_notifications_config.enabled + + def set_maint_notifications_pool_handler( + self, maint_notifications_pool_handler: MaintNotificationsPoolHandler + ): + self.connection_kwargs.update( + { + "maint_notifications_pool_handler": maint_notifications_pool_handler, + "maint_notifications_config": maint_notifications_pool_handler.config, + } + ) + self._update_connection_kwargs_for_maint_notifications() + + self._update_maint_notifications_configs_for_connections( + maint_notifications_pool_handler + ) + + def _update_maint_notifications_configs_for_connections( + self, maint_notifications_pool_handler + ): + """Update the maintenance notifications config for all connections in the pool.""" + with self._get_pool_lock(): + for conn in self._get_free_connections(): + conn.set_maint_notifications_pool_handler( + maint_notifications_pool_handler + ) + conn.maint_notifications_config = ( + maint_notifications_pool_handler.config + ) + for conn in self._get_in_use_connections(): + conn.set_maint_notifications_pool_handler( + maint_notifications_pool_handler + ) + conn.maint_notifications_config = ( + maint_notifications_pool_handler.config + ) + + def _update_connection_kwargs_for_maint_notifications(self): + """Store original connection parameters for maintenance notifications.""" + if self.connection_kwargs.get("orig_host_address", None) is None: + # If orig_host_address is None it means we haven't + # configured the original values yet + self.connection_kwargs.update( + { + "orig_host_address": self.connection_kwargs.get("host"), + "orig_socket_timeout": self.connection_kwargs.get( + "socket_timeout", None + ), + "orig_socket_connect_timeout": self.connection_kwargs.get( + "socket_connect_timeout", None + ), + } + ) + + def _should_update_connection( + self, + conn: "MaintNotificationsAbstractConnection", + matching_pattern: Literal[ + "connected_address", "configured_address", "notification_hash" + ] = "connected_address", + matching_address: Optional[str] = None, + matching_notification_hash: Optional[int] = None, + ) -> bool: + """ + Check if the connection should be updated based on the matching criteria. + """ + if matching_pattern == "connected_address": + if matching_address and conn.getpeername() != matching_address: + return False + elif matching_pattern == "configured_address": + if matching_address and conn.host != matching_address: + return False + elif matching_pattern == "notification_hash": + if ( + matching_notification_hash + and conn.maintenance_notification_hash != matching_notification_hash + ): + return False + return True + + def update_connection_settings( + self, + conn: "MaintNotificationsAbstractConnection", + state: Optional["MaintenanceState"] = None, + maintenance_notification_hash: Optional[int] = None, + host_address: Optional[str] = None, + relaxed_timeout: Optional[float] = None, + update_notification_hash: bool = False, + reset_host_address: bool = False, + reset_relaxed_timeout: bool = False, + ): + """ + Update the settings for a single connection. + """ + if state: + conn.maintenance_state = state + + if update_notification_hash: + # update the notification hash only if requested + conn.maintenance_notification_hash = maintenance_notification_hash + + if host_address is not None: + conn.set_tmp_settings(tmp_host_address=host_address) + + if relaxed_timeout is not None: + conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout) + if reset_relaxed_timeout or reset_host_address: + conn.reset_tmp_settings( + reset_host_address=reset_host_address, + reset_relaxed_timeout=reset_relaxed_timeout, + ) -def parse_url(url): - if not ( - url.startswith("redis://") - or url.startswith("rediss://") - or url.startswith("unix://") - ): - raise ValueError( - "Redis URL must specify one of the following " - "schemes (redis://, rediss://, unix://)" - ) + conn.update_current_socket_timeout(relaxed_timeout) - url = urlparse(url) - kwargs = {} + def update_connections_settings( + self, + state: Optional["MaintenanceState"] = None, + maintenance_notification_hash: Optional[int] = None, + host_address: Optional[str] = None, + relaxed_timeout: Optional[float] = None, + matching_address: Optional[str] = None, + matching_notification_hash: Optional[int] = None, + matching_pattern: Literal[ + "connected_address", "configured_address", "notification_hash" + ] = "connected_address", + update_notification_hash: bool = False, + reset_host_address: bool = False, + reset_relaxed_timeout: bool = False, + include_free_connections: bool = True, + ): + """ + Update the settings for all matching connections in the pool. - for name, value in parse_qs(url.query).items(): - if value and len(value) > 0: - value = unquote(value[0]) - parser = URL_QUERY_ARGUMENT_PARSERS.get(name) - if parser: - try: - kwargs[name] = parser(value) - except (TypeError, ValueError): - raise ValueError(f"Invalid value for '{name}' in connection URL.") - else: - kwargs[name] = value + This method does not create new connections. + This method does not affect the connection kwargs. - if url.username: - kwargs["username"] = unquote(url.username) - if url.password: - kwargs["password"] = unquote(url.password) + :param state: The maintenance state to set for the connection. + :param maintenance_notification_hash: The hash of the maintenance notification + to set for the connection. + :param host_address: The host address to set for the connection. + :param relaxed_timeout: The relaxed timeout to set for the connection. + :param matching_address: The address to match for the connection. + :param matching_notification_hash: The notification hash to match for the connection. + :param matching_pattern: The pattern to match for the connection. + :param update_notification_hash: Whether to update the notification hash for the connection. + :param reset_host_address: Whether to reset the host address to the original address. + :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout. + :param include_free_connections: Whether to include free/available connections. + """ + with self._get_pool_lock(): + for conn in self._get_in_use_connections(): + if self._should_update_connection( + conn, + matching_pattern, + matching_address, + matching_notification_hash, + ): + self.update_connection_settings( + conn, + state=state, + maintenance_notification_hash=maintenance_notification_hash, + host_address=host_address, + relaxed_timeout=relaxed_timeout, + update_notification_hash=update_notification_hash, + reset_host_address=reset_host_address, + reset_relaxed_timeout=reset_relaxed_timeout, + ) - # We only support redis://, rediss:// and unix:// schemes. - if url.scheme == "unix": - if url.path: - kwargs["path"] = unquote(url.path) - kwargs["connection_class"] = UnixDomainSocketConnection + if include_free_connections: + for conn in self._get_free_connections(): + if self._should_update_connection( + conn, + matching_pattern, + matching_address, + matching_notification_hash, + ): + self.update_connection_settings( + conn, + state=state, + maintenance_notification_hash=maintenance_notification_hash, + host_address=host_address, + relaxed_timeout=relaxed_timeout, + update_notification_hash=update_notification_hash, + reset_host_address=reset_host_address, + reset_relaxed_timeout=reset_relaxed_timeout, + ) - else: # implied: url.scheme in ("redis", "rediss"): - if url.hostname: - kwargs["host"] = unquote(url.hostname) - if url.port: - kwargs["port"] = int(url.port) + def update_connection_kwargs( + self, + **kwargs, + ): + """ + Update the connection kwargs for all future connections. - # If there's a path argument, use it as the db argument if a - # querystring value wasn't specified - if url.path and "db" not in kwargs: - try: - kwargs["db"] = int(unquote(url.path).replace("/", "")) - except (AttributeError, ValueError): - pass + This method updates the connection kwargs for all future connections created by the pool. + Existing connections are not affected. + """ + self.connection_kwargs.update(kwargs) - if url.scheme == "rediss": - kwargs["connection_class"] = SSLConnection + def update_active_connections_for_reconnect( + self, + moving_address_src: Optional[str] = None, + ): + """ + Mark all active connections for reconnect. + This is used when a cluster node is migrated to a different address. - return kwargs + :param moving_address_src: The address of the node that is being moved. + """ + with self._get_pool_lock(): + for conn in self._get_in_use_connections(): + if self._should_update_connection( + conn, "connected_address", moving_address_src + ): + conn.mark_for_reconnect() + def disconnect_free_connections( + self, + moving_address_src: Optional[str] = None, + ): + """ + Disconnect all free/available connections. + This is used when a cluster node is migrated to a different address. -_CP = TypeVar("_CP", bound="ConnectionPool") + :param moving_address_src: The address of the node that is being moved. + """ + with self._get_pool_lock(): + for conn in self._get_free_connections(): + if self._should_update_connection( + conn, "connected_address", moving_address_src + ): + conn.disconnect() -class ConnectionPool: +class ConnectionPool(MaintNotificationsConnectionPoolBase, ConnectionPoolInterface): """ Create a connection pool. ``If max_connections`` is set, then this object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's @@ -1757,16 +2351,16 @@ def __init__( raise ValueError('"max_connections" must be a positive integer') self.connection_class = connection_class - self.connection_kwargs = connection_kwargs + self._connection_kwargs = connection_kwargs self.max_connections = max_connections self.cache = None self._cache_factory = cache_factory if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"): - if self.connection_kwargs.get("protocol") not in [3, "3"]: + if self._connection_kwargs.get("protocol") not in [3, "3"]: raise RedisError("Client caching is only supported with RESP version 3") - cache = self.connection_kwargs.get("cache") + cache = self._connection_kwargs.get("cache") if cache is not None: if not isinstance(cache, CacheInterface): @@ -1778,29 +2372,13 @@ def __init__( self.cache = self._cache_factory.get_cache() else: self.cache = CacheFactory( - self.connection_kwargs.get("cache_config") + self._connection_kwargs.get("cache_config") ).get_cache() connection_kwargs.pop("cache", None) connection_kwargs.pop("cache_config", None) - if self.connection_kwargs.get( - "maint_notifications_pool_handler" - ) or self.connection_kwargs.get("maint_notifications_config"): - if self.connection_kwargs.get("protocol") not in [3, "3"]: - raise RedisError( - "Push handlers on connection are only supported with RESP version 3" - ) - config = self.connection_kwargs.get("maint_notifications_config", None) or ( - self.connection_kwargs.get("maint_notifications_pool_handler").config - if self.connection_kwargs.get("maint_notifications_pool_handler") - else None - ) - - if config and config.enabled: - self._update_connection_kwargs_for_maint_notifications() - - self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) + self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None) if self._event_dispatcher is None: self._event_dispatcher = EventDispatcher() @@ -1816,6 +2394,8 @@ def __init__( self._fork_lock = threading.RLock() self._lock = threading.RLock() + MaintNotificationsConnectionPoolBase.__init__(self, **connection_kwargs) + self.reset() def __repr__(self) -> str: @@ -1826,76 +2406,21 @@ def __repr__(self) -> str: f"({conn_kwargs})>)>" ) - def get_protocol(self): - """ - Returns: - The RESP protocol version, or ``None`` if the protocol is not specified, - in which case the server default will be used. - """ - return self.connection_kwargs.get("protocol", None) + @property + def connection_kwargs(self) -> Dict[str, Any]: + return self._connection_kwargs - def maint_notifications_pool_handler_enabled(self): + @connection_kwargs.setter + def connection_kwargs(self, value: Dict[str, Any]): + self._connection_kwargs = value + + def get_protocol(self): """ Returns: - True if the maintenance notifications pool handler is enabled, False otherwise. - """ - maint_notifications_config = self.connection_kwargs.get( - "maint_notifications_config", None - ) - - return maint_notifications_config and maint_notifications_config.enabled - - def set_maint_notifications_pool_handler( - self, maint_notifications_pool_handler: MaintNotificationsPoolHandler - ): - self.connection_kwargs.update( - { - "maint_notifications_pool_handler": maint_notifications_pool_handler, - "maint_notifications_config": maint_notifications_pool_handler.config, - } - ) - self._update_connection_kwargs_for_maint_notifications() - - self._update_maint_notifications_configs_for_connections( - maint_notifications_pool_handler - ) - - def _update_maint_notifications_configs_for_connections( - self, maint_notifications_pool_handler - ): - """Update the maintenance notifications config for all connections in the pool.""" - with self._lock: - for conn in self._available_connections: - conn.set_maint_notifications_pool_handler( - maint_notifications_pool_handler - ) - conn.maint_notifications_config = ( - maint_notifications_pool_handler.config - ) - for conn in self._in_use_connections: - conn.set_maint_notifications_pool_handler( - maint_notifications_pool_handler - ) - conn.maint_notifications_config = ( - maint_notifications_pool_handler.config - ) - - def _update_connection_kwargs_for_maint_notifications(self): - """Store original connection parameters for maintenance notifications.""" - if self.connection_kwargs.get("orig_host_address", None) is None: - # If orig_host_address is None it means we haven't - # configured the original values yet - self.connection_kwargs.update( - { - "orig_host_address": self.connection_kwargs.get("host"), - "orig_socket_timeout": self.connection_kwargs.get( - "socket_timeout", None - ), - "orig_socket_connect_timeout": self.connection_kwargs.get( - "socket_connect_timeout", None - ), - } - ) + The RESP protocol version, or ``None`` if the protocol is not specified, + in which case the server default will be used. + """ + return self.connection_kwargs.get("protocol", None) def reset(self) -> None: self._created_connections = 0 @@ -2059,7 +2584,7 @@ def disconnect(self, inuse_connections: bool = True) -> None: Disconnects connections in the pool If ``inuse_connections`` is True, disconnect connections that are - current in use, potentially by other threads. Otherwise only disconnect + currently in use, potentially by other threads. Otherwise only disconnect connections that are idle in the pool. """ self._checkpid() @@ -2100,185 +2625,16 @@ def re_auth_callback(self, token: TokenInterface): for conn in self._in_use_connections: conn.set_re_auth_token(token) - def _should_update_connection( - self, - conn: "Connection", - matching_pattern: Literal[ - "connected_address", "configured_address", "notification_hash" - ] = "connected_address", - matching_address: Optional[str] = None, - matching_notification_hash: Optional[int] = None, - ) -> bool: - """ - Check if the connection should be updated based on the matching criteria. - """ - if matching_pattern == "connected_address": - if matching_address and conn.getpeername() != matching_address: - return False - elif matching_pattern == "configured_address": - if matching_address and conn.host != matching_address: - return False - elif matching_pattern == "notification_hash": - if ( - matching_notification_hash - and conn.maintenance_notification_hash != matching_notification_hash - ): - return False - return True - - def update_connection_settings( - self, - conn: "Connection", - state: Optional["MaintenanceState"] = None, - maintenance_notification_hash: Optional[int] = None, - host_address: Optional[str] = None, - relaxed_timeout: Optional[float] = None, - update_notification_hash: bool = False, - reset_host_address: bool = False, - reset_relaxed_timeout: bool = False, - ): - """ - Update the settings for a single connection. - """ - if state: - conn.maintenance_state = state - - if update_notification_hash: - # update the notification hash only if requested - conn.maintenance_notification_hash = maintenance_notification_hash - - if host_address is not None: - conn.set_tmp_settings(tmp_host_address=host_address) - - if relaxed_timeout is not None: - conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout) - - if reset_relaxed_timeout or reset_host_address: - conn.reset_tmp_settings( - reset_host_address=reset_host_address, - reset_relaxed_timeout=reset_relaxed_timeout, - ) - - conn.update_current_socket_timeout(relaxed_timeout) - - def update_connections_settings( - self, - state: Optional["MaintenanceState"] = None, - maintenance_notification_hash: Optional[int] = None, - host_address: Optional[str] = None, - relaxed_timeout: Optional[float] = None, - matching_address: Optional[str] = None, - matching_notification_hash: Optional[int] = None, - matching_pattern: Literal[ - "connected_address", "configured_address", "notification_hash" - ] = "connected_address", - update_notification_hash: bool = False, - reset_host_address: bool = False, - reset_relaxed_timeout: bool = False, - include_free_connections: bool = True, - ): - """ - Update the settings for all matching connections in the pool. - - This method does not create new connections. - This method does not affect the connection kwargs. - - :param state: The maintenance state to set for the connection. - :param maintenance_notification_hash: The hash of the maintenance notification - to set for the connection. - :param host_address: The host address to set for the connection. - :param relaxed_timeout: The relaxed timeout to set for the connection. - :param matching_address: The address to match for the connection. - :param matching_notification_hash: The notification hash to match for the connection. - :param matching_pattern: The pattern to match for the connection. - :param update_notification_hash: Whether to update the notification hash for the connection. - :param reset_host_address: Whether to reset the host address to the original address. - :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout. - :param include_free_connections: Whether to include free/available connections. - """ - with self._lock: - for conn in self._in_use_connections: - if self._should_update_connection( - conn, - matching_pattern, - matching_address, - matching_notification_hash, - ): - self.update_connection_settings( - conn, - state=state, - maintenance_notification_hash=maintenance_notification_hash, - host_address=host_address, - relaxed_timeout=relaxed_timeout, - update_notification_hash=update_notification_hash, - reset_host_address=reset_host_address, - reset_relaxed_timeout=reset_relaxed_timeout, - ) - - if include_free_connections: - for conn in self._available_connections: - if self._should_update_connection( - conn, - matching_pattern, - matching_address, - matching_notification_hash, - ): - self.update_connection_settings( - conn, - state=state, - maintenance_notification_hash=maintenance_notification_hash, - host_address=host_address, - relaxed_timeout=relaxed_timeout, - update_notification_hash=update_notification_hash, - reset_host_address=reset_host_address, - reset_relaxed_timeout=reset_relaxed_timeout, - ) - - def update_connection_kwargs( - self, - **kwargs, - ): - """ - Update the connection kwargs for all future connections. - - This method updates the connection kwargs for all future connections created by the pool. - Existing connections are not affected. - """ - self.connection_kwargs.update(kwargs) - - def update_active_connections_for_reconnect( - self, - moving_address_src: Optional[str] = None, - ): - """ - Mark all active connections for reconnect. - This is used when a cluster node is migrated to a different address. + def _get_pool_lock(self): + return self._lock - :param moving_address_src: The address of the node that is being moved. - """ + def _get_free_connections(self): with self._lock: - for conn in self._in_use_connections: - if self._should_update_connection( - conn, "connected_address", moving_address_src - ): - conn.mark_for_reconnect() - - def disconnect_free_connections( - self, - moving_address_src: Optional[str] = None, - ): - """ - Disconnect all free/available connections. - This is used when a cluster node is migrated to a different address. + return self._available_connections - :param moving_address_src: The address of the node that is being moved. - """ + def _get_in_use_connections(self): with self._lock: - for conn in self._available_connections: - if self._should_update_connection( - conn, "connected_address", moving_address_src - ): - conn.disconnect() + return self._in_use_connections async def _mock(self, error: RedisError): """ @@ -2391,7 +2747,7 @@ def make_connection(self): ) else: connection = self.connection_class(**self.connection_kwargs) - self._connections.append(connection) + self._connections.append(connection) return connection finally: if self._locked: @@ -2520,124 +2876,19 @@ def disconnect(self): pass self._locked = False - def update_connections_settings( - self, - state: Optional["MaintenanceState"] = None, - maintenance_notification_hash: Optional[int] = None, - relaxed_timeout: Optional[float] = None, - host_address: Optional[str] = None, - matching_address: Optional[str] = None, - matching_notification_hash: Optional[int] = None, - matching_pattern: Literal[ - "connected_address", "configured_address", "notification_hash" - ] = "connected_address", - update_notification_hash: bool = False, - reset_host_address: bool = False, - reset_relaxed_timeout: bool = False, - include_free_connections: bool = True, - ): - """ - Override base class method to work with BlockingConnectionPool's structure. - """ + def _get_free_connections(self): with self._lock: - if include_free_connections: - for conn in tuple(self._connections): - if self._should_update_connection( - conn, - matching_pattern, - matching_address, - matching_notification_hash, - ): - self.update_connection_settings( - conn, - state=state, - maintenance_notification_hash=maintenance_notification_hash, - host_address=host_address, - relaxed_timeout=relaxed_timeout, - update_notification_hash=update_notification_hash, - reset_host_address=reset_host_address, - reset_relaxed_timeout=reset_relaxed_timeout, - ) - else: - connections_in_queue = {conn for conn in self.pool.queue if conn} - for conn in self._connections: - if conn not in connections_in_queue: - if self._should_update_connection( - conn, - matching_pattern, - matching_address, - matching_notification_hash, - ): - self.update_connection_settings( - conn, - state=state, - maintenance_notification_hash=maintenance_notification_hash, - host_address=host_address, - relaxed_timeout=relaxed_timeout, - update_notification_hash=update_notification_hash, - reset_host_address=reset_host_address, - reset_relaxed_timeout=reset_relaxed_timeout, - ) - - def update_active_connections_for_reconnect( - self, - moving_address_src: Optional[str] = None, - ): - """ - Mark all active connections for reconnect. - This is used when a cluster node is migrated to a different address. + return {conn for conn in self.pool.queue if conn} - :param moving_address_src: The address of the node that is being moved. - """ + def _get_in_use_connections(self): with self._lock: + # free connections connections_in_queue = {conn for conn in self.pool.queue if conn} - for conn in self._connections: - if conn not in connections_in_queue: - if self._should_update_connection( - conn, - matching_pattern="connected_address", - matching_address=moving_address_src, - ): - conn.mark_for_reconnect() - - def disconnect_free_connections( - self, - moving_address_src: Optional[str] = None, - ): - """ - Disconnect all free/available connections. - This is used when a cluster node is migrated to a different address. - - :param moving_address_src: The address of the node that is being moved. - """ - with self._lock: - existing_connections = self.pool.queue - - for conn in existing_connections: - if conn: - if self._should_update_connection( - conn, "connected_address", moving_address_src - ): - conn.disconnect() - - def _update_maint_notifications_config_for_connections( - self, maint_notifications_config - ): - for conn in tuple(self._connections): - conn.maint_notifications_config = maint_notifications_config - - def _update_maint_notifications_configs_for_connections( - self, maint_notifications_pool_handler - ): - """Update the maintenance notifications config for all connections in the pool.""" - with self._lock: - for conn in tuple(self._connections): - conn.set_maint_notifications_pool_handler( - maint_notifications_pool_handler - ) - conn.maint_notifications_config = ( - maint_notifications_pool_handler.config - ) + # in self._connections we keep all created connections + # so the ones that are not in the queue are the in use ones + return { + conn for conn in self._connections if conn not in connections_in_queue + } def set_in_maintenance(self, in_maintenance: bool): """ diff --git a/redis/maint_notifications.py b/redis/maint_notifications.py index 37e4f93a3f..0188775935 100644 --- a/redis/maint_notifications.py +++ b/redis/maint_notifications.py @@ -33,8 +33,8 @@ def __str__(self): if TYPE_CHECKING: from redis.connection import ( BlockingConnectionPool, - ConnectionInterface, ConnectionPool, + MaintNotificationsAbstractConnection, ) @@ -501,7 +501,7 @@ def is_relaxed_timeouts_enabled(self) -> bool: return self.relaxed_timeout != -1 def get_endpoint_type( - self, host: str, connection: "ConnectionInterface" + self, host: str, connection: "MaintNotificationsAbstractConnection" ) -> EndpointType: """ Determine the appropriate endpoint type for CLIENT MAINT_NOTIFICATIONS command. @@ -567,7 +567,7 @@ def __init__( self._lock = threading.RLock() self.connection = None - def set_connection(self, connection: "ConnectionInterface"): + def set_connection(self, connection: "MaintNotificationsAbstractConnection"): self.connection = connection def remove_expired_notifications(self): @@ -751,7 +751,9 @@ class MaintNotificationsConnectionHandler: } def __init__( - self, connection: "ConnectionInterface", config: MaintNotificationsConfig + self, + connection: "MaintNotificationsAbstractConnection", + config: MaintNotificationsConfig, ) -> None: self.connection = connection self.config = config diff --git a/tests/test_maint_notifications.py b/tests/test_maint_notifications.py index 08ac15368f..424d99c854 100644 --- a/tests/test_maint_notifications.py +++ b/tests/test_maint_notifications.py @@ -1,8 +1,9 @@ +from socket import socket import threading from unittest.mock import Mock, call, patch, MagicMock import pytest -from redis.connection import ConnectionInterface +from redis.connection import ConnectionInterface, MaintNotificationsAbstractConnection from redis.maint_notifications import ( MaintenanceNotification, @@ -758,13 +759,16 @@ def __init__(self, resolved_ip): def getpeername(self): return (self.resolved_ip, 6379) - class MockConnection(ConnectionInterface): + class MockConnection(MaintNotificationsAbstractConnection, ConnectionInterface): def __init__(self, host, resolved_ip=None, is_ssl=False): self.host = host self.port = 6379 self._sock = MockSocket(resolved_ip) if resolved_ip else None self.__class__.__name__ = "SSLConnection" if is_ssl else "Connection" + def _get_socket(self): + return self._sock + def get_resolved_ip(self): # Call the actual method from AbstractConnection from redis.connection import AbstractConnection diff --git a/tests/test_maint_notifications_handling.py b/tests/test_maint_notifications_handling.py index 54b6e2dff7..964646538b 100644 --- a/tests/test_maint_notifications_handling.py +++ b/tests/test_maint_notifications_handling.py @@ -7,8 +7,11 @@ from time import sleep from redis import Redis +from redis.cache import CacheConfig from redis.connection import ( AbstractConnection, + CacheProxyConnection, + Connection, ConnectionPool, BlockingConnectionPool, MaintenanceState, @@ -68,7 +71,7 @@ def validate_in_use_connections_state( # and timeout is updated for connection in in_use_connections: if expected_should_reconnect != "any": - assert connection._should_reconnect == expected_should_reconnect + assert connection.should_reconnect() == expected_should_reconnect assert connection.host == expected_host_address assert connection.socket_timeout == expected_socket_timeout assert connection.socket_connect_timeout == expected_socket_connect_timeout @@ -78,13 +81,12 @@ def validate_in_use_connections_state( connection.orig_socket_connect_timeout == expected_orig_socket_connect_timeout ) - if connection._sock is not None: - assert connection._sock.gettimeout() == expected_current_socket_timeout - assert connection._sock.connected is True + conn_socket = connection._get_socket() + if conn_socket is not None: + assert conn_socket.gettimeout() == expected_current_socket_timeout + assert conn_socket.connected is True if expected_current_peername != "any": - assert ( - connection._sock.getpeername()[0] == expected_current_peername - ) + assert conn_socket.getpeername()[0] == expected_current_peername assert connection.maintenance_state == expected_state @staticmethod @@ -112,7 +114,7 @@ def validate_free_connections_state( connected_count = 0 for connection in free_connections: - assert connection._should_reconnect is False + assert connection.should_reconnect() is False assert connection.host == expected_host_address assert connection.socket_timeout == expected_socket_timeout assert connection.socket_connect_timeout == expected_socket_connect_timeout @@ -126,10 +128,11 @@ def validate_free_connections_state( if expected_state == MaintenanceState.NONE: assert connection.maintenance_notification_hash is None - if connection._sock is not None: - assert connection._sock.connected is True + conn_socket = connection._get_socket() + if conn_socket is not None: + assert conn_socket.connected is True if connected_to_tmp_address and tmp_address != "any": - assert connection._sock.getpeername()[0] == tmp_address + assert conn_socket.getpeername()[0] == tmp_address connected_count += 1 assert connected_count == should_be_connected_count @@ -201,7 +204,7 @@ def send(self, data): # Analyze the command and prepare appropriate response if b"HELLO" in data: - response = b"%7\r\n$6\r\nserver\r\n$5\r\nredis\r\n$7\r\nversion\r\n$5\r\n7.0.0\r\n$5\r\nproto\r\n:3\r\n$2\r\nid\r\n:1\r\n$4\r\nmode\r\n$10\r\nstandalone\r\n$4\r\nrole\r\n$6\r\nmaster\r\n$7\r\nmodules\r\n*0\r\n" + response = b"%7\r\n$6\r\nserver\r\n$5\r\nredis\r\n$7\r\nversion\r\n$5\r\n7.4.0\r\n$5\r\nproto\r\n:3\r\n$2\r\nid\r\n:1\r\n$4\r\nmode\r\n$10\r\nstandalone\r\n$4\r\nrole\r\n$6\r\nmaster\r\n$7\r\nmodules\r\n*0\r\n" self.pending_responses.append(response) elif b"MAINT_NOTIFICATIONS" in data and b"internal-ip" in data: # Simulate error response - activate it only for internal-ip tests @@ -392,6 +395,8 @@ def teardown_method(self): def _get_client( self, pool_class, + connection_class=Connection, + enable_cache=False, max_connections=10, maint_notifications_config=None, setup_pool_handler=False, @@ -413,13 +418,18 @@ def _get_client( if maint_notifications_config is not None else self.config ) + pool_kwargs = {} + if enable_cache: + pool_kwargs = {"cache_config": CacheConfig()} test_pool = pool_class( + connection_class=connection_class, host=DEFAULT_ADDRESS.split(":")[0], port=int(DEFAULT_ADDRESS.split(":")[1]), max_connections=max_connections, protocol=3, # Required for maintenance notifications maint_notifications_config=config, + **pool_kwargs, ) test_redis_client = Redis(connection_pool=test_pool) @@ -576,7 +586,7 @@ def test_client_initialization(self): assert pool_handler.config == self.config conn = test_redis_client.connection_pool.get_connection() - assert conn._should_reconnect is False + assert conn.should_reconnect() is False assert conn.orig_host_address == "localhost" assert conn.orig_socket_timeout is None @@ -1268,6 +1278,7 @@ def test_moving_none_notifications_handling_integration(self, pool_class): ) # Wait for half of MOVING timeout to expire and the proactive reconnect to run sleep(MOVING_TIMEOUT / 2 + 0.2) + Helpers.validate_in_use_connections_state( in_use_connections, expected_should_reconnect=True, @@ -1399,13 +1410,15 @@ def test_create_new_conn_while_moving_not_expired(self, pool_class): assert new_connection.host == AFTER_MOVING_ADDRESS.split(":")[0] assert new_connection.socket_timeout is self.config.relaxed_timeout # New connections should be connected to the temporary address - assert new_connection._sock is not None - assert new_connection._sock.connected is True + assert new_connection._get_socket() is not None + assert new_connection._get_socket().connected is True assert ( - new_connection._sock.getpeername()[0] + new_connection._get_socket().getpeername()[0] == AFTER_MOVING_ADDRESS.split(":")[0] ) - assert new_connection._sock.gettimeout() == self.config.relaxed_timeout + assert ( + new_connection._get_socket().gettimeout() == self.config.relaxed_timeout + ) finally: if hasattr(test_redis_client.connection_pool, "disconnect"): @@ -1465,10 +1478,10 @@ def test_create_new_conn_after_moving_expires(self, pool_class): assert new_connection.orig_host_address == DEFAULT_ADDRESS.split(":")[0] assert new_connection.orig_socket_timeout is None # New connections should be connected to the original address - assert new_connection._sock is not None - assert new_connection._sock.connected is True + assert new_connection._get_socket() is not None + assert new_connection._get_socket().connected is True # Socket timeout should be None (original timeout) - assert new_connection._sock.gettimeout() is None + assert new_connection._get_socket().gettimeout() is None finally: if hasattr(test_redis_client.connection_pool, "disconnect"): @@ -1575,8 +1588,8 @@ def test_receive_migrated_after_moving(self, pool_class): # Note: New connections may not inherit the exact relaxed timeout value # but they should have the temporary host address # New connections should be connected - if connection._sock is not None: - assert connection._sock.connected is True + if connection._get_socket() is not None: + assert connection._get_socket().connected is True # Release the new connections for connection in new_connections: @@ -1708,7 +1721,6 @@ def test_overlapping_moving_notifications(self, pool_class): expected_current_socket_timeout=self.config.relaxed_timeout, expected_current_peername=orig_after_moving.split(":")[0], ) - # print(test_redis_client.connection_pool._available_connections) Helpers.validate_free_connections_state( test_redis_client.connection_pool, should_be_connected_count=1, @@ -1790,8 +1802,18 @@ def worker(idx): if hasattr(test_redis_client.connection_pool, "disconnect"): test_redis_client.connection_pool.disconnect() - @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) - def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): + @pytest.mark.parametrize( + "pool_class,enable_cache", + [ + (ConnectionPool, False), + (ConnectionPool, True), + (BlockingConnectionPool, False), + (BlockingConnectionPool, True), + ], + ) + def test_moving_migrating_migrated_moved_state_transitions( + self, pool_class, enable_cache + ): """ Test moving configs are not lost if the per connection notifications get picked up after moving is handled. Sequence of notifications: MOVING, MIGRATING, MIGRATED, FAILING_OVER, FAILED_OVER, MOVED. @@ -1800,7 +1822,10 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): """ # Setup test_redis_client = self._get_client( - pool_class, max_connections=5, setup_pool_handler=True + pool_class, + max_connections=5, + setup_pool_handler=True, + enable_cache=enable_cache, ) pool = test_redis_client.connection_pool pool_handler = pool.connection_kwargs["maint_notifications_pool_handler"] From 3e79f84fa971a2d5226ef4764740fb32d1575623 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Fri, 3 Oct 2025 16:20:23 +0300 Subject: [PATCH 2/8] Fixing credential tests and linters errors --- tests/test_credentials.py | 8 +++++++- tests/test_maint_notifications.py | 1 - tests/test_maint_notifications_handling.py | 1 - 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 8b8e0cfc2c..c5892ca984 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -11,7 +11,10 @@ from redis import AuthenticationError, DataError, Redis, ResponseError from redis.auth.err import RequestTokenErr from redis.backoff import NoBackoff -from redis.connection import ConnectionInterface, ConnectionPool +from redis.connection import ( + ConnectionInterface, + ConnectionPool, +) from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider from redis.exceptions import ConnectionError, RedisError from redis.retry import Retry @@ -428,6 +431,8 @@ def re_auth_callback(token): def test_re_auth_pub_sub_in_resp3(self, credential_provider): mock_pubsub_connection = Mock(spec=ConnectionInterface) mock_pubsub_connection.get_protocol.return_value = 3 + + mock_pubsub_connection.should_reconnect = Mock(return_value=False) mock_pubsub_connection.credential_provider = credential_provider mock_pubsub_connection.retry = Retry(NoBackoff(), 3) mock_another_connection = Mock(spec=ConnectionInterface) @@ -488,6 +493,7 @@ def re_auth_callback(token): def test_do_not_re_auth_pub_sub_in_resp2(self, credential_provider): mock_pubsub_connection = Mock(spec=ConnectionInterface) mock_pubsub_connection.get_protocol.return_value = 2 + mock_pubsub_connection.should_reconnect = Mock(return_value=False) mock_pubsub_connection.credential_provider = credential_provider mock_pubsub_connection.retry = Retry(NoBackoff(), 3) mock_another_connection = Mock(spec=ConnectionInterface) diff --git a/tests/test_maint_notifications.py b/tests/test_maint_notifications.py index 424d99c854..85aa671390 100644 --- a/tests/test_maint_notifications.py +++ b/tests/test_maint_notifications.py @@ -1,4 +1,3 @@ -from socket import socket import threading from unittest.mock import Mock, call, patch, MagicMock import pytest diff --git a/tests/test_maint_notifications_handling.py b/tests/test_maint_notifications_handling.py index 964646538b..2935743d7c 100644 --- a/tests/test_maint_notifications_handling.py +++ b/tests/test_maint_notifications_handling.py @@ -10,7 +10,6 @@ from redis.cache import CacheConfig from redis.connection import ( AbstractConnection, - CacheProxyConnection, Connection, ConnectionPool, BlockingConnectionPool, From e348b4a9fd2f2c84d07a8685f0b7568ecebbdf87 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Mon, 6 Oct 2025 15:14:08 +0300 Subject: [PATCH 3/8] Adding better handling of configs and fixing tests --- redis/client.py | 50 ++--- redis/cluster.py | 6 + redis/connection.py | 210 +++++++++++++++------ redis/maint_notifications.py | 15 +- tests/test_credentials.py | 1 - tests/test_maint_notifications_handling.py | 105 +++-------- 6 files changed, 224 insertions(+), 163 deletions(-) diff --git a/redis/client.py b/redis/client.py index b8d2e8af5d..5f42a74c73 100755 --- a/redis/client.py +++ b/redis/client.py @@ -278,6 +278,17 @@ def __init__( single_connection_client: if `True`, connection pool is not used. In that case `Redis` instance use is not thread safe. + decode_responses: + if `True`, the response will be decoded to utf-8. + Argument is ignored when connection_pool is provided. + maint_notifications_config: + configuration the pool to support maintenance notifications - see + `redis.maint_notifications.MaintNotificationsConfig` for details. + Only supported with RESP3 + If not provided and protocol is RESP3, the maintenance notifications + will be enabled by default (logic is included in the connection pool + initialization). + Argument is ignored when connection_pool is provided. """ if event_dispatcher is None: self._event_dispatcher = EventDispatcher() @@ -354,6 +365,22 @@ def __init__( "cache_config": cache_config, } ) + maint_notifications_enabled = ( + maint_notifications_config and maint_notifications_config.enabled + ) + if maint_notifications_enabled and protocol not in [ + 3, + "3", + ]: + raise RedisError( + "Maintenance notifications handlers on connection are only supported with RESP version 3" + ) + if maint_notifications_config: + kwargs.update( + { + "maint_notifications_config": maint_notifications_config, + } + ) connection_pool = ConnectionPool(**kwargs) self._event_dispatcher.dispatch( AfterPooledConnectionsInstantiationEvent( @@ -377,23 +404,6 @@ def __init__( ]: raise RedisError("Client caching is only supported with RESP version 3") - if maint_notifications_config and self.connection_pool.get_protocol() not in [ - 3, - "3", - ]: - raise RedisError( - "Push handlers on connection are only supported with RESP version 3" - ) - if maint_notifications_config and maint_notifications_config.enabled: - self.maint_notifications_pool_handler = MaintNotificationsPoolHandler( - self.connection_pool, maint_notifications_config - ) - self.connection_pool.set_maint_notifications_pool_handler( - self.maint_notifications_pool_handler - ) - else: - self.maint_notifications_pool_handler = None - self.single_connection_lock = threading.RLock() self.connection = None self._single_connection_client = single_connection_client @@ -591,15 +601,9 @@ def monitor(self): return Monitor(self.connection_pool) def client(self): - maint_notifications_config = ( - None - if self.maint_notifications_pool_handler is None - else self.maint_notifications_pool_handler.config - ) return self.__class__( connection_pool=self.connection_pool, single_connection_client=True, - maint_notifications_config=maint_notifications_config, ) def __enter__(self): diff --git a/redis/cluster.py b/redis/cluster.py index 1d4a3e0d0c..c238c171be 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -50,6 +50,7 @@ WatchError, ) from redis.lock import Lock +from redis.maint_notifications import MaintNotificationsConfig from redis.retry import Retry from redis.utils import ( deprecated_args, @@ -1663,6 +1664,11 @@ def create_redis_node(self, host, port, **kwargs): backoff=NoBackoff(), retries=0, supported_errors=(ConnectionError,) ) + protocol = kwargs.get("protocol", None) + if protocol in [3, "3"]: + kwargs.update( + {"maint_notifications_config": MaintNotificationsConfig(enabled=False)} + ) if self.from_url: # Create a redis node with a costumed connection pool kwargs.update({"host": host}) diff --git a/redis/connection.py b/redis/connection.py index 26b696116e..db313c58e4 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -281,11 +281,10 @@ def __init__( This is useful when the parser is created after this object. """ self.maint_notifications_config = maint_notifications_config - self.maint_notifications_pool_handler = maint_notifications_pool_handler self.maintenance_state = maintenance_state self.maintenance_notification_hash = maintenance_notification_hash self._configure_maintenance_notifications( - self.maint_notifications_pool_handler, + maint_notifications_pool_handler, orig_host_address, orig_socket_timeout, orig_socket_connect_timeout, @@ -360,7 +359,9 @@ def disconnect(self, *args): def _configure_maintenance_notifications( self, - maint_notifications_pool_handler=None, + maint_notifications_pool_handler: Optional[ + MaintNotificationsPoolHandler + ] = None, orig_host_address=None, orig_socket_timeout=None, orig_socket_connect_timeout=None, @@ -370,30 +371,45 @@ def _configure_maintenance_notifications( Enable maintenance notifications by setting up handlers and storing original connection parameters. - Should be used ONLY parsers that support push notifications. + Should be used ONLY with parsers that support push notifications. """ if ( not self.maint_notifications_config or not self.maint_notifications_config.enabled ): + self._maint_notifications_pool_handler = None self._maint_notifications_connection_handler = None return if not parser: raise RedisError( - "To configure maintenance notifications, a parser must be provided." + "To configure maintenance notifications, a parser must be provided!" ) - # Set up pool handler if available if maint_notifications_pool_handler: - parser.set_node_moving_push_handler( - maint_notifications_pool_handler.handle_notification + # Extract a reference to a new pool handler that copies all properties + # of the original one and has a different connection reference + # This is needed because when we attach the handler to the parser + # we need to make sure that the handler has a reference to the + # connection that the parser is attached to. + self._maint_notifications_pool_handler = ( + maint_notifications_pool_handler.get_handler_for_connection() ) + self._maint_notifications_pool_handler.set_connection(self) + else: + self._maint_notifications_pool_handler = None - # Set up connection handler self._maint_notifications_connection_handler = ( MaintNotificationsConnectionHandler(self, self.maint_notifications_config) ) + + # Set up pool handler if available + if self._maint_notifications_pool_handler: + parser.set_node_moving_push_handler( + self._maint_notifications_pool_handler.handle_notification + ) + + # Set up connection handler parser.set_maintenance_push_handler( self._maint_notifications_connection_handler.handle_notification ) @@ -409,14 +425,24 @@ def _configure_maintenance_notifications( else self.socket_connect_timeout ) - def set_maint_notifications_pool_handler( + def set_maint_notifications_pool_handler_for_connection( self, maint_notifications_pool_handler: MaintNotificationsPoolHandler ): - maint_notifications_pool_handler.set_connection(self) + # Deep copy the pool handler to avoid sharing the same pool handler + # between multiple connections, because otherwise each connection will override + # the connection reference and the pool handler will only hold a reference + # to the last connection that was set. + maint_notifications_pool_handler_copy = ( + maint_notifications_pool_handler.get_handler_for_connection() + ) + + maint_notifications_pool_handler_copy.set_connection(self) self._get_parser().set_node_moving_push_handler( - maint_notifications_pool_handler.handle_notification + maint_notifications_pool_handler_copy.handle_notification ) + self._maint_notifications_pool_handler = maint_notifications_pool_handler_copy + # Update maintenance notification connection handler if it doesn't exist if not self._maint_notifications_connection_handler: self._maint_notifications_connection_handler = ( @@ -1286,7 +1312,7 @@ def __init__( MaintNotificationsAbstractConnection.__init__( self, self._conn.maint_notifications_config, - self._conn.maint_notifications_pool_handler, + self._conn._maint_notifications_pool_handler, self._conn.maintenance_state, self._conn.maintenance_notification_hash, self._conn.host, @@ -1307,9 +1333,11 @@ def deregister_connect_callback(self, callback): def set_parser(self, parser_class): self._conn.set_parser(parser_class) - def set_maint_notifications_pool_handler(self, maint_notifications_pool_handler): + def set_maint_notifications_pool_handler_for_connection( + self, maint_notifications_pool_handler + ): if isinstance(self._conn, MaintNotificationsAbstractConnection): - self._conn.set_maint_notifications_pool_handler( + self._conn.set_maint_notifications_pool_handler_for_connection( maint_notifications_pool_handler ) @@ -1978,9 +2006,9 @@ def re_auth_callback(self, token: TokenInterface): pass -class MaintNotificationsConnectionPoolBase: +class MaintNotificationsAbstractConnectionPool: """ - Mixin class for handling maintenance notifications logic. + Abstract class for handling maintenance notifications logic. This class is mixed into the ConnectionPool classes. This class is not intended to be used directly! @@ -1989,23 +2017,31 @@ class MaintNotificationsConnectionPoolBase: connection pool handling is encapsulated in this class. """ - def __init__(self, **kwargs): - # Initialize maintenance notifications if enabled - if kwargs.get("maint_notifications_pool_handler") or kwargs.get( - "maint_notifications_config" - ): - if kwargs.get("protocol") not in [3, "3"]: + def __init__( + self, + maint_notifications_config: Optional[MaintNotificationsConfig] = None, + **kwargs, + ): + # Initialize maintenance notifications + is_protocol_supported = kwargs.get("protocol") in [3, "3"] + if maint_notifications_config is None and is_protocol_supported: + maint_notifications_config = MaintNotificationsConfig() + + if maint_notifications_config and maint_notifications_config.enabled: + if not is_protocol_supported: raise RedisError( - "Push handlers on connection are only supported with RESP version 3" + "Maintenance notifications handlers on connection are only supported with RESP version 3" ) - config = kwargs.get("maint_notifications_config", None) - handler = kwargs.get("maint_notifications_pool_handler", None) - - config = config or (handler.config if handler else None) + self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( + self, maint_notifications_config + ) - if config and config.enabled: - self._update_connection_kwargs_for_maint_notifications() + self._update_connection_kwargs_for_maint_notifications( + self._maint_notifications_pool_handler + ) + else: + self._maint_notifications_pool_handler = None @property @abstractmethod @@ -2031,68 +2067,107 @@ def _get_in_use_connections( ) -> Iterable["MaintNotificationsAbstractConnection"]: pass - def maint_notifications_pool_handler_enabled(self): + def maint_notifications_enabled(self): """ Returns: - True if the maintenance notifications pool handler is enabled, False otherwise. + True if the maintenance notifications are enabled, False otherwise. + The maintenance notifications config is stored in the pool handler. + If the pool handler is not set, the maintenance notifications are not enabled. """ - maint_notifications_config = self.connection_kwargs.get( - "maint_notifications_config", None + maint_notifications_config = ( + self._maint_notifications_pool_handler.config + if self._maint_notifications_pool_handler + else None ) return maint_notifications_config and maint_notifications_config.enabled - def set_maint_notifications_pool_handler( + def update_maint_notifications_config( + self, maint_notifications_config: MaintNotificationsConfig + ): + """ + Updates the maintenance notifications configuration. + This method should be called only if the pool was created + without enabling the maintenance notifications and + in a later point in time maintenance notifications + are requested to be enabled. + """ + if ( + self.maint_notifications_enabled() + and not maint_notifications_config.enabled + ): + raise ValueError( + "Cannot disable maintenance notifications after enabling them" + ) + # first update pool settings + if not self._maint_notifications_pool_handler: + self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( + self, maint_notifications_config + ) + else: + self._maint_notifications_pool_handler.config = maint_notifications_config + + # then update connection kwargs and existing connections + self._update_connection_kwargs_for_maint_notifications( + self._maint_notifications_pool_handler + ) + self._update_maint_notifications_configs_for_connections( + self._maint_notifications_pool_handler + ) + + def _update_connection_kwargs_for_maint_notifications( self, maint_notifications_pool_handler: MaintNotificationsPoolHandler ): + """ + Update the connection kwargs for all future connections. + """ + if not self.maint_notifications_enabled(): + return + self.connection_kwargs.update( { "maint_notifications_pool_handler": maint_notifications_pool_handler, "maint_notifications_config": maint_notifications_pool_handler.config, } ) - self._update_connection_kwargs_for_maint_notifications() - self._update_maint_notifications_configs_for_connections( - maint_notifications_pool_handler - ) + # Store original connection parameters for maintenance notifications. + if self.connection_kwargs.get("orig_host_address", None) is None: + # If orig_host_address is None it means we haven't + # configured the original values yet + self.connection_kwargs.update( + { + "orig_host_address": self.connection_kwargs.get("host"), + "orig_socket_timeout": self.connection_kwargs.get( + "socket_timeout", None + ), + "orig_socket_connect_timeout": self.connection_kwargs.get( + "socket_connect_timeout", None + ), + } + ) def _update_maint_notifications_configs_for_connections( - self, maint_notifications_pool_handler + self, maint_notifications_pool_handler: MaintNotificationsPoolHandler ): """Update the maintenance notifications config for all connections in the pool.""" with self._get_pool_lock(): for conn in self._get_free_connections(): - conn.set_maint_notifications_pool_handler( + conn.set_maint_notifications_pool_handler_for_connection( maint_notifications_pool_handler ) conn.maint_notifications_config = ( maint_notifications_pool_handler.config ) + conn.disconnect() for conn in self._get_in_use_connections(): - conn.set_maint_notifications_pool_handler( + conn.set_maint_notifications_pool_handler_for_connection( maint_notifications_pool_handler ) conn.maint_notifications_config = ( maint_notifications_pool_handler.config ) - - def _update_connection_kwargs_for_maint_notifications(self): - """Store original connection parameters for maintenance notifications.""" - if self.connection_kwargs.get("orig_host_address", None) is None: - # If orig_host_address is None it means we haven't - # configured the original values yet - self.connection_kwargs.update( - { - "orig_host_address": self.connection_kwargs.get("host"), - "orig_socket_timeout": self.connection_kwargs.get( - "socket_timeout", None - ), - "orig_socket_connect_timeout": self.connection_kwargs.get( - "socket_connect_timeout", None - ), - } - ) + conn.mark_for_reconnect() def _should_update_connection( self, @@ -2275,7 +2350,7 @@ def disconnect_free_connections( conn.disconnect() -class ConnectionPool(MaintNotificationsConnectionPoolBase, ConnectionPoolInterface): +class ConnectionPool(MaintNotificationsAbstractConnectionPool, ConnectionPoolInterface): """ Create a connection pool. ``If max_connections`` is set, then this object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's @@ -2286,6 +2361,12 @@ class ConnectionPool(MaintNotificationsConnectionPoolBase, ConnectionPoolInterfa unix sockets. :py:class:`~redis.SSLConnection` can be used for SSL enabled connections. + If ``maint_notifications_config`` is provided, the connection pool will support + maintenance notifications. + Maintenance notifications are supported only with RESP3. + If the ``maint_notifications_config`` is not provided but the ``protocol`` is 3, + the maintenance notifications will be enabled by default. + Any additional keyword arguments are passed to the constructor of ``connection_class``. """ @@ -2344,6 +2425,7 @@ def __init__( connection_class=Connection, max_connections: Optional[int] = None, cache_factory: Optional[CacheFactoryInterface] = None, + maint_notifications_config: Optional[MaintNotificationsConfig] = None, **connection_kwargs, ): max_connections = max_connections or 2**31 @@ -2394,7 +2476,11 @@ def __init__( self._fork_lock = threading.RLock() self._lock = threading.RLock() - MaintNotificationsConnectionPoolBase.__init__(self, **connection_kwargs) + MaintNotificationsAbstractConnectionPool.__init__( + self, + maint_notifications_config=maint_notifications_config, + **connection_kwargs, + ) self.reset() @@ -2512,7 +2598,7 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection": if ( connection.can_read() and self.cache is None - and not self.maint_notifications_pool_handler_enabled() + and not self.maint_notifications_enabled() ): raise ConnectionError("Connection has data") except (ConnectionError, TimeoutError, OSError): diff --git a/redis/maint_notifications.py b/redis/maint_notifications.py index 0188775935..1ada097a99 100644 --- a/redis/maint_notifications.py +++ b/redis/maint_notifications.py @@ -32,9 +32,8 @@ def __str__(self): if TYPE_CHECKING: from redis.connection import ( - BlockingConnectionPool, - ConnectionPool, MaintNotificationsAbstractConnection, + MaintNotificationsAbstractConnectionPool, ) @@ -558,7 +557,7 @@ def get_endpoint_type( class MaintNotificationsPoolHandler: def __init__( self, - pool: Union["ConnectionPool", "BlockingConnectionPool"], + pool: "MaintNotificationsAbstractConnectionPool", config: MaintNotificationsConfig, ) -> None: self.pool = pool @@ -570,6 +569,16 @@ def __init__( def set_connection(self, connection: "MaintNotificationsAbstractConnection"): self.connection = connection + def get_handler_for_connection(self): + # Deep all data that should be shared between connections + # but each connection should have its own pool handler + # since each connection can be in a different state + copy = MaintNotificationsPoolHandler(self.pool, self.config) + copy._processed_notifications = self._processed_notifications + copy._lock = self._lock + copy.connection = None + return copy + def remove_expired_notifications(self): with self._lock: for notification in tuple(self._processed_notifications): diff --git a/tests/test_credentials.py b/tests/test_credentials.py index c5892ca984..29140b9e05 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -431,7 +431,6 @@ def re_auth_callback(token): def test_re_auth_pub_sub_in_resp3(self, credential_provider): mock_pubsub_connection = Mock(spec=ConnectionInterface) mock_pubsub_connection.get_protocol.return_value = 3 - mock_pubsub_connection.should_reconnect = Mock(return_value=False) mock_pubsub_connection.credential_provider = credential_provider mock_pubsub_connection.retry = Retry(NoBackoff(), 3) diff --git a/tests/test_maint_notifications_handling.py b/tests/test_maint_notifications_handling.py index 2935743d7c..15885462a8 100644 --- a/tests/test_maint_notifications_handling.py +++ b/tests/test_maint_notifications_handling.py @@ -398,7 +398,6 @@ def _get_client( enable_cache=False, max_connections=10, maint_notifications_config=None, - setup_pool_handler=False, ): """Helper method to create a pool and Redis client with maintenance notifications configuration. @@ -432,15 +431,6 @@ def _get_client( ) test_redis_client = Redis(connection_pool=test_pool) - # Set up pool handler for moving notifications if requested - if setup_pool_handler: - pool_handler = MaintNotificationsPoolHandler( - test_redis_client.connection_pool, config - ) - test_redis_client.connection_pool.set_maint_notifications_pool_handler( - pool_handler - ) - return test_redis_client @@ -499,6 +489,9 @@ def test_handshake_failure_when_enabled(self): ) try: with pytest.raises(ResponseError): + # handshake should fail + # socket mock will return error when enabling maint notifications + # for internal-ip test_redis_client.set("hello", "world") finally: @@ -515,7 +508,13 @@ def _validate_connection_handlers(self, conn, pool_handler, config): assert parser_handler is not None assert hasattr(parser_handler, "__self__") assert hasattr(parser_handler, "__func__") - assert parser_handler.__self__ is pool_handler + assert parser_handler.__self__.connection is conn + assert parser_handler.__self__.pool is pool_handler.pool + assert parser_handler.__self__._lock is pool_handler._lock + assert ( + parser_handler.__self__._processed_notifications + is pool_handler._processed_notifications + ) assert parser_handler.__func__ is pool_handler.handle_notification.__func__ # Test that the maintenance handler function is correctly set @@ -585,36 +584,12 @@ def test_client_initialization(self): assert pool_handler.config == self.config conn = test_redis_client.connection_pool.get_connection() + assert conn.should_reconnect() is False assert conn.orig_host_address == "localhost" assert conn.orig_socket_timeout is None - # Test that the node moving handler function is correctly set by - # comparing the underlying function and instance - parser_handler = conn._parser.node_moving_push_handler_func - assert parser_handler is not None - assert hasattr(parser_handler, "__self__") - assert hasattr(parser_handler, "__func__") - assert parser_handler.__self__ is pool_handler - assert parser_handler.__func__ is pool_handler.handle_notification.__func__ - - # Test that the maintenance handler function is correctly set - maintenance_handler = conn._parser.maintenance_push_handler_func - assert maintenance_handler is not None - assert hasattr(maintenance_handler, "__self__") - assert hasattr(maintenance_handler, "__func__") - # The maintenance handler should be bound to the connection's - # maintenance notification connection handler - assert ( - maintenance_handler.__self__ is conn._maint_notifications_connection_handler - ) - assert ( - maintenance_handler.__func__ - is conn._maint_notifications_connection_handler.handle_notification.__func__ - ) - - # Validate that the connection's maintenance handler has the same config object - assert conn._maint_notifications_connection_handler.config is self.config + self._validate_connection_handlers(conn, pool_handler, self.config) def test_maint_handler_init_for_existing_connections(self): """Test that maintenance notification handlers are properly set on existing and new connections @@ -639,13 +614,13 @@ def test_maint_handler_init_for_existing_connections(self): enabled_config = MaintNotificationsConfig( enabled=True, proactive_reconnect=True, relaxed_timeout=30 ) - pool_handler = MaintNotificationsPoolHandler( - test_redis_client.connection_pool, enabled_config - ) - test_redis_client.connection_pool.set_maint_notifications_pool_handler( - pool_handler + test_redis_client.connection_pool.update_maint_notifications_config( + enabled_config ) + pool_handler = ( + test_redis_client.connection_pool._maint_notifications_pool_handler + ) # Validate the existing connection after enabling maintenance notifications # Both existing and new connections should now have full handler setup self._validate_connection_handlers(existing_conn, pool_handler, enabled_config) @@ -653,6 +628,7 @@ def test_maint_handler_init_for_existing_connections(self): # Create a new connection and validate it has full handlers new_conn = test_redis_client.connection_pool.get_connection() self._validate_connection_handlers(new_conn, pool_handler, enabled_config) + self._validate_connection_handlers(existing_conn, pool_handler, enabled_config) # Clean up connections test_redis_client.connection_pool.release(existing_conn) @@ -674,11 +650,11 @@ def test_connection_pool_creation_with_maintenance_notifications(self, pool_clas == self.config ) # Pool should have maintenance notifications enabled - assert test_pool.maint_notifications_pool_handler_enabled() is True + assert test_pool.maint_notifications_enabled() is True # Create and set a pool handler - pool_handler = MaintNotificationsPoolHandler(test_pool, self.config) - test_pool.set_maint_notifications_pool_handler(pool_handler) + test_pool.update_maint_notifications_config(self.config) + pool_handler = test_pool._maint_notifications_pool_handler # Validate that the handler is properly set on the pool assert ( @@ -1065,9 +1041,7 @@ def test_moving_related_notifications_handling_integration(self, pool_class): 5. Tests both ConnectionPool and BlockingConnectionPool implementations """ # Create a pool and Redis client with maintenance notifications and pool handler - test_redis_client = self._get_client( - pool_class, max_connections=10, setup_pool_handler=True - ) + test_redis_client = self._get_client(pool_class, max_connections=10) try: # Create several connections and return them in the pool @@ -1199,9 +1173,7 @@ def test_moving_none_notifications_handling_integration(self, pool_class): 5. Tests both ConnectionPool and BlockingConnectionPool implementations """ # Create a pool and Redis client with maintenance notifications and pool handler - test_redis_client = self._get_client( - pool_class, max_connections=10, setup_pool_handler=True - ) + test_redis_client = self._get_client(pool_class, max_connections=10) try: # Create several connections and return them in the pool @@ -1277,7 +1249,6 @@ def test_moving_none_notifications_handling_integration(self, pool_class): ) # Wait for half of MOVING timeout to expire and the proactive reconnect to run sleep(MOVING_TIMEOUT / 2 + 0.2) - Helpers.validate_in_use_connections_state( in_use_connections, expected_should_reconnect=True, @@ -1349,9 +1320,7 @@ def test_create_new_conn_while_moving_not_expired(self, pool_class): 3. Pool configuration is properly applied to newly created connections """ # Create a pool and Redis client with maintenance notifications and pool handler - test_redis_client = self._get_client( - pool_class, max_connections=10, setup_pool_handler=True - ) + test_redis_client = self._get_client(pool_class, max_connections=10) try: # Create several connections and return them in the pool @@ -1434,9 +1403,7 @@ def test_create_new_conn_after_moving_expires(self, pool_class): 3. New connections don't inherit temporary settings """ # Create a pool and Redis client with maintenance notifications and pool handler - test_redis_client = self._get_client( - pool_class, max_connections=10, setup_pool_handler=True - ) + test_redis_client = self._get_client(pool_class, max_connections=10) try: # Create several connections and return them in the pool @@ -1501,9 +1468,7 @@ def test_receive_migrated_after_moving(self, pool_class): it should not decrease timeouts (future refactoring consideration). """ # Create a pool and Redis client with maintenance notifications and pool handler - test_redis_client = self._get_client( - pool_class, max_connections=10, setup_pool_handler=True - ) + test_redis_client = self._get_client(pool_class, max_connections=10) try: # Create several connections and return them in the pool @@ -1609,9 +1574,7 @@ def test_overlapping_moving_notifications(self, pool_class): Ensures that the second MOVING notification updates the pool and connections as expected, and that expiry/cleanup works. """ global AFTER_MOVING_ADDRESS - test_redis_client = self._get_client( - pool_class, max_connections=5, setup_pool_handler=True - ) + test_redis_client = self._get_client(pool_class, max_connections=5) try: # Create and release some connections in_use_connections = [] @@ -1762,9 +1725,7 @@ def test_thread_safety_concurrent_notification_handling(self, pool_class): """ import threading - test_redis_client = self._get_client( - pool_class, max_connections=5, setup_pool_handler=True - ) + test_redis_client = self._get_client(pool_class, max_connections=5) results = [] errors = [] @@ -1823,18 +1784,16 @@ def test_moving_migrating_migrated_moved_state_transitions( test_redis_client = self._get_client( pool_class, max_connections=5, - setup_pool_handler=True, enable_cache=enable_cache, ) pool = test_redis_client.connection_pool - pool_handler = pool.connection_kwargs["maint_notifications_pool_handler"] # Create and release some connections in_use_connections = [] for _ in range(3): in_use_connections.append(pool.get_connection()) - pool_handler.set_connection(in_use_connections[0]) + pool_handler = in_use_connections[0]._maint_notifications_pool_handler while len(in_use_connections) > 0: pool.release(in_use_connections.pop()) @@ -2043,10 +2002,8 @@ def test_migrating_after_moving_multiple_proxies(self, pool_class): protocol=3, # Required for maintenance notifications maint_notifications_config=self.config, ) - pool.set_maint_notifications_pool_handler( - MaintNotificationsPoolHandler(pool, self.config) - ) - pool_handler = pool.connection_kwargs["maint_notifications_pool_handler"] + + pool_handler = pool._maint_notifications_pool_handler # Create and release some connections key1 = "1.2.3.4" From e34064400f1902107d4f5a7e5fb0e9577c1e566f Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Mon, 6 Oct 2025 16:25:38 +0300 Subject: [PATCH 4/8] Fixing linter- unused import, and a docstring --- redis/client.py | 1 - redis/maint_notifications.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/redis/client.py b/redis/client.py index 5f42a74c73..51b699721e 100755 --- a/redis/client.py +++ b/redis/client.py @@ -58,7 +58,6 @@ from redis.lock import Lock from redis.maint_notifications import ( MaintNotificationsConfig, - MaintNotificationsPoolHandler, ) from redis.retry import Retry from redis.utils import ( diff --git a/redis/maint_notifications.py b/redis/maint_notifications.py index 1ada097a99..5b8b08c1be 100644 --- a/redis/maint_notifications.py +++ b/redis/maint_notifications.py @@ -570,7 +570,7 @@ def set_connection(self, connection: "MaintNotificationsAbstractConnection"): self.connection = connection def get_handler_for_connection(self): - # Deep all data that should be shared between connections + # Copy all data that should be shared between connections # but each connection should have its own pool handler # since each connection can be in a different state copy = MaintNotificationsPoolHandler(self.pool, self.config) From 21afd2c79635d2dbaf1ee564a67cb1ef4e512d0e Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Tue, 7 Oct 2025 11:46:34 +0300 Subject: [PATCH 5/8] Fixing check for correct parser when protocol is 3 --- redis/connection.py | 14 +++++++++- tests/test_maint_notifications_handling.py | 32 ++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/redis/connection.py b/redis/connection.py index db313c58e4..31ce30fd8d 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -386,6 +386,13 @@ def _configure_maintenance_notifications( "To configure maintenance notifications, a parser must be provided!" ) + if not isinstance(parser, _HiredisParser) and not isinstance( + parser, _RESP3Parser + ): + raise RedisError( + "Maintenance notifications are only supported with hiredis and RESP3 parsers!" + ) + if maint_notifications_pool_handler: # Extract a reference to a new pool handler that copies all properties # of the original one and has a different connection reference @@ -741,7 +748,12 @@ def __init__( raise ConnectionError("protocol must be either 2 or 3") # p = DEFAULT_RESP_VERSION self.protocol = p - if self.protocol == 3 and parser_class == DefaultParser: + if self.protocol == 3 and parser_class == _RESP2Parser: + # If the protocol is 3 but the parser is RESP2, change it to RESP3 + # This is needed because the parser might be set before the protocol + # or might be provided as a kwarg to the constructor + # We need to react on discrepancy only for RESP2 and RESP3 + # as hiredis supports both parser_class = _RESP3Parser self.set_parser(parser_class) diff --git a/tests/test_maint_notifications_handling.py b/tests/test_maint_notifications_handling.py index 15885462a8..556b63d7e1 100644 --- a/tests/test_maint_notifications_handling.py +++ b/tests/test_maint_notifications_handling.py @@ -304,6 +304,38 @@ def recv(self, bufsize): raise BlockingIOError(errno.EAGAIN, "Resource temporarily unavailable") + def recv_into(self, buffer, nbytes=0): + """ + Receive data from Redis and write it into the provided buffer. + Returns the number of bytes written. + + This method is used by the hiredis parser for efficient data reading. + """ + if self.closed: + raise ConnectionError("Socket is closed") + + # Use pending responses that were prepared when commands were sent + if self.pending_responses: + response = self.pending_responses.pop(0) + if b"MOVING" in response: + self.moving_sent = True + + # Determine how many bytes to write + if nbytes == 0: + nbytes = len(buffer) + + # Write data into the buffer (up to nbytes or response length) + bytes_to_write = min(len(response), nbytes, len(buffer)) + buffer[:bytes_to_write] = response[:bytes_to_write] + + return bytes_to_write + else: + # No data available - this should block or raise an exception + # For can_read checks, we should indicate no data is available + import errno + + raise BlockingIOError(errno.EAGAIN, "Resource temporarily unavailable") + def fileno(self): """Return a fake file descriptor for select/poll operations.""" return 1 # Fake file descriptor From 690f5d83ae95f32f769c7ec03d3ff227c97b8fbe Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Wed, 8 Oct 2025 14:53:25 +0300 Subject: [PATCH 6/8] Moving _should_reconnect related code back to generic ConnectionInterface as it is also used by AA --- redis/connection.py | 65 +++++++++++++++++++++------------------------ 1 file changed, 31 insertions(+), 34 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index 31ce30fd8d..bc13316109 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -241,6 +241,18 @@ def set_re_auth_token(self, token: TokenInterface): def re_auth(self): pass + @abstractmethod + def mark_for_reconnect(self): + pass + + @abstractmethod + def should_reconnect(self): + pass + + @abstractmethod + def reset_should_reconnect(self): + pass + class MaintNotificationsAbstractConnection: """ @@ -290,7 +302,6 @@ def __init__( orig_socket_connect_timeout, parser, ) - self._should_reconnect = False @abstractmethod def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]: @@ -590,15 +601,6 @@ def getpeername(self): return conn_socket.getpeername()[0] return None - def mark_for_reconnect(self): - self._should_reconnect = True - - def should_reconnect(self): - return self._should_reconnect - - def reset_should_reconnect(self): - self._should_reconnect = False - def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None): conn_socket = self._get_socket() if conn_socket: @@ -758,6 +760,7 @@ def __init__( self.set_parser(parser_class) self._command_packer = self._construct_command_packer(command_packer) + self._should_reconnect = False # Set up maintenance notifications MaintNotificationsAbstractConnection.__init__( @@ -1023,6 +1026,15 @@ def disconnect(self, *args): except OSError: pass + def mark_for_reconnect(self): + self._should_reconnect = True + + def should_reconnect(self): + return self._should_reconnect + + def reset_should_reconnect(self): + self._should_reconnect = False + def _send_ping(self): """Send PING, expect PONG in return""" self.send_command("PING", check_health=False) @@ -1507,6 +1519,15 @@ def set_re_auth_token(self, token: TokenInterface): def re_auth(self): self._conn.re_auth() + def mark_for_reconnect(self): + self._conn.mark_for_reconnect() + + def should_reconnect(self): + return self._conn.should_reconnect() + + def reset_should_reconnect(self): + self._conn.reset_should_reconnect() + @property def host(self) -> str: return self._conn.host @@ -1565,30 +1586,6 @@ def getpeername(self): "Maintenance notifications are not supported by this connection type" ) - def mark_for_reconnect(self): - if isinstance(self._conn, MaintNotificationsAbstractConnection): - self._conn.mark_for_reconnect() - else: - raise NotImplementedError( - "Maintenance notifications are not supported by this connection type" - ) - - def should_reconnect(self): - if isinstance(self._conn, MaintNotificationsAbstractConnection): - return self._conn.should_reconnect() - else: - raise NotImplementedError( - "Maintenance notifications are not supported by this connection type" - ) - - def reset_should_reconnect(self): - if isinstance(self._conn, MaintNotificationsAbstractConnection): - self._conn.reset_should_reconnect() - else: - raise NotImplementedError( - "Maintenance notifications are not supported by this connection type" - ) - def get_resolved_ip(self): if isinstance(self._conn, MaintNotificationsAbstractConnection): return self._conn.get_resolved_ip() From c75397e1f8985236fc4aaf78df7b28fa914eb07f Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Wed, 8 Oct 2025 16:29:47 +0300 Subject: [PATCH 7/8] Fixing e2e tests redis client setup --- tests/test_scenario/conftest.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index 409f3088ca..a7bdb61b07 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -233,9 +233,5 @@ def _get_client_maint_notifications( ) logging.info("Redis client created with maintenance notifications enabled") logging.info(f"Client uses Protocol: {client.connection_pool.get_protocol()}") - maintenance_handler_exists = client.maint_notifications_pool_handler is not None - logging.info( - f"Maintenance notifications pool handler: {maintenance_handler_exists}" - ) return client From 04ca1d69fee302675dd3e4601e566472c8d5e956 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Wed, 8 Oct 2025 18:18:54 +0300 Subject: [PATCH 8/8] Applying review comments --- redis/connection.py | 82 ++++++++++++++++++++++----------------------- 1 file changed, 41 insertions(+), 41 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index bc13316109..35e2bdf9ce 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -243,14 +243,24 @@ def re_auth(self): @abstractmethod def mark_for_reconnect(self): + """ + Mark the connection to be reconnected on the next command. + This is useful when a connection is moved to a different node. + """ pass @abstractmethod def should_reconnect(self): + """ + Returns True if the connection should be reconnected. + """ pass @abstractmethod def reset_should_reconnect(self): + """ + Reset the internal flag to False. + """ pass @@ -1560,71 +1570,61 @@ def _get_socket(self) -> Optional[socket.socket]: "Maintenance notifications are not supported by this connection type" ) - @property - def maintenance_state(self) -> MaintenanceState: - if isinstance(self._conn, MaintNotificationsAbstractConnection): - return self._conn.maintenance_state - else: + def _get_maint_notifications_connection_instance( + self, connection + ) -> MaintNotificationsAbstractConnection: + """ + Validate that connection instance supports maintenance notifications. + With this helper method we ensure that we are working + with the correct connection type. + After twe validate that connection instance supports maintenance notifications + we can safely return the connection instance + as MaintNotificationsAbstractConnection. + """ + if not isinstance(connection, MaintNotificationsAbstractConnection): raise NotImplementedError( "Maintenance notifications are not supported by this connection type" ) + else: + return connection + + @property + def maintenance_state(self) -> MaintenanceState: + con = self._get_maint_notifications_connection_instance(self._conn) + return con.maintenance_state @maintenance_state.setter def maintenance_state(self, state: MaintenanceState): - if isinstance(self._conn, MaintNotificationsAbstractConnection): - self._conn.maintenance_state = state - else: - raise NotImplementedError( - "Maintenance notifications are not supported by this connection type" - ) + con = self._get_maint_notifications_connection_instance(self._conn) + con.maintenance_state = state def getpeername(self): - if isinstance(self._conn, MaintNotificationsAbstractConnection): - return self._conn.getpeername() - else: - raise NotImplementedError( - "Maintenance notifications are not supported by this connection type" - ) + con = self._get_maint_notifications_connection_instance(self._conn) + return con.getpeername() def get_resolved_ip(self): - if isinstance(self._conn, MaintNotificationsAbstractConnection): - return self._conn.get_resolved_ip() - else: - raise NotImplementedError( - "Maintenance notifications are not supported by this connection type" - ) + con = self._get_maint_notifications_connection_instance(self._conn) + return con.get_resolved_ip() def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None): - if isinstance(self._conn, MaintNotificationsAbstractConnection): - self._conn.update_current_socket_timeout(relaxed_timeout) - else: - raise NotImplementedError( - "Maintenance notifications are not supported by this connection type" - ) + con = self._get_maint_notifications_connection_instance(self._conn) + con.update_current_socket_timeout(relaxed_timeout) def set_tmp_settings( self, tmp_host_address: Optional[str] = None, tmp_relaxed_timeout: Optional[float] = None, ): - if isinstance(self._conn, MaintNotificationsAbstractConnection): - self._conn.set_tmp_settings(tmp_host_address, tmp_relaxed_timeout) - else: - raise NotImplementedError( - "Maintenance notifications are not supported by this connection type" - ) + con = self._get_maint_notifications_connection_instance(self._conn) + con.set_tmp_settings(tmp_host_address, tmp_relaxed_timeout) def reset_tmp_settings( self, reset_host_address: bool = False, reset_relaxed_timeout: bool = False, ): - if isinstance(self._conn, MaintNotificationsAbstractConnection): - self._conn.reset_tmp_settings(reset_host_address, reset_relaxed_timeout) - else: - raise NotImplementedError( - "Maintenance notifications are not supported by this connection type" - ) + con = self._get_maint_notifications_connection_instance(self._conn) + con.reset_tmp_settings(reset_host_address, reset_relaxed_timeout) def _connect(self): self._conn._connect()