Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""feat: websearch rate limits

Revision ID: 597f5c4be67e
Revises: b9642f45db1d
Create Date: 2025-07-28 09:40:04.424627

"""

from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "597f5c4be67e"
down_revision: Union[str, None] = "b9642f45db1d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"users", sa.Column("web_search_ratelimit_day", sa.Integer(), nullable=True)
)
op.add_column(
"users", sa.Column("web_search_ratelimit_hour", sa.Integer(), nullable=True)
)
op.add_column(
"users", sa.Column("web_search_ratelimit_minute", sa.Integer(), nullable=True)
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("users", "web_search_ratelimit_minute")
op.drop_column("users", "web_search_ratelimit_hour")
op.drop_column("users", "web_search_ratelimit_day")
# ### end Alembic commands ###
3 changes: 3 additions & 0 deletions nilai-api/src/nilai_api/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
USER_RATE_LIMIT_MINUTE: int | None = 100
USER_RATE_LIMIT_HOUR: int | None = 1000
USER_RATE_LIMIT_DAY: int | None = 10000
WEB_SEARCH_RATE_LIMIT_MINUTE: int | None = 1
WEB_SEARCH_RATE_LIMIT_HOUR: int | None = 3
WEB_SEARCH_RATE_LIMIT_DAY: int | None = 72

if ENVIRONMENT == "mainnet":
from .mainnet import * # noqa
Expand Down
3 changes: 3 additions & 0 deletions nilai-api/src/nilai_api/config/mainnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@
USER_RATE_LIMIT_MINUTE = None
USER_RATE_LIMIT_HOUR = None
USER_RATE_LIMIT_DAY = None
WEB_SEARCH_RATE_LIMIT_MINUTE = None
WEB_SEARCH_RATE_LIMIT_HOUR = None
WEB_SEARCH_RATE_LIMIT_DAY = None
3 changes: 3 additions & 0 deletions nilai-api/src/nilai_api/config/testnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@
USER_RATE_LIMIT_MINUTE = 10
USER_RATE_LIMIT_HOUR = 100
USER_RATE_LIMIT_DAY = 1000
WEB_SEARCH_RATE_LIMIT_MINUTE = 1
WEB_SEARCH_RATE_LIMIT_HOUR = 3
WEB_SEARCH_RATE_LIMIT_DAY = 72
18 changes: 18 additions & 0 deletions nilai-api/src/nilai_api/db/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
USER_RATE_LIMIT_MINUTE,
USER_RATE_LIMIT_HOUR,
USER_RATE_LIMIT_DAY,
WEB_SEARCH_RATE_LIMIT_MINUTE,
WEB_SEARCH_RATE_LIMIT_HOUR,
WEB_SEARCH_RATE_LIMIT_DAY,
)

logger = logging.getLogger(__name__)
Expand All @@ -38,6 +41,15 @@ class UserModel(Base):
ratelimit_minute: int = Column(
Integer, default=USER_RATE_LIMIT_MINUTE, nullable=True
) # type: ignore
web_search_ratelimit_day: int = Column(
Integer, default=WEB_SEARCH_RATE_LIMIT_DAY, nullable=True
) # type: ignore
web_search_ratelimit_hour: int = Column(
Integer, default=WEB_SEARCH_RATE_LIMIT_HOUR, nullable=True
) # type: ignore
web_search_ratelimit_minute: int = Column(
Integer, default=WEB_SEARCH_RATE_LIMIT_MINUTE, nullable=True
) # type: ignore

def __repr__(self):
return f"<User(userid={self.userid}, name={self.name})>"
Expand All @@ -55,6 +67,9 @@ class UserData(BaseModel):
ratelimit_day: Optional[int] = None
ratelimit_hour: Optional[int] = None
ratelimit_minute: Optional[int] = None
web_search_ratelimit_day: Optional[int] = None
web_search_ratelimit_hour: Optional[int] = None
web_search_ratelimit_minute: Optional[int] = None

model_config = ConfigDict(from_attributes=True)

Expand All @@ -72,6 +87,9 @@ def from_sqlalchemy(cls, user: UserModel) -> "UserData":
ratelimit_day=user.ratelimit_day,
ratelimit_hour=user.ratelimit_hour,
ratelimit_minute=user.ratelimit_minute,
web_search_ratelimit_day=user.web_search_ratelimit_day,
web_search_ratelimit_hour=user.web_search_ratelimit_hour,
web_search_ratelimit_minute=user.web_search_ratelimit_minute,
)


Expand Down
46 changes: 41 additions & 5 deletions nilai-api/src/nilai_api/rate_limiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,22 @@ async def setup_redis_conn(redis_url):
return client, lua_sha


async def _extract_coroutine_result(maybe_future, request: Request):
if iscoroutine(maybe_future):
return await maybe_future
else:
return maybe_future


class UserRateLimits(BaseModel):
subscription_holder: str
day_limit: int | None
hour_limit: int | None
minute_limit: int | None
token_rate_limit: TokenRateLimits | None
web_search_day_limit: int | None
web_search_hour_limit: int | None
web_search_minute_limit: int | None


def get_user_limits(
Expand All @@ -62,6 +72,9 @@ def get_user_limits(
hour_limit=auth_info.user.ratelimit_hour,
minute_limit=auth_info.user.ratelimit_minute,
token_rate_limit=auth_info.token_rate_limit,
web_search_day_limit=auth_info.user.web_search_ratelimit_day,
web_search_hour_limit=auth_info.user.web_search_ratelimit_hour,
web_search_minute_limit=auth_info.user.web_search_ratelimit_minute,
)


Expand All @@ -73,15 +86,18 @@ def __init__(
[Request], Tuple[int, str] | Awaitable[Tuple[int, str]]
]
| None = None,
web_search_extractor: Callable[[Request], bool | Awaitable[bool]] | None = None,
):
"""
concurrent: Maximum number of concurrent requests allowed for a single path
concurrent_extractor: A callable that extracts the concurrent limit and key from the request
web_search_extractor: A callable that extracts the web_search flag from the request

