Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,19 @@ up -d
docker run -d --name redis \
-p 6379:6379 \
redis:latest

# Start PostgreSQL
docker run -d --name postgres \
-e POSTGRES_USER=${POSTGRES_USER} \
-e POSTGRES_PASSWORD=${POSTGRES_PASSWORD} \
-e POSTGRES_DB=${POSTGRES_DB} \
-p 5432:5432 \
--network frontend_net \
--volume postgres_data:/var/lib/postgresql/data \
postgres:16
```
```

2. **Start PostgreSQL**
```shell
docker run -d --name postgres \
-e POSTGRES_USER=${POSTGRES_USER} \
-e POSTGRES_PASSWORD=${POSTGRES_PASSWORD} \
-e POSTGRES_DB=${POSTGRES_DB} \
-p 5432:5432 \
--network frontend_net \
--volume postgres_data:/var/lib/postgresql/data \
postgres:16
```

2. **Run API Server**
```shell
Expand Down Expand Up @@ -191,11 +193,12 @@ To configure vLLM for **local execution on macOS**, execute the following steps:
```shell
# Clone vLLM repository (root folder)
git clone https://github.com/vllm-project/vllm.git
cd vllm
git checkout v0.7.3 # We use v0.7.3
# Build vLLM OpenAI (vllm folder)
cd vllm
docker build -f Dockerfile.arm -t vllm/vllm-openai . --shm-size=4g
# Build nilai attestation container

# Build nilai attestation container (root folder)
docker build -t nillion/nilai-attestation:latest -f docker/attestation.Dockerfile .
# Build vLLM docker container (root folder)
docker build -t nillion/nilai-vllm:latest -f docker/vllm.Dockerfile .
Expand Down
1 change: 1 addition & 0 deletions nilai-api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies = [
"accelerate>=1.1.1",
"alembic>=1.14.1",
"cryptography>=43.0.1",
"duckduckgo-search>=8.1.1",
"fastapi[standard]>=0.115.5",
"gunicorn>=23.0.0",
"nilai-common",
Expand Down
96 changes: 96 additions & 0 deletions nilai-api/src/nilai_api/handlers/web_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import asyncio
import logging
from typing import List

from duckduckgo_search import DDGS
from fastapi import HTTPException, status

from nilai_common.api_model import Source
from nilai_common import Message
from nilai_common.api_model import EnhancedMessages, WebSearchContext

logger = logging.getLogger(__name__)


def perform_web_search_sync(query: str) -> WebSearchContext:
"""Synchronously query DuckDuckGo and build a contextual prompt.

The function sends *query* to DuckDuckGo, extracts the first three text results,
formats them in a single prompt, and returns that prompt together with the
metadata (URL and snippet) of every result.
"""
if not query or not query.strip():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Web search requested with an empty query",
)

try:
with DDGS() as ddgs:
raw_results = list(ddgs.text(query, max_results=3, region="us-en"))

if not raw_results:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Web search failed, service currently unavailable",
)

snippets: List[str] = []
sources: List[Source] = []

for result in raw_results:
if result.get("title") and result.get("body"):
title = result["title"]
body = result["body"][:500]
snippets.append(f"{title}: {body}")
sources.append(Source(source=result["href"], content=body))

prompt = (
"You have access to the following current information from web search:\n"
+ "\n".join(snippets)
)

return WebSearchContext(prompt=prompt, sources=sources)

except HTTPException:
raise
except Exception as exc:
logger.error("Error performing web search: %s", exc)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Web search failed, service currently unavailable",
) from exc


async def get_web_search_context(query: str) -> WebSearchContext:
"""Non-blocking wrapper around *perform_web_search_sync*."""
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, perform_web_search_sync, query)


async def enhance_messages_with_web_search(
messages: List[Message], query: str
) -> EnhancedMessages:
ctx = await get_web_search_context(query)
enhanced = [Message(role="system", content=ctx.prompt)] + messages
return EnhancedMessages(messages=enhanced, sources=ctx.sources)


async def handle_web_search(req_messages: List[Message]) -> EnhancedMessages:
"""Handle web search for the given messages.

Only the last user message is used as the query.
"""

user_query = ""
for message in reversed(req_messages):
if message.role == "user":
user_query = message.content
break

if not user_query:
return EnhancedMessages(messages=req_messages, sources=[])
try:
return await enhance_messages_with_web_search(req_messages, user_query)
except Exception:
return EnhancedMessages(messages=req_messages, sources=[])
40 changes: 30 additions & 10 deletions nilai-api/src/nilai_api/routers/private.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import AsyncGenerator, Optional, Union, List, Tuple
from nilai_api.attestation import get_attestation_report
from nilai_api.handlers.nilrag import handle_nilrag
from nilai_api.handlers.web_search import handle_web_search

from fastapi import APIRouter, Body, Depends, HTTPException, status, Request
from fastapi.responses import StreamingResponse
Expand All @@ -23,8 +24,9 @@
Message,
ModelMetadata,
SignedChatCompletion,
Usage,
Nonce,
Source,
Usage,
)
from openai import AsyncOpenAI, OpenAI

