diff --git a/nilai-api/alembic/versions/597f5c4be67e_feat_websearch_rate_limits.py b/nilai-api/alembic/versions/597f5c4be67e_feat_websearch_rate_limits.py new file mode 100644 index 00000000..32addf62 --- /dev/null +++ b/nilai-api/alembic/versions/597f5c4be67e_feat_websearch_rate_limits.py @@ -0,0 +1,41 @@ +"""feat: websearch rate limits + +Revision ID: 597f5c4be67e +Revises: b9642f45db1d +Create Date: 2025-07-28 09:40:04.424627 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "597f5c4be67e" +down_revision: Union[str, None] = "b9642f45db1d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "users", sa.Column("web_search_ratelimit_day", sa.Integer(), nullable=True) + ) + op.add_column( + "users", sa.Column("web_search_ratelimit_hour", sa.Integer(), nullable=True) + ) + op.add_column( + "users", sa.Column("web_search_ratelimit_minute", sa.Integer(), nullable=True) + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("users", "web_search_ratelimit_minute") + op.drop_column("users", "web_search_ratelimit_hour") + op.drop_column("users", "web_search_ratelimit_day") + # ### end Alembic commands ### diff --git a/nilai-api/src/nilai_api/config/__init__.py b/nilai-api/src/nilai_api/config/__init__.py index 3c9cea8b..9a0e466e 100644 --- a/nilai-api/src/nilai_api/config/__init__.py +++ b/nilai-api/src/nilai_api/config/__init__.py @@ -33,6 +33,9 @@ USER_RATE_LIMIT_MINUTE: int | None = 100 USER_RATE_LIMIT_HOUR: int | None = 1000 USER_RATE_LIMIT_DAY: int | None = 10000 +WEB_SEARCH_RATE_LIMIT_MINUTE: int | None = 1 +WEB_SEARCH_RATE_LIMIT_HOUR: int | None = 3 +WEB_SEARCH_RATE_LIMIT_DAY: int | None = 72 if ENVIRONMENT == "mainnet": from .mainnet import * # noqa diff --git a/nilai-api/src/nilai_api/config/mainnet.py b/nilai-api/src/nilai_api/config/mainnet.py index 3c1552f6..031416ee 100644 --- a/nilai-api/src/nilai_api/config/mainnet.py +++ b/nilai-api/src/nilai_api/config/mainnet.py @@ -16,3 +16,6 @@ USER_RATE_LIMIT_MINUTE = None USER_RATE_LIMIT_HOUR = None USER_RATE_LIMIT_DAY = None +WEB_SEARCH_RATE_LIMIT_MINUTE = None +WEB_SEARCH_RATE_LIMIT_HOUR = None +WEB_SEARCH_RATE_LIMIT_DAY = None diff --git a/nilai-api/src/nilai_api/config/testnet.py b/nilai-api/src/nilai_api/config/testnet.py index 7e6bdf1b..8d7d4619 100644 --- a/nilai-api/src/nilai_api/config/testnet.py +++ b/nilai-api/src/nilai_api/config/testnet.py @@ -19,3 +19,6 @@ USER_RATE_LIMIT_MINUTE = 10 USER_RATE_LIMIT_HOUR = 100 USER_RATE_LIMIT_DAY = 1000 +WEB_SEARCH_RATE_LIMIT_MINUTE = 1 +WEB_SEARCH_RATE_LIMIT_HOUR = 3 +WEB_SEARCH_RATE_LIMIT_DAY = 72 diff --git a/nilai-api/src/nilai_api/db/users.py b/nilai-api/src/nilai_api/db/users.py index 7eb71c94..91a5d3f5 100644 --- a/nilai-api/src/nilai_api/db/users.py +++ b/nilai-api/src/nilai_api/db/users.py @@ -14,6 +14,9 @@ USER_RATE_LIMIT_MINUTE, USER_RATE_LIMIT_HOUR, USER_RATE_LIMIT_DAY, + WEB_SEARCH_RATE_LIMIT_MINUTE, + WEB_SEARCH_RATE_LIMIT_HOUR, + WEB_SEARCH_RATE_LIMIT_DAY, ) logger = logging.getLogger(__name__) @@ -38,6 +41,15 @@ class UserModel(Base): ratelimit_minute: int = Column( Integer, default=USER_RATE_LIMIT_MINUTE, nullable=True ) # type: ignore + web_search_ratelimit_day: int = Column( + Integer, default=WEB_SEARCH_RATE_LIMIT_DAY, nullable=True + ) # type: ignore + web_search_ratelimit_hour: int = Column( + Integer, default=WEB_SEARCH_RATE_LIMIT_HOUR, nullable=True + ) # type: ignore + web_search_ratelimit_minute: int = Column( + Integer, default=WEB_SEARCH_RATE_LIMIT_MINUTE, nullable=True + ) # type: ignore def __repr__(self): return f"" @@ -55,6 +67,9 @@ class UserData(BaseModel): ratelimit_day: Optional[int] = None ratelimit_hour: Optional[int] = None ratelimit_minute: Optional[int] = None + web_search_ratelimit_day: Optional[int] = None + web_search_ratelimit_hour: Optional[int] = None + web_search_ratelimit_minute: Optional[int] = None model_config = ConfigDict(from_attributes=True) @@ -72,6 +87,9 @@ def from_sqlalchemy(cls, user: UserModel) -> "UserData": ratelimit_day=user.ratelimit_day, ratelimit_hour=user.ratelimit_hour, ratelimit_minute=user.ratelimit_minute, + web_search_ratelimit_day=user.web_search_ratelimit_day, + web_search_ratelimit_hour=user.web_search_ratelimit_hour, + web_search_ratelimit_minute=user.web_search_ratelimit_minute, ) diff --git a/nilai-api/src/nilai_api/rate_limiting.py b/nilai-api/src/nilai_api/rate_limiting.py index af7971b0..0074e6fd 100644 --- a/nilai-api/src/nilai_api/rate_limiting.py +++ b/nilai-api/src/nilai_api/rate_limiting.py @@ -39,12 +39,22 @@ async def setup_redis_conn(redis_url): return client, lua_sha +async def _extract_coroutine_result(maybe_future, request: Request): + if iscoroutine(maybe_future): + return await maybe_future + else: + return maybe_future + + class UserRateLimits(BaseModel): subscription_holder: str day_limit: int | None hour_limit: int | None minute_limit: int | None token_rate_limit: TokenRateLimits | None + web_search_day_limit: int | None + web_search_hour_limit: int | None + web_search_minute_limit: int | None def get_user_limits( @@ -62,6 +72,9 @@ def get_user_limits( hour_limit=auth_info.user.ratelimit_hour, minute_limit=auth_info.user.ratelimit_minute, token_rate_limit=auth_info.token_rate_limit, + web_search_day_limit=auth_info.user.web_search_ratelimit_day, + web_search_hour_limit=auth_info.user.web_search_ratelimit_hour, + web_search_minute_limit=auth_info.user.web_search_ratelimit_minute, ) @@ -73,15 +86,18 @@ def __init__( [Request], Tuple[int, str] | Awaitable[Tuple[int, str]] ] | None = None, + web_search_extractor: Callable[[Request], bool | Awaitable[bool]] | None = None, ): """ concurrent: Maximum number of concurrent requests allowed for a single path concurrent_extractor: A callable that extracts the concurrent limit and key from the request + web_search_extractor: A callable that extracts the web_search flag from the request concurrent and concurrent_extractor are mutually exclusive """ self.max_concurrent = concurrent self.concurrent_extractor = concurrent_extractor + self.web_search_extractor = web_search_extractor async def __call__( self, @@ -90,6 +106,7 @@ async def __call__( ): redis = request.state.redis redis_rate_limit_command = request.state.redis_rate_limit_command + await self.check_bucket( redis, redis_rate_limit_command, @@ -130,6 +147,27 @@ async def __call__( limit.ms_remaining, ) + if self.web_search_extractor: + web_search_enabled = await _extract_coroutine_result( + self.web_search_extractor(request), request + ) + + if web_search_enabled: + web_search_limits = [ + (user_limits.web_search_minute_limit, MINUTE_MS, "minute"), + (user_limits.web_search_hour_limit, HOUR_MS, "hour"), + (user_limits.web_search_day_limit, DAY_MS, "day"), + ] + + for limit, milliseconds, time_unit in web_search_limits: + await self.check_bucket( + redis, + redis_rate_limit_command, + f"web_search_{time_unit}:{user_limits.subscription_holder}", + limit, + milliseconds, + ) + key = await self.check_concurrent_and_increment(redis, request) try: yield @@ -164,11 +202,9 @@ async def check_concurrent_and_increment( return None if self.concurrent_extractor: - maybe_future = self.concurrent_extractor(request) - if iscoroutine(maybe_future): - max_concurrent, key = await maybe_future - else: - max_concurrent, key = maybe_future # type: ignore + max_concurrent, key = await _extract_coroutine_result( + self.concurrent_extractor(request), request + ) else: max_concurrent, key = self.max_concurrent, request.url.path diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index 22a47c25..592ef1a0 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -116,6 +116,16 @@ async def chat_completion_concurrent_rate_limit(request: Request) -> Tuple[int, return limit, key +async def chat_completion_web_search_rate_limit(request: Request) -> bool: + """Extract web_search flag from request body for rate limiting.""" + body = await request.json() + try: + chat_request = ChatRequest(**body) + except ValueError: + raise HTTPException(status_code=400, detail="Invalid request body") + return getattr(chat_request, "web_search", False) + + @router.post("/v1/chat/completions", tags=["Chat"], response_model=None) async def chat_completion( req: ChatRequest = Body( @@ -127,7 +137,12 @@ async def chat_completion( ], ) ), - _=Depends(RateLimit(concurrent_extractor=chat_completion_concurrent_rate_limit)), + _rate_limit=Depends( + RateLimit( + concurrent_extractor=chat_completion_concurrent_rate_limit, + web_search_extractor=chat_completion_web_search_rate_limit, + ) + ), auth_info: AuthenticationInfo = Depends(get_auth_info), ) -> Union[SignedChatCompletion, StreamingResponse]: """ diff --git a/tests/unit/nilai_api/test_rate_limiting.py b/tests/unit/nilai_api/test_rate_limiting.py index 759d233d..9a658a92 100644 --- a/tests/unit/nilai_api/test_rate_limiting.py +++ b/tests/unit/nilai_api/test_rate_limiting.py @@ -48,6 +48,9 @@ async def test_concurrent_rate_limit(req): hour_limit=None, minute_limit=None, token_rate_limit=None, + web_search_day_limit=None, + web_search_hour_limit=None, + web_search_minute_limit=None, ) futures = [consume_generator(rate_limit(req, user_limits)) for _ in range(5)] @@ -74,6 +77,9 @@ async def test_concurrent_rate_limit(req): hour_limit=None, minute_limit=None, token_rate_limit=None, + web_search_day_limit=None, + web_search_hour_limit=None, + web_search_minute_limit=None, ), UserRateLimits( subscription_holder=random_id(), @@ -81,6 +87,9 @@ async def test_concurrent_rate_limit(req): hour_limit=11, minute_limit=None, token_rate_limit=None, + web_search_day_limit=None, + web_search_hour_limit=None, + web_search_minute_limit=None, ), UserRateLimits( subscription_holder=random_id(), @@ -88,6 +97,9 @@ async def test_concurrent_rate_limit(req): hour_limit=None, minute_limit=12, token_rate_limit=None, + web_search_day_limit=None, + web_search_hour_limit=None, + web_search_minute_limit=None, ), UserRateLimits( subscription_holder=random_id(), @@ -103,6 +115,9 @@ async def test_concurrent_rate_limit(req): ) ] ), + web_search_day_limit=None, + web_search_hour_limit=None, + web_search_minute_limit=None, ), ], ) @@ -115,3 +130,48 @@ async def test_user_limit(req, user_limits): futures = [consume_generator(rate_limit(req, user_limits)) for _ in range(3)] with pytest.raises(HTTPException): await asyncio.gather(*futures) + + +@pytest.mark.asyncio +async def test_web_search_rate_limits(redis_client): + """Verify that a user is rate limited for web-search requests across all time windows.""" + + # Build a dummy authenticated user + apikey = random_id() + + # Mock the incoming request with web_search enabled + mock_request = MagicMock(spec=Request) + mock_request.state.redis = redis_client[0] + mock_request.state.redis_rate_limit_command = redis_client[1] + + async def json_body(): + return { + "model": "meta-llama/Llama-3.2-1B-Instruct", + "messages": [{"role": "user", "content": "hi"}], + "web_search": True, + } + + mock_request.json = json_body + + # Create rate limit with web search enabled + async def web_search_extractor(request): + return True + + rate_limit = RateLimit(web_search_extractor=web_search_extractor) + user_limits = UserRateLimits( + subscription_holder=apikey, + day_limit=None, + hour_limit=None, + minute_limit=None, + token_rate_limit=None, + web_search_day_limit=72, + web_search_hour_limit=3, + web_search_minute_limit=1, + ) + + # First request should succeed (minute limit: 1, hour limit: 3, day limit: 72) + await consume_generator(rate_limit(mock_request, user_limits)) + + # Second request should be rejected due to minute limit (1 per minute) + with pytest.raises(HTTPException): + await consume_generator(rate_limit(mock_request, user_limits))