Skip to content

Commit

Permalink
qwen support tool message
Browse files Browse the repository at this point in the history
  • Loading branch information
codingl2k1 committed Dec 21, 2023
1 parent 5cfe4c4 commit 0a06b3a
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 4 deletions.
59 changes: 58 additions & 1 deletion xinference/core/tests/test_restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ def test_restful_api_for_qwen_tool_calls(setup, model_format, quantization):
},
],
"tools": tools,
"max_tokens": 200,
"max_tokens": 2048,
"temperature": 0,
}
response = requests.post(url, json=payload)
Expand All @@ -768,6 +768,63 @@ def test_restful_api_for_qwen_tool_calls(setup, model_format, quantization):
arg = json.loads(arguments)
assert arg == {"search_query": "周杰伦"}

# Check tool message.
payload = {
"model": model_uid_res,
"messages": [
{
"role": "user",
"content": "谁是周杰伦?",
},
completion["choices"][0]["message"],
{
"role": "tool",
"content": "Jay Chou is a Taiwanese singer, songwriter, record producer, rapper, actor, television personality, and businessman.",
},
],
"tools": tools,
"max_tokens": 2048,
"temperature": 0,
}
response = requests.post(url, json=payload)
completion2 = response.json()
assert "stop" == completion2["choices"][0]["finish_reason"]
assert "周杰伦" in completion2["choices"][0]["message"]["content"]
assert "歌手" in completion2["choices"][0]["message"]["content"]

# Check continue tool call.
payload = {
"model": model_uid_res,
"messages": [
{
"role": "user",
"content": "谁是周杰伦?",
},
completion["choices"][0]["message"],
{
"role": "tool",
"content": "Jay Chou is a Taiwanese singer, songwriter, record producer, rapper, actor, television personality, and businessman.",
},
completion2["choices"][0]["message"],
{"role": "user", "content": "画一个他的卡通形象出来"},
],
"tools": tools,
"max_tokens": 2048,
"temperature": 0,
}
response = requests.post(url, json=payload)
completion3 = response.json()
assert "tool_calls" == completion3["choices"][0]["finish_reason"]
assert (
"image_gen"
== completion3["choices"][0]["message"]["tool_calls"][0]["function"]["name"]
)
arguments = completion3["choices"][0]["message"]["tool_calls"][0]["function"][
"arguments"
]
arg = json.loads(arguments)
assert arg == {"prompt": "Jay Chou"}

_check_invalid_tool_calls(endpoint, model_uid_res)


Expand Down
30 changes: 27 additions & 3 deletions xinference/model/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,20 +214,44 @@ def get_prompt(
tools_name_text.append(name)
tools_text_string = "\n\n".join(tools_text)
tools_name_text_string = ", ".join(tools_name_text)
tool_system = "\n\n" + react_instruction.format(
tool_system = react_instruction.format(
tools_text=tools_text_string,
tools_name_text=tools_name_text_string,
)
new_query = tool_system + f"\n\nQuestion: {prompt}"
chat_history[-2]["content"] = new_query
else:
tool_system = ""

ret = f"<|im_start|>system\n{prompt_style.system_prompt}<|im_end|>"
for message in chat_history:
role = message["role"]
content = message["content"]

ret += prompt_style.intra_message_sep
if tools:
if role == "user":
if tool_system:
content = tool_system + f"\n\nQuestion: {content}"
tool_system = ""
else:
content = f"Question: {content}"
elif role == "assistant":
tool_calls = message.get("tool_calls")
if tool_calls:
func_call = tool_calls[0]["function"]
f_name, f_args = (
func_call["name"],
func_call["arguments"],
)
content = f"Thought: I can use {f_name}.\nAction: {f_name}\nAction Input: {f_args}"
elif content:
content = f"Thought: I now know the final answer.\nFinal answer: {content}"
elif role == "tool":
role = "function"
content = f"Observation: {content}"
else:
raise Exception(f"Unsupported message role: {role}")
if content:
content = content.lstrip("\n").rstrip()
ret += f"<|im_start|>{role}\n{content}<|im_end|>"
else:
ret += f"<|im_start|>{role}\n"
Expand Down

0 comments on commit 0a06b3a

Please sign in to comment.