Skip to content

Commit 314cfad

Browse files
authored
[Frontend] Generate valid tool call IDs when using tokenizer-mode=mistral (#12332)
1 parent 985b4a2 commit 314cfad

File tree

8 files changed

+149
-8
lines changed

8 files changed

+149
-8
lines changed

tests/mistral_tool_use/__init__.py

Whitespace-only changes.

tests/mistral_tool_use/conftest.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
import pytest_asyncio
5+
from huggingface_hub import snapshot_download
6+
7+
from tests.utils import RemoteOpenAIServer
8+
from vllm.platforms import current_platform
9+
10+
from .utils import ARGS, CONFIGS, ServerConfig
11+
12+
13+
# for each server config, download the model and return the config
14+
@pytest.fixture(scope="session", params=CONFIGS.keys())
15+
def server_config(request):
16+
config = CONFIGS[request.param]
17+
18+
if current_platform.is_rocm() and not config.get("supports_rocm", True):
19+
pytest.skip("The {} model can't be tested on the ROCm platform".format(
20+
config["model"]))
21+
22+
# download model and tokenizer using transformers
23+
snapshot_download(config["model"])
24+
yield CONFIGS[request.param]
25+
26+
27+
# run this for each server config
28+
@pytest.fixture(scope="session")
29+
def server(request, server_config: ServerConfig):
30+
model = server_config["model"]
31+
args_for_model = server_config["arguments"]
32+
with RemoteOpenAIServer(model, ARGS + args_for_model,
33+
max_wait_seconds=480) as server:
34+
yield server
35+
36+
37+
@pytest_asyncio.fixture
38+
async def client(server: RemoteOpenAIServer):
39+
async with server.get_async_client() as async_client:
40+
yield async_client
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import openai
4+
import pytest
5+
6+
from tests.tool_use.utils import MESSAGES_ASKING_FOR_TOOLS, WEATHER_TOOL
7+
8+
9+
# test: a tool_choice with mistral-tokenizer results in an ID of length 9
10+
@pytest.mark.asyncio
11+
async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI):
12+
models = await client.models.list()
13+
model_name: str = models.data[0].id
14+
chat_completion = await client.chat.completions.create(
15+
messages=MESSAGES_ASKING_FOR_TOOLS,
16+
temperature=0,
17+
max_completion_tokens=100,
18+
model=model_name,
19+
tools=[WEATHER_TOOL],
20+
tool_choice=WEATHER_TOOL,
21+
logprobs=False)
22+
23+
choice = chat_completion.choices[0]
24+
25+
assert choice.finish_reason != "tool_calls" # "stop" or "length"
26+
assert choice.message.role == "assistant"
27+
assert choice.message.tool_calls is None \
28+
or len(choice.message.tool_calls) == 1
29+
assert len(choice.message.tool_calls[0].id) == 9 # length of 9 for mistral

tests/mistral_tool_use/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from typing import Dict, List, Optional
4+
5+
from typing_extensions import TypedDict
6+
7+
8+
class ServerConfig(TypedDict, total=False):
9+
model: str
10+
arguments: List[str]
11+
system_prompt: Optional[str]
12+
supports_parallel: Optional[bool]
13+
supports_rocm: Optional[bool]
14+
15+
16+
ARGS: List[str] = ["--max-model-len", "1024"]
17+
18+
CONFIGS: Dict[str, ServerConfig] = {
19+
"mistral": {
20+
"model":
21+
"mistralai/Mistral-7B-Instruct-v0.3",
22+
"arguments": [
23+
"--tokenizer-mode", "mistral",
24+
"--ignore-patterns=\"consolidated.safetensors\""
25+
],
26+
"system_prompt":
27+
"You are a helpful assistant with access to tools. If a tool"
28+
" that you have would be helpful to answer a user query, "
29+
"call the tool. Otherwise, answer the user's query directly "
30+
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
31+
"to the user's question - just respond to it normally."
32+
},
33+
}

