Skip to content

Commit

Permalink
Fix expiry for memcached & mongodb
Browse files Browse the repository at this point in the history
- Ensure expiry is in localtime for mongodb
- Only extend window expiry when using fixed elastic with memcached
- Improve strategy tests by removing mocked time
  • Loading branch information
alisaifee committed Jan 3, 2022
1 parent 036110a commit 9763183
Show file tree
Hide file tree
Showing 13 changed files with 450 additions and 390 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@ on:

jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8]
marker: [not integration]
os: [ubuntu-latest]
include:
- python-version: 3.9
marker: ''
os: ubuntu-latest
- python-version: "3.10"
marker: 'not ((redis or redis_sentinel or redis_cluster) and asynchronous)'
os: ubuntu-latest
runs-on: "${{ matrix.os }}"
steps:
- uses: actions/checkout@v2
- uses: docker-practice/actions-setup-docker@master
Expand Down
6 changes: 1 addition & 5 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
*.pyc
*.log
cover/*
.mypy_cache/*
.coverage*
.test_env
.idea
build/
dist/
htmlcov
*egg-info*
*.rdb
redis-git
.python-version
# gae test files
google_appengine
google
.*.swp
12 changes: 9 additions & 3 deletions limits/aio/storage/memcached.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,21 @@ async def incr(self, key: str, expiry: int, elastic_expiry=False) -> int:

if elastic_expiry:
await storage.touch(limit_key, exptime=expiry)
await storage.set(
expire_key,
str(expiry + time.time()).encode("utf-8"),
exptime=expiry,
noreply=False,
)

return value
else:
await storage.set(
expire_key,
str(expiry + time.time()).encode("utf-8"),
exptime=expiry,
noreply=False,
)

return value

return 1

async def get_expiry(self, key: str) -> int:
Expand Down
3 changes: 2 additions & 1 deletion limits/aio/storage/mongodb.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import calendar
import datetime
import functools
import time
Expand Down Expand Up @@ -116,7 +117,7 @@ async def get_expiry(self, key: str) -> int:
counter = await self.database.counters.find_one({"_id": key})
expiry = counter["expireAt"] if counter else datetime.datetime.utcnow()

return int(time.mktime(expiry.timetuple()))
return calendar.timegm(expiry.timetuple())

async def get(self, key: str):
"""
Expand Down
10 changes: 9 additions & 1 deletion limits/storage/memcached.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,22 @@ def incr(self, key: str, expiry: int, elastic_expiry=False) -> int:
value = self.storage.incr(key, 1) or 1
if elastic_expiry:
self.call_memcached_func(self.storage.touch, key, expiry)
self.call_memcached_func(
self.storage.set,
key + "/expires",
expiry + time.time(),
expire=expiry,
noreply=False,
)
return value
else:
self.call_memcached_func(
self.storage.set,
key + "/expires",
expiry + time.time(),
expire=expiry,
noreply=False,
)
return value
return 1

def get_expiry(self, key: str) -> int:
Expand Down
3 changes: 2 additions & 1 deletion limits/storage/mongodb.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import calendar
import datetime
import time
from typing import Any, Dict, Tuple
Expand Down Expand Up @@ -74,7 +75,7 @@ def get_expiry(self, key: str) -> int:
counter = self.counters.find_one({"_id": key})
expiry = counter["expireAt"] if counter else datetime.datetime.utcnow()

return int(time.mktime(expiry.timetuple()))
return calendar.timegm(expiry.timetuple())

def get(self, key: str):
"""
Expand Down
24 changes: 0 additions & 24 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +0,0 @@
import functools
import math
import platform
import time
import unittest


def skip_if_pypy(fn):
return unittest.skipIf(
platform.python_implementation().lower() == "pypy", reason="Skipped for pypy"
)(fn)


def fixed_start(fn):
@functools.wraps(fn)
def __inner(*a, **k):
start = time.time()

while time.time() < math.ceil(start):
time.sleep(0.01)

return fn(*a, **k)

return __inner
6 changes: 5 additions & 1 deletion tests/aio/storage/test_memcached.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
FixedWindowRateLimiter,
)
from limits.storage import storage_from_string
from tests import fixed_start
from tests.utils import fixed_start


@pytest.mark.flaky
Expand All @@ -34,10 +34,12 @@ async def test_fixed_window(self):
per_min = RateLimitItemPerSecond(10)
start = time.time()
count = 0

while time.time() - start < 0.5 and count < 10:
assert await limiter.hit(per_min)
count += 1
assert not await limiter.hit(per_min)

while time.time() - start <= 1:
await asyncio.sleep(0.1)
assert await limiter.hit(per_min)
Expand All @@ -50,10 +52,12 @@ async def test_fixed_window_cluster(self):
per_min = RateLimitItemPerSecond(10)
start = time.time()
count = 0

while time.time() - start < 0.5 and count < 10:
assert await limiter.hit(per_min)
count += 1
assert not await limiter.hit(per_min)

while time.time() - start <= 1:
await asyncio.sleep(0.1)
assert await limiter.hit(per_min)
Expand Down
4 changes: 2 additions & 2 deletions tests/aio/storage/test_mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ async def test_init_options(self, mocker):

constructor = mocker.spy(motor.motor_asyncio, "AsyncIOMotorClient")
assert await storage_from_string(
f"async+{self.storage_url}", connectTimeoutMS=1
f"async+{self.storage_url}", socketTimeoutMS=100
).check()
assert constructor.call_args[1]["connectTimeoutMS"] == 1
assert constructor.call_args[1]["socketTimeoutMS"] == 100

@pytest.mark.asyncio
async def test_fixed_window(self):
Expand Down
Loading

0 comments on commit 9763183

Please sign in to comment.