-
Notifications
You must be signed in to change notification settings - Fork 285
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Implement serverside rate limiter (#4431)
* Spike out serverside rate limiting middleware * Refactor to keep multiple active partitions * Create StrawberryRateLimiter extension * Add rate limiter tests * Update extension to run synchronously * Ensure rate limiter extension works for both sync and async resolvers * Clean up type annotations * Improve tests and reset behavior * Clarify edge cases in testing and documentation * Use pop instead of `del` * Remove extension implementation
- Loading branch information
1 parent
02cb662
commit 85b6a1d
Showing
2 changed files
with
294 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
import time | ||
from collections import defaultdict | ||
from functools import partial | ||
from typing import Any, DefaultDict, List, Optional | ||
|
||
from phoenix.exceptions import PhoenixException | ||
|
||
|
||
class UnavailableTokensError(PhoenixException): | ||
pass | ||
|
||
|
||
class TokenBucket: | ||
""" | ||
An implementation of the token-bucket algorithm for use as a rate limiter. | ||
Args: | ||
per_second_request_rate (float): The allowed request rate. | ||
enforcement_window_minutes (float): The time window over which the rate limit is enforced. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
per_second_request_rate: float, | ||
enforcement_window_seconds: float = 1, | ||
): | ||
self.enforcement_window = enforcement_window_seconds | ||
self.rate = per_second_request_rate | ||
|
||
now = time.time() | ||
self.last_checked = now | ||
self.tokens = self.max_tokens() | ||
|
||
def max_tokens(self) -> float: | ||
return self.rate * self.enforcement_window | ||
|
||
def available_tokens(self) -> float: | ||
now = time.time() | ||
time_since_last_checked = now - self.last_checked | ||
self.tokens = min(self.max_tokens(), self.rate * time_since_last_checked + self.tokens) | ||
self.last_checked = now | ||
return self.tokens | ||
|
||
def make_request_if_ready(self) -> None: | ||
if self.available_tokens() < 1: | ||
raise UnavailableTokensError | ||
self.tokens -= 1 | ||
|
||
|
||
class ServerRateLimiter: | ||
""" | ||
This rate limiter holds a cache of token buckets that enforce rate limits. | ||
The cache is kept in partitions that rotate every `partition_seconds`. Each user's rate limiter | ||
can be accessed from all active partitions, the number of active partitions is set with | ||
`active_partitions`. This guarantees that a user's rate limiter will sit in the cache for at | ||
least: | ||
minimum_cache_lifetime = (active_partitions - 1) * partition_seconds | ||
Every time the cache is accessed, inactive partitions are purged. If enough time has passed, | ||
the entire cache is purged. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
per_second_rate_limit: float = 0.5, | ||
enforcement_window_seconds: float = 5, | ||
partition_seconds: float = 60, | ||
active_partitions: int = 2, | ||
): | ||
self.bucket_factory = partial( | ||
TokenBucket, | ||
per_second_request_rate=per_second_rate_limit, | ||
enforcement_window_seconds=enforcement_window_seconds, | ||
) | ||
self.partition_seconds = partition_seconds | ||
self.active_partitions = active_partitions | ||
self.num_partitions = active_partitions + 2 # two overflow partitions to avoid edge cases | ||
self._reset_rate_limiters() | ||
self._last_cleanup_time = time.time() | ||
|
||
def _reset_rate_limiters(self) -> None: | ||
self.cache_partitions: List[DefaultDict[Any, TokenBucket]] = [ | ||
defaultdict(self.bucket_factory) for _ in range(self.num_partitions) | ||
] | ||
|
||
def _current_partition_index(self, timestamp: float) -> int: | ||
return ( | ||
int(timestamp // self.partition_seconds) % self.num_partitions | ||
) # a cyclic bucket index | ||
|
||
def _active_partition_indices(self, current_index: int) -> List[int]: | ||
return [(current_index - ii) % self.num_partitions for ii in range(self.active_partitions)] | ||
|
||
def _inactive_partition_indices(self, current_index: int) -> List[int]: | ||
active_indices = set(self._active_partition_indices(current_index)) | ||
all_indices = set(range(self.num_partitions)) | ||
return list(all_indices - active_indices) | ||
|
||
def _cleanup_expired_limiters(self, request_time: float) -> None: | ||
time_since_last_cleanup = request_time - self._last_cleanup_time | ||
if time_since_last_cleanup >= ((self.num_partitions - 1) * self.partition_seconds): | ||
# Reset the cache to avoid "looping" back to the same partitions | ||
self._reset_rate_limiters() | ||
self._last_cleanup_time = request_time | ||
return | ||
|
||
current_partition_index = self._current_partition_index(request_time) | ||
inactive_indices = self._inactive_partition_indices(current_partition_index) | ||
for ii in inactive_indices: | ||
self.cache_partitions[ii] = defaultdict(self.bucket_factory) | ||
self._last_cleanup_time = request_time | ||
|
||
def _fetch_token_bucket(self, key: str, request_time: float) -> TokenBucket: | ||
current_partition_index = self._current_partition_index(request_time) | ||
active_indices = self._active_partition_indices(current_partition_index) | ||
bucket: Optional[TokenBucket] = None | ||
for ii in active_indices: | ||
partition = self.cache_partitions[ii] | ||
if key in partition: | ||
bucket = partition.pop(key) | ||
break | ||
|
||
current_partition = self.cache_partitions[current_partition_index] | ||
if key not in current_partition and bucket is not None: | ||
current_partition[key] = bucket | ||
return current_partition[key] | ||
|
||
def make_request(self, key: str) -> None: | ||
request_time = time.time() | ||
self._cleanup_expired_limiters(request_time) | ||
rate_limiter = self._fetch_token_bucket(key, request_time) | ||
rate_limiter.make_request_if_ready() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import time | ||
from contextlib import contextmanager | ||
from typing import Optional | ||
from unittest import mock | ||
|
||
import pytest | ||
|
||
from phoenix.server.rate_limiters import ServerRateLimiter, TokenBucket, UnavailableTokensError | ||
|
||
|
||
@contextmanager | ||
def freeze_time(frozen_time: Optional[float] = None): | ||
frozen_time = time.time() if frozen_time is None else frozen_time | ||
|
||
with mock.patch("time.time") as mock_time: | ||
mock_time.return_value = frozen_time | ||
yield mock_time | ||
|
||
|
||
@contextmanager | ||
def warp_time(start: Optional[float]): | ||
sleeps = [0] | ||
current_time = start | ||
start = time.time() if start is None else start | ||
|
||
def instant_sleep(time): | ||
nonlocal sleeps | ||
sleeps.append(time) | ||
|
||
def time_warp(): | ||
try: | ||
nonlocal current_time | ||
nonlocal sleeps | ||
current_time += sleeps.pop() | ||
return current_time | ||
except IndexError: | ||
return current_time | ||
|
||
with mock.patch("time.time") as mock_time: | ||
with mock.patch("time.sleep") as mock_sleep: | ||
mock_sleep.side_effect = instant_sleep | ||
mock_time.side_effect = time_warp | ||
yield | ||
|
||
|
||
def test_token_bucket_gains_tokens_over_time(): | ||
start = time.time() | ||
|
||
with freeze_time(start): | ||
bucket = TokenBucket(per_second_request_rate=1, enforcement_window_seconds=30) | ||
bucket.tokens = 0 # start at 0 | ||
|
||
with freeze_time(start + 5): | ||
assert bucket.available_tokens() == 5 | ||
|
||
with freeze_time(start + 10): | ||
assert bucket.available_tokens() == 10 | ||
|
||
|
||
def test_token_bucket_can_max_out_on_requests(): | ||
start = time.time() | ||
|
||
with freeze_time(start): | ||
bucket = TokenBucket(per_second_request_rate=1, enforcement_window_seconds=120) | ||
bucket.tokens = 0 # start at 0 | ||
|
||
with freeze_time(start + 30): | ||
assert bucket.available_tokens() == 30 | ||
|
||
with freeze_time(start + 120): | ||
assert bucket.available_tokens() == 120 | ||
|
||
with freeze_time(start + 130): | ||
assert bucket.available_tokens() == 120 # should max out at 120 | ||
|
||
|
||
def test_token_bucket_spends_tokens(): | ||
start = time.time() | ||
|
||
with freeze_time(start): | ||
bucket = TokenBucket(per_second_request_rate=1, enforcement_window_seconds=10) | ||
bucket.tokens = 0 # start at 0 | ||
|
||
with freeze_time(start + 3): | ||
assert bucket.available_tokens() == 3 | ||
bucket.make_request_if_ready() | ||
assert bucket.available_tokens() == 2 | ||
|
||
|
||
def test_token_bucket_cannot_spend_unavailable_tokens(): | ||
start = time.time() | ||
|
||
with freeze_time(start): | ||
bucket = TokenBucket(per_second_request_rate=1, enforcement_window_seconds=2) | ||
bucket.tokens = 0 # start at 0 | ||
|
||
with freeze_time(start + 1): | ||
assert bucket.available_tokens() == 1 | ||
bucket.make_request_if_ready() # should spend one token | ||
with pytest.raises(UnavailableTokensError): | ||
bucket.make_request_if_ready() # should raise since no tokens left | ||
|
||
|
||
def test_rate_limiter_cleans_up_old_partitions(): | ||
start = time.time() | ||
|
||
with freeze_time(start): | ||
limiter = ServerRateLimiter( | ||
per_second_rate_limit=1, | ||
enforcement_window_seconds=100, | ||
partition_seconds=10, | ||
active_partitions=2, | ||
) | ||
limiter.make_request("test_key_1") | ||
limiter.make_request("test_key_2") | ||
limiter.make_request("test_key_3") | ||
limiter.make_request("test_key_4") | ||
partition_sizes = [len(partition) for partition in limiter.cache_partitions] | ||
assert sum(partition_sizes) == 4 | ||
|
||
interval = limiter.partition_seconds | ||
with freeze_time(start + interval): | ||
# after a partition interval, the cache rolls over to a second active partition | ||
limiter.make_request("test_key_4") # moves test_key_4 to current partition | ||
limiter.make_request("test_key_5") # creates test_key_5 in current partition | ||
partition_sizes = [len(partition) for partition in limiter.cache_partitions] | ||
assert sum(partition_sizes) == 5 | ||
assert 2 in partition_sizes # two rate limiters in current cache partition | ||
assert 3 in partition_sizes # three rate limiters remaining in original partition | ||
|
||
with freeze_time(start + interval + (limiter.num_partitions * interval)): | ||
limiter.make_request("fresh_key") # when "looping" partitions, cache should be reset | ||
assert sum(len(partition) for partition in limiter.cache_partitions) == 1 | ||
|
||
|
||
def test_rate_limiter_caches_token_buckets(): | ||
start = time.time() | ||
|
||
with freeze_time(start): | ||
limiter = ServerRateLimiter( | ||
per_second_rate_limit=0.5, | ||
enforcement_window_seconds=20, | ||
partition_seconds=1, | ||
active_partitions=2, | ||
) | ||
limiter.make_request("test_key") | ||
limiter.make_request("test_key") | ||
limiter.make_request("test_key") | ||
token_bucket = None | ||
for partition in limiter.cache_partitions: | ||
if "test_key" in partition: | ||
token_bucket = partition["test_key"] | ||
break | ||
assert token_bucket is not None, "Token bucket for 'test_key' should exist" | ||
assert token_bucket.tokens == 7 | ||
|
||
with freeze_time(start + 1): | ||
assert token_bucket.available_tokens() == 7.5 | ||
limiter.make_request("test_key") | ||
assert token_bucket.tokens == 6.5 |