From 88dd7671acbc814893e05a1abe2d17c030907434 Mon Sep 17 00:00:00 2001 From: blefo Date: Wed, 16 Jul 2025 14:00:55 +0200 Subject: [PATCH 1/3] docs: update README docs: format PostgreSQL setup instructions in README for clarity --- README.md | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 271f8213..3f28151e 100644 --- a/README.md +++ b/README.md @@ -108,17 +108,19 @@ up -d docker run -d --name redis \ -p 6379:6379 \ redis:latest - -# Start PostgreSQL -docker run -d --name postgres \ - -e POSTGRES_USER=${POSTGRES_USER} \ - -e POSTGRES_PASSWORD=${POSTGRES_PASSWORD} \ - -e POSTGRES_DB=${POSTGRES_DB} \ - -p 5432:5432 \ - --network frontend_net \ - --volume postgres_data:/var/lib/postgresql/data \ - postgres:16 -``` + ``` + +2. **Start PostgreSQL** + ```shell + docker run -d --name postgres \ + -e POSTGRES_USER=${POSTGRES_USER} \ + -e POSTGRES_PASSWORD=${POSTGRES_PASSWORD} \ + -e POSTGRES_DB=${POSTGRES_DB} \ + -p 5432:5432 \ + --network frontend_net \ + --volume postgres_data:/var/lib/postgresql/data \ + postgres:16 + ``` 2. **Run API Server** ```shell @@ -191,11 +193,12 @@ To configure vLLM for **local execution on macOS**, execute the following steps: ```shell # Clone vLLM repository (root folder) git clone https://github.com/vllm-project/vllm.git +cd vllm git checkout v0.7.3 # We use v0.7.3 # Build vLLM OpenAI (vllm folder) -cd vllm docker build -f Dockerfile.arm -t vllm/vllm-openai . --shm-size=4g -# Build nilai attestation container + +# Build nilai attestation container (root folder) docker build -t nillion/nilai-attestation:latest -f docker/attestation.Dockerfile . # Build vLLM docker container (root folder) docker build -t nillion/nilai-vllm:latest -f docker/vllm.Dockerfile . From 827f63d3d651d416ce0540f1601016e913ba7f5d Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Mon, 21 Jul 2025 11:54:25 +0200 Subject: [PATCH 2/3] feat: web_search parameter in the chat completion endpoint web_search parameter can be set to true or false in the chat endpoint. If true, the first three searches are added to the context. The user can access the sources in the output object --- nilai-api/pyproject.toml | 1 + .../src/nilai_api/handlers/web_search.py | 96 +++++++++++ nilai-api/src/nilai_api/routers/private.py | 40 +++-- .../nilai-common/src/nilai_common/__init__.py | 8 +- .../src/nilai_common/api_model.py | 45 ++++- tests/e2e/test_openai.py | 31 ++++ tests/unit/nilai_api/routers/test_private.py | 6 +- tests/unit/nilai_api/test_web_search.py | 156 ++++++++++++++++++ 8 files changed, 370 insertions(+), 13 deletions(-) create mode 100644 nilai-api/src/nilai_api/handlers/web_search.py create mode 100644 tests/unit/nilai_api/test_web_search.py diff --git a/nilai-api/pyproject.toml b/nilai-api/pyproject.toml index 079934aa..b91b3e54 100644 --- a/nilai-api/pyproject.toml +++ b/nilai-api/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "accelerate>=1.1.1", "alembic>=1.14.1", "cryptography>=43.0.1", + "duckduckgo-search>=8.1.1", "fastapi[standard]>=0.115.5", "gunicorn>=23.0.0", "nilai-common", diff --git a/nilai-api/src/nilai_api/handlers/web_search.py b/nilai-api/src/nilai_api/handlers/web_search.py new file mode 100644 index 00000000..f93f4dc8 --- /dev/null +++ b/nilai-api/src/nilai_api/handlers/web_search.py @@ -0,0 +1,96 @@ +import asyncio +import logging +from typing import List + +from duckduckgo_search import DDGS +from fastapi import HTTPException, status + +from nilai_common.api_model import Source +from nilai_common import Message +from nilai_common.api_model import EnhancedMessages, WebSearchContext + +logger = logging.getLogger(__name__) + + +def perform_web_search_sync(query: str) -> WebSearchContext: + """Synchronously query DuckDuckGo and build a contextual prompt. + + The function sends *query* to DuckDuckGo, extracts the first three text results, + formats them in a single prompt, and returns that prompt together with the + metadata (URL and snippet) of every result. + """ + if not query or not query.strip(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Web search requested with an empty query", + ) + + try: + with DDGS() as ddgs: + raw_results = list(ddgs.text(query, max_results=3, region="us-en")) + + if not raw_results: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Web search failed, service currently unavailable", + ) + + snippets: List[str] = [] + sources: List[Source] = [] + + for result in raw_results: + if result.get("title") and result.get("body"): + title = result["title"] + body = result["body"][:500] + snippets.append(f"{title}: {body}") + sources.append(Source(source=result["href"], content=body)) + + prompt = ( + "You have access to the following current information from web search:\n" + + "\n".join(snippets) + ) + + return WebSearchContext(prompt=prompt, sources=sources) + + except HTTPException: + raise + except Exception as exc: + logger.error("Error performing web search: %s", exc) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Web search failed, service currently unavailable", + ) from exc + + +async def get_web_search_context(query: str) -> WebSearchContext: + """Non-blocking wrapper around *perform_web_search_sync*.""" + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, perform_web_search_sync, query) + + +async def enhance_messages_with_web_search( + messages: List[Message], query: str +) -> EnhancedMessages: + ctx = await get_web_search_context(query) + enhanced = [Message(role="system", content=ctx.prompt)] + messages + return EnhancedMessages(messages=enhanced, sources=ctx.sources) + + +async def handle_web_search(req_messages: List[Message]) -> EnhancedMessages: + """Handle web search for the given messages. + + Only the last user message is used as the query. + """ + + user_query = "" + for message in reversed(req_messages): + if message.role == "user": + user_query = message.content + break + + if not user_query: + return EnhancedMessages(messages=req_messages, sources=[]) + try: + return await enhance_messages_with_web_search(req_messages, user_query) + except Exception: + return EnhancedMessages(messages=req_messages, sources=[]) diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index dbc9fc42..22a47c25 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -5,6 +5,7 @@ from typing import AsyncGenerator, Optional, Union, List, Tuple from nilai_api.attestation import get_attestation_report from nilai_api.handlers.nilrag import handle_nilrag +from nilai_api.handlers.web_search import handle_web_search from fastapi import APIRouter, Body, Depends, HTTPException, status, Request from fastapi.responses import StreamingResponse @@ -23,8 +24,9 @@ Message, ModelMetadata, SignedChatCompletion, - Usage, Nonce, + Source, + Usage, ) from openai import AsyncOpenAI, OpenAI @@ -139,6 +141,7 @@ async def chat_completion( - Must include non-empty list of messages - Must specify a model - Supports multiple message formats (system, user, assistant) + - Optional web_search parameter to enhance context with current information ### Response Components - Model-generated text completion @@ -147,10 +150,18 @@ async def chat_completion( ### Processing Steps 1. Validate input request parameters - 2. Prepare messages for model processing - 3. Generate AI model response - 4. Track and update token usage - 5. Cryptographically sign the response + 2. If web_search is enabled, perform web search and enhance context + 3. Prepare messages for model processing + 4. Generate AI model response + 5. Track and update token usage + 6. Cryptographically sign the response + + ### Web Search Feature + When web_search=True, the system will: + - Extract the user's query from the last user message + - Perform a web search using DuckDuckGo API + - Enhance the conversation context with current information + - Add search results as a system message for better responses ### Potential HTTP Errors - **400 Bad Request**: @@ -161,13 +172,13 @@ async def chat_completion( ### Example ```python - # Generate a chat completion + # Generate a chat completion with web search request = ChatRequest( model="meta-llama/Llama-3.2-1B-Instruct", messages=[ {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": "Hello, who are you?"} - ] + {"role": "user", "content": "What's the latest news about AI?"} + ], ) response = await chat_completion(request, user) """ @@ -194,6 +205,13 @@ async def chat_completion( if req.nilrag: await handle_nilrag(req) + messages = req.messages + sources: Optional[List[Source]] = None + if req.web_search: + web_search_result = await handle_web_search(messages) + messages = web_search_result.messages + sources = web_search_result.sources + if req.stream: client = AsyncOpenAI(base_url=model_url, api_key="") @@ -202,7 +220,7 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: try: response = await client.chat.completions.create( model=req.model, - messages=req.messages, # type: ignore + messages=messages, # type: ignore stream=req.stream, # type: ignore top_p=req.top_p, temperature=req.temperature, @@ -255,7 +273,7 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: client = OpenAI(base_url=model_url, api_key="") response = client.chat.completions.create( model=req.model, - messages=req.messages, # type: ignore + messages=messages, # type: ignore stream=req.stream, top_p=req.top_p, temperature=req.temperature, @@ -266,7 +284,9 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: model_response = SignedChatCompletion( **response.model_dump(), signature="", + sources=sources, ) + if model_response.usage is None: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, diff --git a/packages/nilai-common/src/nilai_common/__init__.py b/packages/nilai-common/src/nilai_common/__init__.py index 69ada60b..18bcf74c 100644 --- a/packages/nilai-common/src/nilai_common/__init__.py +++ b/packages/nilai-common/src/nilai_common/__init__.py @@ -10,10 +10,13 @@ Nonce, AMDAttestationToken, NVAttestationToken, + Source, + EnhancedMessages, + WebSearchContext, ) -from openai.types.completion_usage import CompletionUsage as Usage from nilai_common.config import SETTINGS from nilai_common.discovery import ModelServiceDiscovery +from openai.types.completion_usage import CompletionUsage as Usage __all__ = [ "Message", @@ -30,4 +33,7 @@ "AMDAttestationToken", "NVAttestationToken", "SETTINGS", + "Source", + "EnhancedMessages", + "WebSearchContext", ] diff --git a/packages/nilai-common/src/nilai_common/api_model.py b/packages/nilai-common/src/nilai_common/api_model.py index f092e2d5..7a5e0526 100644 --- a/packages/nilai-common/src/nilai_common/api_model.py +++ b/packages/nilai-common/src/nilai_common/api_model.py @@ -1,5 +1,6 @@ import uuid -from typing import Annotated, List, Optional, Literal, Iterable + +from typing import Annotated, Iterable, List, Literal, Optional from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat.chat_completion import Choice as OpenaAIChoice @@ -15,6 +16,23 @@ class Choice(OpenaAIChoice): pass +class Source(BaseModel): + source: str + content: str + + +class EnhancedMessages(BaseModel): + messages: List[Message] + sources: List[Source] + + +class WebSearchContext(BaseModel): + """Prompt and sources obtained from a web search.""" + + prompt: str + sources: List[Source] + + class ChatRequest(BaseModel): model: str messages: List[Message] = Field(..., min_length=1) @@ -24,10 +42,17 @@ class ChatRequest(BaseModel): stream: Optional[bool] = False tools: Optional[Iterable[ChatCompletionToolParam]] = None nilrag: Optional[dict] = {} + web_search: Optional[bool] = Field( + default=False, + description="Enable web search to enhance context with current information", + ) class SignedChatCompletion(ChatCompletion): signature: str + sources: Optional[List[Source]] = Field( + default=None, description="Sources used for web search when enabled" + ) class ModelMetadata(BaseModel): @@ -75,3 +100,21 @@ class AttestationReport(BaseModel): verifying_key: Annotated[str, Field(description="PEM encoded public key")] cpu_attestation: AMDAttestationToken gpu_attestation: NVAttestationToken + + +__all__ = [ + "Message", + "Choice", + "Source", + "ChatRequest", + "SignedChatCompletion", + "EnhancedMessages", + "ModelMetadata", + "ModelEndpoint", + "HealthCheckResponse", + "Nonce", + "AMDAttestationToken", + "NVAttestationToken", + "AttestationReport", + "WebSearchContext", +] diff --git a/tests/e2e/test_openai.py b/tests/e2e/test_openai.py index 48ad1884..893dd8bd 100644 --- a/tests/e2e/test_openai.py +++ b/tests/e2e/test_openai.py @@ -719,3 +719,34 @@ def test_model_streaming_request_high_token(client): assert chunk_count > 0, ( "Should receive at least one chunk for high token streaming request" ) + + +@pytest.mark.parametrize( + "model", + test_models, +) +def test_web_search_eurovision_2024(client, model): + """Test web_search using a query that requires up-to-date information (Eurovision 2024 winner).""" + response = client.chat.completions.create( + model=model, + messages=[ + { + "role": "system", + "content": "You are a helpful assistant that provides accurate and up-to-date information.", + }, + { + "role": "user", + "content": "Who won the Roland Garros 0pen in 2024?", + }, + ], + extra_body={"web_search": True}, + temperature=0.2, + max_tokens=150, + ) + assert isinstance(response, ChatCompletion) + assert response.model == model + assert len(response.choices) > 0 + content = response.choices[0].message.content + assert content + keywords = ["carlos", "alcaraz", "iga", "świątek", "swiatek"] + assert any(k in content.lower() for k in keywords) diff --git a/tests/unit/nilai_api/routers/test_private.py b/tests/unit/nilai_api/routers/test_private.py index f3138302..4289194d 100644 --- a/tests/unit/nilai_api/routers/test_private.py +++ b/tests/unit/nilai_api/routers/test_private.py @@ -170,12 +170,16 @@ def test_get_models(mock_user, mock_user_manager, mock_state, client): def test_chat_completion(mock_user, mock_state, mock_user_manager, mocker, client): mocker.patch("openai.api_key", new="test-api-key") - # Mock the response from the OpenAI API + from openai.types.chat import ChatCompletion data = RESPONSE.model_dump() + data.pop("signature") + data.pop("sources", None) + response_data = ChatCompletion(**data) + mocker.patch( "openai._base_client.SyncAPIClient._request", return_value=response_data ) diff --git a/tests/unit/nilai_api/test_web_search.py b/tests/unit/nilai_api/test_web_search.py new file mode 100644 index 00000000..799cf551 --- /dev/null +++ b/tests/unit/nilai_api/test_web_search.py @@ -0,0 +1,156 @@ +import pytest +from unittest.mock import patch, MagicMock +from nilai_api.handlers.web_search import ( + perform_web_search_sync, + get_web_search_context, + enhance_messages_with_web_search, + handle_web_search, +) +from nilai_common import Message +from nilai_common.api_model import WebSearchContext, EnhancedMessages + + +def test_perform_web_search_sync_success(): + """Test successful web search with mock response""" + mock_search_results = [ + { + "title": "Latest AI Developments", + "body": "OpenAI announces GPT-5 with improved capabilities and better performance across various tasks.", + "href": "https://example.com/ai1", + }, + { + "title": "AI Breakthrough in Robotics", + "body": "New neural network architecture improves robot learning efficiency by 40% in recent studies.", + "href": "https://example.com/ai2", + }, + ] + + with patch("nilai_api.handlers.web_search.DDGS") as mock_ddgs: + mock_instance = MagicMock() + mock_ddgs.return_value.__enter__.return_value = mock_instance + mock_instance.text.return_value = mock_search_results + mock_instance.news.return_value = [] + + ctx = perform_web_search_sync("AI developments") + + assert len(ctx.sources) == 2 + assert "GPT-5" in ctx.prompt + assert "40%" in ctx.prompt + assert ctx.sources[0].source == "https://example.com/ai1" + assert ctx.sources[1].source == "https://example.com/ai2" + + +def test_perform_web_search_sync_no_results(): + """Test web search with no results""" + with patch("nilai_api.handlers.web_search.DDGS") as mock_ddgs: + mock_instance = MagicMock() + mock_ddgs.return_value.__enter__.return_value = mock_instance + mock_instance.text.return_value = [] + mock_instance.news.return_value = [] + + with pytest.raises(Exception): + _ = perform_web_search_sync("nonexistent query") + + +def test_perform_web_search_sync_fallback_to_news(): + """Test web search fallback to news when text search returns no results""" + mock_news_results = [ + { + "title": "Breaking AI News", + "body": "Major breakthrough in artificial intelligence research announced today.", + "href": "https://example.com/news1", + } + ] + + with patch("nilai_api.handlers.web_search.DDGS") as mock_ddgs: + mock_instance = MagicMock() + mock_ddgs.return_value.__enter__.return_value = mock_instance + mock_instance.text.return_value = [] + mock_instance.news.return_value = mock_news_results + + with pytest.raises(Exception): + _ = perform_web_search_sync("AI news") + + +@pytest.mark.asyncio +async def test_enhance_messages_with_web_search(): + """Test message enhancement with web search results""" + original_messages = [ + Message(role="system", content="You are a helpful assistant"), + Message(role="user", content="What is the latest AI news?"), + ] + + with patch("nilai_api.handlers.web_search.perform_web_search_sync") as mock_search: + mock_search.return_value = WebSearchContext( + prompt="Latest AI Developments: OpenAI announces GPT-5\nAI Breakthrough: New neural network improves efficiency by 40%", + sources=[], + ) + + enhanced = await enhance_messages_with_web_search(original_messages, "AI news") + + assert len(enhanced.messages) == 3 + assert enhanced.messages[0].role == "system" + assert "Latest AI Developments" in str(enhanced.messages[0].content) + assert enhanced.sources == [] + + +@pytest.mark.asyncio +async def test_handle_web_search(): + """Test web search handler with user messages""" + messages = [ + Message(role="user", content="Tell me about current events"), + ] + + with patch( + "nilai_api.handlers.web_search.enhance_messages_with_web_search" + ) as mock_enhance: + mock_enhance.return_value = EnhancedMessages( + messages=[Message(role="system", content="Enhanced context")] + messages, + sources=[], + ) + enhanced = await handle_web_search(messages) + + mock_enhance.assert_called_once_with(messages, "Tell me about current events") + assert len(enhanced.messages) == len(messages) + 1 + assert enhanced.sources == [] + + +@pytest.mark.asyncio +async def test_handle_web_search_no_user_message(): + """Test web search handler with no user message""" + messages = [ + Message(role="assistant", content="Hello! How can I help you?"), + ] + + enhanced = await handle_web_search(messages) + + assert enhanced.messages == messages + assert enhanced.sources == [] + + +@pytest.mark.asyncio +async def test_handle_web_search_exception_handling(): + """Test web search handler exception handling""" + messages = [ + Message(role="system", content="You are a helpful assistant"), + Message(role="user", content="What's the weather like?"), + ] + + with patch( + "nilai_api.handlers.web_search.enhance_messages_with_web_search" + ) as mock_enhance: + mock_enhance.side_effect = Exception("Search service unavailable") + + enhanced = await handle_web_search(messages) + + assert enhanced.messages == messages + assert enhanced.sources == [] + + +@pytest.mark.asyncio +async def test_get_web_search_context_async_wrapper(): + with patch("nilai_api.handlers.web_search.perform_web_search_sync") as mock_sync: + mock_sync.return_value = WebSearchContext(prompt="info", sources=[]) + ctx = await get_web_search_context("query") + assert ctx.prompt == "info" + mock_sync.assert_called_once() From 270c328f158dde713b94adcab7abfb77ffae3f0f Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Tue, 22 Jul 2025 11:55:31 +0200 Subject: [PATCH 3/3] fix: update test_web_search_roland_garros_2024 Add retry logic to reduce failed tests due to source quality --- .../src/nilai_common/api_model.py | 18 ----- tests/e2e/test_openai.py | 71 ++++++++++++------- 2 files changed, 46 insertions(+), 43 deletions(-) diff --git a/packages/nilai-common/src/nilai_common/api_model.py b/packages/nilai-common/src/nilai_common/api_model.py index 7a5e0526..d18126dd 100644 --- a/packages/nilai-common/src/nilai_common/api_model.py +++ b/packages/nilai-common/src/nilai_common/api_model.py @@ -100,21 +100,3 @@ class AttestationReport(BaseModel): verifying_key: Annotated[str, Field(description="PEM encoded public key")] cpu_attestation: AMDAttestationToken gpu_attestation: NVAttestationToken - - -__all__ = [ - "Message", - "Choice", - "Source", - "ChatRequest", - "SignedChatCompletion", - "EnhancedMessages", - "ModelMetadata", - "ModelEndpoint", - "HealthCheckResponse", - "Nonce", - "AMDAttestationToken", - "NVAttestationToken", - "AttestationReport", - "WebSearchContext", -] diff --git a/tests/e2e/test_openai.py b/tests/e2e/test_openai.py index 893dd8bd..35ff8bdb 100644 --- a/tests/e2e/test_openai.py +++ b/tests/e2e/test_openai.py @@ -725,28 +725,49 @@ def test_model_streaming_request_high_token(client): "model", test_models, ) -def test_web_search_eurovision_2024(client, model): - """Test web_search using a query that requires up-to-date information (Eurovision 2024 winner).""" - response = client.chat.completions.create( - model=model, - messages=[ - { - "role": "system", - "content": "You are a helpful assistant that provides accurate and up-to-date information.", - }, - { - "role": "user", - "content": "Who won the Roland Garros 0pen in 2024?", - }, - ], - extra_body={"web_search": True}, - temperature=0.2, - max_tokens=150, - ) - assert isinstance(response, ChatCompletion) - assert response.model == model - assert len(response.choices) > 0 - content = response.choices[0].message.content - assert content - keywords = ["carlos", "alcaraz", "iga", "świątek", "swiatek"] - assert any(k in content.lower() for k in keywords) +def test_web_search_roland_garros_2024(client, model): + """Test web_search using a query that requires up-to-date information (Roland Garros 2024 winner).""" + max_retries = 3 + last_exception = None + + for attempt in range(max_retries): + try: + print(f"\nAttempt {attempt + 1}/{max_retries}...") + + response = client.chat.completions.create( + model=model, + messages=[ + { + "role": "system", + "content": "You are a helpful assistant that provides accurate and up-to-date information.", + }, + { + "role": "user", + "content": "Who won the Roland Garros Open in 2024? Just reply with the winner's name.", + }, + ], + extra_body={"web_search": True}, + temperature=0.2, + max_tokens=150, + ) + + assert isinstance(response, ChatCompletion) + assert response.model == model + assert len(response.choices) > 0 + + content = response.choices[0].message.content + assert content, "Response content is empty." + + keywords = ["carlos", "alcaraz", "iga", "świątek", "swiatek"] + assert any(k in content.lower() for k in keywords) + + print(f"Success on attempt {attempt + 1}") + return + except AssertionError as e: + print(f"Assertion failed on attempt {attempt + 1}: {e}") + last_exception = e + if attempt < max_retries - 1: + print("Retrying...") + else: + print("All retries failed.") + raise last_exception