diff --git a/xinference/core/tests/test_restful_api.py b/xinference/core/tests/test_restful_api.py index 9e4d3aac51..9101c7bf21 100644 --- a/xinference/core/tests/test_restful_api.py +++ b/xinference/core/tests/test_restful_api.py @@ -378,6 +378,52 @@ def test_restful_api_for_embedding(setup): assert len(response_data) == 0 +def _check_invalid_tool_calls(endpoint, model_uid_res): + import openai + + client = openai.Client(api_key="not empty", base_url=f"{endpoint}/v1") + + tools = [ + { + "type": "function", + "function": { + "name": "get_exchange_rate", + "description": "Get the exchange rate between two currencies", + "parameters": { + "type": "object", + "properties": { + "base_currency": { + "type": "string", + "description": "The currency to convert from", + }, + "target_currency": { + "type": "string", + "description": "The currency to convert to", + }, + }, + "required": ["base_currency", "target_currency"], + }, + }, + } + ] + + completion = client.chat.completions.create( + model=model_uid_res, + messages=[ + { + "content": "Can you book a flight for me from New York to London?", + "role": "user", + } + ], + tools=tools, + max_tokens=200, + temperature=0.1, + ) + assert "stop" == completion.choices[0].finish_reason + assert completion.choices[0].message.content + assert len(completion.choices[0].message.tool_calls) == 0 + + @pytest.mark.parametrize( "model_format, quantization", [("ggmlv3", "q4_0"), ("pytorch", None)] ) @@ -499,6 +545,8 @@ def test_restful_api_for_tool_calls(setup, model_format, quantization): arg = json.loads(arguments) assert arg == {"symbol": "10111"} + _check_invalid_tool_calls(endpoint, model_uid_res) + @pytest.mark.parametrize( "model_format, quantization", [("ggufv2", "Q4_K_S"), ("pytorch", None)] @@ -594,6 +642,8 @@ def test_restful_api_for_gorilla_openfunctions_tool_calls( arg = json.loads(arguments) assert arg == {"loc": 94704, "time": 10, "type": "plus"} + _check_invalid_tool_calls(endpoint, model_uid_res) + @pytest.mark.parametrize( "model_format, quantization", @@ -696,6 +746,8 @@ def test_restful_api_for_qwen_tool_calls(setup, model_format, quantization): arg = json.loads(arguments) assert arg == {"search_query": "周杰伦"} + _check_invalid_tool_calls(endpoint, model_uid_res) + def test_restful_api_with_request_limits(setup): model_name = "gte-base" diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py index 9a0a8ccc33..e692ef8891 100644 --- a/xinference/model/llm/utils.py +++ b/xinference/model/llm/utils.py @@ -385,7 +385,7 @@ def _eval_gorilla_openfunctions_arguments(c, tools): arguments = c["choices"][0]["text"] def tool_call(n, **kwargs): - return n, kwargs + return None, n, kwargs try: return eval( @@ -393,11 +393,13 @@ def tool_call(n, **kwargs): ) except Exception as e: logger.error("Eval tool calls completion failed: %s", e) - return arguments, arguments + return arguments, None, None @staticmethod def _eval_chatglm3_arguments(c, tools): - return c[0]["name"], c[0]["parameters"] + if isinstance(c[0], str): + return c[0], None, None + return None, c[0]["name"], c[0]["parameters"] @staticmethod def _eval_qwen_chat_arguments(c, tools): @@ -416,24 +418,44 @@ def _eval_qwen_chat_arguments(c, tools): if 0 <= i < j < k: plugin_name = text[i + len("\nAction:") : j].strip() plugin_args = text[j + len("\nAction Input:") : k].strip() - return plugin_name, json.loads(plugin_args) + return None, plugin_name, json.loads(plugin_args) logger.error("No ReAct response detected, please check your stop.") except Exception as e: logger.error("Eval tool calls completion failed: %s", e) - return text, text + return text, None, None @classmethod def _tool_calls_completion(cls, model_name, model_uid, c, tools): _id = str(uuid.uuid4()) if model_name == "gorilla-openfunctions-v1": - func, args = cls._eval_gorilla_openfunctions_arguments(c, tools) + content, func, args = cls._eval_gorilla_openfunctions_arguments(c, tools) elif model_name == "chatglm3": - func, args = cls._eval_chatglm3_arguments(c, tools) + content, func, args = cls._eval_chatglm3_arguments(c, tools) elif model_name == "qwen-chat": - func, args = cls._eval_qwen_chat_arguments(c, tools) + content, func, args = cls._eval_qwen_chat_arguments(c, tools) else: raise Exception(f"Model {model_name} is not support tool calls.") + if content: + m = {"role": "assistant", "content": content, "tool_calls": []} + finish_reason = "stop" + else: + m = { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": f"call_{_id}", + "type": "function", + "function": { + "name": func, + "arguments": json.dumps(args), + }, + } + ], + } + finish_reason = "tool_calls" + return { "id": "chat" + f"cmpl-{_id}", "model": model_uid, @@ -442,21 +464,8 @@ def _tool_calls_completion(cls, model_name, model_uid, c, tools): "choices": [ { "index": 0, - "message": { - "role": "assistant", - "content": None, - "tool_calls": [ - { - "id": f"call_{_id}", - "type": "function", - "function": { - "name": func, - "arguments": json.dumps(args), - }, - } - ], - }, - "finish_reason": "tool_calls", + "message": m, + "finish_reason": finish_reason, } ], "usage": {