Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nilai-api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dependencies = [
"accelerate>=1.1.1",
"alembic>=1.14.1",
"cryptography>=43.0.1",
"duckduckgo-search>=8.1.1",
"ddgs>=9.4.1",
"fastapi[standard]>=0.115.5",
"gunicorn>=23.0.0",
"nilai-common",
Expand Down
94 changes: 78 additions & 16 deletions nilai-api/src/nilai_api/handlers/web_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@
import logging
from typing import List

from duckduckgo_search import DDGS
from ddgs import DDGS
from fastapi import HTTPException, status

from nilai_common.api_model import Source
from nilai_common import Message
from nilai_common.api_model import EnhancedMessages, WebSearchContext
from nilai_common.api_model import WebSearchEnhancedMessages, WebSearchContext

logger = logging.getLogger(__name__)


def perform_web_search_sync(query: str) -> WebSearchContext:
"""Synchronously query DuckDuckGo and build a contextual prompt.
"""Synchronously query Brave and build a contextual prompt.

The function sends *query* to DuckDuckGo, extracts the first three text results,
The function sends *query* to Brave, extracts the first three text results,
formats them in a single prompt, and returns that prompt together with the
metadata (URL and snippet) of every result.
"""
Expand All @@ -27,7 +27,9 @@ def perform_web_search_sync(query: str) -> WebSearchContext:

try:
with DDGS() as ddgs:
raw_results = list(ddgs.text(query, max_results=3, region="us-en"))
raw_results = list(
ddgs.text(query=query, max_results=3, region="us-en", backend="brave")
)

if not raw_results:
raise HTTPException(
Expand All @@ -43,7 +45,11 @@ def perform_web_search_sync(query: str) -> WebSearchContext:
title = result["title"]
body = result["body"][:500]
snippets.append(f"{title}: {body}")
sources.append(Source(source=result["href"], content=body))
sources.append(
Source(
source=result.get("href", result.get("url", "")), content=body
)
)

prompt = (
"You have access to the following current information from web search:\n"
Expand All @@ -70,27 +76,83 @@ async def get_web_search_context(query: str) -> WebSearchContext:

async def enhance_messages_with_web_search(
messages: List[Message], query: str
) -> EnhancedMessages:
) -> WebSearchEnhancedMessages:
"""Enhance a list of messages with web search context.

Args:
messages: List of conversation messages to enhance
query: Search query to retrieve web search results for

Returns:
WebSearchEnhancedMessages containing the original messages with web search
context prepended as a system message, along with source information
"""
ctx = await get_web_search_context(query)
enhanced = [Message(role="system", content=ctx.prompt)] + messages
return EnhancedMessages(messages=enhanced, sources=ctx.sources)
return WebSearchEnhancedMessages(messages=enhanced, sources=ctx.sources)


async def handle_web_search(req_messages: List[Message]) -> EnhancedMessages:
"""Handle web search for the given messages.
async def generate_search_query_from_llm(
user_message: str, model_name: str, client
) -> str:
"""Generate a web search query from a user message using an LLM.

Only the last user message is used as the query.
"""
Args:
user_message: The user's input message to convert into a search query
model_name: The name of the LLM model to use for query generation
client: The LLM client instance for making API calls

Returns:
A concise web search query string optimized for information retrieval

Raises:
Exception: If the LLM API call fails or returns an invalid response
"""
system_prompt = """You are given a user question. Generate a concise web search query that would help retrieve information to answer the question. If you cannot improve the user's question, simply repeat it as the search query. Do not answer the query. The query must be at least 10 words long. Output only the search query.\n\nExample:\nUser: Who won the Roland Garros Open in 2024? Just reply with the winner's name.\nSearch query: Roland Garros 2024 winner"""
messages = [
Message(role="system", content=system_prompt),
Message(role="user", content=user_message),
]
req = {
"model": model_name,
"messages": [m.model_dump() for m in messages],
"max_tokens": 150,
}
response = await client.chat.completions.create(**req)
logger.info(
f"For {[m.model_dump() for m in messages]}, Generated search query: {response.choices[0].message.content.strip()}"
)
return response.choices[0].message.content.strip()


async def handle_web_search(
req_messages: List[Message], model_name: str, client
) -> WebSearchEnhancedMessages:
"""Handle web search enhancement for a conversation.

Extracts the most recent user message, generates an optimized search query
using an LLM, and enhances the conversation with web search results.

Args:
req_messages: List of conversation messages to process
model_name: Name of the LLM model to use for query generation
client: LLM client instance for making API calls

Returns:
WebSearchEnhancedMessages with web search context added, or original
messages if no user query is found or search fails
"""
user_query = ""
for message in reversed(req_messages):
if message.role == "user":
user_query = message.content
break

if not user_query:
return EnhancedMessages(messages=req_messages, sources=[])
return WebSearchEnhancedMessages(messages=req_messages, sources=[])
try:
return await enhance_messages_with_web_search(req_messages, user_query)
concise_query = await generate_search_query_from_llm(
user_query, model_name, client
)
return await enhance_messages_with_web_search(req_messages, concise_query)
except Exception:
return EnhancedMessages(messages=req_messages, sources=[])
return WebSearchEnhancedMessages(messages=req_messages, sources=[])
13 changes: 6 additions & 7 deletions nilai-api/src/nilai_api/routers/private.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
Source,
Usage,
)
from openai import AsyncOpenAI, OpenAI
from openai import AsyncOpenAI

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -159,7 +159,7 @@ async def chat_completion(
### Web Search Feature
When web_search=True, the system will:
- Extract the user's query from the last user message
- Perform a web search using DuckDuckGo API
- Perform a web search using Brave API
- Enhance the conversation context with current information
- Add search results as a system message for better responses

Expand Down Expand Up @@ -202,19 +202,19 @@ async def chat_completion(
f"Chat completion request for model {model_name} from user {auth_info.user.userid} on url: {model_url}"
)

client = AsyncOpenAI(base_url=model_url, api_key="<not-needed>")

if req.nilrag:
await handle_nilrag(req)

messages = req.messages
sources: Optional[List[Source]] = None
if req.web_search:
web_search_result = await handle_web_search(messages)
web_search_result = await handle_web_search(messages, model_name, client)
messages = web_search_result.messages
sources = web_search_result.sources

if req.stream:
client = AsyncOpenAI(base_url=model_url, api_key="<not-needed>")

# Forwarding Streamed Responses
async def chat_completion_stream_generator() -> AsyncGenerator[str, None]:
try:
Expand Down Expand Up @@ -270,8 +270,7 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]:
chat_completion_stream_generator(),
media_type="text/event-stream", # Ensure client interprets as Server-Sent Events
)
client = OpenAI(base_url=model_url, api_key="<not-needed>")
response = client.chat.completions.create(
response = await client.chat.completions.create(
model=req.model,
messages=messages, # type: ignore
stream=req.stream,
Expand Down
4 changes: 2 additions & 2 deletions packages/nilai-common/src/nilai_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
AMDAttestationToken,
NVAttestationToken,
Source,
EnhancedMessages,
WebSearchEnhancedMessages,
WebSearchContext,
)
from nilai_common.config import SETTINGS
Expand All @@ -34,6 +34,6 @@
"NVAttestationToken",
"SETTINGS",
"Source",
"EnhancedMessages",
"WebSearchEnhancedMessages",
"WebSearchContext",
]
2 changes: 1 addition & 1 deletion packages/nilai-common/src/nilai_common/api_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class Source(BaseModel):
content: str


class EnhancedMessages(BaseModel):
class WebSearchEnhancedMessages(BaseModel):
messages: List[Message]
sources: List[Source]

Expand Down
15 changes: 7 additions & 8 deletions tests/e2e/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,9 +725,9 @@ def test_model_streaming_request_high_token(client):
"model",
test_models,
)
def test_web_search_roland_garros_2024(client, model):
"""Test web_search using a query that requires up-to-date information (Roland Garros 2024 winner)."""
max_retries = 3
def test_web_search(client, model):
"""Test web_search checking that the sources field is not None."""
max_retries = 10
last_exception = None

for attempt in range(max_retries):
Expand Down Expand Up @@ -755,11 +755,10 @@ def test_web_search_roland_garros_2024(client, model):
assert response.model == model
assert len(response.choices) > 0

content = response.choices[0].message.content
assert content, "Response content is empty."

keywords = ["carlos", "alcaraz", "iga", "świątek", "swiatek"]
assert any(k in content.lower() for k in keywords)
sources = getattr(response, "sources", None)
assert sources is not None, "Sources field should not be None"
assert isinstance(sources, list), "Sources should be a list"
assert len(sources) > 0, "Sources should not be empty"

print(f"Success on attempt {attempt + 1}")
return
Expand Down
17 changes: 8 additions & 9 deletions tests/unit/nilai_api/routers/test_private.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,21 +170,22 @@ def test_get_models(mock_user, mock_user_manager, mock_state, client):

def test_chat_completion(mock_user, mock_state, mock_user_manager, mocker, client):
mocker.patch("openai.api_key", new="test-api-key")

from openai.types.chat import ChatCompletion

data = RESPONSE.model_dump()

data.pop("signature")
data.pop("sources", None)

response_data = ChatCompletion(**data)

# Patch nilai_api.routers.private.AsyncOpenAI to return a mock instance with chat.completions.create as an AsyncMock
mock_chat_completions = MagicMock()
mock_chat_completions.create = mocker.AsyncMock(return_value=response_data)
mock_chat = MagicMock()
mock_chat.completions = mock_chat_completions
mock_async_openai_instance = MagicMock()
mock_async_openai_instance.chat = mock_chat
mocker.patch(
"openai._base_client.SyncAPIClient._request", return_value=response_data
"nilai_api.routers.private.AsyncOpenAI", return_value=mock_async_openai_instance
)

# Mock client.post behavior
response = client.post(
"/v1/chat/completions",
json={
Expand All @@ -196,8 +197,6 @@ def test_chat_completion(mock_user, mock_state, mock_user_manager, mocker, clien
},
headers={"Authorization": "Bearer test-api-key"},
)

# Assertions
assert response.status_code == 200
assert "usage" in response.json()
assert response.json()["usage"] == {
Expand Down
46 changes: 28 additions & 18 deletions tests/unit/nilai_api/test_web_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
handle_web_search,
)
from nilai_common import Message
from nilai_common.api_model import WebSearchContext, EnhancedMessages
from nilai_common.api_model import WebSearchContext, WebSearchEnhancedMessages


def test_perform_web_search_sync_success():
Expand Down Expand Up @@ -100,16 +100,22 @@ async def test_handle_web_search():
messages = [
Message(role="user", content="Tell me about current events"),
]

with patch(
"nilai_api.handlers.web_search.enhance_messages_with_web_search"
) as mock_enhance:
mock_enhance.return_value = EnhancedMessages(
with (
patch(
"nilai_api.handlers.web_search.enhance_messages_with_web_search"
) as mock_enhance,
patch(
"nilai_api.handlers.web_search.generate_search_query_from_llm"
) as mock_generate_query,
):
mock_enhance.return_value = WebSearchEnhancedMessages(
messages=[Message(role="system", content="Enhanced context")] + messages,
sources=[],
)
enhanced = await handle_web_search(messages)

mock_generate_query.return_value = "Tell me about current events"
dummy_client = MagicMock()
enhanced = await handle_web_search(messages, "dummy-model", dummy_client)
mock_generate_query.assert_called_once()
mock_enhance.assert_called_once_with(messages, "Tell me about current events")
assert len(enhanced.messages) == len(messages) + 1
assert enhanced.sources == []
Expand All @@ -121,9 +127,8 @@ async def test_handle_web_search_no_user_message():
messages = [
Message(role="assistant", content="Hello! How can I help you?"),
]

enhanced = await handle_web_search(messages)

dummy_client = MagicMock()
enhanced = await handle_web_search(messages, "dummy-model", dummy_client)
assert enhanced.messages == messages
assert enhanced.sources == []

Expand All @@ -135,14 +140,19 @@ async def test_handle_web_search_exception_handling():
Message(role="system", content="You are a helpful assistant"),
Message(role="user", content="What's the weather like?"),
]

with patch(
"nilai_api.handlers.web_search.enhance_messages_with_web_search"
) as mock_enhance:
with (
patch(
"nilai_api.handlers.web_search.enhance_messages_with_web_search"
) as mock_enhance,
patch(
"nilai_api.handlers.web_search.generate_search_query_from_llm"
) as mock_generate_query,
):
mock_enhance.side_effect = Exception("Search service unavailable")

enhanced = await handle_web_search(messages)

mock_generate_query.return_value = "What's the weather like?"
dummy_client = MagicMock()
enhanced = await handle_web_search(messages, "dummy-model", dummy_client)
mock_generate_query.assert_called_once()
assert enhanced.messages == messages
assert enhanced.sources == []

Expand Down
Loading