Skip to content

Commit

Permalink
feat: Implement serverside rate limiter (#4431)
Browse files Browse the repository at this point in the history
* Spike out serverside rate limiting middleware

* Refactor to keep multiple active partitions

* Create StrawberryRateLimiter extension

* Add rate limiter tests

* Update extension to run synchronously

* Ensure rate limiter extension works for both sync and async resolvers

* Clean up type annotations

* Improve tests and reset behavior

* Clarify edge cases in testing and documentation

* Use pop instead of `del`

* Remove extension implementation
  • Loading branch information
anticorrelator authored and Parker-Stafford committed Sep 4, 2024
1 parent 02cb662 commit 85b6a1d
Show file tree
Hide file tree
Showing 2 changed files with 294 additions and 0 deletions.
134 changes: 134 additions & 0 deletions src/phoenix/server/rate_limiters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import time
from collections import defaultdict
from functools import partial
from typing import Any, DefaultDict, List, Optional

from phoenix.exceptions import PhoenixException


class UnavailableTokensError(PhoenixException):
pass


class TokenBucket:
"""
An implementation of the token-bucket algorithm for use as a rate limiter.
Args:
per_second_request_rate (float): The allowed request rate.
enforcement_window_minutes (float): The time window over which the rate limit is enforced.
"""

def __init__(
self,
per_second_request_rate: float,
enforcement_window_seconds: float = 1,
):
self.enforcement_window = enforcement_window_seconds
self.rate = per_second_request_rate

now = time.time()
self.last_checked = now
self.tokens = self.max_tokens()

def max_tokens(self) -> float:
return self.rate * self.enforcement_window

def available_tokens(self) -> float:
now = time.time()
time_since_last_checked = now - self.last_checked
self.tokens = min(self.max_tokens(), self.rate * time_since_last_checked + self.tokens)
self.last_checked = now
return self.tokens

def make_request_if_ready(self) -> None:
if self.available_tokens() < 1:
raise UnavailableTokensError
self.tokens -= 1


class ServerRateLimiter:
"""
This rate limiter holds a cache of token buckets that enforce rate limits.
The cache is kept in partitions that rotate every `partition_seconds`. Each user's rate limiter
can be accessed from all active partitions, the number of active partitions is set with
`active_partitions`. This guarantees that a user's rate limiter will sit in the cache for at
least:
minimum_cache_lifetime = (active_partitions - 1) * partition_seconds
Every time the cache is accessed, inactive partitions are purged. If enough time has passed,
the entire cache is purged.
"""

def __init__(
self,
per_second_rate_limit: float = 0.5,
enforcement_window_seconds: float = 5,
partition_seconds: float = 60,
active_partitions: int = 2,
):
self.bucket_factory = partial(
TokenBucket,
per_second_request_rate=per_second_rate_limit,
enforcement_window_seconds=enforcement_window_seconds,
)
self.partition_seconds = partition_seconds
self.active_partitions = active_partitions
self.num_partitions = active_partitions + 2 # two overflow partitions to avoid edge cases
self._reset_rate_limiters()
self._last_cleanup_time = time.time()

def _reset_rate_limiters(self) -> None:
self.cache_partitions: List[DefaultDict[Any, TokenBucket]] = [
defaultdict(self.bucket_factory) for _ in range(self.num_partitions)
]

def _current_partition_index(self, timestamp: float) -> int:
return (
int(timestamp // self.partition_seconds) % self.num_partitions
) # a cyclic bucket index

def _active_partition_indices(self, current_index: int) -> List[int]:
return [(current_index - ii) % self.num_partitions for ii in range(self.active_partitions)]

def _inactive_partition_indices(self, current_index: int) -> List[int]:
active_indices = set(self._active_partition_indices(current_index))
all_indices = set(range(self.num_partitions))
return list(all_indices - active_indices)

def _cleanup_expired_limiters(self, request_time: float) -> None:
time_since_last_cleanup = request_time - self._last_cleanup_time
if time_since_last_cleanup >= ((self.num_partitions - 1) * self.partition_seconds):
# Reset the cache to avoid "looping" back to the same partitions
self._reset_rate_limiters()
self._last_cleanup_time = request_time
return

current_partition_index = self._current_partition_index(request_time)
inactive_indices = self._inactive_partition_indices(current_partition_index)
for ii in inactive_indices:
self.cache_partitions[ii] = defaultdict(self.bucket_factory)
self._last_cleanup_time = request_time

def _fetch_token_bucket(self, key: str, request_time: float) -> TokenBucket:
current_partition_index = self._current_partition_index(request_time)
active_indices = self._active_partition_indices(current_partition_index)
bucket: Optional[TokenBucket] = None
for ii in active_indices:
partition = self.cache_partitions[ii]
if key in partition:
bucket = partition.pop(key)
break

current_partition = self.cache_partitions[current_partition_index]
if key not in current_partition and bucket is not None:
current_partition[key] = bucket
return current_partition[key]

def make_request(self, key: str) -> None:
request_time = time.time()
self._cleanup_expired_limiters(request_time)
rate_limiter = self._fetch_token_bucket(key, request_time)
rate_limiter.make_request_if_ready()
160 changes: 160 additions & 0 deletions tests/server/test_rate_limiters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import time
from contextlib import contextmanager
from typing import Optional
from unittest import mock

import pytest

from phoenix.server.rate_limiters import ServerRateLimiter, TokenBucket, UnavailableTokensError


@contextmanager
def freeze_time(frozen_time: Optional[float] = None):
frozen_time = time.time() if frozen_time is None else frozen_time

with mock.patch("time.time") as mock_time:
mock_time.return_value = frozen_time
yield mock_time


@contextmanager
def warp_time(start: Optional[float]):
sleeps = [0]
current_time = start
start = time.time() if start is None else start

def instant_sleep(time):
nonlocal sleeps
sleeps.append(time)

def time_warp():
try:
nonlocal current_time
nonlocal sleeps
current_time += sleeps.pop()
return current_time
except IndexError:
return current_time

with mock.patch("time.time") as mock_time:
with mock.patch("time.sleep") as mock_sleep:
mock_sleep.side_effect = instant_sleep
mock_time.side_effect = time_warp
yield


def test_token_bucket_gains_tokens_over_time():
start = time.time()

with freeze_time(start):
bucket = TokenBucket(per_second_request_rate=1, enforcement_window_seconds=30)
bucket.tokens = 0 # start at 0

with freeze_time(start + 5):
assert bucket.available_tokens() == 5

with freeze_time(start + 10):
assert bucket.available_tokens() == 10


def test_token_bucket_can_max_out_on_requests():
start = time.time()

with freeze_time(start):
bucket = TokenBucket(per_second_request_rate=1, enforcement_window_seconds=120)
bucket.tokens = 0 # start at 0

with freeze_time(start + 30):
assert bucket.available_tokens() == 30

with freeze_time(start + 120):
assert bucket.available_tokens() == 120

with freeze_time(start + 130):
assert bucket.available_tokens() == 120 # should max out at 120


def test_token_bucket_spends_tokens():
start = time.time()

with freeze_time(start):
bucket = TokenBucket(per_second_request_rate=1, enforcement_window_seconds=10)
bucket.tokens = 0 # start at 0

with freeze_time(start + 3):
assert bucket.available_tokens() == 3
bucket.make_request_if_ready()
assert bucket.available_tokens() == 2


def test_token_bucket_cannot_spend_unavailable_tokens():
start = time.time()

with freeze_time(start):
bucket = TokenBucket(per_second_request_rate=1, enforcement_window_seconds=2)
bucket.tokens = 0 # start at 0

with freeze_time(start + 1):
assert bucket.available_tokens() == 1
bucket.make_request_if_ready() # should spend one token
with pytest.raises(UnavailableTokensError):
bucket.make_request_if_ready() # should raise since no tokens left


def test_rate_limiter_cleans_up_old_partitions():
start = time.time()

with freeze_time(start):
limiter = ServerRateLimiter(
per_second_rate_limit=1,
enforcement_window_seconds=100,
partition_seconds=10,
active_partitions=2,
)
limiter.make_request("test_key_1")
limiter.make_request("test_key_2")
limiter.make_request("test_key_3")
limiter.make_request("test_key_4")
partition_sizes = [len(partition) for partition in limiter.cache_partitions]
assert sum(partition_sizes) == 4

interval = limiter.partition_seconds
with freeze_time(start + interval):
# after a partition interval, the cache rolls over to a second active partition
limiter.make_request("test_key_4") # moves test_key_4 to current partition
limiter.make_request("test_key_5") # creates test_key_5 in current partition
partition_sizes = [len(partition) for partition in limiter.cache_partitions]
assert sum(partition_sizes) == 5
assert 2 in partition_sizes # two rate limiters in current cache partition
assert 3 in partition_sizes # three rate limiters remaining in original partition

with freeze_time(start + interval + (limiter.num_partitions * interval)):
limiter.make_request("fresh_key") # when "looping" partitions, cache should be reset
assert sum(len(partition) for partition in limiter.cache_partitions) == 1


def test_rate_limiter_caches_token_buckets():
start = time.time()

with freeze_time(start):
limiter = ServerRateLimiter(
per_second_rate_limit=0.5,
enforcement_window_seconds=20,
partition_seconds=1,
active_partitions=2,
)
limiter.make_request("test_key")
limiter.make_request("test_key")
limiter.make_request("test_key")
token_bucket = None
for partition in limiter.cache_partitions:
if "test_key" in partition:
token_bucket = partition["test_key"]
break
assert token_bucket is not None, "Token bucket for 'test_key' should exist"
assert token_bucket.tokens == 7

with freeze_time(start + 1):
assert token_bucket.available_tokens() == 7.5
limiter.make_request("test_key")
assert token_bucket.tokens == 6.5

0 comments on commit 85b6a1d

Please sign in to comment.