vllm/entrypoints/openai/serving_chat.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,15 @@
2828
from vllm.entrypoints.openai.serving_engine import OpenAIServing
2929
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
3030
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
31+
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
32+
MistralToolCall)
3133
from vllm.logger import init_logger
3234
from vllm.outputs import CompletionOutput, RequestOutput
3335
from vllm.sampling_params import BeamSearchParams, SamplingParams
3436
from vllm.sequence import Logprob
3537
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
36-
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls
38+
from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls,
39+
truncate_tool_call_ids)
3740

3841
logger = init_logger(__name__)
3942

@@ -150,11 +153,12 @@ async def create_chat_completion(
150153
return self.create_error_response(
151154
"tool_choice = \"required\" is not supported!")
152155

153-
# because of issues with pydantic we need to potentially
154-
# re-serialize the tool_calls field of the request
155-
# for more info: see comment in `maybe_serialize_tool_calls`
156156
if isinstance(tokenizer, MistralTokenizer):
157+
# because of issues with pydantic we need to potentially
158+
# re-serialize the tool_calls field of the request
159+
# for more info: see comment in `maybe_serialize_tool_calls`
157160
maybe_serialize_tool_calls(request)
161+
truncate_tool_call_ids(request)
158162

159163
if (request.tool_choice == "auto" and
160164
not (self.enable_auto_tools and tool_parser is not None)
@@ -745,11 +749,13 @@ async def chat_completion_full_generator(
745749
elif request.tool_choice and type(
746750
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
747751

752+
tool_call_class = MistralToolCall if isinstance(
753+
tokenizer, MistralTokenizer) else ToolCall
748754
message = ChatMessage(
749755
role=role,
750756
content="",
751757
tool_calls=[
752-
ToolCall(function=FunctionCall(
758+
tool_call_class(function=FunctionCall(
753759
name=request.tool_choice.function.name,
754760
arguments=output.text))
755761
])

vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class MistralToolCall(ToolCall):
3333

3434
@staticmethod
3535
def generate_random_id():
36-
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
36+
# Mistral Tool Call Ids must be alphanumeric with a length of 9.
3737
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
3838
return "".join(choices(ALPHANUMERIC, k=9))
3939

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from .mistral import MistralTokenizer, maybe_serialize_tool_calls
3+
from .mistral import (MistralTokenizer, maybe_serialize_tool_calls,
4+
truncate_tool_call_ids)
45

5-
__all__ = ["MistralTokenizer", "maybe_serialize_tool_calls"]
6+
__all__ = [
7+
"MistralTokenizer", "maybe_serialize_tool_calls", "truncate_tool_call_ids"
8+
]

vllm/transformers_utils/tokenizers/mistral.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,36 @@ def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
6868
request.messages[i]["tool_calls"] = validated_tool_calls
6969

7070

71+
def truncate_tool_call_ids(request: "ChatCompletionRequest"):
72+
"""Truncates tool call IDs for Mistral's ID requirements."""
73+
for i, message in enumerate(request.messages):
74+
if message.get("role") == 'assistant':
75+
tool_calls = message.get("tool_calls", [])
76+
for tool_call in tool_calls:
77+
if len(tool_call["id"]) > 9:
78+
logger.warning(
79+
"Truncating tool call ID: %s to %s",
80+
tool_call["id"],
81+
tool_call["id"][-9:],
82+
)
83+
tool_call["id"] = tool_call["id"][-9:]
84+
85+
request.messages[i]["tool_calls"] = tool_calls
86+
87+
elif message.get("role") in {"tool_results", "tool"}:
88+
if "tool_call_id" in message:
89+
tool_call_id = message["tool_call_id"]
90+
91+
if len(tool_call_id) > 9:
92+
logger.warning(
93+
"Truncating tool_call_id: %s to %s",
94+
tool_call_id,
95+
tool_call_id[-9:],
96+
)
97+
tool_call_id = tool_call_id[-9:]
98+
request.messages[i]["tool_call_id"] = tool_call_id
99+
100+
71101
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
72102
repo_cache = os.path.join(
73103
huggingface_hub.constants.HF_HUB_CACHE,

0 commit comments

Comments
 (0)