diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index ef15a4b4..a3467a5e 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -8,16 +8,19 @@ on: jobs: test: - runs-on: ubuntu-latest strategy: matrix: python-version: [3.7, 3.8] marker: [not integration] + os: [ubuntu-latest] include: - python-version: 3.9 marker: '' + os: ubuntu-latest - python-version: "3.10" marker: 'not ((redis or redis_sentinel or redis_cluster) and asynchronous)' + os: ubuntu-latest + runs-on: "${{ matrix.os }}" steps: - uses: actions/checkout@v2 - uses: docker-practice/actions-setup-docker@master diff --git a/.gitignore b/.gitignore index e175ffc4..a6b6ef6c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ *.pyc *.log cover/* +.mypy_cache/* .coverage* .test_env .idea @@ -8,10 +9,5 @@ build/ dist/ htmlcov *egg-info* -*.rdb -redis-git .python-version -# gae test files -google_appengine -google .*.swp diff --git a/limits/aio/storage/memcached.py b/limits/aio/storage/memcached.py index f384a4bd..09f28a01 100644 --- a/limits/aio/storage/memcached.py +++ b/limits/aio/storage/memcached.py @@ -88,15 +88,21 @@ async def incr(self, key: str, expiry: int, elastic_expiry=False) -> int: if elastic_expiry: await storage.touch(limit_key, exptime=expiry) + await storage.set( + expire_key, + str(expiry + time.time()).encode("utf-8"), + exptime=expiry, + noreply=False, + ) + + return value + else: await storage.set( expire_key, str(expiry + time.time()).encode("utf-8"), exptime=expiry, noreply=False, ) - - return value - return 1 async def get_expiry(self, key: str) -> int: diff --git a/limits/aio/storage/mongodb.py b/limits/aio/storage/mongodb.py index 65a4c9fa..194e9abe 100644 --- a/limits/aio/storage/mongodb.py +++ b/limits/aio/storage/mongodb.py @@ -1,4 +1,5 @@ import asyncio +import calendar import datetime import functools import time @@ -116,7 +117,7 @@ async def get_expiry(self, key: str) -> int: counter = await self.database.counters.find_one({"_id": key}) expiry = counter["expireAt"] if counter else datetime.datetime.utcnow() - return int(time.mktime(expiry.timetuple())) + return calendar.timegm(expiry.timetuple()) async def get(self, key: str): """ diff --git a/limits/storage/memcached.py b/limits/storage/memcached.py index 65649b9f..a470ad42 100644 --- a/limits/storage/memcached.py +++ b/limits/storage/memcached.py @@ -117,6 +117,15 @@ def incr(self, key: str, expiry: int, elastic_expiry=False) -> int: value = self.storage.incr(key, 1) or 1 if elastic_expiry: self.call_memcached_func(self.storage.touch, key, expiry) + self.call_memcached_func( + self.storage.set, + key + "/expires", + expiry + time.time(), + expire=expiry, + noreply=False, + ) + return value + else: self.call_memcached_func( self.storage.set, key + "/expires", @@ -124,7 +133,6 @@ def incr(self, key: str, expiry: int, elastic_expiry=False) -> int: expire=expiry, noreply=False, ) - return value return 1 def get_expiry(self, key: str) -> int: diff --git a/limits/storage/mongodb.py b/limits/storage/mongodb.py index 2bff16a5..872b8cd0 100644 --- a/limits/storage/mongodb.py +++ b/limits/storage/mongodb.py @@ -1,3 +1,4 @@ +import calendar import datetime import time from typing import Any, Dict, Tuple @@ -74,7 +75,7 @@ def get_expiry(self, key: str) -> int: counter = self.counters.find_one({"_id": key}) expiry = counter["expireAt"] if counter else datetime.datetime.utcnow() - return int(time.mktime(expiry.timetuple())) + return calendar.timegm(expiry.timetuple()) def get(self, key: str): """ diff --git a/tests/__init__.py b/tests/__init__.py index 10d28b42..e69de29b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,24 +0,0 @@ -import functools -import math -import platform -import time -import unittest - - -def skip_if_pypy(fn): - return unittest.skipIf( - platform.python_implementation().lower() == "pypy", reason="Skipped for pypy" - )(fn) - - -def fixed_start(fn): - @functools.wraps(fn) - def __inner(*a, **k): - start = time.time() - - while time.time() < math.ceil(start): - time.sleep(0.01) - - return fn(*a, **k) - - return __inner diff --git a/tests/aio/storage/test_memcached.py b/tests/aio/storage/test_memcached.py index 60911c50..c8e6f52f 100644 --- a/tests/aio/storage/test_memcached.py +++ b/tests/aio/storage/test_memcached.py @@ -10,7 +10,7 @@ FixedWindowRateLimiter, ) from limits.storage import storage_from_string -from tests import fixed_start +from tests.utils import fixed_start @pytest.mark.flaky @@ -34,10 +34,12 @@ async def test_fixed_window(self): per_min = RateLimitItemPerSecond(10) start = time.time() count = 0 + while time.time() - start < 0.5 and count < 10: assert await limiter.hit(per_min) count += 1 assert not await limiter.hit(per_min) + while time.time() - start <= 1: await asyncio.sleep(0.1) assert await limiter.hit(per_min) @@ -50,10 +52,12 @@ async def test_fixed_window_cluster(self): per_min = RateLimitItemPerSecond(10) start = time.time() count = 0 + while time.time() - start < 0.5 and count < 10: assert await limiter.hit(per_min) count += 1 assert not await limiter.hit(per_min) + while time.time() - start <= 1: await asyncio.sleep(0.1) assert await limiter.hit(per_min) diff --git a/tests/aio/storage/test_mongodb.py b/tests/aio/storage/test_mongodb.py index e43fe9d1..6bb0ae8a 100644 --- a/tests/aio/storage/test_mongodb.py +++ b/tests/aio/storage/test_mongodb.py @@ -24,9 +24,9 @@ async def test_init_options(self, mocker): constructor = mocker.spy(motor.motor_asyncio, "AsyncIOMotorClient") assert await storage_from_string( - f"async+{self.storage_url}", connectTimeoutMS=1 + f"async+{self.storage_url}", socketTimeoutMS=100 ).check() - assert constructor.call_args[1]["connectTimeoutMS"] == 1 + assert constructor.call_args[1]["socketTimeoutMS"] == 100 @pytest.mark.asyncio async def test_fixed_window(self): diff --git a/tests/aio/test_strategy.py b/tests/aio/test_strategy.py index 9b07b85c..61b10be4 100644 --- a/tests/aio/test_strategy.py +++ b/tests/aio/test_strategy.py @@ -1,206 +1,97 @@ -import asyncio import time -import hiro import pytest -from limits.aio.storage import ( - MemcachedStorage, - MemoryStorage, - MongoDBStorage, - RedisSentinelStorage, - RedisStorage, -) +from limits.aio.storage import MemcachedStorage from limits.aio.strategies import ( FixedWindowElasticExpiryRateLimiter, FixedWindowRateLimiter, MovingWindowRateLimiter, ) -from limits.limits import RateLimitItemPerMinute, RateLimitItemPerSecond +from limits.limits import RateLimitItemPerSecond +from limits.storage import storage_from_string +from tests.utils import ( + async_all_storage, + async_moving_window_storage, + async_window, + fixed_start, +) @pytest.mark.asynchronous +@pytest.mark.asyncio class TestAsyncWindow: - @pytest.mark.asyncio - async def test_fixed_window(self): - storage = MemoryStorage() + @async_all_storage + @fixed_start + async def test_fixed_window(self, uri, args, fixture): + storage = storage_from_string(uri, **args) limiter = FixedWindowRateLimiter(storage) - with hiro.Timeline().freeze() as timeline: - start = int(time.time()) - limit = RateLimitItemPerSecond(10, 2) - assert all([await limiter.hit(limit) for _ in range(0, 10)]) - timeline.forward(1) - assert not await limiter.hit(limit) - assert (await limiter.get_window_stats(limit))[1] == 0 - assert (await limiter.get_window_stats(limit))[0] == start + 2 - timeline.forward(1) - assert (await limiter.get_window_stats(limit))[1] == 10 - assert await limiter.hit(limit) - - @pytest.mark.asyncio - async def test_fixed_window_with_elastic_expiry_in_memory(self): - storage = MemoryStorage() - limiter = FixedWindowElasticExpiryRateLimiter(storage) - with hiro.Timeline().freeze() as timeline: - start = int(time.time()) - limit = RateLimitItemPerSecond(10, 2) - assert all([await limiter.hit(limit) for _ in range(0, 10)]) - timeline.forward(1) - assert not await limiter.hit(limit) - assert (await limiter.get_window_stats(limit))[1] == 0 - # three extensions to the expiry - assert (await limiter.get_window_stats(limit))[0] == start + 3 - timeline.forward(1) - assert not await limiter.hit(limit) - timeline.forward(3) - start = int(time.time()) - assert await limiter.hit(limit) - assert (await limiter.get_window_stats(limit))[1] == 9 - assert (await limiter.get_window_stats(limit))[0] == start + 2 - - @pytest.mark.flaky - @pytest.mark.asyncio - @pytest.mark.memcached - async def test_fixed_window_with_elastic_expiry_memcached(self, memcached): - storage = MemcachedStorage("async+memcached://localhost:22122") - limiter = FixedWindowElasticExpiryRateLimiter(storage) limit = RateLimitItemPerSecond(10, 2) - - for _ in range(0, 10): - assert await limiter.hit(limit) - await asyncio.sleep(1) - assert not await limiter.hit(limit) - await asyncio.sleep(1) + async with async_window(1) as (start, end): + assert all([await limiter.hit(limit) for _ in range(0, 10)]) assert not await limiter.hit(limit) assert (await limiter.get_window_stats(limit))[1] == 0 + assert (await limiter.get_window_stats(limit))[0] == start + 2 - @pytest.mark.asyncio - @pytest.mark.mongodb - async def test_fixed_window_with_elastic_expiry_mongo(self, mongodb): - storage = MongoDBStorage("async+mongodb://localhost:37017") + @async_all_storage + async def test_fixed_window_with_elastic_expiry(self, uri, args, fixture): + storage = storage_from_string(uri, **args) limiter = FixedWindowElasticExpiryRateLimiter(storage) limit = RateLimitItemPerSecond(10, 2) - - for _ in range(0, 10): - assert await limiter.hit(limit) - await asyncio.sleep(1) - assert not await limiter.hit(limit) - await asyncio.sleep(1) - assert not await limiter.hit(limit) + async with async_window(1) as (start, end): + assert all([await limiter.hit(limit) for _ in range(0, 10)]) + assert not await limiter.hit(limit) assert (await limiter.get_window_stats(limit))[1] == 0 + assert (await limiter.get_window_stats(limit))[0] == start + 2 + async with async_window(3) as (start, end): + assert not await limiter.hit(limit) + assert await limiter.hit(limit) + assert (await limiter.get_window_stats(limit))[1] == 9 + assert (await limiter.get_window_stats(limit))[0] == end + 2 - @pytest.mark.asyncio - @pytest.mark.redis - async def test_fixed_window_with_elastic_expiry_redis(self, redis_basic): - storage = RedisStorage("async+redis://localhost:7379") - limiter = FixedWindowElasticExpiryRateLimiter(storage) + @async_moving_window_storage + async def test_moving_window(self, uri, args, fixture): + storage = storage_from_string(uri, **args) + limiter = MovingWindowRateLimiter(storage) limit = RateLimitItemPerSecond(10, 2) - for _ in range(0, 10): - assert await limiter.hit(limit) - await asyncio.sleep(1) - assert not await limiter.hit(limit) - await asyncio.sleep(1) - assert not await limiter.hit(limit) + # 5 hits in the first 100ms + async with async_window(0.1): + assert all([await limiter.hit(limit) for i in range(5)]) + # 5 hits in the last 100ms + async with async_window(2, delay=1.8): + assert all([await limiter.hit(limit) for i in range(5)]) + # 11th fails + assert not await limiter.hit(limit) + # 5 more succeed since there were only 5 in the last 2 seconds + assert all([await limiter.hit(limit) for i in range(5)]) assert (await limiter.get_window_stats(limit))[1] == 0 + assert (await limiter.get_window_stats(limit))[0] == int(time.time() + 2) - @pytest.mark.asyncio - @pytest.mark.redis_sentinel - async def test_fixed_window_with_elastic_expiry_redis_sentinel( - self, redis_sentinel - ): - storage = RedisSentinelStorage( - "async+redis+sentinel://localhost:26379/localhost-redis-sentinel" - ) - limiter = FixedWindowElasticExpiryRateLimiter(storage) - limit = RateLimitItemPerSecond(10, 2) - - for _ in range(0, 10): - assert await limiter.hit(limit) - await asyncio.sleep(1) - assert not await limiter.hit(limit) - await asyncio.sleep(1) + @pytest.mark.memcached + async def test_moving_window_memcached(self, memcached): + storage = MemcachedStorage("memcached://localhost:22122") + with pytest.raises(NotImplementedError): + MovingWindowRateLimiter(storage) + + @async_all_storage + async def test_test_fixed_window(self, uri, args, fixture): + storage = storage_from_string(uri, **args) + limiter = FixedWindowRateLimiter(storage) + limit = RateLimitItemPerSecond(2, 1) + assert await limiter.hit(limit) + assert await limiter.test(limit) + assert await limiter.hit(limit) + assert not await limiter.test(limit) assert not await limiter.hit(limit) - assert (await limiter.get_window_stats(limit))[1] == 0 - - @pytest.mark.asyncio - async def test_moving_window_in_memory(self): - storage = MemoryStorage() - limiter = MovingWindowRateLimiter(storage) - with hiro.Timeline().freeze() as timeline: - limit = RateLimitItemPerMinute(10) - - for i in range(0, 5): - assert await limiter.hit(limit) - assert await limiter.hit(limit) - assert (await limiter.get_window_stats(limit))[1] == 10 - ((i + 1) * 2) - timeline.forward(10) - assert (await limiter.get_window_stats(limit))[1] == 0 - assert not await limiter.hit(limit) - timeline.forward(20) - assert (await limiter.get_window_stats(limit))[1] == 2 - assert (await limiter.get_window_stats(limit))[0] == int(time.time() + 30) - timeline.forward(31) - assert (await limiter.get_window_stats(limit))[1] == 10 - @pytest.mark.asyncio - @pytest.mark.mongodb - async def test_moving_window_mongo(self, mongodb): - storage = MongoDBStorage("async+mongodb://localhost:37017") + @async_moving_window_storage + async def test_test_moving_window(self, uri, args, fixture): + storage = storage_from_string(uri, **args) + limit = RateLimitItemPerSecond(2, 1) limiter = MovingWindowRateLimiter(storage) - with hiro.Timeline().freeze() as timeline: - limit = RateLimitItemPerMinute(10) - - for i in range(0, 5): - assert await limiter.hit(limit) - assert await limiter.hit(limit) - assert (await limiter.get_window_stats(limit))[1] == 10 - ((i + 1) * 2) - timeline.forward(10) - assert (await limiter.get_window_stats(limit))[1] == 0 - assert not await limiter.hit(limit) - timeline.forward(20) - assert (await limiter.get_window_stats(limit))[1] == 2 - assert (await limiter.get_window_stats(limit))[0] == int(time.time() + 30) - timeline.forward(31) - assert (await limiter.get_window_stats(limit))[1] == 10 - - @pytest.mark.asyncio - @pytest.mark.redis - async def test_moving_window_redis(self, redis_basic): - storage = RedisStorage("async+redis://localhost:7379") - limiter = MovingWindowRateLimiter(storage) - limit = RateLimitItemPerSecond(10, 2) - - for i in range(0, 10): - assert await limiter.hit(limit) - assert (await limiter.get_window_stats(limit))[1] == 10 - (i + 1) - await asyncio.sleep(2 * 0.095) - assert not await limiter.hit(limit) - await asyncio.sleep(0.4) assert await limiter.hit(limit) + assert await limiter.test(limit) assert await limiter.hit(limit) - assert (await limiter.get_window_stats(limit))[1] == 0 - - @pytest.mark.asyncio - async def test_test_fixed_window(self): - with hiro.Timeline().freeze(): - store = MemoryStorage() - limiter = FixedWindowRateLimiter(store) - limit = RateLimitItemPerSecond(2, 1) - assert await limiter.hit(limit) - assert await limiter.test(limit) - assert await limiter.hit(limit) - assert not await limiter.test(limit) - assert not await limiter.hit(limit) - - @pytest.mark.asyncio - async def test_test_moving_window(self): - with hiro.Timeline().freeze(): - store = MemoryStorage() - limit = RateLimitItemPerSecond(2, 1) - limiter = MovingWindowRateLimiter(store) - assert await limiter.hit(limit) - assert await limiter.test(limit) - assert await limiter.hit(limit) - assert not await limiter.test(limit) - assert not await limiter.hit(limit) + assert not await limiter.test(limit) + assert not await limiter.hit(limit) diff --git a/tests/storage/test_memcached.py b/tests/storage/test_memcached.py index fd3a8c5e..2b35cf31 100644 --- a/tests/storage/test_memcached.py +++ b/tests/storage/test_memcached.py @@ -9,7 +9,7 @@ FixedWindowElasticExpiryRateLimiter, FixedWindowRateLimiter, ) -from tests import fixed_start +from tests.utils import fixed_start @pytest.mark.memcached diff --git a/tests/test_strategy.py b/tests/test_strategy.py index d759ff4b..95716e43 100644 --- a/tests/test_strategy.py +++ b/tests/test_strategy.py @@ -1,184 +1,65 @@ -import threading +import math import time -import hiro import pytest -from limits.limits import RateLimitItemPerMinute, RateLimitItemPerSecond -from limits.storage import ( - MemcachedStorage, - MemoryStorage, - MongoDBStorage, - RedisSentinelStorage, - RedisStorage, -) +from limits.limits import RateLimitItemPerSecond +from limits.storage import MemcachedStorage, storage_from_string from limits.strategies import ( FixedWindowElasticExpiryRateLimiter, FixedWindowRateLimiter, MovingWindowRateLimiter, ) +from tests.utils import all_storage, fixed_start, moving_window_storage, window class TestWindow: - def test_fixed_window(self): - storage = MemoryStorage() + @all_storage + @fixed_start + def test_fixed_window(self, uri, args, fixture): + storage = storage_from_string(uri, **args) limiter = FixedWindowRateLimiter(storage) - with hiro.Timeline().freeze() as timeline: - start = int(time.time()) - limit = RateLimitItemPerSecond(10, 2) - assert all([limiter.hit(limit) for _ in range(0, 10)]) - timeline.forward(1) - assert not limiter.hit(limit) - assert limiter.get_window_stats(limit)[1] == 0 - assert limiter.get_window_stats(limit)[0] == start + 2 - timeline.forward(1) - assert limiter.get_window_stats(limit)[1] == 10 - assert limiter.hit(limit) - - @pytest.mark.flaky - def test_fixed_window_with_elastic_expiry_in_memory(self): - storage = MemoryStorage() - limiter = FixedWindowElasticExpiryRateLimiter(storage) - with hiro.Timeline().freeze() as timeline: - start = int(time.time()) - limit = RateLimitItemPerSecond(10, 2) - assert all([limiter.hit(limit) for _ in range(0, 10)]) - timeline.forward(1) - assert not limiter.hit(limit) - assert limiter.get_window_stats(limit)[1] == 0 - # three extensions to the expiry - assert limiter.get_window_stats(limit)[0] == start + 3 - timeline.forward(1) - assert not limiter.hit(limit) - timeline.forward(3) - start = int(time.time()) - assert limiter.hit(limit) - assert limiter.get_window_stats(limit)[1] == 9 - assert limiter.get_window_stats(limit)[0] == start + 2 - - @pytest.mark.flaky - @pytest.mark.memcached - def test_fixed_window_with_elastic_expiry_memcache(self, memcached): - storage = MemcachedStorage("memcached://localhost:22122") - limiter = FixedWindowElasticExpiryRateLimiter(storage) limit = RateLimitItemPerSecond(10, 2) - assert all([limiter.hit(limit) for _ in range(0, 10)]) - time.sleep(1) - assert not limiter.hit(limit) - time.sleep(1) - assert not limiter.hit(limit) - - @pytest.mark.flaky - @pytest.mark.memcached - def test_fixed_window_with_elastic_expiry_memcache_concurrency(self, memcached): - storage = MemcachedStorage("memcached://localhost:22122") - limiter = FixedWindowElasticExpiryRateLimiter(storage) - start = int(time.time()) - limit = RateLimitItemPerSecond(10, 2) - - def _c(): - for i in range(0, 5): - limiter.hit(limit) - - t1, t2 = threading.Thread(target=_c), threading.Thread(target=_c) - t1.start(), t2.start() - t1.join(), t2.join() - assert limiter.get_window_stats(limit)[1] == 0 - assert start + 2 <= limiter.get_window_stats(limit)[0] <= start + 3 - assert storage.get(limit.key_for()) == 10 - - @pytest.mark.mongodb - def test_fixed_window_with_elastic_expiry_mongo(self, mongodb): - storage = MongoDBStorage("mongodb://localhost:37017") - limiter = FixedWindowElasticExpiryRateLimiter(storage) - limit = RateLimitItemPerSecond(10, 2) - assert all([limiter.hit(limit) for _ in range(0, 10)]) - time.sleep(1) - assert not limiter.hit(limit) - time.sleep(1) - assert not limiter.hit(limit) - assert limiter.get_window_stats(limit)[1] == 0 - - @pytest.mark.redis - def test_fixed_window_with_elastic_expiry_redis(self, redis_basic): - storage = RedisStorage("redis://localhost:7379") - limiter = FixedWindowElasticExpiryRateLimiter(storage) - limit = RateLimitItemPerSecond(10, 2) - assert all([limiter.hit(limit) for _ in range(0, 10)]) - time.sleep(1) - assert not limiter.hit(limit) - time.sleep(1) + with window(1) as (start, end): + assert all([limiter.hit(limit) for _ in range(0, 10)]) assert not limiter.hit(limit) assert limiter.get_window_stats(limit)[1] == 0 + assert limiter.get_window_stats(limit)[0] == math.floor(start + 2) - @pytest.mark.redis_sentinel - def test_fixed_window_with_elastic_expiry_redis_sentinel(self, redis_sentinel): - storage = RedisSentinelStorage( - "redis+sentinel://localhost:26379", service_name="localhost-redis-sentinel" - ) + @all_storage + def test_fixed_window_with_elastic_expiry(self, uri, args, fixture): + storage = storage_from_string(uri, **args) limiter = FixedWindowElasticExpiryRateLimiter(storage) limit = RateLimitItemPerSecond(10, 2) - assert all([limiter.hit(limit) for _ in range(0, 10)]) - time.sleep(1) - assert not limiter.hit(limit) - time.sleep(1) - assert not limiter.hit(limit) + with window(1) as (start, end): + assert all([limiter.hit(limit) for _ in range(0, 10)]) + assert not limiter.hit(limit) assert limiter.get_window_stats(limit)[1] == 0 - - def test_moving_window_in_memory(self): - storage = MemoryStorage() - limiter = MovingWindowRateLimiter(storage) - with hiro.Timeline().freeze() as timeline: - limit = RateLimitItemPerMinute(10) - - for i in range(0, 5): - assert limiter.hit(limit) - assert limiter.hit(limit) - assert limiter.get_window_stats(limit)[1] == 10 - ((i + 1) * 2) - timeline.forward(10) - assert limiter.get_window_stats(limit)[1] == 0 + assert limiter.get_window_stats(limit)[0] == start + 2 + with window(3) as (start, end): assert not limiter.hit(limit) - timeline.forward(20) - assert limiter.get_window_stats(limit)[1] == 2 - assert limiter.get_window_stats(limit)[0] == int(time.time() + 30) - timeline.forward(31) - assert limiter.get_window_stats(limit)[1] == 10 - - @pytest.mark.redis - def test_moving_window_redis(self, redis_basic): - storage = RedisStorage("redis://localhost:7379") - limiter = MovingWindowRateLimiter(storage) - limit = RateLimitItemPerSecond(10, 2) - - for i in range(0, 10): - assert limiter.hit(limit) - assert limiter.get_window_stats(limit)[1] == 10 - (i + 1) - time.sleep(2 * 0.095) - assert not limiter.hit(limit) - time.sleep(0.4) assert limiter.hit(limit) - assert limiter.hit(limit) - assert limiter.get_window_stats(limit)[1] == 0 + assert limiter.get_window_stats(limit)[1] == 9 + assert limiter.get_window_stats(limit)[0] == end + 2 - @pytest.mark.mongodb - def test_moving_window_mongo(self, mongodb): - storage = MongoDBStorage("mongodb://localhost:37017") + @moving_window_storage + def test_moving_window(self, uri, args, fixture): + storage = storage_from_string(uri, **args) limiter = MovingWindowRateLimiter(storage) - with hiro.Timeline().freeze() as timeline: - limit = RateLimitItemPerMinute(10) + limit = RateLimitItemPerSecond(10, 2) - for i in range(0, 5): - assert limiter.hit(limit) - assert limiter.hit(limit) - assert limiter.get_window_stats(limit)[1] == 10 - ((i + 1) * 2) - timeline.forward(10) - assert limiter.get_window_stats(limit)[1] == 0 + # 5 hits in the first 100ms + with window(0.1): + assert all(limiter.hit(limit) for i in range(5)) + # 5 hits in the last 100ms + with window(2, delay=1.8): + assert all(limiter.hit(limit) for i in range(5)) + # 11th fails assert not limiter.hit(limit) - timeline.forward(20) - assert limiter.get_window_stats(limit)[1] == 2 - assert limiter.get_window_stats(limit)[0] == int(time.time() + 30) - timeline.forward(31) - assert limiter.get_window_stats(limit)[1] == 10 + # 5 more succeed since there were only 5 in the last 2 seconds + assert all(limiter.hit(limit) for i in range(5)) + assert limiter.get_window_stats(limit)[1] == 0 + assert limiter.get_window_stats(limit)[0] == int(time.time() + 2) @pytest.mark.memcached def test_moving_window_memcached(self, memcached): @@ -186,24 +67,24 @@ def test_moving_window_memcached(self, memcached): with pytest.raises(NotImplementedError): MovingWindowRateLimiter(storage) - def test_test_fixed_window(self): - with hiro.Timeline().freeze(): - store = MemoryStorage() - limiter = FixedWindowRateLimiter(store) - limit = RateLimitItemPerSecond(2, 1) - assert limiter.hit(limit), store - assert limiter.test(limit), store - assert limiter.hit(limit), store - assert not limiter.test(limit), store - assert not limiter.hit(limit), store + @all_storage + def test_test_fixed_window(self, uri, args, fixture): + storage = storage_from_string(uri, **args) + limiter = FixedWindowRateLimiter(storage) + limit = RateLimitItemPerSecond(2, 1) + assert limiter.hit(limit) + assert limiter.test(limit) + assert limiter.hit(limit) + assert not limiter.test(limit) + assert not limiter.hit(limit) - def test_test_moving_window(self): - with hiro.Timeline().freeze(): - store = MemoryStorage() - limit = RateLimitItemPerSecond(2, 1) - limiter = MovingWindowRateLimiter(store) - assert limiter.hit(limit), store - assert limiter.test(limit), store - assert limiter.hit(limit), store - assert not limiter.test(limit), store - assert not limiter.hit(limit), store + @moving_window_storage + def test_test_moving_window(self, uri, args, fixture): + storage = storage_from_string(uri, **args) + limit = RateLimitItemPerSecond(2, 1) + limiter = MovingWindowRateLimiter(storage) + assert limiter.hit(limit) + assert limiter.test(limit) + assert limiter.hit(limit) + assert not limiter.test(limit) + assert not limiter.hit(limit) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..024cd3b1 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,293 @@ +import asyncio +import contextlib +import functools +import math +import time +from typing import Optional + +import pytest + + +def fixed_start(fn): + @functools.wraps(fn) + def __inner(*a, **k): + start = time.time() + + while time.time() < math.ceil(start): + time.sleep(0.01) + + return fn(*a, **k) + + return __inner + + +@contextlib.contextmanager +def window(delay_end: float, delay: Optional[float] = None): + start = time.time() + + if delay is not None: + while time.time() - start < delay: + time.sleep(0.001) + yield (int(start), int(start + delay_end)) + + while time.time() - start < delay_end: + time.sleep(0.001) + + +@contextlib.asynccontextmanager +async def async_window(delay_end: float, delay: Optional[float] = None): + start = time.time() + + if delay is not None: + while time.time() - start < delay: + await asyncio.sleep(0.001) + + yield (int(start), int(start + delay_end)) + + while time.time() - start < delay_end: + await asyncio.sleep(0.001) + + +all_storage = pytest.mark.parametrize( + "uri, args, fixture", + [ + ("memory://", {}, None), + pytest.param( + "redis://localhost:7379", + {}, + pytest.lazy_fixture("redis_basic"), + marks=pytest.mark.redis, + ), + pytest.param( + "redis+unix:///tmp/limits.redis.sock", + {}, + pytest.lazy_fixture("redis_uds"), + marks=pytest.mark.redis, + ), + pytest.param( + "redis+unix://:password/tmp/limits.redis.sock", + {}, + pytest.lazy_fixture("redis_uds"), + marks=pytest.mark.redis, + ), + pytest.param( + "memcached://localhost:22122", + {}, + pytest.lazy_fixture("memcached"), + marks=pytest.mark.memcached, + ), + pytest.param( + "memcached://localhost:22122,localhost:22123", + {}, + pytest.lazy_fixture("memcached_cluster"), + marks=pytest.mark.memcached, + ), + pytest.param( + "redis+sentinel://localhost:26379", + {"service_name": "localhost-redis-sentinel"}, + pytest.lazy_fixture("redis_sentinel"), + marks=pytest.mark.redis_sentinel, + ), + pytest.param( + "redis+sentinel://localhost:26379/localhost-redis-sentinel", + {}, + pytest.lazy_fixture("redis_sentinel"), + marks=pytest.mark.redis_sentinel, + ), + pytest.param( + "redis+sentinel://:sekret@localhost:36379/localhost-redis-sentinel", + {}, + pytest.lazy_fixture("redis_sentinel_auth"), + marks=pytest.mark.redis_sentinel, + ), + pytest.param( + "redis+cluster://localhost:7001/", + {}, + pytest.lazy_fixture("redis_cluster"), + marks=pytest.mark.redis_cluster, + ), + pytest.param( + "mongodb://localhost:37017/", + {}, + pytest.lazy_fixture("mongodb"), + marks=pytest.mark.mongodb, + ), + ], +) + +moving_window_storage = pytest.mark.parametrize( + "uri, args, fixture", + [ + ("memory://", {}, None), + pytest.param( + "redis://localhost:7379", + {}, + pytest.lazy_fixture("redis_basic"), + marks=pytest.mark.redis, + ), + pytest.param( + "redis+unix:///tmp/limits.redis.sock", + {}, + pytest.lazy_fixture("redis_uds"), + marks=pytest.mark.redis, + ), + pytest.param( + "redis+unix://:password/tmp/limits.redis.sock", + {}, + pytest.lazy_fixture("redis_uds"), + marks=pytest.mark.redis, + ), + pytest.param( + "redis+sentinel://localhost:26379", + {"service_name": "localhost-redis-sentinel"}, + pytest.lazy_fixture("redis_sentinel"), + marks=pytest.mark.redis_sentinel, + ), + pytest.param( + "redis+sentinel://localhost:26379/localhost-redis-sentinel", + {}, + pytest.lazy_fixture("redis_sentinel"), + marks=pytest.mark.redis_sentinel, + ), + pytest.param( + "redis+sentinel://:sekret@localhost:36379/localhost-redis-sentinel", + {}, + pytest.lazy_fixture("redis_sentinel_auth"), + marks=pytest.mark.redis_sentinel, + ), + pytest.param( + "redis+cluster://localhost:7001/", + {}, + pytest.lazy_fixture("redis_cluster"), + marks=pytest.mark.redis_cluster, + ), + pytest.param( + "mongodb://localhost:37017/", + {}, + pytest.lazy_fixture("mongodb"), + marks=pytest.mark.mongodb, + ), + ], +) + +async_all_storage = pytest.mark.parametrize( + "uri, args, fixture", + [ + ("async+memory://", {}, None), + pytest.param( + "async+redis://localhost:7379", + {}, + pytest.lazy_fixture("redis_basic"), + marks=pytest.mark.redis, + ), + pytest.param( + "async+redis+unix:///tmp/limits.redis.sock", + {}, + pytest.lazy_fixture("redis_uds"), + marks=pytest.mark.redis, + ), + pytest.param( + "async+redis+unix://:password/tmp/limits.redis.sock", + {}, + pytest.lazy_fixture("redis_uds"), + marks=pytest.mark.redis, + ), + pytest.param( + "async+memcached://localhost:22122", + {}, + pytest.lazy_fixture("memcached"), + marks=pytest.mark.memcached, + ), + pytest.param( + "async+memcached://localhost:22122,localhost:22123", + {}, + pytest.lazy_fixture("memcached_cluster"), + marks=pytest.mark.memcached, + ), + pytest.param( + "async+redis+sentinel://localhost:26379", + {"service_name": "localhost-redis-sentinel"}, + pytest.lazy_fixture("redis_sentinel"), + marks=pytest.mark.redis_sentinel, + ), + pytest.param( + "async+redis+sentinel://localhost:26379/localhost-redis-sentinel", + {}, + pytest.lazy_fixture("redis_sentinel"), + marks=pytest.mark.redis_sentinel, + ), + pytest.param( + "async+redis+sentinel://:sekret@localhost:36379/localhost-redis-sentinel", + {}, + pytest.lazy_fixture("redis_sentinel_auth"), + marks=pytest.mark.redis_sentinel, + ), + pytest.param( + "async+redis+cluster://localhost:7001/", + {}, + pytest.lazy_fixture("redis_cluster"), + marks=pytest.mark.redis_cluster, + ), + pytest.param( + "async+mongodb://localhost:37017/", + {}, + pytest.lazy_fixture("mongodb"), + marks=pytest.mark.mongodb, + ), + ], +) + +async_moving_window_storage = pytest.mark.parametrize( + "uri, args, fixture", + [ + ("async+memory://", {}, None), + pytest.param( + "async+redis://localhost:7379", + {}, + pytest.lazy_fixture("redis_basic"), + marks=pytest.mark.redis, + ), + pytest.param( + "async+redis+unix:///tmp/limits.redis.sock", + {}, + pytest.lazy_fixture("redis_uds"), + marks=pytest.mark.redis, + ), + pytest.param( + "async+redis+unix://:password/tmp/limits.redis.sock", + {}, + pytest.lazy_fixture("redis_uds"), + marks=pytest.mark.redis, + ), + pytest.param( + "async+redis+sentinel://localhost:26379", + {"service_name": "localhost-redis-sentinel"}, + pytest.lazy_fixture("redis_sentinel"), + marks=pytest.mark.redis_sentinel, + ), + pytest.param( + "async+redis+sentinel://localhost:26379/localhost-redis-sentinel", + {}, + pytest.lazy_fixture("redis_sentinel"), + marks=pytest.mark.redis_sentinel, + ), + pytest.param( + "async+redis+sentinel://:sekret@localhost:36379/localhost-redis-sentinel", + {}, + pytest.lazy_fixture("redis_sentinel_auth"), + marks=pytest.mark.redis_sentinel, + ), + pytest.param( + "async+redis+cluster://localhost:7001/", + {}, + pytest.lazy_fixture("redis_cluster"), + marks=pytest.mark.redis_cluster, + ), + pytest.param( + "async+mongodb://localhost:37017/", + {}, + pytest.lazy_fixture("mongodb"), + marks=pytest.mark.mongodb, + ), + ], +)