concurrent and concurrent_extractor are mutually exclusive
"""
self.max_concurrent = concurrent
self.concurrent_extractor = concurrent_extractor
self.web_search_extractor = web_search_extractor

async def __call__(
self,
Expand All @@ -90,6 +106,7 @@ async def __call__(
):
redis = request.state.redis
redis_rate_limit_command = request.state.redis_rate_limit_command

await self.check_bucket(
redis,
redis_rate_limit_command,
Expand Down Expand Up @@ -130,6 +147,27 @@ async def __call__(
limit.ms_remaining,
)

if self.web_search_extractor:
web_search_enabled = await _extract_coroutine_result(
self.web_search_extractor(request), request
)

if web_search_enabled:
web_search_limits = [
(user_limits.web_search_minute_limit, MINUTE_MS, "minute"),
(user_limits.web_search_hour_limit, HOUR_MS, "hour"),
(user_limits.web_search_day_limit, DAY_MS, "day"),
]

for limit, milliseconds, time_unit in web_search_limits:
await self.check_bucket(
redis,
redis_rate_limit_command,
f"web_search_{time_unit}:{user_limits.subscription_holder}",
limit,
milliseconds,
)

key = await self.check_concurrent_and_increment(redis, request)
try:
yield
Expand Down Expand Up @@ -164,11 +202,9 @@ async def check_concurrent_and_increment(
return None

if self.concurrent_extractor:
maybe_future = self.concurrent_extractor(request)
if iscoroutine(maybe_future):
max_concurrent, key = await maybe_future
else:
max_concurrent, key = maybe_future # type: ignore
max_concurrent, key = await _extract_coroutine_result(
self.concurrent_extractor(request), request
)
else:
max_concurrent, key = self.max_concurrent, request.url.path

Expand Down
17 changes: 16 additions & 1 deletion nilai-api/src/nilai_api/routers/private.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,16 @@ async def chat_completion_concurrent_rate_limit(request: Request) -> Tuple[int,
return limit, key


async def chat_completion_web_search_rate_limit(request: Request) -> bool:
"""Extract web_search flag from request body for rate limiting."""
body = await request.json()
try:
chat_request = ChatRequest(**body)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid request body")
return getattr(chat_request, "web_search", False)


@router.post("/v1/chat/completions", tags=["Chat"], response_model=None)
async def chat_completion(
req: ChatRequest = Body(
Expand All @@ -127,7 +137,12 @@ async def chat_completion(
],
)
),
_=Depends(RateLimit(concurrent_extractor=chat_completion_concurrent_rate_limit)),
_rate_limit=Depends(
RateLimit(
concurrent_extractor=chat_completion_concurrent_rate_limit,
web_search_extractor=chat_completion_web_search_rate_limit,
)
),
auth_info: AuthenticationInfo = Depends(get_auth_info),
) -> Union[SignedChatCompletion, StreamingResponse]:
"""
Expand Down
60 changes: 60 additions & 0 deletions tests/unit/nilai_api/test_rate_limiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ async def test_concurrent_rate_limit(req):
hour_limit=None,
minute_limit=None,
token_rate_limit=None,
web_search_day_limit=None,
web_search_hour_limit=None,
web_search_minute_limit=None,
)

