diff --git a/tests/tokenization/test_mistral_tokenizer.py b/tests/tokenization/test_mistral_tokenizer.py index f1c880286951..b16d9af35be9 100644 --- a/tests/tokenization/test_mistral_tokenizer.py +++ b/tests/tokenization/test_mistral_tokenizer.py @@ -1,15 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -from mistral_common.protocol.instruct.messages import UserMessage +from mistral_common.protocol.instruct.messages import (AssistantMessage, + ToolMessage, + UserMessage) from mistral_common.protocol.instruct.request import ChatCompletionRequest -from mistral_common.protocol.instruct.tool_calls import Function, Tool +from mistral_common.protocol.instruct.tool_calls import (Function, + FunctionCall, Tool, + ToolCall) from vllm.transformers_utils.tokenizers.mistral import ( make_mistral_chat_completion_request) -# yapf: enable @pytest.mark.parametrize( "openai_request,expected_mistral_request", [( @@ -78,6 +81,107 @@ ) def test_make_mistral_chat_completion_request(openai_request, expected_mistral_request): - assert (make_mistral_chat_completion_request( - openai_request["messages"], - openai_request["tools"]) == expected_mistral_request) + actual_request = make_mistral_chat_completion_request( + openai_request["messages"], openai_request["tools"]) + assert actual_request == expected_mistral_request + + +# Tool use with list content and reasoning_content +@pytest.mark.parametrize("openai_request,expected_mistral_request", [( + { + "messages": [ + { + "role": "user", + "content": "What's the weather in Paris?", + }, + { + "role": + "assistant", + "reasoning_content": + None, + "content": + None, + "tool_calls": [{ + "id": "call123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "Paris"}', + }, + }], + }, + { + "role": "tool", + "content": [{ + "type": "text", + "text": "Rainy" + }], + "name": "get_weather", + "tool_call_id": "call123", + }, + ], + "tools": [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Gets the current weather in a city.", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name" + } + }, + "required": ["city"], + }, + }, + }], + }, + ChatCompletionRequest( + messages=[ + UserMessage(content="What's the weather in Paris?"), + AssistantMessage( + content=None, + tool_calls=[ + ToolCall( + id="call123", + function=FunctionCall( + name="get_weather", + arguments='{"city": "Paris"}', + ), + ) + ], + ), + ToolMessage( + content="Rainy", + tool_call_id="call123", + name="get_weather", + ), + ], + tools=[ + Tool( + type="function", + function=Function( + name="get_weather", + description="Gets the current weather in a city.", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name" + } + }, + "required": ["city"], + }, + ), + ) + ], + ), +)]) +def test_make_mistral_chat_completion_request_list_content( + openai_request, expected_mistral_request): + actual_request = make_mistral_chat_completion_request( + openai_request["messages"], openai_request["tools"]) + assert actual_request == expected_mistral_request diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 05de6a603655..23b6f67f09df 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -156,7 +156,11 @@ def make_mistral_chat_completion_request( # # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80 for message in messages: - if message.get("role") == "assistant": + # Remove reasoning_content as unsupported by Mistral + _ = message.pop("reasoning_content", None) # type: ignore + + # Convert list text content to string + if message.get("role") in ("assistant", "tool"): content = message.get("content") if isinstance(content, list): content = "\n".join(chunk.get("text") for chunk in content)