diff --git a/redis/connection.py b/redis/connection.py index 00d293a238..0b081bd79b 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1151,6 +1151,10 @@ def disconnect(self, inuse_connections=True): for connection in connections: connection.disconnect() + def close(self) -> None: + """Close the pool, disconnecting all connections""" + self.disconnect() + def set_retry(self, retry: "Retry") -> None: self.connection_kwargs.update({"retry": retry}) for conn in self._available_connections: diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index ab0fc6be98..ef70a8ff35 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -1,6 +1,7 @@ import os import re import time +from contextlib import closing from threading import Thread from unittest import mock @@ -51,6 +52,16 @@ def test_connection_creation(self): assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs + def test_closing(self): + connection_kwargs = {"foo": "bar", "biz": "baz"} + pool = redis.ConnectionPool( + connection_class=DummyConnection, + max_connections=None, + **connection_kwargs, + ) + with closing(pool): + pass + def test_multiple_connections(self, master_host): connection_kwargs = {"host": master_host[0], "port": master_host[1]} pool = self.get_pool(connection_kwargs=connection_kwargs)