|
4 | 4 | from unittest.mock import MagicMock, patch |
5 | 5 |
|
6 | 6 | import pytest |
7 | | -from openai_harmony import StreamState |
| 7 | +from openai_harmony import Author, Message, Role, StreamState, TextContent |
8 | 8 |
|
9 | 9 | from vllm.entrypoints.context import HarmonyContext, StreamingHarmonyContext |
10 | 10 | from vllm.outputs import CompletionOutput, RequestOutput |
@@ -312,9 +312,9 @@ async def test_negative_tool_tokens_edge_case(): |
312 | 312 | @pytest.mark.asyncio |
313 | 313 | async def test_streaming_multi_turn_token_counting(mock_parser): |
314 | 314 | """Test token counting for streaming multi-turn conversations. |
315 | | - |
316 | | - This test focuses on how StreamingHarmonyContext counts tokens in a |
317 | | - multi-turn conversation with streaming (token-by-token) outputs and |
| 315 | +
|
| 316 | + This test focuses on how StreamingHarmonyContext counts tokens in a |
| 317 | + multi-turn conversation with streaming (token-by-token) outputs and |
318 | 318 | message boundaries. |
319 | 319 | """ |
320 | 320 | # Create a streaming context |
@@ -423,3 +423,78 @@ async def test_streaming_multi_turn_token_counting(mock_parser): |
423 | 423 | additional_tool_tokens = 13 - 8 - 3 # = 2 |
424 | 424 | assert context.num_tool_output_tokens == expected_tool_tokens \ |
425 | 425 | + additional_tool_tokens |
| 426 | + |
| 427 | + |
| 428 | +@pytest.mark.asyncio |
| 429 | +async def test_streaming_message_synchronization(mock_parser): |
| 430 | + """Test message synchronization logic from lines 413-417 in context.py. |
| 431 | +
|
| 432 | + This test verifies that when parser.messages contains more messages than |
| 433 | + the context's _messages (minus initial messages), the context properly |
| 434 | + extends its message list with the new parser messages. |
| 435 | + """ |
| 436 | + |
| 437 | + # Create a streaming context with some initial messages |
| 438 | + initial_messages = [ |
| 439 | + Message( |
| 440 | + author=Author(role=Role.USER, name="user"), |
| 441 | + content=[TextContent(text="Hello")], |
| 442 | + recipient=Role.ASSISTANT, |
| 443 | + ) |
| 444 | + ] |
| 445 | + context = StreamingHarmonyContext(messages=initial_messages, |
| 446 | + available_tools=[]) |
| 447 | + |
| 448 | + # Verify initial state |
| 449 | + assert len(context._messages) == 1 |
| 450 | + assert context.num_init_messages == 1 |
| 451 | + |
| 452 | + # Mock parser to have more messages than context |
| 453 | + # Simulate parser having processed 3 new messages |
| 454 | + mock_parser.messages = [ |
| 455 | + Message( |
| 456 | + author=Author(role=Role.ASSISTANT, name="assistant"), |
| 457 | + content=[TextContent(text="Response 1")], |
| 458 | + recipient=Role.USER, |
| 459 | + ), |
| 460 | + ] |
| 461 | + |
| 462 | + # This should trigger the message synchronization logic |
| 463 | + context.append_output( |
| 464 | + create_mock_request_output(prompt_token_ids=[1, 2, 3], |
| 465 | + output_token_ids=[101], |
| 466 | + finished=False)) |
| 467 | + |
| 468 | + # Verify that messages were synchronized |
| 469 | + assert len(context._messages) == 2 |
| 470 | + |
| 471 | + # Verify the new messages were added correctly |
| 472 | + assert context._messages[1].content[0].text == "Response 1" |
| 473 | + |
| 474 | + # Test the specific condition from line 413-414: |
| 475 | + # len(self._messages) - self.num_init_messages < len(self.parser.messages) |
| 476 | + messages_minus_init = len(context._messages) - context.num_init_messages |
| 477 | + parser_messages_count = len(mock_parser.messages) |
| 478 | + |
| 479 | + # After synchronization, they should be equal (no longer less than) |
| 480 | + assert messages_minus_init == parser_messages_count |
| 481 | + |
| 482 | + # Test edge case: add one more parser message |
| 483 | + mock_parser.messages.append( |
| 484 | + Message( |
| 485 | + author=Author(role=Role.ASSISTANT, name="assistant"), |
| 486 | + content=[TextContent(text="Response 4")], |
| 487 | + recipient=Role.USER, |
| 488 | + )) |
| 489 | + |
| 490 | + # Create another output to trigger synchronization again |
| 491 | + mock_output2 = create_mock_request_output(prompt_token_ids=[1, 2, 3], |
| 492 | + output_token_ids=[102], |
| 493 | + finished=True) |
| 494 | + |
| 495 | + context.append_output(mock_output2) |
| 496 | + |
| 497 | + # Verify the fourth message was added, num_init_messages is still 1 |
| 498 | + assert len(context._messages) == 3 |
| 499 | + assert context.num_init_messages == 1 |
| 500 | + assert context._messages[2].content[0].text == "Response 4" |
0 commit comments