Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 38 additions & 28 deletions nilai-api/src/nilai_api/routers/private.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Fast API and serving
import asyncio
import json
import logging
import time
import uuid
Expand Down Expand Up @@ -306,45 +306,47 @@ async def chat_completion(
logger.info(f"[chat] web_search messages: {messages}")

if req.stream:
# Forwarding Streamed Responses

async def chat_completion_stream_generator() -> AsyncGenerator[str, None]:
t_call = time.monotonic()
prompt_token_usage = 0
completion_token_usage = 0

try:
logger.info(f"[chat] stream start request_id={request_id}")
t_call = time.monotonic()
current_messages = messages

request_kwargs = {
"model": req.model,
"messages": current_messages, # type: ignore
"stream": True, # type: ignore
"messages": messages,
"stream": True,
"top_p": req.top_p,
"temperature": req.temperature,
"max_tokens": req.max_tokens,
"extra_body": {
"stream_options": {
"include_usage": True,
"continuous_usage_stats": True,
"continuous_usage_stats": False,
}
},
}
if req.tools:
request_kwargs["tools"] = req.tools # type: ignore
request_kwargs["tools"] = req.tools

response = await client.chat.completions.create(**request_kwargs)

response = await client.chat.completions.create(**request_kwargs) # type: ignore
prompt_token_usage: int = 0
completion_token_usage: int = 0
async for chunk in response:
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
await asyncio.sleep(0)

prompt_token_usage = (
chunk.usage.prompt_tokens if chunk.usage else prompt_token_usage
)
completion_token_usage = (
chunk.usage.completion_tokens
if chunk.usage
else completion_token_usage
)
if chunk.usage is not None:
prompt_token_usage = chunk.usage.prompt_tokens
completion_token_usage = chunk.usage.completion_tokens

payload = chunk.model_dump(exclude_unset=True)

if chunk.usage is not None and sources:
payload["sources"] = [
s.model_dump(mode="json") for s in sources
]

yield f"data: {json.dumps(payload)}\n\n"

await UserManager.update_token_usage(
auth_info.user.userid,
Expand All @@ -359,18 +361,26 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]:
web_search_calls=len(sources) if sources else 0,
)
logger.info(
f"[chat] stream done request_id={request_id} prompt_tokens={prompt_token_usage} completion_tokens={completion_token_usage} duration_ms={(time.monotonic() - t_call) * 1000:.0f} total_ms={(time.monotonic() - t_start) * 1000:.0f}"
"[chat] stream done request_id=%s prompt_tokens=%d completion_tokens=%d "
"duration_ms=%.0f total_ms=%.0f",
request_id,
prompt_token_usage,
completion_token_usage,
(time.monotonic() - t_call) * 1000,
(time.monotonic() - t_start) * 1000,
)

except Exception as e:
logger.error(f"[chat] stream error request_id={request_id} error={e}")
return
logger.error(
"[chat] stream error request_id=%s error=%s", request_id, e
)
yield f"data: {json.dumps({'error': 'stream_failed', 'message': str(e)})}\n\n"

# Return the streaming response
return StreamingResponse(
chat_completion_stream_generator(),
media_type="text/event-stream", # Ensure client interprets as Server-Sent Events
media_type="text/event-stream",
)

current_messages = messages
request_kwargs = {
"model": req.model,
Expand Down
4 changes: 1 addition & 3 deletions tests/e2e/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,6 @@ def test_streaming_chat_completion(client, model):
if chunk.usage:
had_usage = True
print(f"Model {model} usage: {chunk.usage}")

# Limit processing to avoid long tests
if chunk_count >= 20:
break
assert had_usage, f"No usage data received for {model} streaming request"
assert chunk_count > 0, f"No chunks received for {model} streaming request"
Expand Down Expand Up @@ -459,6 +456,7 @@ def test_function_calling_with_streaming(client, model):
if chunk.usage:
had_usage = True
print(f"Model {model} usage: {chunk.usage}")
break

assert had_tool_call, f"No tool calls received for {model} streaming request"
assert had_usage, f"No usage data received for {model} streaming request"
Expand Down
106 changes: 105 additions & 1 deletion tests/unit/nilai_api/routers/test_private.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import asyncio
import json
from unittest.mock import AsyncMock, MagicMock

import pytest
from fastapi.testclient import TestClient

from nilai_api.db.users import RateLimits, UserModel
from nilai_common import AttestationReport
from nilai_common import AttestationReport, Source

from nilai_api.state import state
from ... import model_endpoint, model_metadata, response as RESPONSE
Expand Down Expand Up @@ -221,3 +222,106 @@ def test_chat_completion(mock_user, mock_state, mock_user_manager, mocker, clien
"completion_tokens_details": None,
"prompt_tokens_details": None,
}


def test_chat_completion_stream_includes_sources(
mock_user, mock_state, mock_user_manager, mocker, client
):
source = Source(source="https://example.com", content="Example result")

mock_web_search_result = MagicMock()
mock_web_search_result.messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Tell me something new."},
]
mock_web_search_result.sources = [source]

mocker.patch(
"nilai_api.routers.private.handle_web_search",
new=AsyncMock(return_value=mock_web_search_result),
)

class MockChunk:
def __init__(self, data, usage=None):
self._data = data
self.usage = usage

def model_dump(self, exclude_unset=True):
return self._data

class MockUsage:
def __init__(self, prompt_tokens: int, completion_tokens: int):
self.prompt_tokens = prompt_tokens
self.completion_tokens = completion_tokens

first_chunk = MockChunk(
data={
"id": "stream-1",
"object": "chat.completion.chunk",
"model": "meta-llama/Llama-3.2-1B-Instruct",
"created": 0,
"choices": [{"delta": {"content": "Hello"}, "index": 0}],
}
)

final_chunk = MockChunk(
data={
"id": "stream-1",
"object": "chat.completion.chunk",
"model": "meta-llama/Llama-3.2-1B-Instruct",
"created": 0,
"choices": [
{"delta": {}, "finish_reason": "stop", "index": 0},
],
"usage": {
"prompt_tokens": 5,
"completion_tokens": 7,
"total_tokens": 12,
},
},
usage=MockUsage(prompt_tokens=5, completion_tokens=7),
)

async def chunk_generator():
yield first_chunk
yield final_chunk

mock_chat_completions = MagicMock()
mock_chat_completions.create = AsyncMock(return_value=chunk_generator())
mock_chat = MagicMock()
mock_chat.completions = mock_chat_completions
mock_async_openai_instance = MagicMock()
mock_async_openai_instance.chat = mock_chat

mocker.patch(
"nilai_api.routers.private.AsyncOpenAI",
return_value=mock_async_openai_instance,
)

payload = {
"model": "meta-llama/Llama-3.2-1B-Instruct",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Tell me something new."},
],
"stream": True,
"web_search": True,
}

headers = {"Authorization": "Bearer test-api-key"}

with client.stream(
"POST", "/v1/chat/completions", json=payload, headers=headers
) as response:
assert response.status_code == 200
data_lines = [
line for line in response.iter_lines() if line and line.startswith("data: ")
]

assert data_lines, "Expected SSE data from stream response"
first_payload = json.loads(data_lines[0][len("data: ") :])
assert "sources" not in first_payload
final_payload = json.loads(data_lines[-1][len("data: ") :])
assert "sources" in final_payload
assert len(final_payload["sources"]) == 1
assert final_payload["sources"][0]["source"] == "https://example.com"
Loading