Skip to content

Commit

Permalink
Allow passing in an explicit connection pool for redis (#110)
Browse files Browse the repository at this point in the history
* Allow explicit connection pool for redis
  • Loading branch information
alisaifee authored Mar 11, 2022
1 parent b22ccbb commit 41ba006
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 6 deletions.
6 changes: 6 additions & 0 deletions doc/source/storage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,12 @@ or :code:`redis+unix:///path/to/socket?db=n` (for database `n`).
If the database is password protected the password can be provided in the url, for example
:code:`redis://:foobared@localhost:6379` or :code:`redis+unix//:foobered/path/to/socket` if using a UDS..

For scenarios where a redis connection pool is already available and can be reused, it can be provided
in :paramref:`~limits.storage.storage_from_string.options`, for example::

pool = redis.connections.BlockingConnectionPool.from_url("redis://.....")
storage_from_string("redis://", connection_pool=pool)
Depends on: :pypi:`redis`

Redis over SSL
Expand Down
13 changes: 11 additions & 2 deletions limits/aio/storage/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,18 @@ class RedisStorage(RedisInteractor, Storage, MovingWindowSupport):
"""
DEPENDENCIES = ["coredis"]

def __init__(self, uri: str, **options) -> None:
def __init__(
self, uri: str, connection_pool: Optional[Any] = None, **options
) -> None:
"""
:param uri: uri of the form `async+redis://[:password]@host:port`,
`async+redis://[:password]@host:port/db`,
`async+rediss://[:password]@host:port`, `async+unix:///path/to/sock` etc.
This uri is passed directly to :func:`coredis.StrictRedis.from_url` with
the initial `a` removed, except for the case of `redis+unix` where it
is replaced with `unix`.
:param connection_pool: if provided, the redis client is initialized with
the connection pool and any other params passed as :paramref:`options`
:param options: all remaining keyword arguments are passed
directly to the constructor of :class:`coredis.StrictRedis`
:raise ConfigurationError: when the redis library is not available
Expand All @@ -148,7 +152,12 @@ def __init__(self, uri: str, **options) -> None:
super().__init__()

self.dependency = self.dependencies["coredis"]
self.storage = self.dependency.StrictRedis.from_url(uri, **options)
if connection_pool:
self.storage = self.dependency.StrictRedis(
connection_pool=connection_pool, **options
)
else:
self.storage = self.dependency.StrictRedis.from_url(uri, **options)

self.initialize_storage(uri)

Expand Down
16 changes: 13 additions & 3 deletions limits/storage/redis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Any, Tuple
from typing import Any, Optional, Tuple

from ..util import get_package_data
from .base import MovingWindowSupport, Storage
Expand Down Expand Up @@ -111,13 +111,20 @@ class RedisStorage(RedisInteractor, Storage, MovingWindowSupport):

DEPENDENCIES = ["redis"]

def __init__(self, uri: str, **options):
def __init__(
self,
uri: str,
connection_pool: Optional[Any] = None,
**options,
):
"""
:param uri: uri of the form ``redis://[:password]@host:port``,
``redis://[:password]@host:port/db``,
``rediss://[:password]@host:port``, ``redis+unix:///path/to/sock`` etc.
This uri is passed directly to :func:`redis.from_url` except for the
case of ``redis+unix://`` where it is replaced with ``unix://``.
:param connection_pool: if provided, the redis client is initialized with
the connection pool and any other params passed as :paramref:`options`
:param options: all remaining keyword arguments are passed
directly to the constructor of :class:`redis.Redis`
:raise ConfigurationError: when the :pypi:`redis` library is not available
Expand All @@ -126,7 +133,10 @@ def __init__(self, uri: str, **options):
redis = self.dependencies["redis"]
uri = uri.replace("redis+unix", "unix")

self.storage = redis.from_url(uri, **options)
if not connection_pool:
self.storage = redis.from_url(uri, **options)
else:
self.storage = redis.Redis(connection_pool=connection_pool, **options)
self.initialize_storage(uri)

def initialize_storage(self, _uri: str):
Expand Down
10 changes: 10 additions & 0 deletions tests/aio/storage/test_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ async def test_moving_window_clear(self):
assert await limiter.hit(per_min)

@pytest.mark.asyncio
@pytest.mark.flaky
async def test_moving_window_expiry(self):
limiter = MovingWindowRateLimiter(self.storage)
limit = RateLimitItemPerSecond(2)
Expand Down Expand Up @@ -97,6 +98,15 @@ async def test_init_options(self, mocker):
from_url.spy_return.connection_pool.connection_kwargs["stream_timeout"] == 1
)

@pytest.mark.asyncio
async def test_custom_connection_pool(self):
import coredis

pool = coredis.BlockingConnectionPool.from_url(self.storage_url)
storage = storage_from_string("async+redis://", connection_pool=pool)

assert await storage.check()


@pytest.mark.redis
class TestAsyncRedisAuthStorage(AsyncSharedRedisTests):
Expand Down
6 changes: 6 additions & 0 deletions tests/storage/test_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ def test_init_options(self, mocker):
assert storage_from_string(self.storage_url, socket_timeout=1).check()
assert from_url.call_args[1]["socket_timeout"] == 1

def test_custom_connection_pool(self):
pool = redis.connection.BlockingConnectionPool.from_url(self.storage_url)
storage = storage_from_string("redis://", connection_pool=pool)

assert storage.check()


@pytest.mark.redis
class TestRedisUnixSocketStorage(SharedRedisTests):
Expand Down
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ async def async_window(delay_end: float, delay: Optional[float] = None):
"redis+sentinel://localhost:26379",
{"service_name": "localhost-redis-sentinel"},
pytest.lazy_fixture("redis_sentinel"),
marks=pytest.mark.redis_sentinel,
marks=[pytest.mark.redis_sentinel, pytest.mark.flaky],
),
pytest.param(
"redis+sentinel://localhost:26379/localhost-redis-sentinel",
Expand Down

0 comments on commit 41ba006

Please sign in to comment.