diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index c1cc1d310c..537ab2101b 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1187,6 +1187,10 @@ async def disconnect(self, inuse_connections: bool = True): if exc: raise exc + async def aclose(self) -> None: + """Close the pool, disconnecting all connections""" + await self.disconnect() + def set_retry(self, retry: "Retry") -> None: for conn in self._available_connections: conn.retry = retry diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index bd44504cc4..8edc1cb016 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -8,7 +8,7 @@ from redis.asyncio.connection import Connection, to_bool from tests.conftest import skip_if_redis_enterprise, skip_if_server_version_lt -from .compat import mock +from .compat import aclosing, mock from .conftest import asynccontextmanager from .test_pubsub import wait_for_message @@ -134,6 +134,16 @@ async def test_connection_creation(self): assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs + async def test_aclosing(self): + connection_kwargs = {"foo": "bar", "biz": "baz"} + pool = redis.ConnectionPool( + connection_class=DummyConnection, + max_connections=None, + **connection_kwargs, + ) + async with aclosing(pool): + pass + async def test_multiple_connections(self, master_host): connection_kwargs = {"host": master_host[0]} async with self.get_pool(connection_kwargs=connection_kwargs) as pool: