Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement serverside rate limiter #4431

Merged
merged 11 commits into from
Aug 30, 2024
134 changes: 134 additions & 0 deletions src/phoenix/server/rate_limiters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import time
from collections import defaultdict
from functools import partial
from typing import Any, DefaultDict, List, Optional

from phoenix.exceptions import PhoenixException


class UnavailableTokensError(PhoenixException):
pass


class TokenBucket:
"""
An implementation of the token-bucket algorithm for use as a rate limiter.

Args:
per_second_request_rate (float): The allowed request rate.
enforcement_window_minutes (float): The time window over which the rate limit is enforced.
"""

def __init__(
self,
per_second_request_rate: float,
enforcement_window_seconds: float = 1,
):
self.enforcement_window = enforcement_window_seconds
self.rate = per_second_request_rate

now = time.time()
self.last_checked = now
self.tokens = self.max_tokens()

def max_tokens(self) -> float:
return self.rate * self.enforcement_window

def available_tokens(self) -> float:
now = time.time()
time_since_last_checked = now - self.last_checked
self.tokens = min(self.max_tokens(), self.rate * time_since_last_checked + self.tokens)
self.last_checked = now
return self.tokens

def make_request_if_ready(self) -> None:
if self.available_tokens() < 1:
raise UnavailableTokensError
self.tokens -= 1


class ServerRateLimiter:
"""
This rate limiter holds a cache of token buckets that enforce rate limits.

The cache is kept in partitions that rotate every `partition_seconds`. Each user's rate limiter
can be accessed from all active partitions, the number of active partitions is set with
`active_partitions`. This guarantees that a user's rate limiter will sit in the cache for at
least:

minimum_cache_lifetime = (active_partitions - 1) * partition_seconds

Every time the cache is accessed, inactive partitions are purged. If enough time has passed,
the entire cache is purged.
"""

def __init__(
self,
per_second_rate_limit: float = 0.5,
enforcement_window_seconds: float = 5,
partition_seconds: float = 60,
active_partitions: int = 2,
):
self.bucket_factory = partial(
TokenBucket,
per_second_request_rate=per_second_rate_limit,
enforcement_window_seconds=enforcement_window_seconds,
)
self.partition_seconds = partition_seconds
self.active_partitions = active_partitions
self.num_partitions = active_partitions + 2 # two overflow partitions to avoid edge cases
self._reset_rate_limiters()
self._last_cleanup_time = time.time()

def _reset_rate_limiters(self) -> None:
self.cache_partitions: List[DefaultDict[Any, TokenBucket]] = [
defaultdict(self.bucket_factory) for _ in range(self.num_partitions)
]

