diff --git a/src/phoenix/server/rate_limiters.py b/src/phoenix/server/rate_limiters.py new file mode 100644 index 0000000000..859a84a424 --- /dev/null +++ b/src/phoenix/server/rate_limiters.py @@ -0,0 +1,134 @@ +import time +from collections import defaultdict +from functools import partial +from typing import Any, DefaultDict, List, Optional + +from phoenix.exceptions import PhoenixException + + +class UnavailableTokensError(PhoenixException): + pass + + +class TokenBucket: + """ + An implementation of the token-bucket algorithm for use as a rate limiter. + + Args: + per_second_request_rate (float): The allowed request rate. + enforcement_window_minutes (float): The time window over which the rate limit is enforced. + """ + + def __init__( + self, + per_second_request_rate: float, + enforcement_window_seconds: float = 1, + ): + self.enforcement_window = enforcement_window_seconds + self.rate = per_second_request_rate + + now = time.time() + self.last_checked = now + self.tokens = self.max_tokens() + + def max_tokens(self) -> float: + return self.rate * self.enforcement_window + + def available_tokens(self) -> float: + now = time.time() + time_since_last_checked = now - self.last_checked + self.tokens = min(self.max_tokens(), self.rate * time_since_last_checked + self.tokens) + self.last_checked = now + return self.tokens + + def make_request_if_ready(self) -> None: + if self.available_tokens() < 1: + raise UnavailableTokensError + self.tokens -= 1 + + +class ServerRateLimiter: + """ + This rate limiter holds a cache of token buckets that enforce rate limits. + + The cache is kept in partitions that rotate every `partition_seconds`. Each user's rate limiter + can be accessed from all active partitions, the number of active partitions is set with + `active_partitions`. This guarantees that a user's rate limiter will sit in the cache for at + least: + + minimum_cache_lifetime = (active_partitions - 1) * partition_seconds + + Every time the cache is accessed, inactive partitions are purged. If enough time has passed, + the entire cache is purged. + """ + + def __init__( + self, + per_second_rate_limit: float = 0.5, + enforcement_window_seconds: float = 5, + partition_seconds: float = 60, + active_partitions: int = 2, + ): + self.bucket_factory = partial( + TokenBucket, + per_second_request_rate=per_second_rate_limit, + enforcement_window_seconds=enforcement_window_seconds, + ) + self.partition_seconds = partition_seconds + self.active_partitions = active_partitions + self.num_partitions = active_partitions + 2 # two overflow partitions to avoid edge cases + self._reset_rate_limiters() + self._last_cleanup_time = time.time() + + def _reset_rate_limiters(self) -> None: + self.cache_partitions: List[DefaultDict[Any, TokenBucket]] = [ + defaultdict(self.bucket_factory) for _ in range(self.num_partitions) + ] + + def _current_partition_index(self, timestamp: float) -> int: + return ( + int(timestamp // self.partition_seconds) % self.num_partitions + ) # a cyclic bucket index + + def _active_partition_indices(self, current_index: int) -> List[int]: + return [(current_index - ii) % self.num_partitions for ii in range(self.active_partitions)] + + def _inactive_partition_indices(self, current_index: int) -> List[int]: + active_indices = set(self._active_partition_indices(current_index)) + all_indices = set(range(self.num_partitions)) + return list(all_indices - active_indices) + + def _cleanup_expired_limiters(self, request_time: float) -> None: + time_since_last_cleanup = request_time - self._last_cleanup_time + if time_since_last_cleanup >= ((self.num_partitions - 1) * self.partition_seconds): + # Reset the cache to avoid "looping" back to the same partitions + self._reset_rate_limiters() + self._last_cleanup_time = request_time + return + + current_partition_index = self._current_partition_index(request_time) + inactive_indices = self._inactive_partition_indices(current_partition_index) + for ii in inactive_indices: + self.cache_partitions[ii] = defaultdict(self.bucket_factory) + self._last_cleanup_time = request_time + + def _fetch_token_bucket(self, key: str, request_time: float) -> TokenBucket: + current_partition_index = self._current_partition_index(request_time) + active_indices = self._active_partition_indices(current_partition_index) + bucket: Optional[TokenBucket] = None + for ii in active_indices: + partition = self.cache_partitions[ii] + if key in partition: + bucket = partition.pop(key) + break + + current_partition = self.cache_partitions[current_partition_index] + if key not in current_partition and bucket is not None: + current_partition[key] = bucket + return current_partition[key] + + def make_request(self, key: str) -> None: + request_time = time.time() + self._cleanup_expired_limiters(request_time) + rate_limiter = self._fetch_token_bucket(key, request_time) + rate_limiter.make_request_if_ready() diff --git a/tests/server/test_rate_limiters.py b/tests/server/test_rate_limiters.py new file mode 100644 index 0000000000..1aa4f750cf --- /dev/null +++ b/tests/server/test_rate_limiters.py @@ -0,0 +1,160 @@ +import time +from contextlib import contextmanager +from typing import Optional +from unittest import mock + +import pytest + +from phoenix.server.rate_limiters import ServerRateLimiter, TokenBucket, UnavailableTokensError + + +@contextmanager +def freeze_time(frozen_time: Optional[float] = None): + frozen_time = time.time() if frozen_time is None else frozen_time + + with mock.patch("time.time") as mock_time: + mock_time.return_value = frozen_time + yield mock_time + + +@contextmanager +def warp_time(start: Optional[float]): + sleeps = [0] + current_time = start + start = time.time() if start is None else start + + def instant_sleep(time): + nonlocal sleeps + sleeps.append(time) + + def time_warp(): + try: + nonlocal current_time + nonlocal sleeps + current_time += sleeps.pop() + return current_time + except IndexError: + return current_time + + with mock.patch("time.time") as mock_time: + with mock.patch("time.sleep") as mock_sleep: + mock_sleep.side_effect = instant_sleep + mock_time.side_effect = time_warp + yield + + +def test_token_bucket_gains_tokens_over_time(): + start = time.time() + + with freeze_time(start): + bucket = TokenBucket(per_second_request_rate=1, enforcement_window_seconds=30) + bucket.tokens = 0 # start at 0 + + with freeze_time(start + 5): + assert bucket.available_tokens() == 5 + + with freeze_time(start + 10): + assert bucket.available_tokens() == 10 + + +def test_token_bucket_can_max_out_on_requests(): + start = time.time() + + with freeze_time(start): + bucket = TokenBucket(per_second_request_rate=1, enforcement_window_seconds=120) + bucket.tokens = 0 # start at 0 + + with freeze_time(start + 30): + assert bucket.available_tokens() == 30 + + with freeze_time(start + 120): + assert bucket.available_tokens() == 120 + + with freeze_time(start + 130): + assert bucket.available_tokens() == 120 # should max out at 120 + + +def test_token_bucket_spends_tokens(): + start = time.time() + + with freeze_time(start): + bucket = TokenBucket(per_second_request_rate=1, enforcement_window_seconds=10) + bucket.tokens = 0 # start at 0 + + with freeze_time(start + 3): + assert bucket.available_tokens() == 3 + bucket.make_request_if_ready() + assert bucket.available_tokens() == 2 + + +def test_token_bucket_cannot_spend_unavailable_tokens(): + start = time.time() + + with freeze_time(start): + bucket = TokenBucket(per_second_request_rate=1, enforcement_window_seconds=2) + bucket.tokens = 0 # start at 0 + + with freeze_time(start + 1): + assert bucket.available_tokens() == 1 + bucket.make_request_if_ready() # should spend one token + with pytest.raises(UnavailableTokensError): + bucket.make_request_if_ready() # should raise since no tokens left + + +def test_rate_limiter_cleans_up_old_partitions(): + start = time.time() + + with freeze_time(start): + limiter = ServerRateLimiter( + per_second_rate_limit=1, + enforcement_window_seconds=100, + partition_seconds=10, + active_partitions=2, + ) + limiter.make_request("test_key_1") + limiter.make_request("test_key_2") + limiter.make_request("test_key_3") + limiter.make_request("test_key_4") + partition_sizes = [len(partition) for partition in limiter.cache_partitions] + assert sum(partition_sizes) == 4 + + interval = limiter.partition_seconds + with freeze_time(start + interval): + # after a partition interval, the cache rolls over to a second active partition + limiter.make_request("test_key_4") # moves test_key_4 to current partition + limiter.make_request("test_key_5") # creates test_key_5 in current partition + partition_sizes = [len(partition) for partition in limiter.cache_partitions] + assert sum(partition_sizes) == 5 + assert 2 in partition_sizes # two rate limiters in current cache partition + assert 3 in partition_sizes # three rate limiters remaining in original partition + + with freeze_time(start + interval + (limiter.num_partitions * interval)): + limiter.make_request("fresh_key") # when "looping" partitions, cache should be reset + assert sum(len(partition) for partition in limiter.cache_partitions) == 1 + + +def test_rate_limiter_caches_token_buckets(): + start = time.time() + + with freeze_time(start): + limiter = ServerRateLimiter( + per_second_rate_limit=0.5, + enforcement_window_seconds=20, + partition_seconds=1, + active_partitions=2, + ) + limiter.make_request("test_key") + limiter.make_request("test_key") + limiter.make_request("test_key") + token_bucket = None + for partition in limiter.cache_partitions: + if "test_key" in partition: + token_bucket = partition["test_key"] + break + assert token_bucket is not None, "Token bucket for 'test_key' should exist" + assert token_bucket.tokens == 7 + + with freeze_time(start + 1): + assert token_bucket.available_tokens() == 7.5 + limiter.make_request("test_key") + assert token_bucket.tokens == 6.5