Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add mistral function calling format to all models loaded with "mistral" format #8515

Merged
138 changes: 138 additions & 0 deletions examples/offline_chat_with_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# ruff: noqa
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Example file of how to use Mistral models with function calling

import json
import random
import string

from vllm import LLM
from vllm.sampling_params import SamplingParams

# This script is an offline demo for running Pixtral.
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
#
# If you want to run a server/client setup, please follow this code:
#
# - Server:
#
# ```bash
# vllm serve mistralai/Mistral-7B-Instruct-v0.3 --tokenizer-mode mistral --load-format mistral --config-format mistral
# ```
#
# - Client:
#
# ```bash
# curl --location 'http://<your-node-url>:8000/v1/chat/completions' \
# --header 'Content-Type: application/json' \
# --header 'Authorization: Bearer token' \
# --data '{
# "model": "mistralai/Pixtral-12B-2409",
# "messages": [
# {
# "role": "user",
# "content": [
# {"type" : "text", "text": "Describe this image in detail please."},
# {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}},
# {"type" : "text", "text": "and this one as well. Answer in French."},
# {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}}
# ]
# }
# ]
# }'
# ```
#
# Usage:
# python demo.py simple
# python demo.py advanced

model_name = "mistralai/Mistral-7B-Instruct-v0.3"
# or switch to "mistralai/Mistral-Nemo-Instruct-2407"
# or "mistralai/Mistral-Large-Instruct-2407"
# or any other mistral model with function calling ability

sampling_params = SamplingParams(max_tokens=8192, temperature=0.0)
llm = LLM(model=model_name,
tokenizer_mode="mistral",
config_format="mistral",
load_format="mistral")


def generate_random_id(length=9):
characters = string.ascii_letters + string.digits
random_id = ''.join(random.choice(characters) for _ in range(length))
return random_id


# simulate an API that can be called
def get_current_weather(city: str, state: str, unit: 'str'):
return (f"The weather in {city}, {state} is 85 degrees {unit}. It is "
"partly cloudly, with highs in the 90's.")


tool_funtions = {"get_current_weather": get_current_weather}

tools = [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to find the weather for, e.g. 'San Francisco'"
},
"state": {
"type":
"string",
"description":
"the two-letter abbreviation for the state that the city is"
" in, e.g. 'CA' which would mean 'California'"
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["city", "state", "unit"]
}
}
}]

messages = [{
"role":
"user",
"content":
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
}]

outputs = llm.chat(messages, sampling_params=sampling_params, tools=tools)
output = outputs[0].outputs[0].text.strip()

# append the assistant message
messages.append({
"role": "assistant",
"content": output,
})

# let's now actually parse and execute the model's output simulating an API call by using the
# above defined function
tool_calls = json.loads(output)
tool_answers = [
tool_funtions[call['name']](**call['arguments']) for call in tool_calls
]

# append the answer as a tool message and let the LLM give you an answer
messages.append({
"role": "tool",
"content": "\n\n".join(tool_answers),
"tool_call_id": generate_random_id(),
})

outputs = llm.chat(messages, sampling_params, tools=tools)

print(outputs[0].outputs[0].text.strip())
# yields
# 'The weather in Dallas, TX is 85 degrees fahrenheit. '
# 'It is partly cloudly, with highs in the 90's.'
69 changes: 67 additions & 2 deletions tests/models/decoder_only/language/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,59 @@
"""
import pytest

from vllm import SamplingParams

from ...utils import check_logprobs_close

MODELS = [
"mistralai/Mistral-7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.3",
"mistralai/Mistral-7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.3",
"mistralai/Mistral-Nemo-Instruct-2407"
]

SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)

# for function calling
TOOLS = [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to find the weather for, e.g. 'San Francisco'"
},
"state": {
"type":
"string",
"description":
"the two-letter abbreviation for the state that the city is"
" in, e.g. 'CA' which would mean 'California'"
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["city", "state", "unit"]
}
}
}]
MSGS = [{
"role":
"user",
"content": ("Can you tell me what the temperate"
" will be in Dallas, in fahrenheit?")
}]
EXPECTED_FUNC_CALL = (
'[{"name": "get_current_weather", "arguments": '
'{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]')


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
Expand Down Expand Up @@ -81,3 +127,22 @@ def test_mistral_format(
name_0="hf",
name_1="mistral",
)


@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("model", MODELS[1:]) # v1 can't do func calling
def test_mistral_function_calling(
vllm_runner,
model: str,
dtype: str,
) -> None:
with vllm_runner(model,
dtype=dtype,
tokenizer_mode="mistral",
config_format="mistral",
load_format="mistral") as vllm_model:
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
outputs = vllm_model.model.chat(MSGS,
tools=TOOLS,
sampling_params=SAMPLING_PARAMS)

assert outputs[0].outputs[0].text.strip() == EXPECTED_FUNC_CALL
6 changes: 5 additions & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Union, cast,
overload)

from tqdm import tqdm

Expand Down Expand Up @@ -357,6 +358,7 @@ def chat(
lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None,
add_generation_prompt: bool = True,
tools: Optional[List[Dict[str, Any]]] = None,
) -> List[RequestOutput]:
"""
Generate responses for a chat conversation.
Expand Down Expand Up @@ -401,13 +403,15 @@ def chat(
messages=messages,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
tools=tools,
)
else:
prompt = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
tools=tools,
)

inputs: PromptInputs
Expand Down
9 changes: 5 additions & 4 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ async def create_chat_completion(
]

prompt: Union[str, List[int]]
if isinstance(tokenizer, MistralTokenizer):
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
if is_mistral_tokenizer:
prompt = apply_mistral_chat_template(
tokenizer,
messages=request.messages,
Expand Down Expand Up @@ -159,10 +160,10 @@ async def create_chat_completion(
return self.create_error_response(
"tool_choice = \"required\" is not supported!")

# "auto" tools requires --enable-auto-tool-choice
# and --tool-call-parser
if request.tool_choice == "auto" and not (
if not is_mistral_tokenizer and request.tool_choice == "auto" and not (
self.enable_auto_tools and self.tool_parser is not None):
# for hf tokenizers, "auto" tools requires
# --enable-auto-tool-choice and --tool-call-parser
return self.create_error_response(
"\"auto\" tool choice requires "
"--enable-auto-tool-choice and --tool-call-parser to be set")
Expand Down
10 changes: 6 additions & 4 deletions vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,18 +165,20 @@ def apply_chat_template(self,
messages: List["ChatCompletionMessageParam"],
tools: Optional[Dict[str, Any]] = None,
**kwargs) -> List[int]:
assert tools is None, "`tools` are not yet supported."

request = ChatCompletionRequest(
messages=messages) # type: ignore[type-var]
request = ChatCompletionRequest(messages=messages,
tools=tools) # type: ignore[type-var]
encoded = self.mistral.encode_chat_completion(request)

# encode-decode to get clean prompt
return encoded.tokens

def convert_tokens_to_string(self, tokens: List[str]) -> str:
if isinstance(self.tokenizer, Tekkenizer):
return "".join(tokens)
return "".join(
t for t in tokens
if t not in self.tokenizer._all_special_tokens
)
else:
return self.tokenizer.decode(tokens) # type: ignore[arg-type]

Expand Down
Loading