diff --git a/CHANGES b/CHANGES index f0d75a45ce..a5e4f526ef 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Adds capability for cluster mode to await free connection instead of raising. * Move doctests (doc code examples) to main branch * Update `ResponseT` type hint * Allow to control the minimum SSL version diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 40b2948a7f..7af155436f 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -1,5 +1,4 @@ import asyncio -import collections import random import socket import ssl @@ -7,7 +6,6 @@ from typing import ( Any, Callable, - Deque, Dict, Generator, List, @@ -65,9 +63,9 @@ RedisClusterException, ResponseError, SlotNotCoveredError, - TimeoutError, - TryAgainError, ) +from redis.exceptions import TimeoutError as RedisTimeoutError +from redis.exceptions import TryAgainError from redis.typing import AnyKeyT, EncodableT, KeyT from redis.utils import ( deprecated_function, @@ -264,6 +262,7 @@ def __init__( socket_timeout: Optional[float] = None, retry: Optional["Retry"] = None, retry_on_error: Optional[List[Type[Exception]]] = None, + wait_for_connections: bool = False, # SSL related kwargs ssl: bool = False, ssl_ca_certs: Optional[str] = None, @@ -326,6 +325,7 @@ def __init__( "socket_timeout": socket_timeout, "retry": retry, "protocol": protocol, + "wait_for_connections": wait_for_connections, # Client cache related kwargs "cache_enabled": cache_enabled, "client_cache": client_cache, @@ -364,7 +364,7 @@ def __init__( ) if not retry_on_error: # Default errors for retrying - retry_on_error = [ConnectionError, TimeoutError] + retry_on_error = [ConnectionError, RedisTimeoutError] self.retry.update_supported_errors(retry_on_error) kwargs.update({"retry": self.retry}) @@ -800,7 +800,7 @@ async def _execute_command( return await target_node.execute_command(*args, **kwargs) except (BusyLoadingError, MaxConnectionsError): raise - except (ConnectionError, TimeoutError): + except (ConnectionError, RedisTimeoutError): # Connection retries are being handled in the node's # Retry object. # Remove the failed node from the startup nodes before we try @@ -962,6 +962,7 @@ class ClusterNode: __slots__ = ( "_connections", "_free", + "acquire_connection_timeout", "connection_class", "connection_kwargs", "host", @@ -970,6 +971,7 @@ class ClusterNode: "port", "response_callbacks", "server_type", + "wait_for_connections", ) def __init__( @@ -980,6 +982,7 @@ def __init__( *, max_connections: int = 2**31, connection_class: Type[Connection] = Connection, + wait_for_connections: bool = False, **connection_kwargs: Any, ) -> None: if host == "localhost": @@ -996,9 +999,11 @@ def __init__( self.connection_class = connection_class self.connection_kwargs = connection_kwargs self.response_callbacks = connection_kwargs.pop("response_callbacks", {}) + self.acquire_connection_timeout = connection_kwargs.get("socket_timeout", 30) self._connections: List[Connection] = [] - self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections) + self._free: asyncio.Queue[Connection] = asyncio.Queue() + self.wait_for_connections = wait_for_connections def __repr__(self) -> str: return ( @@ -1039,14 +1044,24 @@ async def disconnect(self) -> None: if exc: raise exc - def acquire_connection(self) -> Connection: + async def acquire_connection(self) -> Connection: try: - return self._free.popleft() - except IndexError: + return self._free.get_nowait() + except asyncio.QueueEmpty: if len(self._connections) < self.max_connections: connection = self.connection_class(**self.connection_kwargs) self._connections.append(connection) return connection + elif self.wait_for_connections: + try: + connection = await asyncio.wait_for( + self._free.get(), self.acquire_connection_timeout + ) + return connection + except TimeoutError: + raise RedisTimeoutError( + "Timeout reached waiting for a free connection" + ) raise MaxConnectionsError() @@ -1075,12 +1090,12 @@ async def parse_response( async def execute_command(self, *args: Any, **kwargs: Any) -> Any: # Acquire connection - connection = self.acquire_connection() + connection = await self.acquire_connection() keys = kwargs.pop("keys", None) response_from_cache = await connection._get_from_local_cache(args) if response_from_cache is not None: - self._free.append(connection) + await self._free.put(connection) return response_from_cache else: # Execute command @@ -1094,11 +1109,11 @@ async def execute_command(self, *args: Any, **kwargs: Any) -> Any: return response finally: # Release connection - self._free.append(connection) + await self._free.put(connection) async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: # Acquire connection - connection = self.acquire_connection() + connection = await self.acquire_connection() # Execute command await connection.send_packed_command( @@ -1117,7 +1132,7 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: ret = True # Release connection - self._free.append(connection) + await self._free.put(connection) return ret diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index c16272bb5b..3889e60542 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -464,6 +464,28 @@ async def read_response_mocked(*args: Any, **kwargs: Any) -> None: await rc.aclose() + async def test_max_connections_waited( + self, create_redis: Callable[..., RedisCluster] + ) -> None: + rc = await create_redis( + cls=RedisCluster, max_connections=10, wait_for_connections=True + ) + for node in rc.get_nodes(): + assert node.max_connections == 10 + + with mock.patch.object(Connection, "read_response") as read_response: + + async def read_response_mocked(*args: Any, **kwargs: Any) -> None: + await asyncio.sleep(1) + + read_response.side_effect = read_response_mocked + + await asyncio.gather( + *(rc.ping(target_nodes=RedisCluster.DEFAULT_NODE) for _ in range(20)) + ) + assert len(rc.get_default_node()._connections) == 10 + await rc.aclose() + async def test_execute_command_errors(self, r: RedisCluster) -> None: """ Test that if no key is provided then exception should be raised.