Skip to content

Commit

Permalink
ENH: Handle tool call failed (#767)
Browse files Browse the repository at this point in the history
  • Loading branch information
codingl2k1 authored Dec 15, 2023
1 parent afaa882 commit 176db9c
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 23 deletions.
52 changes: 52 additions & 0 deletions xinference/core/tests/test_restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,52 @@ def test_restful_api_for_embedding(setup):
assert len(response_data) == 0


def _check_invalid_tool_calls(endpoint, model_uid_res):
import openai

client = openai.Client(api_key="not empty", base_url=f"{endpoint}/v1")

tools = [
{
"type": "function",
"function": {
"name": "get_exchange_rate",
"description": "Get the exchange rate between two currencies",
"parameters": {
"type": "object",
"properties": {
"base_currency": {
"type": "string",
"description": "The currency to convert from",
},
"target_currency": {
"type": "string",
"description": "The currency to convert to",
},
},
"required": ["base_currency", "target_currency"],
},
},
}
]

completion = client.chat.completions.create(
model=model_uid_res,
messages=[
{
"content": "Can you book a flight for me from New York to London?",
"role": "user",
}
],
tools=tools,
max_tokens=200,
temperature=0.1,
)
assert "stop" == completion.choices[0].finish_reason
assert completion.choices[0].message.content
assert len(completion.choices[0].message.tool_calls) == 0


@pytest.mark.parametrize(
"model_format, quantization", [("ggmlv3", "q4_0"), ("pytorch", None)]
)
Expand Down Expand Up @@ -499,6 +545,8 @@ def test_restful_api_for_tool_calls(setup, model_format, quantization):
arg = json.loads(arguments)
assert arg == {"symbol": "10111"}

_check_invalid_tool_calls(endpoint, model_uid_res)


@pytest.mark.parametrize(
"model_format, quantization", [("ggufv2", "Q4_K_S"), ("pytorch", None)]
Expand Down Expand Up @@ -594,6 +642,8 @@ def test_restful_api_for_gorilla_openfunctions_tool_calls(
arg = json.loads(arguments)
assert arg == {"loc": 94704, "time": 10, "type": "plus"}

_check_invalid_tool_calls(endpoint, model_uid_res)


@pytest.mark.parametrize(
"model_format, quantization",
Expand Down Expand Up @@ -696,6 +746,8 @@ def test_restful_api_for_qwen_tool_calls(setup, model_format, quantization):
arg = json.loads(arguments)
assert arg == {"search_query": "周杰伦"}

_check_invalid_tool_calls(endpoint, model_uid_res)


def test_restful_api_with_request_limits(setup):
model_name = "gte-base"
Expand Down
55 changes: 32 additions & 23 deletions xinference/model/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,19 +385,21 @@ def _eval_gorilla_openfunctions_arguments(c, tools):
arguments = c["choices"][0]["text"]

def tool_call(n, **kwargs):
return n, kwargs
return None, n, kwargs

try:
return eval(
arguments, {n: functools.partial(tool_call, n) for n in tool_names}
)
except Exception as e:
logger.error("Eval tool calls completion failed: %s", e)
return arguments, arguments
return arguments, None, None

@staticmethod
def _eval_chatglm3_arguments(c, tools):
return c[0]["name"], c[0]["parameters"]
if isinstance(c[0], str):
return c[0], None, None
return None, c[0]["name"], c[0]["parameters"]

@staticmethod
def _eval_qwen_chat_arguments(c, tools):
Expand All @@ -416,24 +418,44 @@ def _eval_qwen_chat_arguments(c, tools):
if 0 <= i < j < k:
plugin_name = text[i + len("\nAction:") : j].strip()
plugin_args = text[j + len("\nAction Input:") : k].strip()
return plugin_name, json.loads(plugin_args)
return None, plugin_name, json.loads(plugin_args)
logger.error("No ReAct response detected, please check your stop.")
except Exception as e:
logger.error("Eval tool calls completion failed: %s", e)
return text, text
return text, None, None

@classmethod
def _tool_calls_completion(cls, model_name, model_uid, c, tools):
_id = str(uuid.uuid4())
if model_name == "gorilla-openfunctions-v1":
func, args = cls._eval_gorilla_openfunctions_arguments(c, tools)
content, func, args = cls._eval_gorilla_openfunctions_arguments(c, tools)
elif model_name == "chatglm3":
func, args = cls._eval_chatglm3_arguments(c, tools)
content, func, args = cls._eval_chatglm3_arguments(c, tools)
elif model_name == "qwen-chat":
func, args = cls._eval_qwen_chat_arguments(c, tools)
content, func, args = cls._eval_qwen_chat_arguments(c, tools)
else:
raise Exception(f"Model {model_name} is not support tool calls.")

if content:
m = {"role": "assistant", "content": content, "tool_calls": []}
finish_reason = "stop"
else:
m = {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": f"call_{_id}",
"type": "function",
"function": {
"name": func,
"arguments": json.dumps(args),
},
}
],
}
finish_reason = "tool_calls"

return {
"id": "chat" + f"cmpl-{_id}",
"model": model_uid,
Expand All @@ -442,21 +464,8 @@ def _tool_calls_completion(cls, model_name, model_uid, c, tools):
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": f"call_{_id}",
"type": "function",
"function": {
"name": func,
"arguments": json.dumps(args),
},
}
],
},
"finish_reason": "tool_calls",
"message": m,
"finish_reason": finish_reason,
}
],
"usage": {
Expand Down

0 comments on commit 176db9c

Please sign in to comment.