Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions chat_api/chats/chats_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def merge_token_items(chat_list: list) -> list:

if done_item:
merged_data.append(done_item)
print("merge data >>>>>>>>>>>>>", merged_data)

return merged_data

Expand All @@ -46,13 +47,20 @@ async def get_chat_stream(token: str, chat_request: ChatRequest):

async with httpx.AsyncClient(timeout=httpx.Timeout(60.0, read=120.0)) as client:
async with client.stream("POST", url, json=chat_request_payload) as response:
if chat_request.thread_id is None:
thread_request = ThreadCreateRequest(email=chat_request.email, device_type=chat_request.device_type, application_name=chat_request.application)
thread = create_thread(thread_request=thread_request)
yield (
f"data: {json.dumps({'thread_id': str(thread.id)})}\n\n"
).encode("utf-8")
# Stream the response chunks
async for line in response.aiter_lines():
frame = sse_frame_from_line(line, on_json=chat_list.append)
if frame:
yield frame

if len(chat_list) > 0:

if chat_request.thread_id is None:
thread_request = ThreadCreateRequest(email=email, device_type=chat_request.device_type, application_name=chat_request.application)
thread = create_thread(thread_request=thread_request)
Expand All @@ -63,9 +71,6 @@ async def get_chat_stream(token: str, chat_request: ChatRequest):
response_payload = ChatResponsePayload(thread_id=thread_id, response=merged_chat_list, question=chat_request.query)
save_chat(db_session, response_payload=response_payload)

yield (
f"data: {json.dumps({'thread_id': str(thread_id)})}\n\n"
).encode("utf-8")

def sse_frame_from_line(
line: str,
Expand Down
48 changes: 45 additions & 3 deletions tests/chats/test_chats_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ def test_sse_frame_from_line_data_line_json() -> None:
assert collected == [{"x": 1}]


def test_sse_frame_from_line_without_data_prefix() -> None:
"""Test that lines without 'data:' prefix are still parsed as JSON."""
collected = []
frame = sse_frame_from_line('{"type": "token", "data": "test"}', on_json=collected.append)
assert frame == b'data: {"type": "token", "data": "test"}\n\n'
assert collected == [{"type": "token", "data": "test"}]


def test_merge_token_items_merges_all_tokens() -> None:
chat_list = [
{"type": "token", "data": "Hello"},
Expand Down Expand Up @@ -71,6 +79,23 @@ def test_merge_token_items_handles_no_tokens() -> None:
assert result[1] == {"type": "done", "data": {}}


def test_merge_token_items_done_item_always_last() -> None:
"""Test that done item is always placed at the end after merged tokens."""
chat_list = [
{"type": "token", "data": "First"},
{"type": "done", "data": {}},
{"type": "token", "data": " Second"},
{"type": "other", "data": "something"},
{"type": "token", "data": " Third"},
]
result = merge_token_items(chat_list)
# Should have: other item, merged tokens, done item
assert len(result) == 3
assert result[0] == {"type": "other", "data": "something"}
assert result[1] == {"data": "First Second Third", "type": "token"}
assert result[2] == {"type": "done", "data": {}}


@patch("chat_api.chats.chats_services.save_chat")
@patch("chat_api.chats.chats_services.SessionLocal")
@patch("chat_api.chats.chats_services.create_thread")
Expand Down Expand Up @@ -117,8 +142,15 @@ async def _collect():

chunks = asyncio.run(_collect())

# Final chunk should include the thread_id that was created
# Thread ID should be yielded when creating a new thread
assert any(b"thread_id" in c for c in chunks)
assert any(str(thread_id).encode("utf-8") in c for c in chunks)

# Verify the first chunk contains the thread_id
first_chunk = chunks[0].decode("utf-8")
assert "thread_id" in first_chunk
assert str(thread_id) in first_chunk

mock_create_thread.assert_called_once()
mock_sessionlocal.assert_called_once()
mock_save_chat.assert_called_once()
Expand Down Expand Up @@ -171,7 +203,13 @@ async def _collect():
# Should not create a new thread
mock_create_thread.assert_not_called()
mock_save_chat.assert_called_once()
assert any(existing_thread_id.encode("utf-8") in c for c in chunks)
# Thread ID should NOT be in chunks when using existing thread (only yielded for new threads)
assert not any(b"thread_id" in c for c in chunks)

# Verify save_chat was called with the existing thread_id
call_args = mock_save_chat.call_args
response_payload = call_args[1]["response_payload"]
assert str(response_payload.thread_id) == existing_thread_id


@patch("chat_api.chats.chats_services.save_chat")
Expand Down Expand Up @@ -232,8 +270,12 @@ async def _collect():
# Check that the response has merged tokens
assert len(response_payload.response) == 3 # search_results, merged token, done
assert response_payload.response[0]["type"] == "search_results"
assert response_payload.response[1]["type"] == "token"
assert response_payload.response[1]["data"] == "Hello world!" # Merged
assert response_payload.response[1]["type"] == "token"
assert response_payload.response[2]["type"] == "done"

# Verify thread_id and question are correctly set
assert response_payload.thread_id == thread_id
assert response_payload.question == "hi"


Loading