diff --git a/README.md b/README.md index 271f8213..3f28151e 100644 --- a/README.md +++ b/README.md @@ -108,17 +108,19 @@ up -d docker run -d --name redis \ -p 6379:6379 \ redis:latest - -# Start PostgreSQL -docker run -d --name postgres \ - -e POSTGRES_USER=${POSTGRES_USER} \ - -e POSTGRES_PASSWORD=${POSTGRES_PASSWORD} \ - -e POSTGRES_DB=${POSTGRES_DB} \ - -p 5432:5432 \ - --network frontend_net \ - --volume postgres_data:/var/lib/postgresql/data \ - postgres:16 -``` + ``` + +2. **Start PostgreSQL** + ```shell + docker run -d --name postgres \ + -e POSTGRES_USER=${POSTGRES_USER} \ + -e POSTGRES_PASSWORD=${POSTGRES_PASSWORD} \ + -e POSTGRES_DB=${POSTGRES_DB} \ + -p 5432:5432 \ + --network frontend_net \ + --volume postgres_data:/var/lib/postgresql/data \ + postgres:16 + ``` 2. **Run API Server** ```shell @@ -191,11 +193,12 @@ To configure vLLM for **local execution on macOS**, execute the following steps: ```shell # Clone vLLM repository (root folder) git clone https://github.com/vllm-project/vllm.git +cd vllm git checkout v0.7.3 # We use v0.7.3 # Build vLLM OpenAI (vllm folder) -cd vllm docker build -f Dockerfile.arm -t vllm/vllm-openai . --shm-size=4g -# Build nilai attestation container + +# Build nilai attestation container (root folder) docker build -t nillion/nilai-attestation:latest -f docker/attestation.Dockerfile . # Build vLLM docker container (root folder) docker build -t nillion/nilai-vllm:latest -f docker/vllm.Dockerfile . diff --git a/nilai-api/pyproject.toml b/nilai-api/pyproject.toml index 079934aa..b91b3e54 100644 --- a/nilai-api/pyproject.toml +++ b/nilai-api/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "accelerate>=1.1.1", "alembic>=1.14.1", "cryptography>=43.0.1", + "duckduckgo-search>=8.1.1", "fastapi[standard]>=0.115.5", "gunicorn>=23.0.0", "nilai-common", diff --git a/nilai-api/src/nilai_api/handlers/web_search.py b/nilai-api/src/nilai_api/handlers/web_search.py new file mode 100644 index 00000000..f93f4dc8 --- /dev/null +++ b/nilai-api/src/nilai_api/handlers/web_search.py @@ -0,0 +1,96 @@ +import asyncio +import logging +from typing import List + +from duckduckgo_search import DDGS +from fastapi import HTTPException, status + +from nilai_common.api_model import Source +from nilai_common import Message +from nilai_common.api_model import EnhancedMessages, WebSearchContext + +logger = logging.getLogger(__name__) + + +def perform_web_search_sync(query: str) -> WebSearchContext: + """Synchronously query DuckDuckGo and build a contextual prompt. + + The function sends *query* to DuckDuckGo, extracts the first three text results, + formats them in a single prompt, and returns that prompt together with the + metadata (URL and snippet) of every result. + """ + if not query or not query.strip(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Web search requested with an empty query", + ) + + try: + with DDGS() as ddgs: + raw_results = list(ddgs.text(query, max_results=3, region="us-en")) + + if not raw_results: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Web search failed, service currently unavailable", + ) + + snippets: List[str] = [] + sources: List[Source] = [] + + for result in raw_results: + if result.get("title") and result.get("body"): + title = result["title"] + body = result["body"][:500] + snippets.append(f"{title}: {body}") + sources.append(Source(source=result["href"], content=body)) + + prompt = ( + "You have access to the following current information from web search:\n" + + "\n".join(snippets) + ) + + return WebSearchContext(prompt=prompt, sources=sources) + + except HTTPException: + raise + except Exception as exc: + logger.error("Error performing web search: %s", exc) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Web search failed, service currently unavailable", + ) from exc + + +async def get_web_search_context(query: str) -> WebSearchContext: + """Non-blocking wrapper around *perform_web_search_sync*.""" + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, perform_web_search_sync, query) + + +async def enhance_messages_with_web_search( + messages: List[Message], query: str +) -> EnhancedMessages: + ctx = await get_web_search_context(query) + enhanced = [Message(role="system", content=ctx.prompt)] + messages + return EnhancedMessages(messages=enhanced, sources=ctx.sources) + + +async def handle_web_search(req_messages: List[Message]) -> EnhancedMessages: + """Handle web search for the given messages. + + Only the last user message is used as the query. + """ + + user_query = "" + for message in reversed(req_messages): + if message.role == "user": + user_query = message.content + break + + if not user_query: + return EnhancedMessages(messages=req_messages, sources=[]) + try: + return await enhance_messages_with_web_search(req_messages, user_query) + except Exception: + return EnhancedMessages(messages=req_messages, sources=[]) diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index dbc9fc42..22a47c25 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -5,6 +5,7 @@ from typing import AsyncGenerator, Optional, Union, List, Tuple from nilai_api.attestation import get_attestation_report from nilai_api.handlers.nilrag import handle_nilrag +from nilai_api.handlers.web_search import handle_web_search from fastapi import APIRouter, Body, Depends, HTTPException, status, Request from fastapi.responses import StreamingResponse @@ -23,8 +24,9 @@ Message, ModelMetadata, SignedChatCompletion, - Usage, Nonce, + Source, + Usage, ) from openai import AsyncOpenAI, OpenAI @@ -139,6 +141,7 @@ async def chat_completion( - Must include non-empty list of messages - Must specify a model - Supports multiple message formats (system, user, assistant) + - Optional web_search parameter to enhance context with current information ### Response Components - Model-generated text completion @@ -147,10 +150,18 @@ async def chat_completion( ### Processing Steps 1. Validate input request parameters - 2. Prepare messages for model processing - 3. Generate AI model response - 4. Track and update token usage - 5. Cryptographically sign the response + 2. If web_search is enabled, perform web search and enhance context + 3. Prepare messages for model processing + 4. Generate AI model response + 5. Track and update token usage + 6. Cryptographically sign the response + + ### Web Search Feature + When web_search=True, the system will: + - Extract the user's query from the last user message + - Perform a web search using DuckDuckGo API + - Enhance the conversation context with current information + - Add search results as a system message for better responses ### Potential HTTP Errors - **400 Bad Request**: @@ -161,13 +172,13 @@ async def chat_completion( ### Example ```python - # Generate a chat completion + # Generate a chat completion with web search request = ChatRequest( model="meta-llama/Llama-3.2-1B-Instruct", messages=[ {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": "Hello, who are you?"} - ] + {"role": "user", "content": "What's the latest news about AI?"} + ], ) response = await chat_completion(request, user) """ @@ -194,6 +205,13 @@ async def chat_completion( if req.nilrag: await handle_nilrag(req) + messages = req.messages + sources: Optional[List[Source]] = None + if req.web_search: + web_search_result = await handle_web_search(messages) + messages = web_search_result.messages + sources = web_search_result.sources + if req.stream: client = AsyncOpenAI(base_url=model_url, api_key="") @@ -202,7 +220,7 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: try: response = await client.chat.completions.create( model=req.model, - messages=req.messages, # type: ignore + messages=messages, # type: ignore stream=req.stream, # type: ignore top_p=req.top_p, temperature=req.temperature, @@ -255,7 +273,7 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: client = OpenAI(base_url=model_url, api_key="") response = client.chat.completions.create( model=req.model, - messages=req.messages, # type: ignore + messages=messages, # type: ignore stream=req.stream, top_p=req.top_p, temperature=req.temperature, @@ -266,7 +284,9 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: model_response = SignedChatCompletion( **response.model_dump(), signature="", + sources=sources, ) + if model_response.usage is None: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, diff --git a/packages/nilai-common/src/nilai_common/__init__.py b/packages/nilai-common/src/nilai_common/__init__.py index 69ada60b..18bcf74c 100644 --- a/packages/nilai-common/src/nilai_common/__init__.py +++ b/packages/nilai-common/src/nilai_common/__init__.py @@ -10,10 +10,13 @@ Nonce, AMDAttestationToken, NVAttestationToken, + Source, + EnhancedMessages, + WebSearchContext, ) -from openai.types.completion_usage import CompletionUsage as Usage from nilai_common.config import SETTINGS from nilai_common.discovery import ModelServiceDiscovery +from openai.types.completion_usage import CompletionUsage as Usage __all__ = [ "Message", @@ -30,4 +33,7 @@ "AMDAttestationToken", "NVAttestationToken", "SETTINGS", + "Source", + "EnhancedMessages", + "WebSearchContext", ] diff --git a/packages/nilai-common/src/nilai_common/api_model.py b/packages/nilai-common/src/nilai_common/api_model.py index f092e2d5..d18126dd 100644 --- a/packages/nilai-common/src/nilai_common/api_model.py +++ b/packages/nilai-common/src/nilai_common/api_model.py @@ -1,5 +1,6 @@ import uuid -from typing import Annotated, List, Optional, Literal, Iterable + +from typing import Annotated, Iterable, List, Literal, Optional from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat.chat_completion import Choice as OpenaAIChoice @@ -15,6 +16,23 @@ class Choice(OpenaAIChoice): pass +class Source(BaseModel): + source: str + content: str + + +class EnhancedMessages(BaseModel): + messages: List[Message] + sources: List[Source] + + +class WebSearchContext(BaseModel): + """Prompt and sources obtained from a web search.""" + + prompt: str + sources: List[Source] + + class ChatRequest(BaseModel): model: str messages: List[Message] = Field(..., min_length=1) @@ -24,10 +42,17 @@ class ChatRequest(BaseModel): stream: Optional[bool] = False tools: Optional[Iterable[ChatCompletionToolParam]] = None nilrag: Optional[dict] = {} + web_search: Optional[bool] = Field( + default=False, + description="Enable web search to enhance context with current information", + ) class SignedChatCompletion(ChatCompletion): signature: str + sources: Optional[List[Source]] = Field( + default=None, description="Sources used for web search when enabled" + ) class ModelMetadata(BaseModel): diff --git a/tests/e2e/test_openai.py b/tests/e2e/test_openai.py index 48ad1884..35ff8bdb 100644 --- a/tests/e2e/test_openai.py +++ b/tests/e2e/test_openai.py @@ -719,3 +719,55 @@ def test_model_streaming_request_high_token(client): assert chunk_count > 0, ( "Should receive at least one chunk for high token streaming request" ) + + +@pytest.mark.parametrize( + "model", + test_models, +) +def test_web_search_roland_garros_2024(client, model): + """Test web_search using a query that requires up-to-date information (Roland Garros 2024 winner).""" + max_retries = 3 + last_exception = None + + for attempt in range(max_retries): + try: + print(f"\nAttempt {attempt + 1}/{max_retries}...") + + response = client.chat.completions.create( + model=model, + messages=[ + { + "role": "system", + "content": "You are a helpful assistant that provides accurate and up-to-date information.", + }, + { + "role": "user", + "content": "Who won the Roland Garros Open in 2024? Just reply with the winner's name.", + }, + ], + extra_body={"web_search": True}, + temperature=0.2, + max_tokens=150, + ) + + assert isinstance(response, ChatCompletion) + assert response.model == model + assert len(response.choices) > 0 + + content = response.choices[0].message.content + assert content, "Response content is empty." + + keywords = ["carlos", "alcaraz", "iga", "świątek", "swiatek"] + assert any(k in content.lower() for k in keywords) + + print(f"Success on attempt {attempt + 1}") + return + except AssertionError as e: + print(f"Assertion failed on attempt {attempt + 1}: {e}") + last_exception = e + if attempt < max_retries - 1: + print("Retrying...") + else: + print("All retries failed.") + raise last_exception diff --git a/tests/unit/nilai_api/routers/test_private.py b/tests/unit/nilai_api/routers/test_private.py index f3138302..4289194d 100644 --- a/tests/unit/nilai_api/routers/test_private.py +++ b/tests/unit/nilai_api/routers/test_private.py @@ -170,12 +170,16 @@ def test_get_models(mock_user, mock_user_manager, mock_state, client): def test_chat_completion(mock_user, mock_state, mock_user_manager, mocker, client): mocker.patch("openai.api_key", new="test-api-key") - # Mock the response from the OpenAI API + from openai.types.chat import ChatCompletion data = RESPONSE.model_dump() + data.pop("signature") + data.pop("sources", None) + response_data = ChatCompletion(**data) + mocker.patch( "openai._base_client.SyncAPIClient._request", return_value=response_data ) diff --git a/tests/unit/nilai_api/test_web_search.py b/tests/unit/nilai_api/test_web_search.py new file mode 100644 index 00000000..799cf551 --- /dev/null +++ b/tests/unit/nilai_api/test_web_search.py @@ -0,0 +1,156 @@ +import pytest +from unittest.mock import patch, MagicMock +from nilai_api.handlers.web_search import ( + perform_web_search_sync, + get_web_search_context, + enhance_messages_with_web_search, + handle_web_search, +) +from nilai_common import Message +from nilai_common.api_model import WebSearchContext, EnhancedMessages + + +def test_perform_web_search_sync_success(): + """Test successful web search with mock response""" + mock_search_results = [ + { + "title": "Latest AI Developments", + "body": "OpenAI announces GPT-5 with improved capabilities and better performance across various tasks.", + "href": "https://example.com/ai1", + }, + { + "title": "AI Breakthrough in Robotics", + "body": "New neural network architecture improves robot learning efficiency by 40% in recent studies.", + "href": "https://example.com/ai2", + }, + ] + + with patch("nilai_api.handlers.web_search.DDGS") as mock_ddgs: + mock_instance = MagicMock() + mock_ddgs.return_value.__enter__.return_value = mock_instance + mock_instance.text.return_value = mock_search_results + mock_instance.news.return_value = [] + + ctx = perform_web_search_sync("AI developments") + + assert len(ctx.sources) == 2 + assert "GPT-5" in ctx.prompt + assert "40%" in ctx.prompt + assert ctx.sources[0].source == "https://example.com/ai1" + assert ctx.sources[1].source == "https://example.com/ai2" + + +def test_perform_web_search_sync_no_results(): + """Test web search with no results""" + with patch("nilai_api.handlers.web_search.DDGS") as mock_ddgs: + mock_instance = MagicMock() + mock_ddgs.return_value.__enter__.return_value = mock_instance + mock_instance.text.return_value = [] + mock_instance.news.return_value = [] + + with pytest.raises(Exception): + _ = perform_web_search_sync("nonexistent query") + + +def test_perform_web_search_sync_fallback_to_news(): + """Test web search fallback to news when text search returns no results""" + mock_news_results = [ + { + "title": "Breaking AI News", + "body": "Major breakthrough in artificial intelligence research announced today.", + "href": "https://example.com/news1", + } + ] + + with patch("nilai_api.handlers.web_search.DDGS") as mock_ddgs: + mock_instance = MagicMock() + mock_ddgs.return_value.__enter__.return_value = mock_instance + mock_instance.text.return_value = [] + mock_instance.news.return_value = mock_news_results + + with pytest.raises(Exception): + _ = perform_web_search_sync("AI news") + + +@pytest.mark.asyncio +async def test_enhance_messages_with_web_search(): + """Test message enhancement with web search results""" + original_messages = [ + Message(role="system", content="You are a helpful assistant"), + Message(role="user", content="What is the latest AI news?"), + ] + + with patch("nilai_api.handlers.web_search.perform_web_search_sync") as mock_search: + mock_search.return_value = WebSearchContext( + prompt="Latest AI Developments: OpenAI announces GPT-5\nAI Breakthrough: New neural network improves efficiency by 40%", + sources=[], + ) + + enhanced = await enhance_messages_with_web_search(original_messages, "AI news") + + assert len(enhanced.messages) == 3 + assert enhanced.messages[0].role == "system" + assert "Latest AI Developments" in str(enhanced.messages[0].content) + assert enhanced.sources == [] + + +@pytest.mark.asyncio +async def test_handle_web_search(): + """Test web search handler with user messages""" + messages = [ + Message(role="user", content="Tell me about current events"), + ] + + with patch( + "nilai_api.handlers.web_search.enhance_messages_with_web_search" + ) as mock_enhance: + mock_enhance.return_value = EnhancedMessages( + messages=[Message(role="system", content="Enhanced context")] + messages, + sources=[], + ) + enhanced = await handle_web_search(messages) + + mock_enhance.assert_called_once_with(messages, "Tell me about current events") + assert len(enhanced.messages) == len(messages) + 1 + assert enhanced.sources == [] + + +@pytest.mark.asyncio +async def test_handle_web_search_no_user_message(): + """Test web search handler with no user message""" + messages = [ + Message(role="assistant", content="Hello! How can I help you?"), + ] + + enhanced = await handle_web_search(messages) + + assert enhanced.messages == messages + assert enhanced.sources == [] + + +@pytest.mark.asyncio +async def test_handle_web_search_exception_handling(): + """Test web search handler exception handling""" + messages = [ + Message(role="system", content="You are a helpful assistant"), + Message(role="user", content="What's the weather like?"), + ] + + with patch( + "nilai_api.handlers.web_search.enhance_messages_with_web_search" + ) as mock_enhance: + mock_enhance.side_effect = Exception("Search service unavailable") + + enhanced = await handle_web_search(messages) + + assert enhanced.messages == messages + assert enhanced.sources == [] + + +@pytest.mark.asyncio +async def test_get_web_search_context_async_wrapper(): + with patch("nilai_api.handlers.web_search.perform_web_search_sync") as mock_sync: + mock_sync.return_value = WebSearchContext(prompt="info", sources=[]) + ctx = await get_web_search_context("query") + assert ctx.prompt == "info" + mock_sync.assert_called_once()