diff --git a/tests/entrypoints/openai/test_response_api_with_harmony.py b/tests/entrypoints/openai/test_response_api_with_harmony.py index 88b3795abe73..c3a82eda2d2a 100644 --- a/tests/entrypoints/openai/test_response_api_with_harmony.py +++ b/tests/entrypoints/openai/test_response_api_with_harmony.py @@ -341,6 +341,8 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool): events.append(event) assert len(events) > 0 + response_completed_event = events[-1] + assert len(response_completed_event.response.output) > 0 if background: starting_after = 5 diff --git a/tests/entrypoints/test_context.py b/tests/entrypoints/test_context.py index 5e6a4c85ff79..2afe9758c2ad 100644 --- a/tests/entrypoints/test_context.py +++ b/tests/entrypoints/test_context.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch import pytest -from openai_harmony import StreamState +from openai_harmony import Author, Message, Role, StreamState, TextContent from vllm.entrypoints.context import HarmonyContext, StreamingHarmonyContext from vllm.outputs import CompletionOutput, RequestOutput @@ -312,9 +312,9 @@ async def test_negative_tool_tokens_edge_case(): @pytest.mark.asyncio async def test_streaming_multi_turn_token_counting(mock_parser): """Test token counting for streaming multi-turn conversations. - - This test focuses on how StreamingHarmonyContext counts tokens in a - multi-turn conversation with streaming (token-by-token) outputs and + + This test focuses on how StreamingHarmonyContext counts tokens in a + multi-turn conversation with streaming (token-by-token) outputs and message boundaries. """ # Create a streaming context @@ -423,3 +423,78 @@ async def test_streaming_multi_turn_token_counting(mock_parser): additional_tool_tokens = 13 - 8 - 3 # = 2 assert context.num_tool_output_tokens == expected_tool_tokens \ + additional_tool_tokens + + +@pytest.mark.asyncio +async def test_streaming_message_synchronization(mock_parser): + """Test message synchronization logic from lines 413-417 in context.py. + + This test verifies that when parser.messages contains more messages than + the context's _messages (minus initial messages), the context properly + extends its message list with the new parser messages. + """ + + # Create a streaming context with some initial messages + initial_messages = [ + Message( + author=Author(role=Role.USER, name="user"), + content=[TextContent(text="Hello")], + recipient=Role.ASSISTANT, + ) + ] + context = StreamingHarmonyContext(messages=initial_messages, + available_tools=[]) + + # Verify initial state + assert len(context._messages) == 1 + assert context.num_init_messages == 1 + + # Mock parser to have more messages than context + # Simulate parser having processed 3 new messages + mock_parser.messages = [ + Message( + author=Author(role=Role.ASSISTANT, name="assistant"), + content=[TextContent(text="Response 1")], + recipient=Role.USER, + ), + ] + + # This should trigger the message synchronization logic + context.append_output( + create_mock_request_output(prompt_token_ids=[1, 2, 3], + output_token_ids=[101], + finished=False)) + + # Verify that messages were synchronized + assert len(context._messages) == 2 + + # Verify the new messages were added correctly + assert context._messages[1].content[0].text == "Response 1" + + # Test the specific condition from line 413-414: + # len(self._messages) - self.num_init_messages < len(self.parser.messages) + messages_minus_init = len(context._messages) - context.num_init_messages + parser_messages_count = len(mock_parser.messages) + + # After synchronization, they should be equal (no longer less than) + assert messages_minus_init == parser_messages_count + + # Test edge case: add one more parser message + mock_parser.messages.append( + Message( + author=Author(role=Role.ASSISTANT, name="assistant"), + content=[TextContent(text="Response 4")], + recipient=Role.USER, + )) + + # Create another output to trigger synchronization again + mock_output2 = create_mock_request_output(prompt_token_ids=[1, 2, 3], + output_token_ids=[102], + finished=True) + + context.append_output(mock_output2) + + # Verify the fourth message was added, num_init_messages is still 1 + assert len(context._messages) == 3 + assert context.num_init_messages == 1 + assert context._messages[2].content[0].text == "Response 4" diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index 6658f91595e5..8619452f2445 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -151,6 +151,9 @@ def append_output(self, output: Union[RequestOutput, self._update_decode_token_usage(output) # Move current turn to previous turn for next turn's calculations self.previous_turn = self.current_turn.copy() + # append_output is called only once before tool calling + # in non-streaming case + # so we can append all the parser messages to _messages output_msgs = self.parser.messages # The responses finish reason is set in the last message self.finish_reason = output.outputs[0].finish_reason @@ -387,7 +390,7 @@ def __init__(self, *args, **kwargs): @property def messages(self) -> list: - return self.parser.messages + return self._messages def append_output(self, output: Union[RequestOutput, list[Message]]) -> None: @@ -412,6 +415,11 @@ def append_output(self, output: Union[RequestOutput, # Check if the current token is part of reasoning content self._update_num_reasoning_tokens() self.last_tok = tok + if len(self._messages) - self.num_init_messages < len( + self.parser.messages): + self._messages.extend( + self.parser.messages[len(self._messages) - + self.num_init_messages:]) else: # Handle the case of tool output in direct message format assert len(output) == 1, "Tool output should be a single message" @@ -424,6 +432,7 @@ def append_output(self, output: Union[RequestOutput, for tok in toks: self.parser.process(tok) self.last_tok = toks[-1] + # TODO: add tool_output messages to self._messages def is_expecting_start(self) -> bool: return self.parser.state == StreamState.EXPECT_START