diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index 70245829..a3b598d9 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -1,5 +1,5 @@ # Fast API and serving -import asyncio +import json import logging import time import uuid @@ -306,45 +306,47 @@ async def chat_completion( logger.info(f"[chat] web_search messages: {messages}") if req.stream: - # Forwarding Streamed Responses + async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: + t_call = time.monotonic() + prompt_token_usage = 0 + completion_token_usage = 0 + try: logger.info(f"[chat] stream start request_id={request_id}") - t_call = time.monotonic() - current_messages = messages + request_kwargs = { "model": req.model, - "messages": current_messages, # type: ignore - "stream": True, # type: ignore + "messages": messages, + "stream": True, "top_p": req.top_p, "temperature": req.temperature, "max_tokens": req.max_tokens, "extra_body": { "stream_options": { "include_usage": True, - "continuous_usage_stats": True, + "continuous_usage_stats": False, } }, } if req.tools: - request_kwargs["tools"] = req.tools # type: ignore + request_kwargs["tools"] = req.tools + + response = await client.chat.completions.create(**request_kwargs) - response = await client.chat.completions.create(**request_kwargs) # type: ignore - prompt_token_usage: int = 0 - completion_token_usage: int = 0 async for chunk in response: - data = chunk.model_dump_json(exclude_unset=True) - yield f"data: {data}\n\n" - await asyncio.sleep(0) - - prompt_token_usage = ( - chunk.usage.prompt_tokens if chunk.usage else prompt_token_usage - ) - completion_token_usage = ( - chunk.usage.completion_tokens - if chunk.usage - else completion_token_usage - ) + if chunk.usage is not None: + prompt_token_usage = chunk.usage.prompt_tokens + completion_token_usage = chunk.usage.completion_tokens + + payload = chunk.model_dump(exclude_unset=True) + + if chunk.usage is not None and sources: + payload["sources"] = [ + s.model_dump(mode="json") for s in sources + ] + + yield f"data: {json.dumps(payload)}\n\n" await UserManager.update_token_usage( auth_info.user.userid, @@ -359,18 +361,26 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: web_search_calls=len(sources) if sources else 0, ) logger.info( - f"[chat] stream done request_id={request_id} prompt_tokens={prompt_token_usage} completion_tokens={completion_token_usage} duration_ms={(time.monotonic() - t_call) * 1000:.0f} total_ms={(time.monotonic() - t_start) * 1000:.0f}" + "[chat] stream done request_id=%s prompt_tokens=%d completion_tokens=%d " + "duration_ms=%.0f total_ms=%.0f", + request_id, + prompt_token_usage, + completion_token_usage, + (time.monotonic() - t_call) * 1000, + (time.monotonic() - t_start) * 1000, ) except Exception as e: - logger.error(f"[chat] stream error request_id={request_id} error={e}") - return + logger.error( + "[chat] stream error request_id=%s error=%s", request_id, e + ) + yield f"data: {json.dumps({'error': 'stream_failed', 'message': str(e)})}\n\n" - # Return the streaming response return StreamingResponse( chat_completion_stream_generator(), - media_type="text/event-stream", # Ensure client interprets as Server-Sent Events + media_type="text/event-stream", ) + current_messages = messages request_kwargs = { "model": req.model, diff --git a/tests/e2e/test_openai.py b/tests/e2e/test_openai.py index bce60198..987365d4 100644 --- a/tests/e2e/test_openai.py +++ b/tests/e2e/test_openai.py @@ -251,9 +251,6 @@ def test_streaming_chat_completion(client, model): if chunk.usage: had_usage = True print(f"Model {model} usage: {chunk.usage}") - - # Limit processing to avoid long tests - if chunk_count >= 20: break assert had_usage, f"No usage data received for {model} streaming request" assert chunk_count > 0, f"No chunks received for {model} streaming request" @@ -459,6 +456,7 @@ def test_function_calling_with_streaming(client, model): if chunk.usage: had_usage = True print(f"Model {model} usage: {chunk.usage}") + break assert had_tool_call, f"No tool calls received for {model} streaming request" assert had_usage, f"No usage data received for {model} streaming request" diff --git a/tests/unit/nilai_api/routers/test_private.py b/tests/unit/nilai_api/routers/test_private.py index ca71ca22..13d9d781 100644 --- a/tests/unit/nilai_api/routers/test_private.py +++ b/tests/unit/nilai_api/routers/test_private.py @@ -1,11 +1,12 @@ import asyncio +import json from unittest.mock import AsyncMock, MagicMock import pytest from fastapi.testclient import TestClient from nilai_api.db.users import RateLimits, UserModel -from nilai_common import AttestationReport +from nilai_common import AttestationReport, Source from nilai_api.state import state from ... import model_endpoint, model_metadata, response as RESPONSE @@ -221,3 +222,106 @@ def test_chat_completion(mock_user, mock_state, mock_user_manager, mocker, clien "completion_tokens_details": None, "prompt_tokens_details": None, } + + +def test_chat_completion_stream_includes_sources( + mock_user, mock_state, mock_user_manager, mocker, client +): + source = Source(source="https://example.com", content="Example result") + + mock_web_search_result = MagicMock() + mock_web_search_result.messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me something new."}, + ] + mock_web_search_result.sources = [source] + + mocker.patch( + "nilai_api.routers.private.handle_web_search", + new=AsyncMock(return_value=mock_web_search_result), + ) + + class MockChunk: + def __init__(self, data, usage=None): + self._data = data + self.usage = usage + + def model_dump(self, exclude_unset=True): + return self._data + + class MockUsage: + def __init__(self, prompt_tokens: int, completion_tokens: int): + self.prompt_tokens = prompt_tokens + self.completion_tokens = completion_tokens + + first_chunk = MockChunk( + data={ + "id": "stream-1", + "object": "chat.completion.chunk", + "model": "meta-llama/Llama-3.2-1B-Instruct", + "created": 0, + "choices": [{"delta": {"content": "Hello"}, "index": 0}], + } + ) + + final_chunk = MockChunk( + data={ + "id": "stream-1", + "object": "chat.completion.chunk", + "model": "meta-llama/Llama-3.2-1B-Instruct", + "created": 0, + "choices": [ + {"delta": {}, "finish_reason": "stop", "index": 0}, + ], + "usage": { + "prompt_tokens": 5, + "completion_tokens": 7, + "total_tokens": 12, + }, + }, + usage=MockUsage(prompt_tokens=5, completion_tokens=7), + ) + + async def chunk_generator(): + yield first_chunk + yield final_chunk + + mock_chat_completions = MagicMock() + mock_chat_completions.create = AsyncMock(return_value=chunk_generator()) + mock_chat = MagicMock() + mock_chat.completions = mock_chat_completions + mock_async_openai_instance = MagicMock() + mock_async_openai_instance.chat = mock_chat + + mocker.patch( + "nilai_api.routers.private.AsyncOpenAI", + return_value=mock_async_openai_instance, + ) + + payload = { + "model": "meta-llama/Llama-3.2-1B-Instruct", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me something new."}, + ], + "stream": True, + "web_search": True, + } + + headers = {"Authorization": "Bearer test-api-key"} + + with client.stream( + "POST", "/v1/chat/completions", json=payload, headers=headers + ) as response: + assert response.status_code == 200 + data_lines = [ + line for line in response.iter_lines() if line and line.startswith("data: ") + ] + + assert data_lines, "Expected SSE data from stream response" + first_payload = json.loads(data_lines[0][len("data: ") :]) + assert "sources" not in first_payload + final_payload = json.loads(data_lines[-1][len("data: ") :]) + assert "sources" in final_payload + assert len(final_payload["sources"]) == 1 + assert final_payload["sources"][0]["source"] == "https://example.com"