diff --git a/chat_api/chats/chats_services.py b/chat_api/chats/chats_services.py index c0c3e1e..2ab408a 100644 --- a/chat_api/chats/chats_services.py +++ b/chat_api/chats/chats_services.py @@ -33,6 +33,7 @@ def merge_token_items(chat_list: list) -> list: if done_item: merged_data.append(done_item) + print("merge data >>>>>>>>>>>>>", merged_data) return merged_data @@ -46,6 +47,12 @@ async def get_chat_stream(token: str, chat_request: ChatRequest): async with httpx.AsyncClient(timeout=httpx.Timeout(60.0, read=120.0)) as client: async with client.stream("POST", url, json=chat_request_payload) as response: + if chat_request.thread_id is None: + thread_request = ThreadCreateRequest(email=chat_request.email, device_type=chat_request.device_type, application_name=chat_request.application) + thread = create_thread(thread_request=thread_request) + yield ( + f"data: {json.dumps({'thread_id': str(thread.id)})}\n\n" + ).encode("utf-8") # Stream the response chunks async for line in response.aiter_lines(): frame = sse_frame_from_line(line, on_json=chat_list.append) @@ -53,6 +60,7 @@ async def get_chat_stream(token: str, chat_request: ChatRequest): yield frame if len(chat_list) > 0: + if chat_request.thread_id is None: thread_request = ThreadCreateRequest(email=email, device_type=chat_request.device_type, application_name=chat_request.application) thread = create_thread(thread_request=thread_request) @@ -63,9 +71,6 @@ async def get_chat_stream(token: str, chat_request: ChatRequest): response_payload = ChatResponsePayload(thread_id=thread_id, response=merged_chat_list, question=chat_request.query) save_chat(db_session, response_payload=response_payload) - yield ( - f"data: {json.dumps({'thread_id': str(thread_id)})}\n\n" - ).encode("utf-8") def sse_frame_from_line( line: str, diff --git a/tests/chats/test_chats_services.py b/tests/chats/test_chats_services.py index 7a7c90e..348ead2 100644 --- a/tests/chats/test_chats_services.py +++ b/tests/chats/test_chats_services.py @@ -30,6 +30,14 @@ def test_sse_frame_from_line_data_line_json() -> None: assert collected == [{"x": 1}] +def test_sse_frame_from_line_without_data_prefix() -> None: + """Test that lines without 'data:' prefix are still parsed as JSON.""" + collected = [] + frame = sse_frame_from_line('{"type": "token", "data": "test"}', on_json=collected.append) + assert frame == b'data: {"type": "token", "data": "test"}\n\n' + assert collected == [{"type": "token", "data": "test"}] + + def test_merge_token_items_merges_all_tokens() -> None: chat_list = [ {"type": "token", "data": "Hello"}, @@ -71,6 +79,23 @@ def test_merge_token_items_handles_no_tokens() -> None: assert result[1] == {"type": "done", "data": {}} +def test_merge_token_items_done_item_always_last() -> None: + """Test that done item is always placed at the end after merged tokens.""" + chat_list = [ + {"type": "token", "data": "First"}, + {"type": "done", "data": {}}, + {"type": "token", "data": " Second"}, + {"type": "other", "data": "something"}, + {"type": "token", "data": " Third"}, + ] + result = merge_token_items(chat_list) + # Should have: other item, merged tokens, done item + assert len(result) == 3 + assert result[0] == {"type": "other", "data": "something"} + assert result[1] == {"data": "First Second Third", "type": "token"} + assert result[2] == {"type": "done", "data": {}} + + @patch("chat_api.chats.chats_services.save_chat") @patch("chat_api.chats.chats_services.SessionLocal") @patch("chat_api.chats.chats_services.create_thread") @@ -117,8 +142,15 @@ async def _collect(): chunks = asyncio.run(_collect()) - # Final chunk should include the thread_id that was created + # Thread ID should be yielded when creating a new thread + assert any(b"thread_id" in c for c in chunks) assert any(str(thread_id).encode("utf-8") in c for c in chunks) + + # Verify the first chunk contains the thread_id + first_chunk = chunks[0].decode("utf-8") + assert "thread_id" in first_chunk + assert str(thread_id) in first_chunk + mock_create_thread.assert_called_once() mock_sessionlocal.assert_called_once() mock_save_chat.assert_called_once() @@ -171,7 +203,13 @@ async def _collect(): # Should not create a new thread mock_create_thread.assert_not_called() mock_save_chat.assert_called_once() - assert any(existing_thread_id.encode("utf-8") in c for c in chunks) + # Thread ID should NOT be in chunks when using existing thread (only yielded for new threads) + assert not any(b"thread_id" in c for c in chunks) + + # Verify save_chat was called with the existing thread_id + call_args = mock_save_chat.call_args + response_payload = call_args[1]["response_payload"] + assert str(response_payload.thread_id) == existing_thread_id @patch("chat_api.chats.chats_services.save_chat") @@ -232,8 +270,12 @@ async def _collect(): # Check that the response has merged tokens assert len(response_payload.response) == 3 # search_results, merged token, done assert response_payload.response[0]["type"] == "search_results" - assert response_payload.response[1]["type"] == "token" assert response_payload.response[1]["data"] == "Hello world!" # Merged + assert response_payload.response[1]["type"] == "token" assert response_payload.response[2]["type"] == "done" + + # Verify thread_id and question are correctly set + assert response_payload.thread_id == thread_id + assert response_payload.question == "hi"