futures = [consume_generator(rate_limit(req, user_limits)) for _ in range(5)]
Expand All @@ -74,20 +77,29 @@ async def test_concurrent_rate_limit(req):
hour_limit=None,
minute_limit=None,
token_rate_limit=None,
web_search_day_limit=None,
web_search_hour_limit=None,
web_search_minute_limit=None,
),
UserRateLimits(
subscription_holder=random_id(),
day_limit=None,
hour_limit=11,
minute_limit=None,
token_rate_limit=None,
web_search_day_limit=None,
web_search_hour_limit=None,
web_search_minute_limit=None,
),
UserRateLimits(
subscription_holder=random_id(),
day_limit=None,
hour_limit=None,
minute_limit=12,
token_rate_limit=None,
web_search_day_limit=None,
web_search_hour_limit=None,
web_search_minute_limit=None,
),
UserRateLimits(
subscription_holder=random_id(),
Expand All @@ -103,6 +115,9 @@ async def test_concurrent_rate_limit(req):
)
]
),
web_search_day_limit=None,
web_search_hour_limit=None,
web_search_minute_limit=None,
),
],
)
Expand All @@ -115,3 +130,48 @@ async def test_user_limit(req, user_limits):
futures = [consume_generator(rate_limit(req, user_limits)) for _ in range(3)]
with pytest.raises(HTTPException):
await asyncio.gather(*futures)


@pytest.mark.asyncio
async def test_web_search_rate_limits(redis_client):
"""Verify that a user is rate limited for web-search requests across all time windows."""

# Build a dummy authenticated user
apikey = random_id()

# Mock the incoming request with web_search enabled
mock_request = MagicMock(spec=Request)
mock_request.state.redis = redis_client[0]
mock_request.state.redis_rate_limit_command = redis_client[1]

async def json_body():
return {
"model": "meta-llama/Llama-3.2-1B-Instruct",
"messages": [{"role": "user", "content": "hi"}],
"web_search": True,
}

mock_request.json = json_body

# Create rate limit with web search enabled
async def web_search_extractor(request):
return True

rate_limit = RateLimit(web_search_extractor=web_search_extractor)
user_limits = UserRateLimits(
subscription_holder=apikey,
day_limit=None,
hour_limit=None,
minute_limit=None,
token_rate_limit=None,
web_search_day_limit=72,
web_search_hour_limit=3,
web_search_minute_limit=1,
)

# First request should succeed (minute limit: 1, hour limit: 3, day limit: 72)
await consume_generator(rate_limit(mock_request, user_limits))

# Second request should be rejected due to minute limit (1 per minute)
with pytest.raises(HTTPException):
await consume_generator(rate_limit(mock_request, user_limits))
Loading