def _current_partition_index(self, timestamp: float) -> int:
return (
int(timestamp // self.partition_seconds) % self.num_partitions
) # a cyclic bucket index

def _active_partition_indices(self, current_index: int) -> List[int]:
return [(current_index - ii) % self.num_partitions for ii in range(self.active_partitions)]
axiomofjoy marked this conversation as resolved.
Show resolved Hide resolved

def _inactive_partition_indices(self, current_index: int) -> List[int]:
active_indices = set(self._active_partition_indices(current_index))
all_indices = set(range(self.num_partitions))
return list(all_indices - active_indices)

def _cleanup_expired_limiters(self, request_time: float) -> None:
time_since_last_cleanup = request_time - self._last_cleanup_time
if time_since_last_cleanup >= ((self.num_partitions - 1) * self.partition_seconds):
# Reset the cache to avoid "looping" back to the same partitions
self._reset_rate_limiters()
self._last_cleanup_time = request_time
return

current_partition_index = self._current_partition_index(request_time)
inactive_indices = self._inactive_partition_indices(current_partition_index)
for ii in inactive_indices:
self.cache_partitions[ii] = defaultdict(self.bucket_factory)
self._last_cleanup_time = request_time

def _fetch_token_bucket(self, key: str, request_time: float) -> TokenBucket:
current_partition_index = self._current_partition_index(request_time)
active_indices = self._active_partition_indices(current_partition_index)
bucket: Optional[TokenBucket] = None
for ii in active_indices:
partition = self.cache_partitions[ii]
if key in partition:
bucket = partition.pop(key)
break

current_partition = self.cache_partitions[current_partition_index]
if key not in current_partition and bucket is not None:
current_partition[key] = bucket
return current_partition[key]

def make_request(self, key: str) -> None:
request_time = time.time()
self._cleanup_expired_limiters(request_time)
rate_limiter = self._fetch_token_bucket(key, request_time)
rate_limiter.make_request_if_ready()
160 changes: 160 additions & 0 deletions tests/server/test_rate_limiters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import time
from contextlib import contextmanager
from typing import Optional
from unittest import mock

import pytest

from phoenix.server.rate_limiters import ServerRateLimiter, TokenBucket, UnavailableTokensError


@contextmanager
def freeze_time(frozen_time: Optional[float] = None):
frozen_time = time.time() if frozen_time is None else frozen_time

with mock.patch("time.time") as mock_time:
mock_time.return_value = frozen_time
yield mock_time


@contextmanager
def warp_time(start: Optional[float]):
sleeps = [0]
current_time = start
start = time.time() if start is None else start

def instant_sleep(time):
nonlocal sleeps
sleeps.append(time)

def time_warp():
try:
nonlocal current_time
nonlocal sleeps
current_time += sleeps.pop()
return current_time
except IndexError:
return current_time

with mock.patch("time.time") as mock_time:
with mock.patch("time.sleep") as mock_sleep:
mock_sleep.side_effect = instant_sleep
mock_time.side_effect = time_warp
yield


def test_token_bucket_gains_tokens_over_time():
start = time.time()

with freeze_time(start):
bucket = TokenBucket(per_second_request_rate=1, enforcement_window_seconds=30)
bucket.tokens = 0 # start at 0

with freeze_time(start + 5):
assert bucket.available_tokens() == 5

with freeze_time(start + 10):
assert bucket.available_tokens() == 10


def test_token_bucket_can_max_out_on_requests():
start = time.time()

with freeze_time(start):
bucket = TokenBucket(per_second_request_rate=1, enforcement_window_seconds=120)
bucket.tokens = 0 # start at 0

with freeze_time(start + 30):
assert bucket.available_tokens() == 30

with freeze_time(start + 120):
assert bucket.available_tokens() == 120

with freeze_time(start + 130):
assert bucket.available_tokens() == 120 # should max out at 120


def test_token_bucket_spends_tokens():
start = time.time()

with freeze_time(start):
bucket = TokenBucket(per_second_request_rate=1, enforcement_window_seconds=10)
bucket.tokens = 0 # start at 0

with freeze_time(start + 3):
assert bucket.available_tokens() == 3
bucket.make_request_if_ready()
assert bucket.available_tokens() == 2


def test_token_bucket_cannot_spend_unavailable_tokens():
start = time.time()

with freeze_time(start):
bucket = TokenBucket(per_second_request_rate=1, enforcement_window_seconds=2)
bucket.tokens = 0 # start at 0

with freeze_time(start + 1):
assert bucket.available_tokens() == 1
bucket.make_request_if_ready() # should spend one token
with pytest.raises(UnavailableTokensError):
bucket.make_request_if_ready() # should raise since no tokens left


def test_rate_limiter_cleans_up_old_partitions():
start = time.time()

with freeze_time(start):
limiter = ServerRateLimiter(
per_second_rate_limit=1,
enforcement_window_seconds=100,
partition_seconds=10,
active_partitions=2,
)
limiter.make_request("test_key_1")
limiter.make_request("test_key_2")
limiter.make_request("test_key_3")
limiter.make_request("test_key_4")
partition_sizes = [len(partition) for partition in limiter.cache_partitions]
assert sum(partition_sizes) == 4

interval = limiter.partition_seconds
with freeze_time(start + interval):
# after a partition interval, the cache rolls over to a second active partition
limiter.make_request("test_key_4") # moves test_key_4 to current partition
limiter.make_request("test_key_5") # creates test_key_5 in current partition
partition_sizes = [len(partition) for partition in limiter.cache_partitions]
assert sum(partition_sizes) == 5
assert 2 in partition_sizes # two rate limiters in current cache partition
assert 3 in partition_sizes # three rate limiters remaining in original partition

with freeze_time(start + interval + (limiter.num_partitions * interval)):
limiter.make_request("fresh_key") # when "looping" partitions, cache should be reset
assert sum(len(partition) for partition in limiter.cache_partitions) == 1


def test_rate_limiter_caches_token_buckets():
start = time.time()

with freeze_time(start):
limiter = ServerRateLimiter(
per_second_rate_limit=0.5,
enforcement_window_seconds=20,
partition_seconds=1,
active_partitions=2,
)
limiter.make_request("test_key")
limiter.make_request("test_key")
limiter.make_request("test_key")
token_bucket = None
for partition in limiter.cache_partitions:
if "test_key" in partition:
token_bucket = partition["test_key"]
break
assert token_bucket is not None, "Token bucket for 'test_key' should exist"
assert token_bucket.tokens == 7

with freeze_time(start + 1):
assert token_bucket.available_tokens() == 7.5
limiter.make_request("test_key")
assert token_bucket.tokens == 6.5
Loading