Expand Down Expand Up @@ -139,6 +141,7 @@ async def chat_completion(
- Must include non-empty list of messages
- Must specify a model
- Supports multiple message formats (system, user, assistant)
- Optional web_search parameter to enhance context with current information

### Response Components
- Model-generated text completion
Expand All @@ -147,10 +150,18 @@ async def chat_completion(

### Processing Steps
1. Validate input request parameters
2. Prepare messages for model processing
3. Generate AI model response
4. Track and update token usage
5. Cryptographically sign the response
2. If web_search is enabled, perform web search and enhance context
3. Prepare messages for model processing
4. Generate AI model response
5. Track and update token usage
6. Cryptographically sign the response

### Web Search Feature
When web_search=True, the system will:
- Extract the user's query from the last user message
- Perform a web search using DuckDuckGo API
- Enhance the conversation context with current information
- Add search results as a system message for better responses

### Potential HTTP Errors
- **400 Bad Request**:
Expand All @@ -161,13 +172,13 @@ async def chat_completion(

### Example
```python
# Generate a chat completion
# Generate a chat completion with web search
request = ChatRequest(
model="meta-llama/Llama-3.2-1B-Instruct",
messages=[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello, who are you?"}
]
{"role": "user", "content": "What's the latest news about AI?"}
],
)
response = await chat_completion(request, user)
"""
Expand All @@ -194,6 +205,13 @@ async def chat_completion(
if req.nilrag:
await handle_nilrag(req)

messages = req.messages
sources: Optional[List[Source]] = None
if req.web_search:
web_search_result = await handle_web_search(messages)
messages = web_search_result.messages
sources = web_search_result.sources

if req.stream:
client = AsyncOpenAI(base_url=model_url, api_key="<not-needed>")

Expand All @@ -202,7 +220,7 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]:
try:
response = await client.chat.completions.create(
model=req.model,
messages=req.messages, # type: ignore
messages=messages, # type: ignore
stream=req.stream, # type: ignore
top_p=req.top_p,
temperature=req.temperature,
Expand Down Expand Up @@ -255,7 +273,7 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]:
client = OpenAI(base_url=model_url, api_key="<not-needed>")
response = client.chat.completions.create(
model=req.model,
messages=req.messages, # type: ignore
messages=messages, # type: ignore
stream=req.stream,
top_p=req.top_p,
temperature=req.temperature,
Expand All @@ -266,7 +284,9 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]:
model_response = SignedChatCompletion(
**response.model_dump(),
signature="",
sources=sources,
)

if model_response.usage is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Expand Down
8 changes: 7 additions & 1 deletion packages/nilai-common/src/nilai_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
Nonce,
AMDAttestationToken,
NVAttestationToken,
Source,
EnhancedMessages,
WebSearchContext,
)
from openai.types.completion_usage import CompletionUsage as Usage
from nilai_common.config import SETTINGS
from nilai_common.discovery import ModelServiceDiscovery
from openai.types.completion_usage import CompletionUsage as Usage

__all__ = [
"Message",
Expand All @@ -30,4 +33,7 @@
"AMDAttestationToken",
"NVAttestationToken",
"SETTINGS",
"Source",
"EnhancedMessages",
"WebSearchContext",
]
27 changes: 26 additions & 1 deletion packages/nilai-common/src/nilai_common/api_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import uuid
from typing import Annotated, List, Optional, Literal, Iterable

from typing import Annotated, Iterable, List, Literal, Optional

from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice as OpenaAIChoice
Expand All @@ -15,6 +16,23 @@ class Choice(OpenaAIChoice):
pass


class Source(BaseModel):
source: str
content: str


class EnhancedMessages(BaseModel):
messages: List[Message]
sources: List[Source]


class WebSearchContext(BaseModel):
"""Prompt and sources obtained from a web search."""

prompt: str
sources: List[Source]


class ChatRequest(BaseModel):
model: str
messages: List[Message] = Field(..., min_length=1)
Expand All @@ -24,10 +42,17 @@ class ChatRequest(BaseModel):
stream: Optional[bool] = False
tools: Optional[Iterable[ChatCompletionToolParam]] = None
nilrag: Optional[dict] = {}
web_search: Optional[bool] = Field(
default=False,
description="Enable web search to enhance context with current information",
)


class SignedChatCompletion(ChatCompletion):
signature: str
sources: Optional[List[Source]] = Field(
default=None, description="Sources used for web search when enabled"
)


class ModelMetadata(BaseModel):
Expand Down
52 changes: 52 additions & 0 deletions tests/e2e/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,3 +719,55 @@ def test_model_streaming_request_high_token(client):
assert chunk_count > 0, (
"Should receive at least one chunk for high token streaming request"
)


@pytest.mark.parametrize(
"model",
test_models,
)
def test_web_search_roland_garros_2024(client, model):
"""Test web_search using a query that requires up-to-date information (Roland Garros 2024 winner)."""
max_retries = 3
last_exception = None

for attempt in range(max_retries):
try:
print(f"\nAttempt {attempt + 1}/{max_retries}...")

response = client.chat.completions.create(
model=model,
messages=[
{
"role": "system",
"content": "You are a helpful assistant that provides accurate and up-to-date information.",
},
{
"role": "user",
"content": "Who won the Roland Garros Open in 2024? Just reply with the winner's name.",
},
],
extra_body={"web_search": True},
temperature=0.2,
max_tokens=150,
)

assert isinstance(response, ChatCompletion)
assert response.model == model
assert len(response.choices) > 0

content = response.choices[0].message.content
assert content, "Response content is empty."

keywords = ["carlos", "alcaraz", "iga", "świątek", "swiatek"]
assert any(k in content.lower() for k in keywords)

print(f"Success on attempt {attempt + 1}")
return
except AssertionError as e:
print(f"Assertion failed on attempt {attempt + 1}: {e}")
last_exception = e
if attempt < max_retries - 1:
print("Retrying...")
else:
print("All retries failed.")
raise last_exception
Loading
Loading