diff --git a/nilai-api/src/nilai_api/config/web_search.py b/nilai-api/src/nilai_api/config/web_search.py index 889ee13f..8a522943 100644 --- a/nilai-api/src/nilai_api/config/web_search.py +++ b/nilai-api/src/nilai_api/config/web_search.py @@ -15,4 +15,4 @@ class WebSearchSettings(BaseModel): max_concurrent_requests: int = Field( default=20, description="Maximum concurrent requests" ) - rps: int = Field(default=20, description="Requests per second limit") + rps: Optional[int] = Field(default=20, description="Requests per second limit") diff --git a/nilai-api/src/nilai_api/handlers/web_search.py b/nilai-api/src/nilai_api/handlers/web_search.py index 5bcaf5a3..c0205816 100644 --- a/nilai-api/src/nilai_api/handlers/web_search.py +++ b/nilai-api/src/nilai_api/handlers/web_search.py @@ -4,9 +4,11 @@ from functools import lru_cache from typing import List, Dict, Any +from fastapi import HTTPException, status, Request +from nilai_api.rate_limiting import RateLimit + import httpx import trafilatura -from fastapi import HTTPException, status from nilai_api.config import CONFIG from nilai_common.api_models import ( @@ -90,11 +92,12 @@ def _get_http_client() -> httpx.AsyncClient: ) -async def _make_brave_api_request(query: str) -> Dict[str, Any]: +async def _make_brave_api_request(query: str, request: Request) -> Dict[str, Any]: """Make an API request to the Brave Search API. Args: query: The search query string to execute + request: FastAPI request object for rate limiting Returns: Dict containing the raw API response data @@ -108,6 +111,8 @@ async def _make_brave_api_request(query: str) -> Dict[str, Any]: detail="Missing BRAVE_SEARCH_API key in environment", ) + await RateLimit.check_brave_rps(request) + q = " ".join(query.split()) params = {**_BRAVE_API_PARAMS_BASE, "q": q} @@ -125,8 +130,15 @@ async def _make_brave_api_request(query: str) -> Dict[str, Any]: params.get("lang"), params.get("count"), ) + resp = await client.get(CONFIG.web_search.api_path, headers=headers, params=params) + if resp.status_code == 429: + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail="Web search rate limit exceeded", + ) + if resp.status_code >= 400: logger.error("Brave API error: %s - %s", resp.status_code, resp.text) raise HTTPException( @@ -225,12 +237,22 @@ async def _fetch_and_extract_page_content( return None -async def perform_web_search_async(query: str) -> WebSearchContext: +async def perform_web_search_async(query: str, request: Request) -> WebSearchContext: """Perform an asynchronous web search using the Brave Search API. Fetches only the exact page for each Brave URL and extracts its main content with trafilatura. If extraction fails, falls back to the Brave snippet. + + Args: + query: The search query string to execute + request: FastAPI request object for rate limiting + + Returns: + WebSearchContext with formatted search results and source information + + Raises: + HTTPException: If no results are found (404) or if the API request fails """ if not (query and query.strip()): logger.warning("Empty or invalid query provided for web search") @@ -240,7 +262,7 @@ async def perform_web_search_async(query: str) -> WebSearchContext: logger.debug("Web search query: %s", query) try: - data = await _make_brave_api_request(query) + data = await _make_brave_api_request(query, request) initial_results = _parse_brave_results(data) except HTTPException: logger.exception("Brave API request failed") @@ -358,36 +380,38 @@ async def _generate_topic_query( return None -async def _perform_search(query: str) -> WebSearchContext: +async def _perform_search(query: str, request: Request) -> WebSearchContext: """Execute a web search with error handling. Args: query: Search query string + request: FastAPI request object for rate limiting Returns: WebSearchContext with results, or empty context if search fails """ try: - return await perform_web_search_async(query) + return await perform_web_search_async(query, request) except Exception: logger.exception("Search failed for query '%s'", query) return WebSearchContext(prompt="", sources=[]) async def enhance_messages_with_web_search( - req: ChatRequest, query: str + req: ChatRequest, query: str, request: Request ) -> WebSearchEnhancedMessages: """Enhance chat messages with web search context for a single query. Args: req: ChatRequest containing conversation messages query: Search query to retrieve web search results for + request: FastAPI request object for rate limiting Returns: WebSearchEnhancedMessages with web search context added to system messages and source information """ - ctx = await perform_web_search_async(query) + ctx = await perform_web_search_async(query, request) query_source = Source(source=WEB_SEARCH_QUERY_SOURCE, content=query) web_search_content = _build_single_search_content(query, ctx.prompt) @@ -469,7 +493,7 @@ async def generate_search_query_from_llm( async def _execute_web_search_workflow( - user_query: str, model_name: str, client: Any + user_query: str, model_name: str, client: Any, request: Request ) -> tuple[List[TopicQuery], List[WebSearchContext]] | tuple[None, None]: """Execute the complete multi-topic web search workflow. @@ -480,6 +504,7 @@ async def _execute_web_search_workflow( user_query: User's query to analyze and search for model_name: Name of the LLM model to use for topic analysis and query generation client: LLM client instance for API calls + request: FastAPI request object for rate limiting Returns: Tuple of (topic_queries, contexts) if successful, or (None, None) if no topics @@ -508,7 +533,7 @@ async def _execute_web_search_workflow( ) return None, None - search_tasks = [_perform_search(tq.query) for tq in topic_queries] + search_tasks = [_perform_search(tq.query, request) for tq in topic_queries] contexts = await asyncio.gather(*search_tasks) return topic_queries, contexts @@ -519,7 +544,7 @@ async def _execute_web_search_workflow( async def handle_web_search( - req_messages: ChatRequest, model_name: str, client: Any + req_messages: ChatRequest, model_name: str, client: Any, request: Request ) -> WebSearchEnhancedMessages: logger.info("Handle web search start") logger.debug( @@ -534,14 +559,16 @@ async def handle_web_search( try: topic_queries, contexts = await _execute_web_search_workflow( - user_query, model_name, client + user_query, model_name, client, request ) if topic_queries is None or contexts is None: concise_query = await generate_search_query_from_llm( user_query, model_name, client ) - return await enhance_messages_with_web_search(req_messages, concise_query) + return await enhance_messages_with_web_search( + req_messages, concise_query, request + ) return await enhance_messages_with_multi_web_search( req_messages, topic_queries, contexts @@ -628,7 +655,7 @@ async def enhance_messages_with_multi_web_search( async def enhance_input_with_web_search( - req: ResponseRequest, query: str + req: ResponseRequest, query: str, request: Request ) -> WebSearchEnhancedInput: """Enhance response input with web search context for a single query. @@ -640,7 +667,7 @@ async def enhance_input_with_web_search( WebSearchEnhancedInput with web search context added to instructions and source information """ - ctx = await perform_web_search_async(query) + ctx = await perform_web_search_async(query, request) query_source = Source(source=WEB_SEARCH_QUERY_SOURCE, content=query) web_search_instructions = _build_single_search_content(query, ctx.prompt) @@ -692,7 +719,7 @@ async def enhance_input_with_multi_web_search( async def handle_web_search_for_responses( - req: ResponseRequest, model_name: str, client: Any + req: ResponseRequest, model_name: str, client: Any, request: Request ) -> WebSearchEnhancedInput: """Handle web search enhancement for response requests. @@ -705,6 +732,7 @@ async def handle_web_search_for_responses( req: ResponseRequest containing input to process model_name: Name of the LLM model to use for query generation client: LLM client instance for making API calls + request: FastAPI request object for rate limiting Returns: WebSearchEnhancedInput with web search context added, or original @@ -724,14 +752,14 @@ async def handle_web_search_for_responses( try: topic_queries, contexts = await _execute_web_search_workflow( - user_query, model_name, client + user_query, model_name, client, request ) if topic_queries is None or contexts is None: concise_query = await generate_search_query_from_llm( user_query, model_name, client ) - return await enhance_input_with_web_search(req, concise_query) + return await enhance_input_with_web_search(req, concise_query, request) return await enhance_input_with_multi_web_search(req, topic_queries, contexts) diff --git a/nilai-api/src/nilai_api/rate_limiting.py b/nilai-api/src/nilai_api/rate_limiting.py index c2d03273..023b341b 100644 --- a/nilai-api/src/nilai_api/rate_limiting.py +++ b/nilai-api/src/nilai_api/rate_limiting.py @@ -177,13 +177,6 @@ async def __call__( user_limits.rate_limits.web_search_rate_limit_day, DAY_MS, ) - await self.check_bucket( - redis, - redis_rate_limit_command, - "web_search_rps", - CONFIG.web_search.rps, - 1000, - ) await self.check_bucket( redis, redis_rate_limit_command, @@ -241,6 +234,26 @@ async def check_concurrent_and_increment( ) return key + @staticmethod + async def check_brave_rps(request: Request) -> None: + """ + Global RPS limit for Brave API calls, across all users. + """ + redis = request.state.redis + redis_rate_limit_command = request.state.redis_rate_limit_command + + limit = CONFIG.web_search.rps + if not limit or limit <= 0: + return + + await RateLimit.check_bucket( + redis, + redis_rate_limit_command, + "brave_rps_global", + limit, + 1000, + ) + @staticmethod async def concurrent_decrement(redis: Redis, key: str | None): if key is None: diff --git a/nilai-api/src/nilai_api/routers/endpoints/chat.py b/nilai-api/src/nilai_api/routers/endpoints/chat.py index 7e1bc424..04e58106 100644 --- a/nilai-api/src/nilai_api/routers/endpoints/chat.py +++ b/nilai-api/src/nilai_api/routers/endpoints/chat.py @@ -59,6 +59,7 @@ async def chat_completion_web_search_rate_limit(request: Request) -> bool: @chat_completion_router.post("/v1/chat/completions", tags=["Chat"], response_model=None) async def chat_completion( + request: Request, req: ChatRequest = Body( ChatRequest( model="meta-llama/Llama-3.2-1B-Instruct", @@ -188,7 +189,7 @@ async def chat_completion( if req.web_search: logger.info(f"[chat] web_search start request_id={request_id}") t_ws = time.monotonic() - web_search_result = await handle_web_search(req, model_name, client) + web_search_result = await handle_web_search(req, model_name, client, request) messages = web_search_result.messages sources = web_search_result.sources logger.info( diff --git a/nilai-api/src/nilai_api/routers/endpoints/responses.py b/nilai-api/src/nilai_api/routers/endpoints/responses.py index a5af3dc4..2beca823 100644 --- a/nilai-api/src/nilai_api/routers/endpoints/responses.py +++ b/nilai-api/src/nilai_api/routers/endpoints/responses.py @@ -60,6 +60,7 @@ async def responses_web_search_rate_limit(request: Request) -> bool: "/v1/responses", tags=["Responses"], response_model=SignedResponse ) async def create_response( + request: Request, req: ResponseRequest = Body( { "model": "openai/gpt-oss-20b", @@ -171,7 +172,7 @@ async def create_response( logger.info(f"[responses] web_search start request_id={request_id}") t_ws = time.monotonic() web_search_result = await handle_web_search_for_responses( - req, model_name, client + req, model_name, client, request ) input_items = web_search_result.input instructions = web_search_result.instructions diff --git a/tests/unit/nilai_api/routers/test_nildb_endpoints.py b/tests/unit/nilai_api/routers/test_nildb_endpoints.py index 7036fecc..ffa88d04 100644 --- a/tests/unit/nilai_api/routers/test_nildb_endpoints.py +++ b/tests/unit/nilai_api/routers/test_nildb_endpoints.py @@ -1,6 +1,6 @@ import pytest from unittest.mock import patch, MagicMock, AsyncMock -from fastapi import HTTPException, status +from fastapi import HTTPException, status, Request from nilai_api.auth.common import AuthenticationInfo, PromptDocument from nilai_api.db.users import RateLimits, UserData, UserModel @@ -240,8 +240,11 @@ async def test_chat_completion_with_prompt_document_injection(self): # Mock handle_tool_workflow to return the response and token counts mock_handle_tool_workflow.return_value = (mock_response, 0, 0) + # Create a mock Request object + mock_request = MagicMock(spec=Request) + # Call the function (this will test the prompt injection logic) - await chat_completion(req=request, auth_info=mock_auth_info) + await chat_completion(mock_request, req=request, auth_info=mock_auth_info) mock_get_prompt.assert_called_once_with(mock_prompt_document) @@ -284,8 +287,13 @@ async def test_chat_completion_prompt_document_extraction_error(self): mock_get_prompt.side_effect = Exception("Unable to extract prompt") + # Create a mock Request object + mock_request = MagicMock(spec=Request) + with pytest.raises(HTTPException) as exc_info: - await chat_completion(req=request, auth_info=mock_auth_info) + await chat_completion( + mock_request, req=request, auth_info=mock_auth_info + ) assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN assert ( @@ -390,8 +398,11 @@ async def test_chat_completion_without_prompt_document(self): # Mock handle_tool_workflow to return the response and token counts mock_handle_tool_workflow.return_value = (mock_response, 0, 0) + # Create a mock Request object + mock_request = MagicMock(spec=Request) + # Call the function - await chat_completion(req=request, auth_info=mock_auth_info) + await chat_completion(mock_request, req=request, auth_info=mock_auth_info) # Should not call get_prompt_from_nildb when no prompt document mock_get_prompt.assert_not_called() @@ -477,7 +488,10 @@ async def test_responses_with_prompt_document_injection(self): mock_handle_tool_workflow.return_value = (mock_response, 0, 0) - await create_response(req=request, auth_info=mock_auth_info) + # Create a mock Request object + mock_request = MagicMock(spec=Request) + + await create_response(mock_request, req=request, auth_info=mock_auth_info) mock_get_prompt.assert_called_once_with(mock_prompt_document) @@ -518,8 +532,13 @@ async def test_responses_prompt_document_extraction_error(self): mock_get_prompt.side_effect = Exception("Unable to extract prompt") + # Create a mock Request object + mock_request = MagicMock(spec=Request) + with pytest.raises(HTTPException) as exc_info: - await create_response(req=request, auth_info=mock_auth_info) + await create_response( + mock_request, req=request, auth_info=mock_auth_info + ) assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN assert ( @@ -604,7 +623,10 @@ async def test_responses_without_prompt_document(self): mock_handle_tool_workflow.return_value = (mock_response, 0, 0) - await create_response(req=request, auth_info=mock_auth_info) + # Create a mock Request object + mock_request = MagicMock(spec=Request) + + await create_response(mock_request, req=request, auth_info=mock_auth_info) mock_get_prompt.assert_not_called() diff --git a/tests/unit/nilai_api/test_rate_limiting.py b/tests/unit/nilai_api/test_rate_limiting.py index 27a5c1bc..21c93f22 100644 --- a/tests/unit/nilai_api/test_rate_limiting.py +++ b/tests/unit/nilai_api/test_rate_limiting.py @@ -73,45 +73,6 @@ async def test_concurrent_rate_limit(req): await asyncio.gather(*futures) -@pytest.mark.asyncio -async def test_web_search_rps_limit(redis_client): - mock_request = MagicMock(spec=Request) - mock_request.state.redis = redis_client[0] - mock_request.state.redis_rate_limit_command = redis_client[1] - # Ensure a clean slate for the global RPS key used by the limiter - await redis_client[0].delete("web_search_rps") - - async def web_search_extractor(_): - return True - - rate_limit = RateLimit(web_search_extractor=web_search_extractor) - user_limits = UserRateLimits( - subscription_holder=random_id(), - token_rate_limit=None, - rate_limits=RateLimits( - user_rate_limit_day=None, - user_rate_limit_hour=None, - user_rate_limit_minute=None, - web_search_rate_limit_day=None, - web_search_rate_limit_hour=None, - web_search_rate_limit_minute=None, - user_rate_limit=None, - web_search_rate_limit=None, - ), - ) - - old_rps = CONFIG.web_search.rps - CONFIG.web_search.rps = 2 - try: - await consume_generator(rate_limit(mock_request, user_limits)) - await consume_generator(rate_limit(mock_request, user_limits)) - with pytest.raises(HTTPException): - await consume_generator(rate_limit(mock_request, user_limits)) - finally: - CONFIG.web_search.rps = old_rps - await redis_client[0].delete("web_search_rps") - - @pytest.mark.asyncio @pytest.mark.parametrize( "user_limits", @@ -239,3 +200,125 @@ async def web_search_extractor(request): # Second request should be rejected due to minute limit (1 per minute) with pytest.raises(HTTPException): await consume_generator(rate_limit(mock_request, user_limits)) + + +@pytest.mark.asyncio +async def test_check_brave_rps_limit(redis_client): + """Test that check_brave_rps enforces global RPS limit across all users.""" + mock_request = MagicMock(spec=Request) + mock_request.state.redis = redis_client[0] + mock_request.state.redis_rate_limit_command = redis_client[1] + + await redis_client[0].delete("brave_rps_global") + + old_rps = CONFIG.web_search.rps + CONFIG.web_search.rps = 3 + try: + rate_limit = RateLimit() + + await rate_limit.check_brave_rps(mock_request) + await rate_limit.check_brave_rps(mock_request) + await rate_limit.check_brave_rps(mock_request) + + with pytest.raises(HTTPException) as exc_info: + await rate_limit.check_brave_rps(mock_request) + + assert exc_info.value.status_code == 429 + assert "Too Many Requests" in str(exc_info.value.detail) + finally: + CONFIG.web_search.rps = old_rps + await redis_client[0].delete("brave_rps_global") + + +@pytest.mark.asyncio +async def test_check_brave_rps_disabled(redis_client): + """Test that check_brave_rps does nothing when limit is disabled.""" + mock_request = MagicMock(spec=Request) + mock_request.state.redis = redis_client[0] + mock_request.state.redis_rate_limit_command = redis_client[1] + + old_rps = CONFIG.web_search.rps + CONFIG.web_search.rps = None + try: + rate_limit = RateLimit() + + await rate_limit.check_brave_rps(mock_request) + await rate_limit.check_brave_rps(mock_request) + await rate_limit.check_brave_rps(mock_request) + finally: + CONFIG.web_search.rps = old_rps + + +@pytest.mark.asyncio +async def test_check_brave_rps_zero_limit(redis_client): + """Test that check_brave_rps does nothing when limit is 0 or negative.""" + mock_request = MagicMock(spec=Request) + mock_request.state.redis = redis_client[0] + mock_request.state.redis_rate_limit_command = redis_client[1] + + old_rps = CONFIG.web_search.rps + CONFIG.web_search.rps = 0 + try: + rate_limit = RateLimit() + + await rate_limit.check_brave_rps(mock_request) + await rate_limit.check_brave_rps(mock_request) + finally: + CONFIG.web_search.rps = old_rps + + +@pytest.mark.asyncio +async def test_check_brave_rps_global_key(redis_client): + """Test that check_brave_rps uses the correct global key across different requests.""" + mock_request_1 = MagicMock(spec=Request) + mock_request_1.state.redis = redis_client[0] + mock_request_1.state.redis_rate_limit_command = redis_client[1] + + mock_request_2 = MagicMock(spec=Request) + mock_request_2.state.redis = redis_client[0] + mock_request_2.state.redis_rate_limit_command = redis_client[1] + + await redis_client[0].delete("brave_rps_global") + + old_rps = CONFIG.web_search.rps + CONFIG.web_search.rps = 2 + try: + rate_limit = RateLimit() + + await rate_limit.check_brave_rps(mock_request_1) + await rate_limit.check_brave_rps(mock_request_2) + + with pytest.raises(HTTPException): + await rate_limit.check_brave_rps(mock_request_1) + finally: + CONFIG.web_search.rps = old_rps + await redis_client[0].delete("brave_rps_global") + + +@pytest.mark.asyncio +async def test_check_brave_rps_reset_after_window(redis_client): + """Test that check_brave_rps resets after the 1 second window expires.""" + mock_request = MagicMock(spec=Request) + mock_request.state.redis = redis_client[0] + mock_request.state.redis_rate_limit_command = redis_client[1] + + await redis_client[0].delete("brave_rps_global") + + old_rps = CONFIG.web_search.rps + CONFIG.web_search.rps = 2 + try: + rate_limit = RateLimit() + + await rate_limit.check_brave_rps(mock_request) + await rate_limit.check_brave_rps(mock_request) + + with pytest.raises(HTTPException): + await rate_limit.check_brave_rps(mock_request) + + await asyncio.sleep(1.1) + + await rate_limit.check_brave_rps(mock_request) + await rate_limit.check_brave_rps(mock_request) + finally: + CONFIG.web_search.rps = old_rps + await redis_client[0].delete("brave_rps_global") diff --git a/tests/unit/nilai_api/test_web_search.py b/tests/unit/nilai_api/test_web_search.py index ef94a1f5..8f71631b 100644 --- a/tests/unit/nilai_api/test_web_search.py +++ b/tests/unit/nilai_api/test_web_search.py @@ -1,10 +1,12 @@ import pytest -from unittest.mock import patch -from fastapi import HTTPException +from unittest.mock import patch, MagicMock, AsyncMock +from fastapi import HTTPException, Request from nilai_api.handlers.web_search import ( perform_web_search_async, enhance_messages_with_web_search, + _make_brave_api_request, ) +from nilai_api.rate_limiting import RateLimit from nilai_common import MessageAdapter, ChatRequest from nilai_common.api_models import ( WebSearchContext, @@ -32,6 +34,8 @@ async def test_perform_web_search_async_success(): } } + mock_request = MagicMock(spec=Request) + with ( patch("nilai_api.config.CONFIG.web_search.api_key", "test-key"), patch( @@ -39,7 +43,7 @@ async def test_perform_web_search_async_success(): return_value=mock_data, ), ): - ctx = await perform_web_search_async("AI developments") + ctx = await perform_web_search_async("AI developments", mock_request) assert ctx.sources is not None assert len(ctx.sources) == 2 @@ -59,6 +63,7 @@ async def test_perform_web_search_async_success(): async def test_perform_web_search_async_no_results(): """Test web search with no results returns 404""" mock_data = {"web": {"results": []}} + mock_request = MagicMock(spec=Request) with ( patch("nilai_api.handlers.web_search.CONFIG.web_search.api_key", "test-key"), @@ -68,7 +73,7 @@ async def test_perform_web_search_async_no_results(): ), pytest.raises(HTTPException) as exc_info, ): - await perform_web_search_async("nonexistent query") + await perform_web_search_async("nonexistent query", mock_request) assert exc_info.value.status_code == 404 @@ -100,6 +105,9 @@ async def test_perform_web_search_async_concurrent_queries(): } } + mock_request_1 = MagicMock(spec=Request) + mock_request_2 = MagicMock(spec=Request) + with ( patch("nilai_api.config.CONFIG.web_search.api_key", "test-key"), patch( @@ -111,8 +119,8 @@ async def test_perform_web_search_async_concurrent_queries(): # Run two concurrent web searches results = await asyncio.gather( - perform_web_search_async("AI news"), - perform_web_search_async("Machine learning"), + perform_web_search_async("AI news", mock_request_1), + perform_web_search_async("Machine learning", mock_request_2), ) # Verify both searches completed successfully @@ -146,6 +154,7 @@ async def test_enhance_messages_with_web_search(): MessageAdapter.new_message(role="user", content="What is the latest AI news?"), ] req = ChatRequest(model="dummy", messages=original_messages) + mock_request = MagicMock(spec=Request) with patch("nilai_api.handlers.web_search.perform_web_search_async") as mock_search: mock_search.return_value = WebSearchContext( @@ -155,7 +164,7 @@ async def test_enhance_messages_with_web_search(): ], ) - enhanced = await enhance_messages_with_web_search(req, "AI news") + enhanced = await enhance_messages_with_web_search(req, "AI news", mock_request) assert len(enhanced.messages) == 2 assert enhanced.messages[0]["role"] == "system" @@ -166,3 +175,65 @@ async def test_enhance_messages_with_web_search(): assert enhanced.sources[0].content == "AI news" assert enhanced.sources[1].source == "https://example.com" assert enhanced.sources[1].content == "OpenAI announces GPT-5" + + +@pytest.mark.asyncio +async def test_make_brave_api_request_calls_rps_limit(): + """Test that _make_brave_api_request calls check_brave_rps for rate limiting.""" + mock_request = MagicMock(spec=Request) + mock_data = { + "web": { + "results": [ + { + "title": "Test Result", + "description": "Test description", + "url": "https://example.com/test", + } + ] + } + } + + with ( + patch("nilai_api.config.CONFIG.web_search.api_key", "test-key"), + patch( + "nilai_api.handlers.web_search.CONFIG.web_search.api_path", + "https://api.brave.com/v1/web/search", + ), + patch("nilai_api.handlers.web_search._get_http_client") as mock_client, + patch.object( + RateLimit, "check_brave_rps", new_callable=AsyncMock + ) as mock_check_rps, + ): + mock_http_client = AsyncMock() + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=mock_data) + mock_http_client.get = AsyncMock(return_value=mock_response) + mock_client.return_value = mock_http_client + + result = await _make_brave_api_request("test query", mock_request) + + mock_check_rps.assert_called_once_with(mock_request) + assert result == mock_data + + +@pytest.mark.asyncio +async def test_make_brave_api_request_rps_limit_exceeded(): + """Test that _make_brave_api_request raises 429 when RPS limit is exceeded.""" + mock_request = MagicMock(spec=Request) + + with ( + patch("nilai_api.config.CONFIG.web_search.api_key", "test-key"), + patch.object( + RateLimit, "check_brave_rps", new_callable=AsyncMock + ) as mock_check_rps, + ): + mock_check_rps.side_effect = HTTPException( + status_code=429, detail="Too Many Requests" + ) + + with pytest.raises(HTTPException) as exc_info: + await _make_brave_api_request("test query", mock_request) + + assert exc_info.value.status_code == 429 + mock_check_rps.assert_called_once_with(mock_request)