From 2565198eeaee217bffd54221ef58c3ca9f303fde Mon Sep 17 00:00:00 2001 From: blefo Date: Thu, 13 Nov 2025 13:22:18 +0100 Subject: [PATCH 1/2] feat: add global RPS limit for Brave API calls and refactor web search handlers to support rate limiting --- nilai-api/src/nilai_api/config/web_search.py | 2 +- .../src/nilai_api/handlers/web_search.py | 52 ++++-- nilai-api/src/nilai_api/rate_limiting.py | 27 ++- .../src/nilai_api/routers/endpoints/chat.py | 3 +- .../nilai_api/routers/endpoints/responses.py | 3 +- .../nilai_api/routers/test_nildb_endpoints.py | 36 +++- tests/unit/nilai_api/test_rate_limiting.py | 161 +++++++++++++----- tests/unit/nilai_api/test_web_search.py | 85 ++++++++- 8 files changed, 288 insertions(+), 81 deletions(-) diff --git a/nilai-api/src/nilai_api/config/web_search.py b/nilai-api/src/nilai_api/config/web_search.py index 889ee13f..8a522943 100644 --- a/nilai-api/src/nilai_api/config/web_search.py +++ b/nilai-api/src/nilai_api/config/web_search.py @@ -15,4 +15,4 @@ class WebSearchSettings(BaseModel): max_concurrent_requests: int = Field( default=20, description="Maximum concurrent requests" ) - rps: int = Field(default=20, description="Requests per second limit") + rps: Optional[int] = Field(default=20, description="Requests per second limit") diff --git a/nilai-api/src/nilai_api/handlers/web_search.py b/nilai-api/src/nilai_api/handlers/web_search.py index 5bcaf5a3..35700923 100644 --- a/nilai-api/src/nilai_api/handlers/web_search.py +++ b/nilai-api/src/nilai_api/handlers/web_search.py @@ -4,9 +4,11 @@ from functools import lru_cache from typing import List, Dict, Any +from fastapi import HTTPException, status, Request +from nilai_api.rate_limiting import RateLimit + import httpx import trafilatura -from fastapi import HTTPException, status from nilai_api.config import CONFIG from nilai_common.api_models import ( @@ -90,7 +92,7 @@ def _get_http_client() -> httpx.AsyncClient: ) -async def _make_brave_api_request(query: str) -> Dict[str, Any]: +async def _make_brave_api_request(query: str, request: Request) -> Dict[str, Any]: """Make an API request to the Brave Search API. Args: @@ -108,6 +110,8 @@ async def _make_brave_api_request(query: str) -> Dict[str, Any]: detail="Missing BRAVE_SEARCH_API key in environment", ) + await RateLimit.check_brave_rps(request) + q = " ".join(query.split()) params = {**_BRAVE_API_PARAMS_BASE, "q": q} @@ -125,8 +129,15 @@ async def _make_brave_api_request(query: str) -> Dict[str, Any]: params.get("lang"), params.get("count"), ) + resp = await client.get(CONFIG.web_search.api_path, headers=headers, params=params) + if resp.status_code == 429: + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail="Web search rate limit exceeded", + ) + if resp.status_code >= 400: logger.error("Brave API error: %s - %s", resp.status_code, resp.text) raise HTTPException( @@ -225,7 +236,7 @@ async def _fetch_and_extract_page_content( return None -async def perform_web_search_async(query: str) -> WebSearchContext: +async def perform_web_search_async(query: str, request: Request) -> WebSearchContext: """Perform an asynchronous web search using the Brave Search API. Fetches only the exact page for each Brave URL and extracts its @@ -240,7 +251,7 @@ async def perform_web_search_async(query: str) -> WebSearchContext: logger.debug("Web search query: %s", query) try: - data = await _make_brave_api_request(query) + data = await _make_brave_api_request(query, request) initial_results = _parse_brave_results(data) except HTTPException: logger.exception("Brave API request failed") @@ -358,36 +369,38 @@ async def _generate_topic_query( return None -async def _perform_search(query: str) -> WebSearchContext: +async def _perform_search(query: str, request: Request) -> WebSearchContext: """Execute a web search with error handling. Args: query: Search query string + request: FastAPI request object for rate limiting Returns: WebSearchContext with results, or empty context if search fails """ try: - return await perform_web_search_async(query) + return await perform_web_search_async(query, request) except Exception: logger.exception("Search failed for query '%s'", query) return WebSearchContext(prompt="", sources=[]) async def enhance_messages_with_web_search( - req: ChatRequest, query: str + req: ChatRequest, query: str, request: Request ) -> WebSearchEnhancedMessages: """Enhance chat messages with web search context for a single query. Args: req: ChatRequest containing conversation messages query: Search query to retrieve web search results for + request: FastAPI request object for rate limiting Returns: WebSearchEnhancedMessages with web search context added to system messages and source information """ - ctx = await perform_web_search_async(query) + ctx = await perform_web_search_async(query, request) query_source = Source(source=WEB_SEARCH_QUERY_SOURCE, content=query) web_search_content = _build_single_search_content(query, ctx.prompt) @@ -469,7 +482,7 @@ async def generate_search_query_from_llm( async def _execute_web_search_workflow( - user_query: str, model_name: str, client: Any + user_query: str, model_name: str, client: Any, request: Request ) -> tuple[List[TopicQuery], List[WebSearchContext]] | tuple[None, None]: """Execute the complete multi-topic web search workflow. @@ -480,6 +493,7 @@ async def _execute_web_search_workflow( user_query: User's query to analyze and search for model_name: Name of the LLM model to use for topic analysis and query generation client: LLM client instance for API calls + request: FastAPI request object for rate limiting Returns: Tuple of (topic_queries, contexts) if successful, or (None, None) if no topics @@ -508,7 +522,7 @@ async def _execute_web_search_workflow( ) return None, None - search_tasks = [_perform_search(tq.query) for tq in topic_queries] + search_tasks = [_perform_search(tq.query, request) for tq in topic_queries] contexts = await asyncio.gather(*search_tasks) return topic_queries, contexts @@ -519,7 +533,7 @@ async def _execute_web_search_workflow( async def handle_web_search( - req_messages: ChatRequest, model_name: str, client: Any + req_messages: ChatRequest, model_name: str, client: Any, request: Request ) -> WebSearchEnhancedMessages: logger.info("Handle web search start") logger.debug( @@ -534,14 +548,16 @@ async def handle_web_search( try: topic_queries, contexts = await _execute_web_search_workflow( - user_query, model_name, client + user_query, model_name, client, request ) if topic_queries is None or contexts is None: concise_query = await generate_search_query_from_llm( user_query, model_name, client ) - return await enhance_messages_with_web_search(req_messages, concise_query) + return await enhance_messages_with_web_search( + req_messages, concise_query, request + ) return await enhance_messages_with_multi_web_search( req_messages, topic_queries, contexts @@ -628,7 +644,7 @@ async def enhance_messages_with_multi_web_search( async def enhance_input_with_web_search( - req: ResponseRequest, query: str + req: ResponseRequest, query: str, request: Request ) -> WebSearchEnhancedInput: """Enhance response input with web search context for a single query. @@ -640,7 +656,7 @@ async def enhance_input_with_web_search( WebSearchEnhancedInput with web search context added to instructions and source information """ - ctx = await perform_web_search_async(query) + ctx = await perform_web_search_async(query, request) query_source = Source(source=WEB_SEARCH_QUERY_SOURCE, content=query) web_search_instructions = _build_single_search_content(query, ctx.prompt) @@ -692,7 +708,7 @@ async def enhance_input_with_multi_web_search( async def handle_web_search_for_responses( - req: ResponseRequest, model_name: str, client: Any + req: ResponseRequest, model_name: str, client: Any, request: Request ) -> WebSearchEnhancedInput: """Handle web search enhancement for response requests. @@ -724,14 +740,14 @@ async def handle_web_search_for_responses( try: topic_queries, contexts = await _execute_web_search_workflow( - user_query, model_name, client + user_query, model_name, client, request ) if topic_queries is None or contexts is None: concise_query = await generate_search_query_from_llm( user_query, model_name, client ) - return await enhance_input_with_web_search(req, concise_query) + return await enhance_input_with_web_search(req, concise_query, request) return await enhance_input_with_multi_web_search(req, topic_queries, contexts) diff --git a/nilai-api/src/nilai_api/rate_limiting.py b/nilai-api/src/nilai_api/rate_limiting.py index c2d03273..023b341b 100644 --- a/nilai-api/src/nilai_api/rate_limiting.py +++ b/nilai-api/src/nilai_api/rate_limiting.py @@ -177,13 +177,6 @@ async def __call__( user_limits.rate_limits.web_search_rate_limit_day, DAY_MS, ) - await self.check_bucket( - redis, - redis_rate_limit_command, - "web_search_rps", - CONFIG.web_search.rps, - 1000, - ) await self.check_bucket( redis, redis_rate_limit_command, @@ -241,6 +234,26 @@ async def check_concurrent_and_increment( ) return key + @staticmethod + async def check_brave_rps(request: Request) -> None: + """ + Global RPS limit for Brave API calls, across all users. + """ + redis = request.state.redis + redis_rate_limit_command = request.state.redis_rate_limit_command + + limit = CONFIG.web_search.rps + if not limit or limit <= 0: + return + + await RateLimit.check_bucket( + redis, + redis_rate_limit_command, + "brave_rps_global", + limit, + 1000, + ) + @staticmethod async def concurrent_decrement(redis: Redis, key: str | None): if key is None: diff --git a/nilai-api/src/nilai_api/routers/endpoints/chat.py b/nilai-api/src/nilai_api/routers/endpoints/chat.py index 7e1bc424..04e58106 100644 --- a/nilai-api/src/nilai_api/routers/endpoints/chat.py +++ b/nilai-api/src/nilai_api/routers/endpoints/chat.py @@ -59,6 +59,7 @@ async def chat_completion_web_search_rate_limit(request: Request) -> bool: @chat_completion_router.post("/v1/chat/completions", tags=["Chat"], response_model=None) async def chat_completion( + request: Request, req: ChatRequest = Body( ChatRequest( model="meta-llama/Llama-3.2-1B-Instruct", @@ -188,7 +189,7 @@ async def chat_completion( if req.web_search: logger.info(f"[chat] web_search start request_id={request_id}") t_ws = time.monotonic() - web_search_result = await handle_web_search(req, model_name, client) + web_search_result = await handle_web_search(req, model_name, client, request) messages = web_search_result.messages sources = web_search_result.sources logger.info( diff --git a/nilai-api/src/nilai_api/routers/endpoints/responses.py b/nilai-api/src/nilai_api/routers/endpoints/responses.py index a5af3dc4..2beca823 100644 --- a/nilai-api/src/nilai_api/routers/endpoints/responses.py +++ b/nilai-api/src/nilai_api/routers/endpoints/responses.py @@ -60,6 +60,7 @@ async def responses_web_search_rate_limit(request: Request) -> bool: "/v1/responses", tags=["Responses"], response_model=SignedResponse ) async def create_response( + request: Request, req: ResponseRequest = Body( { "model": "openai/gpt-oss-20b", @@ -171,7 +172,7 @@ async def create_response( logger.info(f"[responses] web_search start request_id={request_id}") t_ws = time.monotonic() web_search_result = await handle_web_search_for_responses( - req, model_name, client + req, model_name, client, request ) input_items = web_search_result.input instructions = web_search_result.instructions diff --git a/tests/unit/nilai_api/routers/test_nildb_endpoints.py b/tests/unit/nilai_api/routers/test_nildb_endpoints.py index 7036fecc..ffa88d04 100644 --- a/tests/unit/nilai_api/routers/test_nildb_endpoints.py +++ b/tests/unit/nilai_api/routers/test_nildb_endpoints.py @@ -1,6 +1,6 @@ import pytest from unittest.mock import patch, MagicMock, AsyncMock -from fastapi import HTTPException, status +from fastapi import HTTPException, status, Request from nilai_api.auth.common import AuthenticationInfo, PromptDocument from nilai_api.db.users import RateLimits, UserData, UserModel @@ -240,8 +240,11 @@ async def test_chat_completion_with_prompt_document_injection(self): # Mock handle_tool_workflow to return the response and token counts mock_handle_tool_workflow.return_value = (mock_response, 0, 0) + # Create a mock Request object + mock_request = MagicMock(spec=Request) + # Call the function (this will test the prompt injection logic) - await chat_completion(req=request, auth_info=mock_auth_info) + await chat_completion(mock_request, req=request, auth_info=mock_auth_info) mock_get_prompt.assert_called_once_with(mock_prompt_document) @@ -284,8 +287,13 @@ async def test_chat_completion_prompt_document_extraction_error(self): mock_get_prompt.side_effect = Exception("Unable to extract prompt") + # Create a mock Request object + mock_request = MagicMock(spec=Request) + with pytest.raises(HTTPException) as exc_info: - await chat_completion(req=request, auth_info=mock_auth_info) + await chat_completion( + mock_request, req=request, auth_info=mock_auth_info + ) assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN assert ( @@ -390,8 +398,11 @@ async def test_chat_completion_without_prompt_document(self): # Mock handle_tool_workflow to return the response and token counts mock_handle_tool_workflow.return_value = (mock_response, 0, 0) + # Create a mock Request object + mock_request = MagicMock(spec=Request) + # Call the function - await chat_completion(req=request, auth_info=mock_auth_info) + await chat_completion(mock_request, req=request, auth_info=mock_auth_info) # Should not call get_prompt_from_nildb when no prompt document mock_get_prompt.assert_not_called() @@ -477,7 +488,10 @@ async def test_responses_with_prompt_document_injection(self): mock_handle_tool_workflow.return_value = (mock_response, 0, 0) - await create_response(req=request, auth_info=mock_auth_info) + # Create a mock Request object + mock_request = MagicMock(spec=Request) + + await create_response(mock_request, req=request, auth_info=mock_auth_info) mock_get_prompt.assert_called_once_with(mock_prompt_document) @@ -518,8 +532,13 @@ async def test_responses_prompt_document_extraction_error(self): mock_get_prompt.side_effect = Exception("Unable to extract prompt") + # Create a mock Request object + mock_request = MagicMock(spec=Request) + with pytest.raises(HTTPException) as exc_info: - await create_response(req=request, auth_info=mock_auth_info) + await create_response( + mock_request, req=request, auth_info=mock_auth_info + ) assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN assert ( @@ -604,7 +623,10 @@ async def test_responses_without_prompt_document(self): mock_handle_tool_workflow.return_value = (mock_response, 0, 0) - await create_response(req=request, auth_info=mock_auth_info) + # Create a mock Request object + mock_request = MagicMock(spec=Request) + + await create_response(mock_request, req=request, auth_info=mock_auth_info) mock_get_prompt.assert_not_called() diff --git a/tests/unit/nilai_api/test_rate_limiting.py b/tests/unit/nilai_api/test_rate_limiting.py index 27a5c1bc..21c93f22 100644 --- a/tests/unit/nilai_api/test_rate_limiting.py +++ b/tests/unit/nilai_api/test_rate_limiting.py @@ -73,45 +73,6 @@ async def test_concurrent_rate_limit(req): await asyncio.gather(*futures) -@pytest.mark.asyncio -async def test_web_search_rps_limit(redis_client): - mock_request = MagicMock(spec=Request) - mock_request.state.redis = redis_client[0] - mock_request.state.redis_rate_limit_command = redis_client[1] - # Ensure a clean slate for the global RPS key used by the limiter - await redis_client[0].delete("web_search_rps") - - async def web_search_extractor(_): - return True - - rate_limit = RateLimit(web_search_extractor=web_search_extractor) - user_limits = UserRateLimits( - subscription_holder=random_id(), - token_rate_limit=None, - rate_limits=RateLimits( - user_rate_limit_day=None, - user_rate_limit_hour=None, - user_rate_limit_minute=None, - web_search_rate_limit_day=None, - web_search_rate_limit_hour=None, - web_search_rate_limit_minute=None, - user_rate_limit=None, - web_search_rate_limit=None, - ), - ) - - old_rps = CONFIG.web_search.rps - CONFIG.web_search.rps = 2 - try: - await consume_generator(rate_limit(mock_request, user_limits)) - await consume_generator(rate_limit(mock_request, user_limits)) - with pytest.raises(HTTPException): - await consume_generator(rate_limit(mock_request, user_limits)) - finally: - CONFIG.web_search.rps = old_rps - await redis_client[0].delete("web_search_rps") - - @pytest.mark.asyncio @pytest.mark.parametrize( "user_limits", @@ -239,3 +200,125 @@ async def web_search_extractor(request): # Second request should be rejected due to minute limit (1 per minute) with pytest.raises(HTTPException): await consume_generator(rate_limit(mock_request, user_limits)) + + +@pytest.mark.asyncio +async def test_check_brave_rps_limit(redis_client): + """Test that check_brave_rps enforces global RPS limit across all users.""" + mock_request = MagicMock(spec=Request) + mock_request.state.redis = redis_client[0] + mock_request.state.redis_rate_limit_command = redis_client[1] + + await redis_client[0].delete("brave_rps_global") + + old_rps = CONFIG.web_search.rps + CONFIG.web_search.rps = 3 + try: + rate_limit = RateLimit() + + await rate_limit.check_brave_rps(mock_request) + await rate_limit.check_brave_rps(mock_request) + await rate_limit.check_brave_rps(mock_request) + + with pytest.raises(HTTPException) as exc_info: + await rate_limit.check_brave_rps(mock_request) + + assert exc_info.value.status_code == 429 + assert "Too Many Requests" in str(exc_info.value.detail) + finally: + CONFIG.web_search.rps = old_rps + await redis_client[0].delete("brave_rps_global") + + +@pytest.mark.asyncio +async def test_check_brave_rps_disabled(redis_client): + """Test that check_brave_rps does nothing when limit is disabled.""" + mock_request = MagicMock(spec=Request) + mock_request.state.redis = redis_client[0] + mock_request.state.redis_rate_limit_command = redis_client[1] + + old_rps = CONFIG.web_search.rps + CONFIG.web_search.rps = None + try: + rate_limit = RateLimit() + + await rate_limit.check_brave_rps(mock_request) + await rate_limit.check_brave_rps(mock_request) + await rate_limit.check_brave_rps(mock_request) + finally: + CONFIG.web_search.rps = old_rps + + +@pytest.mark.asyncio +async def test_check_brave_rps_zero_limit(redis_client): + """Test that check_brave_rps does nothing when limit is 0 or negative.""" + mock_request = MagicMock(spec=Request) + mock_request.state.redis = redis_client[0] + mock_request.state.redis_rate_limit_command = redis_client[1] + + old_rps = CONFIG.web_search.rps + CONFIG.web_search.rps = 0 + try: + rate_limit = RateLimit() + + await rate_limit.check_brave_rps(mock_request) + await rate_limit.check_brave_rps(mock_request) + finally: + CONFIG.web_search.rps = old_rps + + +@pytest.mark.asyncio +async def test_check_brave_rps_global_key(redis_client): + """Test that check_brave_rps uses the correct global key across different requests.""" + mock_request_1 = MagicMock(spec=Request) + mock_request_1.state.redis = redis_client[0] + mock_request_1.state.redis_rate_limit_command = redis_client[1] + + mock_request_2 = MagicMock(spec=Request) + mock_request_2.state.redis = redis_client[0] + mock_request_2.state.redis_rate_limit_command = redis_client[1] + + await redis_client[0].delete("brave_rps_global") + + old_rps = CONFIG.web_search.rps + CONFIG.web_search.rps = 2 + try: + rate_limit = RateLimit() + + await rate_limit.check_brave_rps(mock_request_1) + await rate_limit.check_brave_rps(mock_request_2) + + with pytest.raises(HTTPException): + await rate_limit.check_brave_rps(mock_request_1) + finally: + CONFIG.web_search.rps = old_rps + await redis_client[0].delete("brave_rps_global") + + +@pytest.mark.asyncio +async def test_check_brave_rps_reset_after_window(redis_client): + """Test that check_brave_rps resets after the 1 second window expires.""" + mock_request = MagicMock(spec=Request) + mock_request.state.redis = redis_client[0] + mock_request.state.redis_rate_limit_command = redis_client[1] + + await redis_client[0].delete("brave_rps_global") + + old_rps = CONFIG.web_search.rps + CONFIG.web_search.rps = 2 + try: + rate_limit = RateLimit() + + await rate_limit.check_brave_rps(mock_request) + await rate_limit.check_brave_rps(mock_request) + + with pytest.raises(HTTPException): + await rate_limit.check_brave_rps(mock_request) + + await asyncio.sleep(1.1) + + await rate_limit.check_brave_rps(mock_request) + await rate_limit.check_brave_rps(mock_request) + finally: + CONFIG.web_search.rps = old_rps + await redis_client[0].delete("brave_rps_global") diff --git a/tests/unit/nilai_api/test_web_search.py b/tests/unit/nilai_api/test_web_search.py index ef94a1f5..72af39f9 100644 --- a/tests/unit/nilai_api/test_web_search.py +++ b/tests/unit/nilai_api/test_web_search.py @@ -1,10 +1,12 @@ import pytest -from unittest.mock import patch -from fastapi import HTTPException +from unittest.mock import patch, MagicMock, AsyncMock +from fastapi import HTTPException, Request from nilai_api.handlers.web_search import ( perform_web_search_async, enhance_messages_with_web_search, + _make_brave_api_request, ) +from nilai_api.rate_limiting import RateLimit from nilai_common import MessageAdapter, ChatRequest from nilai_common.api_models import ( WebSearchContext, @@ -32,6 +34,8 @@ async def test_perform_web_search_async_success(): } } + mock_request = MagicMock(spec=Request) + with ( patch("nilai_api.config.CONFIG.web_search.api_key", "test-key"), patch( @@ -39,7 +43,7 @@ async def test_perform_web_search_async_success(): return_value=mock_data, ), ): - ctx = await perform_web_search_async("AI developments") + ctx = await perform_web_search_async("AI developments", mock_request) assert ctx.sources is not None assert len(ctx.sources) == 2 @@ -59,6 +63,7 @@ async def test_perform_web_search_async_success(): async def test_perform_web_search_async_no_results(): """Test web search with no results returns 404""" mock_data = {"web": {"results": []}} + mock_request = MagicMock(spec=Request) with ( patch("nilai_api.handlers.web_search.CONFIG.web_search.api_key", "test-key"), @@ -68,7 +73,7 @@ async def test_perform_web_search_async_no_results(): ), pytest.raises(HTTPException) as exc_info, ): - await perform_web_search_async("nonexistent query") + await perform_web_search_async("nonexistent query", mock_request) assert exc_info.value.status_code == 404 @@ -100,6 +105,9 @@ async def test_perform_web_search_async_concurrent_queries(): } } + mock_request_1 = MagicMock(spec=Request) + mock_request_2 = MagicMock(spec=Request) + with ( patch("nilai_api.config.CONFIG.web_search.api_key", "test-key"), patch( @@ -111,8 +119,8 @@ async def test_perform_web_search_async_concurrent_queries(): # Run two concurrent web searches results = await asyncio.gather( - perform_web_search_async("AI news"), - perform_web_search_async("Machine learning"), + perform_web_search_async("AI news", mock_request_1), + perform_web_search_async("Machine learning", mock_request_2), ) # Verify both searches completed successfully @@ -146,6 +154,7 @@ async def test_enhance_messages_with_web_search(): MessageAdapter.new_message(role="user", content="What is the latest AI news?"), ] req = ChatRequest(model="dummy", messages=original_messages) + mock_request = MagicMock(spec=Request) with patch("nilai_api.handlers.web_search.perform_web_search_async") as mock_search: mock_search.return_value = WebSearchContext( @@ -155,7 +164,7 @@ async def test_enhance_messages_with_web_search(): ], ) - enhanced = await enhance_messages_with_web_search(req, "AI news") + enhanced = await enhance_messages_with_web_search(req, "AI news", mock_request) assert len(enhanced.messages) == 2 assert enhanced.messages[0]["role"] == "system" @@ -166,3 +175,65 @@ async def test_enhance_messages_with_web_search(): assert enhanced.sources[0].content == "AI news" assert enhanced.sources[1].source == "https://example.com" assert enhanced.sources[1].content == "OpenAI announces GPT-5" + + +@pytest.mark.asyncio +async def test_make_brave_api_request_calls_rps_limit(): + """Test that _make_brave_api_request calls check_brave_rps for rate limiting.""" + mock_request = MagicMock(spec=Request) + mock_data = { + "web": { + "results": [ + { + "title": "Test Result", + "description": "Test description", + "url": "https://example.com/test", + } + ] + } + } + + with ( + patch("nilai_api.config.CONFIG.web_search.api_key", "test-key"), + patch( + "nilai_api.handlers.web_search.CONFIG.web_search.api_path", + "https://api.brave.com/v1/web/search", + ), + patch("nilai_api.handlers.web_search._get_http_client") as mock_client, + patch.object( + RateLimit, "check_brave_rps", new_callable=AsyncMock + ) as mock_check_rps, + ): + mock_http_client = AsyncMock() + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = AsyncMock(return_value=mock_data) + mock_http_client.get = AsyncMock(return_value=mock_response) + mock_client.return_value = mock_http_client + + result = await _make_brave_api_request("test query", mock_request) + + mock_check_rps.assert_called_once_with(mock_request) + assert result == mock_data + + +@pytest.mark.asyncio +async def test_make_brave_api_request_rps_limit_exceeded(): + """Test that _make_brave_api_request raises 429 when RPS limit is exceeded.""" + mock_request = MagicMock(spec=Request) + + with ( + patch("nilai_api.config.CONFIG.web_search.api_key", "test-key"), + patch.object( + RateLimit, "check_brave_rps", new_callable=AsyncMock + ) as mock_check_rps, + ): + mock_check_rps.side_effect = HTTPException( + status_code=429, detail="Too Many Requests" + ) + + with pytest.raises(HTTPException) as exc_info: + await _make_brave_api_request("test query", mock_request) + + assert exc_info.value.status_code == 429 + mock_check_rps.assert_called_once_with(mock_request) From 08362d4082b6b3412f724fdf6cc6c056131643f1 Mon Sep 17 00:00:00 2001 From: blefo Date: Thu, 13 Nov 2025 13:29:54 +0100 Subject: [PATCH 2/2] fix: tests + updated docstring in web_search handler --- nilai-api/src/nilai_api/handlers/web_search.py | 12 ++++++++++++ tests/unit/nilai_api/test_web_search.py | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/nilai-api/src/nilai_api/handlers/web_search.py b/nilai-api/src/nilai_api/handlers/web_search.py index 35700923..c0205816 100644 --- a/nilai-api/src/nilai_api/handlers/web_search.py +++ b/nilai-api/src/nilai_api/handlers/web_search.py @@ -97,6 +97,7 @@ async def _make_brave_api_request(query: str, request: Request) -> Dict[str, Any Args: query: The search query string to execute + request: FastAPI request object for rate limiting Returns: Dict containing the raw API response data @@ -242,6 +243,16 @@ async def perform_web_search_async(query: str, request: Request) -> WebSearchCon Fetches only the exact page for each Brave URL and extracts its main content with trafilatura. If extraction fails, falls back to the Brave snippet. + + Args: + query: The search query string to execute + request: FastAPI request object for rate limiting + + Returns: + WebSearchContext with formatted search results and source information + + Raises: + HTTPException: If no results are found (404) or if the API request fails """ if not (query and query.strip()): logger.warning("Empty or invalid query provided for web search") @@ -721,6 +732,7 @@ async def handle_web_search_for_responses( req: ResponseRequest containing input to process model_name: Name of the LLM model to use for query generation client: LLM client instance for making API calls + request: FastAPI request object for rate limiting Returns: WebSearchEnhancedInput with web search context added, or original diff --git a/tests/unit/nilai_api/test_web_search.py b/tests/unit/nilai_api/test_web_search.py index 72af39f9..8f71631b 100644 --- a/tests/unit/nilai_api/test_web_search.py +++ b/tests/unit/nilai_api/test_web_search.py @@ -207,7 +207,7 @@ async def test_make_brave_api_request_calls_rps_limit(): mock_http_client = AsyncMock() mock_response = AsyncMock() mock_response.status_code = 200 - mock_response.json = AsyncMock(return_value=mock_data) + mock_response.json = MagicMock(return_value=mock_data) mock_http_client.get = AsyncMock(return_value=mock_response) mock_client.return_value = mock_http_client