diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 3fe3ebc47e..df0c17d49c 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -17,7 +17,13 @@ ) from redis.asyncio.client import ResponseCallbackT -from redis.asyncio.connection import Connection, DefaultParser, Encoder, parse_url +from redis.asyncio.connection import ( + Connection, + DefaultParser, + Encoder, + SSLConnection, + parse_url, +) from redis.asyncio.parser import CommandsParser from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis from redis.cluster import ( @@ -42,6 +48,7 @@ ConnectionError, DataError, MasterDownError, + MaxConnectionsError, MovedError, RedisClusterException, ResponseError, @@ -56,44 +63,17 @@ "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] ) -CONNECTION_ALLOWED_KEYS = ( - "client_name", - "db", - "decode_responses", - "encoder_class", - "encoding", - "encoding_errors", - "health_check_interval", - "parser_class", - "password", - "redis_connect_func", - "retry", - "retry_on_timeout", - "socket_connect_timeout", - "socket_keepalive", - "socket_keepalive_options", - "socket_read_size", - "socket_timeout", - "socket_type", - "username", -) - - -def cleanup_kwargs(**kwargs: Any) -> Dict[str, Any]: - """Remove unsupported or disabled keys from kwargs.""" - return {k: v for k, v in kwargs.items() if k in CONNECTION_ALLOWED_KEYS} - class ClusterParser(DefaultParser): EXCEPTION_CLASSES = dict_merge( DefaultParser.EXCEPTION_CLASSES, { "ASK": AskError, - "TRYAGAIN": TryAgainError, - "MOVED": MovedError, "CLUSTERDOWN": ClusterDownError, "CROSSSLOT": ClusterCrossSlotError, "MASTERDOWN": MasterDownError, + "MOVED": MovedError, + "TRYAGAIN": TryAgainError, }, ) @@ -104,7 +84,6 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand Pass one of parameters: - - `url` - `host` & `port` - `startup_nodes` @@ -128,9 +107,6 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand | Port used if **host** is provided :param startup_nodes: | :class:`~.ClusterNode` to used as a startup node - :param cluster_error_retry_attempts: - | Retry command execution attempts when encountering :class:`~.ClusterDownError` - or :class:`~.ConnectionError` :param require_full_coverage: | When set to ``False``: the client will not require a full coverage of the slots. However, if not all slots are covered, and at least one node has @@ -141,6 +117,10 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand thrown. | See: https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters + :param read_from_replicas: + | Enable read from replicas in READONLY mode. You can read possibly stale data. + When set to true, read commands will be assigned between the primary and + its replications in a Round-Robin manner. :param reinitialize_steps: | Specifies the number of MOVED errors that need to occur before reinitializing the whole cluster topology. If a MOVED error occurs and the cluster does not @@ -149,23 +129,27 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand To reinitialize the cluster on every MOVED error, set reinitialize_steps to 1. To avoid reinitializing the cluster on moved errors, set reinitialize_steps to 0. - :param read_from_replicas: - | Enable read from replicas in READONLY mode. You can read possibly stale data. - When set to true, read commands will be assigned between the primary and - its replications in a Round-Robin manner. - :param url: - | See :meth:`.from_url` - :param kwargs: - | Extra arguments that will be passed to the - :class:`~redis.asyncio.connection.Connection` instances when created + :param cluster_error_retry_attempts: + | Number of times to retry before raising an error when :class:`~.TimeoutError` + or :class:`~.ConnectionError` or :class:`~.ClusterDownError` are encountered + :param connection_error_retry_attempts: + | Number of times to retry before reinitializing when :class:`~.TimeoutError` + or :class:`~.ConnectionError` are encountered + :param max_connections: + | Maximum number of connections per node. If there are no free connections & the + maximum number of connections are already created, a + :class:`~.MaxConnectionsError` is raised. This error may be retried as defined + by :attr:`connection_error_retry_attempts` + + | Rest of the arguments will be passed to the + :class:`~redis.asyncio.connection.Connection` instances when created :raises RedisClusterException: - if any arguments are invalid. Eg: + if any arguments are invalid or unknown. Eg: - - db kwarg - - db != 0 in url - - unix socket connection - - none of host & url & startup_nodes were provided + - `db` != 0 or None + - `path` argument for unix socket connection + - none of the `host`/`port` & `startup_nodes` were provided """ @@ -178,7 +162,6 @@ def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster": redis://[[username]:[password]]@localhost:6379/0 rediss://[[username]:[password]]@localhost:6379/0 - unix://[[username]:[password]]@/path/to/socket.sock?db=0 Three URL schemes are supported: @@ -186,32 +169,22 @@ def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster": - `rediss://` creates a SSL wrapped TCP socket connection. See more at: - - ``unix://``: creates a Unix Domain Socket connection. - - The username, password, hostname, path and all querystring values - are passed through urllib.parse.unquote in order to replace any - percent-encoded values with their corresponding characters. - There are several ways to specify a database number. The first value - found will be used: - - 1. A ``db`` querystring option, e.g. redis://localhost?db=0 - 2. If using the redis:// or rediss:// schemes, the path argument - of the url, e.g. redis://localhost/0 - 3. A ``db`` keyword argument to this function. - - If none of these options are specified, the default db=0 is used. - - All querystring options are cast to their appropriate Python types. - Boolean arguments can be specified with string values "True"/"False" - or "Yes"/"No". Values that cannot be properly cast cause a - ``ValueError`` to be raised. Once parsed, the querystring arguments and - keyword arguments are passed to :class:`~redis.asyncio.connection.Connection` - when created. In the case of conflicting arguments, querystring - arguments always win. + The username, password, hostname, path and all querystring values are passed + through ``urllib.parse.unquote`` in order to replace any percent-encoded values + with their corresponding characters. + All querystring options are cast to their appropriate Python types. Boolean + arguments can be specified with string values "True"/"False" or "Yes"/"No". + Values that cannot be properly cast cause a ``ValueError`` to be raised. Once + parsed, the querystring arguments and keyword arguments are passed to + :class:`~redis.asyncio.connection.Connection` when created. + In the case of conflicting arguments, querystring arguments are used. """ - return cls(url=url, **kwargs) + kwargs.update(parse_url(url)) + if kwargs.pop("connection_class", None) is SSLConnection: + kwargs["ssl"] = True + return cls(**kwargs) __slots__ = ( "_initialize", @@ -219,6 +192,7 @@ def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster": "cluster_error_retry_attempts", "command_flags", "commands_parser", + "connection_error_retry_attempts", "connection_kwargs", "encoder", "node_flags", @@ -233,87 +207,131 @@ def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster": def __init__( self, host: Optional[str] = None, - port: int = 6379, + port: Union[str, int] = 6379, + # Cluster related kwargs startup_nodes: Optional[List["ClusterNode"]] = None, - require_full_coverage: bool = False, + require_full_coverage: bool = True, read_from_replicas: bool = False, - cluster_error_retry_attempts: int = 3, reinitialize_steps: int = 10, - url: Optional[str] = None, - **kwargs: Any, + cluster_error_retry_attempts: int = 3, + connection_error_retry_attempts: int = 5, + max_connections: int = 2**31, + # Client related kwargs + db: Union[str, int] = 0, + path: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + client_name: Optional[str] = None, + # Encoding related kwargs + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + # Connection related kwargs + health_check_interval: float = 0, + socket_connect_timeout: Optional[float] = None, + socket_keepalive: bool = False, + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, + socket_timeout: Optional[float] = None, + # SSL related kwargs + ssl: bool = False, + ssl_ca_certs: Optional[str] = None, + ssl_ca_data: Optional[str] = None, + ssl_cert_reqs: str = "required", + ssl_certfile: Optional[str] = None, + ssl_check_hostname: bool = False, + ssl_keyfile: Optional[str] = None, ) -> None: - if not startup_nodes: - startup_nodes = [] + if db: + raise RedisClusterException( + "Argument 'db' must be 0 or None in cluster mode" + ) - if "db" in kwargs: - # Argument 'db' is not possible to use in cluster mode + if path: raise RedisClusterException( - "Argument 'db' is not possible to use in cluster mode" + "Unix domain socket is not supported in cluster mode" ) - # Get the startup node(s) - if url: - url_options = parse_url(url) - if "path" in url_options: - raise RedisClusterException( - "RedisCluster does not currently support Unix Domain " - "Socket connections" - ) - if "db" in url_options and url_options["db"] != 0: - # Argument 'db' is not possible to use in cluster mode - raise RedisClusterException( - "A ``db`` querystring option can only be 0 in cluster mode" - ) - kwargs.update(url_options) - host = kwargs.get("host") - port = kwargs.get("port", port) - elif (not host or not port) and not startup_nodes: - # No startup node was provided + if (not host or not port) and not startup_nodes: raise RedisClusterException( - "RedisCluster requires at least one node to discover the " - "cluster. Please provide one of the followings:\n" - "1. host and port, for example:\n" - " RedisCluster(host='localhost', port=6379)\n" - "2. list of startup nodes, for example:\n" - " RedisCluster(startup_nodes=[ClusterNode('localhost', 6379)," - " ClusterNode('localhost', 6378)])" + "RedisCluster requires at least one node to discover the cluster.\n" + "Please provide one of the following or use RedisCluster.from_url:\n" + ' - host and port: RedisCluster(host="localhost", port=6379)\n' + " - startup_nodes: RedisCluster(startup_nodes=[" + 'ClusterNode("localhost", 6379), ClusterNode("localhost", 6380)])' + ) + + kwargs: Dict[str, Any] = { + "max_connections": max_connections, + "connection_class": Connection, + "parser_class": ClusterParser, + # Client related kwargs + "username": username, + "password": password, + "client_name": client_name, + # Encoding related kwargs + "encoding": encoding, + "encoding_errors": encoding_errors, + "decode_responses": decode_responses, + # Connection related kwargs + "health_check_interval": health_check_interval, + "socket_connect_timeout": socket_connect_timeout, + "socket_keepalive": socket_keepalive, + "socket_keepalive_options": socket_keepalive_options, + "socket_timeout": socket_timeout, + } + + if ssl: + # SSL related kwargs + kwargs.update( + { + "connection_class": SSLConnection, + "ssl_ca_certs": ssl_ca_certs, + "ssl_ca_data": ssl_ca_data, + "ssl_cert_reqs": ssl_cert_reqs, + "ssl_certfile": ssl_certfile, + "ssl_check_hostname": ssl_check_hostname, + "ssl_keyfile": ssl_keyfile, + } ) - # Update the connection arguments - # Whenever a new connection is established, RedisCluster's on_connect - # method should be run - kwargs["redis_connect_func"] = self.on_connect - self.connection_kwargs = kwargs = cleanup_kwargs(**kwargs) - self.response_callbacks = kwargs[ - "response_callbacks" - ] = self.__class__.RESPONSE_CALLBACKS.copy() + if read_from_replicas: + # Call our on_connect function to configure READONLY mode + kwargs["redis_connect_func"] = self.on_connect + + kwargs["response_callbacks"] = self.__class__.RESPONSE_CALLBACKS.copy() + self.connection_kwargs = kwargs + + if startup_nodes: + passed_nodes = [] + for node in startup_nodes: + passed_nodes.append( + ClusterNode(node.host, node.port, **self.connection_kwargs) + ) + startup_nodes = passed_nodes + else: + startup_nodes = [] if host and port: startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs)) - self.nodes_manager = NodesManager( - startup_nodes=startup_nodes, - require_full_coverage=require_full_coverage, - **self.connection_kwargs, - ) - self.encoder = Encoder( - kwargs.get("encoding", "utf-8"), - kwargs.get("encoding_errors", "strict"), - kwargs.get("decode_responses", False), - ) - self.cluster_error_retry_attempts = cluster_error_retry_attempts + self.nodes_manager = NodesManager(startup_nodes, require_full_coverage, kwargs) + self.encoder = Encoder(encoding, encoding_errors, decode_responses) self.read_from_replicas = read_from_replicas self.reinitialize_steps = reinitialize_steps + self.cluster_error_retry_attempts = cluster_error_retry_attempts + self.connection_error_retry_attempts = connection_error_retry_attempts self.reinitialize_counter = 0 self.commands_parser = CommandsParser() self.node_flags = self.__class__.NODE_FLAGS.copy() self.command_flags = self.__class__.COMMAND_FLAGS.copy() + self.response_callbacks = kwargs["response_callbacks"] self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy() self.result_callbacks[ "CLUSTER SLOTS" ] = lambda cmd, res, **kwargs: parse_cluster_slots( list(res.values())[0], **kwargs ) + self._initialize = True self._lock = asyncio.Lock() @@ -365,18 +383,16 @@ def __del__(self) -> None: ... async def on_connect(self, connection: Connection) -> None: - connection.set_parser(ClusterParser) await connection.on_connect() - if self.read_from_replicas: - # Sending READONLY command to server to configure connection as - # readonly. Since each cluster node may change its server type due - # to a failover, we should establish a READONLY connection - # regardless of the server type. If this is a primary connection, - # READONLY would not affect executing write commands. - await connection.send_command("READONLY") - if str_if_bytes(await connection.read_response_without_lock()) != "OK": - raise ConnectionError("READONLY command failed") + # Sending READONLY command to server to configure connection as + # readonly. Since each cluster node may change its server type due + # to a failover, we should establish a READONLY connection + # regardless of the server type. If this is a primary connection, + # READONLY would not affect executing write commands. + await connection.send_command("READONLY") + if str_if_bytes(await connection.read_response_without_lock()) != "OK": + raise ConnectionError("READONLY command failed") def get_nodes(self) -> List["ClusterNode"]: """Get all nodes of the cluster.""" @@ -436,12 +452,12 @@ def get_node_from_key( slot_cache = self.nodes_manager.slots_cache.get(slot) if not slot_cache: raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.') - if replica and len(self.nodes_manager.slots_cache[slot]) < 2: - return None - elif replica: + + if replica: + if len(self.nodes_manager.slots_cache[slot]) < 2: + return None node_idx = 1 else: - # primary node_idx = 0 return slot_cache[node_idx] @@ -638,14 +654,14 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: command, dict(zip(keys, values)), **kwargs ) return dict(zip(keys, values)) - except BaseException as e: + except Exception as e: if type(e) in self.__class__.ERRORS_ALLOW_RETRY: # The nodes and slots cache were reinitialized. # Try again with the new cluster setup. exception = e else: # All other errors should be raised. - raise e + raise # If it fails the configured number of times then raise exception back # to caller of this method @@ -678,19 +694,30 @@ async def _execute_command( return await target_node.execute_command(*args, **kwargs) except BusyLoadingError: raise - except (ConnectionError, TimeoutError): - # Give the node 0.25 seconds to get back up and retry again - # with same node and configuration. After 5 attempts then try - # to reinitialize the cluster and see if the nodes - # configuration has changed or not + except (ConnectionError, TimeoutError) as e: + # Give the node 0.25 seconds to get back up and retry again with the + # same node and configuration. After the defined number of attempts, try + # to reinitialize the cluster and try again. connection_error_retry_counter += 1 - if connection_error_retry_counter < 5: + if ( + connection_error_retry_counter + < self.connection_error_retry_attempts + ): await asyncio.sleep(0.25) else: + if isinstance(e, MaxConnectionsError): + raise # Hard force of reinitialize of the node/slots setup # and try again with the new setup await self.close() raise + except ClusterDownError: + # ClusterDownError can occur during a failover and to get + # self-healed, we will try to reinitialize the cluster layout + # and retry executing the command + await self.close() + await asyncio.sleep(0.25) + raise except MovedError as e: # First, we will try to patch the slots/nodes cache with the # redirected node output and try again. If MovedError exceeds @@ -711,19 +738,12 @@ async def _execute_command( else: self.nodes_manager._moved_exception = e moved = True - except TryAgainError: - if ttl < self.RedisClusterRequestTTL / 2: - await asyncio.sleep(0.05) except AskError as e: redirect_addr = get_node_name(host=e.host, port=e.port) asking = True - except ClusterDownError: - # ClusterDownError can occur during a failover and to get - # self-healed, we will try to reinitialize the cluster layout - # and retry executing the command - await asyncio.sleep(0.25) - await self.close() - raise + except TryAgainError: + if ttl < self.RedisClusterRequestTTL / 2: + await asyncio.sleep(0.05) raise ClusterError("TTL exhausted.") @@ -770,8 +790,9 @@ class ClusterNode: def __init__( self, host: str, - port: int, + port: Union[str, int], server_type: Optional[str] = None, + *, max_connections: int = 2**31, connection_class: Type[Connection] = Connection, **connection_kwargs: Any, @@ -789,9 +810,7 @@ def __init__( self.max_connections = max_connections self.connection_class = connection_class self.connection_kwargs = connection_kwargs - self.response_callbacks = connection_kwargs.pop( - "response_callbacks", RedisCluster.RESPONSE_CALLBACKS - ) + self.response_callbacks = connection_kwargs.pop("response_callbacks", {}) self._connections: List[Connection] = [] self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections) @@ -834,21 +853,15 @@ async def disconnect(self) -> None: raise exc def acquire_connection(self) -> Connection: - if self._free: - for _ in range(len(self._free)): - connection = self._free.popleft() - if connection.is_connected: - return connection - self._free.append(connection) - + try: return self._free.popleft() + except IndexError: + if len(self._connections) < self.max_connections: + connection = self.connection_class(**self.connection_kwargs) + self._connections.append(connection) + return connection - if len(self._connections) < self.max_connections: - connection = self.connection_class(**self.connection_kwargs) - self._connections.append(connection) - return connection - - raise ConnectionError("Too many connections") + raise MaxConnectionsError() async def parse_response( self, connection: Connection, command: str, **kwargs: Any @@ -926,12 +939,12 @@ class NodesManager: def __init__( self, startup_nodes: List["ClusterNode"], - require_full_coverage: bool = False, - **kwargs: Any, + require_full_coverage: bool, + connection_kwargs: Dict[str, Any], ) -> None: self.startup_nodes = {node.name: node for node in startup_nodes} self.require_full_coverage = require_full_coverage - self.connection_kwargs = kwargs + self.connection_kwargs = connection_kwargs self.default_node: "ClusterNode" = None self.nodes_cache: Dict[str, "ClusterNode"] = {} @@ -1050,6 +1063,7 @@ async def initialize(self) -> None: disagreements = [] startup_nodes_reachable = False fully_covered = False + exception = None for startup_node in self.startup_nodes.values(): try: # Make sure cluster mode is enabled on this node @@ -1061,7 +1075,8 @@ async def initialize(self) -> None: ) cluster_slots = await startup_node.execute_command("CLUSTER SLOTS") startup_nodes_reachable = True - except (ConnectionError, TimeoutError): + except (ConnectionError, TimeoutError) as e: + exception = e continue except ResponseError as e: # Isn't a cluster connection, so it won't parse these @@ -1162,7 +1177,7 @@ async def initialize(self) -> None: raise RedisClusterException( "Redis Cluster cannot be connected. Please provide at least " "one reachable node. " - ) + ) from exception # Check if the slots are not fully covered if not fully_covered and self.require_full_coverage: @@ -1327,7 +1342,7 @@ async def execute( await asyncio.sleep(0.25) else: # All other errors should be raised. - raise e + raise # If it fails the configured number of times then raise an exception raise exception diff --git a/redis/cluster.py b/redis/cluster.py index b2d4f3b044..b05cf307db 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -5,7 +5,7 @@ import threading import time from collections import OrderedDict -from typing import Any, Callable, Dict, Tuple +from typing import Any, Callable, Dict, Tuple, Union from redis.client import CaseInsensitiveDict, PubSub, Redis, parse_scan from redis.commands import READ_COMMANDS, CommandsParser, RedisClusterCommands @@ -38,7 +38,7 @@ ) -def get_node_name(host: str, port: int) -> str: +def get_node_name(host: str, port: Union[str, int]) -> str: return f"{host}:{port}" diff --git a/redis/exceptions.py b/redis/exceptions.py index d18b354454..8a8bf423eb 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -199,3 +199,7 @@ class SlotNotCoveredError(RedisClusterException): """ pass + + +class MaxConnectionsError(ConnectionError): + ... diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 8766cbf09b..1365e4daff 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -1,9 +1,11 @@ import asyncio import binascii import datetime +import os import sys import warnings -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Awaitable, Callable, Dict, List, Optional, Type, Union +from urllib.parse import urlparse import pytest @@ -16,10 +18,10 @@ else: import pytest_asyncio -from _pytest.fixtures import FixtureRequest, SubRequest +from _pytest.fixtures import FixtureRequest -from redis.asyncio import Connection, RedisCluster -from redis.asyncio.cluster import ClusterNode, NodesManager +from redis.asyncio.cluster import ClusterNode, NodesManager, RedisCluster +from redis.asyncio.connection import Connection, SSLConnection from redis.asyncio.parser import CommandsParser from redis.cluster import PIPELINE_BLOCKED_COMMANDS, PRIMARY, REPLICA, get_node_name from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot @@ -28,6 +30,7 @@ ClusterDownError, ConnectionError, DataError, + MaxConnectionsError, MovedError, NoPermissionError, RedisClusterException, @@ -41,6 +44,9 @@ skip_unless_arch_bits, ) +pytestmark = pytest.mark.onlycluster + + default_host = "127.0.0.1" default_port = 7000 default_cluster_slots = [ @@ -50,7 +56,7 @@ @pytest_asyncio.fixture() -async def slowlog(request: SubRequest, r: RedisCluster) -> None: +async def slowlog(r: RedisCluster) -> None: """ Set the slowlog threshold to 0, and the max length to 128. This will force every @@ -146,7 +152,7 @@ def mock_all_nodes_resp(rc: RedisCluster, response: Any) -> RedisCluster: async def moved_redirection_helper( - request: FixtureRequest, create_redis: Callable, failover: bool = False + create_redis: Callable[..., RedisCluster], failover: bool = False ) -> None: """ Test that the client handles MOVED response after a failover. @@ -202,7 +208,6 @@ def ok_response(self, *args, **options): assert prev_primary.server_type == REPLICA -@pytest.mark.onlycluster class TestRedisClusterObj: """ Tests for the RedisCluster class @@ -237,10 +242,18 @@ async def test_startup_nodes(self) -> None: await cluster.close() - startup_nodes = [ClusterNode("127.0.0.1", 16379)] - async with RedisCluster(startup_nodes=startup_nodes) as rc: + startup_node = ClusterNode("127.0.0.1", 16379) + async with RedisCluster(startup_nodes=[startup_node], client_name="test") as rc: assert await rc.set("A", 1) assert await rc.get("A") == b"1" + assert all( + [ + name == "test" + for name in ( + await rc.client_getname(target_nodes=rc.ALL_NODES) + ).values() + ] + ) async def test_empty_startup_nodes(self) -> None: """ @@ -253,18 +266,43 @@ async def test_empty_startup_nodes(self) -> None: "RedisCluster requires at least one node to discover the " "cluster" ), str_if_bytes(ex.value) - async def test_from_url(self, r: RedisCluster) -> None: - redis_url = f"redis://{default_host}:{default_port}/0" - with mock.patch.object(RedisCluster, "from_url") as from_url: + async def test_from_url(self, request: FixtureRequest) -> None: + url = request.config.getoption("--redis-url") - async def from_url_mocked(_url, **_kwargs): - return await get_mocked_redis_client(url=_url, **_kwargs) + async with RedisCluster.from_url(url) as rc: + await rc.set("a", 1) + await rc.get("a") == 1 - from_url.side_effect = from_url_mocked - cluster = await RedisCluster.from_url(redis_url) - assert cluster.get_node(host=default_host, port=default_port) is not None + rc = RedisCluster.from_url("rediss://localhost:16379") + assert rc.connection_kwargs["connection_class"] is SSLConnection - await cluster.close() + async def test_max_connections( + self, create_redis: Callable[..., RedisCluster] + ) -> None: + rc = await create_redis(cls=RedisCluster, max_connections=10) + for node in rc.get_nodes(): + assert node.max_connections == 10 + + with mock.patch.object( + Connection, "read_response_without_lock" + ) as read_response_without_lock: + + async def read_response_without_lock_mocked( + *args: Any, **kwargs: Any + ) -> None: + await asyncio.sleep(10) + + read_response_without_lock.side_effect = read_response_without_lock_mocked + + with pytest.raises(MaxConnectionsError): + await asyncio.gather( + *( + rc.ping(target_nodes=RedisCluster.DEFAULT_NODE) + for _ in range(11) + ) + ) + + await rc.close() async def test_execute_command_errors(self, r: RedisCluster) -> None: """ @@ -373,23 +411,23 @@ def ok_response(self, *args, **options): assert await r.execute_command("SET", "foo", "bar") == "MOCK_OK" async def test_moved_redirection( - self, request: FixtureRequest, create_redis: Callable + self, create_redis: Callable[..., RedisCluster] ) -> None: """ Test that the client handles MOVED response. """ - await moved_redirection_helper(request, create_redis, failover=False) + await moved_redirection_helper(create_redis, failover=False) async def test_moved_redirection_after_failover( - self, request: FixtureRequest, create_redis: Callable + self, create_redis: Callable[..., RedisCluster] ) -> None: """ Test that the client handles MOVED response after a failover. """ - await moved_redirection_helper(request, create_redis, failover=True) + await moved_redirection_helper(create_redis, failover=True) async def test_refresh_using_specific_nodes( - self, request: FixtureRequest, create_redis: Callable + self, create_redis: Callable[..., RedisCluster] ) -> None: """ Test making calls on specific nodes when the cluster has failed over to @@ -691,7 +729,6 @@ async def test_can_run_concurrent_commands(self, request: FixtureRequest) -> Non await rc.close() -@pytest.mark.onlycluster class TestClusterRedisCommands: """ Tests for RedisCluster unique commands @@ -1918,7 +1955,7 @@ async def test_cluster_randomkey(self, r: RedisCluster) -> None: @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise() async def test_acl_log( - self, r: RedisCluster, request: FixtureRequest, create_redis: Callable + self, r: RedisCluster, create_redis: Callable[..., RedisCluster] ) -> None: key = "{cache}:" node = r.get_node_from_key(key) @@ -1963,7 +2000,6 @@ async def test_acl_log( await user_client.close() -@pytest.mark.onlycluster class TestNodesManager: """ Tests for the NodesManager class @@ -2095,7 +2131,7 @@ async def test_empty_startup_nodes(self) -> None: specified """ with pytest.raises(RedisClusterException): - await NodesManager([]).initialize() + await NodesManager([], False, {}).initialize() async def test_wrong_startup_nodes_type(self) -> None: """ @@ -2103,11 +2139,9 @@ async def test_wrong_startup_nodes_type(self) -> None: fail """ with pytest.raises(RedisClusterException): - await NodesManager({}).initialize() + await NodesManager({}, False, {}).initialize() - async def test_init_slots_cache_slots_collision( - self, request: FixtureRequest - ) -> None: + async def test_init_slots_cache_slots_collision(self) -> None: """ Test that if 2 nodes do not agree on the same slots setup it should raise an error. In this test both nodes will say that the first @@ -2236,7 +2270,6 @@ def cmd_init_mock(self, r: ClusterNode) -> None: assert rc.get_node(host=default_host, port=7002) is not None -@pytest.mark.onlycluster class TestClusterPipeline: """Tests for the ClusterPipeline class.""" @@ -2484,3 +2517,116 @@ async def test_can_run_concurrent_pipelines(self, r: RedisCluster) -> None: *(self.test_multi_key_operation_with_a_single_slot(r) for i in range(100)), *(self.test_multi_key_operation_with_multi_slots(r) for i in range(100)), ) + + +@pytest.mark.ssl +class TestSSL: + """ + Tests for SSL connections. + + This relies on the --redis-ssl-url for building the client and connecting to the + appropriate port. + """ + + ROOT = os.path.join(os.path.dirname(__file__), "../..") + CERT_DIR = os.path.abspath(os.path.join(ROOT, "docker", "stunnel", "keys")) + if not os.path.isdir(CERT_DIR): # github actions package validation case + CERT_DIR = os.path.abspath( + os.path.join(ROOT, "..", "docker", "stunnel", "keys") + ) + if not os.path.isdir(CERT_DIR): + raise IOError(f"No SSL certificates found. They should be in {CERT_DIR}") + + SERVER_CERT = os.path.join(CERT_DIR, "server-cert.pem") + SERVER_KEY = os.path.join(CERT_DIR, "server-key.pem") + + @pytest_asyncio.fixture() + def create_client(self, request: FixtureRequest) -> Callable[..., RedisCluster]: + ssl_url = request.config.option.redis_ssl_url + ssl_host, ssl_port = urlparse(ssl_url)[1].split(":") + + async def _create_client(mocked: bool = True, **kwargs: Any) -> RedisCluster: + if mocked: + with mock.patch.object( + ClusterNode, "execute_command", autospec=True + ) as execute_command_mock: + + async def execute_command(self, *args, **kwargs): + if args[0] == "INFO": + return {"cluster_enabled": True} + if args[0] == "CLUSTER SLOTS": + return [[0, 16383, [ssl_host, ssl_port, "ssl_node"]]] + if args[0] == "COMMAND": + return { + "ping": { + "name": "ping", + "arity": -1, + "flags": ["stale", "fast"], + "first_key_pos": 0, + "last_key_pos": 0, + "step_count": 0, + } + } + raise NotImplementedError() + + execute_command_mock.side_effect = execute_command + + rc = await RedisCluster(host=ssl_host, port=ssl_port, **kwargs) + + assert len(rc.get_nodes()) == 1 + node = rc.get_default_node() + assert node.port == int(ssl_port) + return rc + + return await RedisCluster(host=ssl_host, port=ssl_port, **kwargs) + + return _create_client + + async def test_ssl_connection_without_ssl( + self, create_client: Callable[..., Awaitable[RedisCluster]] + ) -> None: + with pytest.raises(RedisClusterException) as e: + await create_client(mocked=False, ssl=False) + e = e.value.__cause__ + assert "Connection closed by server" in str(e) + + async def test_ssl_with_invalid_cert( + self, create_client: Callable[..., Awaitable[RedisCluster]] + ) -> None: + with pytest.raises(RedisClusterException) as e: + await create_client(mocked=False, ssl=True) + e = e.value.__cause__.__context__ + assert "SSL: CERTIFICATE_VERIFY_FAILED" in str(e) + + async def test_ssl_connection( + self, create_client: Callable[..., Awaitable[RedisCluster]] + ) -> None: + async with await create_client(ssl=True, ssl_cert_reqs="none") as rc: + assert await rc.ping() + + async def test_validating_self_signed_certificate( + self, create_client: Callable[..., Awaitable[RedisCluster]] + ) -> None: + async with await create_client( + ssl=True, + ssl_ca_certs=self.SERVER_CERT, + ssl_cert_reqs="required", + ssl_certfile=self.SERVER_CERT, + ssl_keyfile=self.SERVER_KEY, + ) as rc: + assert await rc.ping() + + async def test_validating_self_signed_string_certificate( + self, create_client: Callable[..., Awaitable[RedisCluster]] + ) -> None: + with open(self.SERVER_CERT) as f: + cert_data = f.read() + + async with await create_client( + ssl=True, + ssl_ca_data=cert_data, + ssl_cert_reqs="required", + ssl_certfile=self.SERVER_CERT, + ssl_keyfile=self.SERVER_KEY, + ) as rc: + assert await rc.ping()