From e883333de9ce9991ca81a029cee8a1269a6a88f5 Mon Sep 17 00:00:00 2001 From: Zhicheng Zhang Date: Wed, 17 Jul 2024 11:27:25 +0800 Subject: [PATCH] Feat/parallel tool calls (#532) --- apps/agentfabric/requirements.txt | 7 +- modelscope_agent/agent.py | 29 ++-- modelscope_agent/agents/role_play.py | 153 ++++++------------ .../utils/function_call_with_raw_prompt.py | 153 ++++++++++++++++++ modelscope_agent_servers/README.md | 135 ++++++++++++++++ .../assistant_server/api.py | 11 +- .../assistant_server/models.py | 3 + .../assistant_server/utils.py | 41 ++--- tests/test_agent.py | 8 +- 9 files changed, 380 insertions(+), 160 deletions(-) create mode 100644 modelscope_agent/llm/utils/function_call_with_raw_prompt.py diff --git a/apps/agentfabric/requirements.txt b/apps/agentfabric/requirements.txt index dad23eaef..026adff23 100644 --- a/apps/agentfabric/requirements.txt +++ b/apps/agentfabric/requirements.txt @@ -1,11 +1,8 @@ -dashscope -faiss-cpu gradio==4.36.1 -langchain markdown-cjk-spacing mdx_truly_sane_lists -modelscope-agent==0.6.4 -modelscope_studio +modelscope-agent>=0.6.4 +modelscope_studio>=0.4.0 pymdown-extensions python-slugify unstructured diff --git a/modelscope_agent/agent.py b/modelscope_agent/agent.py index 0e18f24c9..c86cf325e 100644 --- a/modelscope_agent/agent.py +++ b/modelscope_agent/agent.py @@ -121,11 +121,14 @@ def _call_llm(self, stream=self.stream, **kwargs) - def _call_tool(self, tool_name: str, tool_args: str, **kwargs): + def _call_tool(self, tool_list: list, **kwargs): """ Use when calling tools in bot() """ + # version < 0.6.6 only one tool is in the tool_list + tool_name = tool_list[0]['name'] + tool_args = tool_list[0]['arguments'] self.callback_manager.on_tool_start(tool_name, tool_args) try: result = self.function_map[tool_name].call(tool_args, **kwargs) @@ -213,7 +216,7 @@ def _register_tool(self, tool_class_with_tenant[tenant_id] = self.function_map[tool_name] def _detect_tool(self, message: Union[str, - dict]) -> Tuple[bool, str, str, str]: + dict]) -> Tuple[bool, list, str]: """ A built-in tool call detection for func_call format @@ -225,26 +228,26 @@ def _detect_tool(self, message: Union[str, Returns: - bool: need to call tool or not - - str: tool name - - str: tool args + - list: tool list - str: text replies except for tool calls """ - func_name = None - func_args = None + + func_calls = [] assert isinstance(message, dict) + # deprecating if 'function_call' in message and message['function_call']: func_call = message['function_call'] - func_name = func_call.get('name', '') - func_args = func_call.get('arguments', '') - # Compatible with OpenAI API + func_calls.append(func_call) + + # Follow OpenAI API, allow multi func_calls if 'tool_calls' in message and message['tool_calls']: - func_call = message['tool_calls'][0]['function'] - func_name = func_call.get('name', '') - func_args = func_call.get('arguments', '') + for item in message['tool_calls']: + func_call = item['function'] + func_calls.append(func_call) text = message.get('content', '') - return (func_name is not None), func_name, func_args, text + return (len(func_calls) > 0), func_calls, text def _parse_image_url(self, image_url: List[Union[str, Dict]], messages: List[Dict]) -> List[Dict]: diff --git a/modelscope_agent/agents/role_play.py b/modelscope_agent/agents/role_play.py index 5a13da4ca..4437d5b64 100644 --- a/modelscope_agent/agents/role_play.py +++ b/modelscope_agent/agents/role_play.py @@ -5,6 +5,10 @@ from modelscope_agent import Agent from modelscope_agent.agent_env_util import AgentEnvMixin from modelscope_agent.llm.base import BaseChatModel +from modelscope_agent.llm.utils.function_call_with_raw_prompt import ( + DEFAULT_EXEC_TEMPLATE, SPECIAL_PREFIX_TEMPLATE_TOOL, + SPECIAL_PREFIX_TEMPLATE_TOOL_FOR_CHAT, TOOL_TEMPLATE, + convert_tools_to_prompt, detect_multi_tool) from modelscope_agent.tools.base import BaseTool from modelscope_agent.utils.base64_utils import encode_files_to_base64 from modelscope_agent.utils.logger import agent_logger as logger @@ -19,23 +23,6 @@ """ -TOOL_TEMPLATE_ZH = """ -# 工具 - -## 你拥有如下工具: - -{tool_descs} - -## 当你需要调用工具时,请在你的回复中穿插如下的工具调用命令,可以根据需求调用零次或多次: - -工具调用 -Action: 工具的名称,必须是[{tool_names}]之一 -Action Input: 工具的输入 -Observation: 工具返回的结果 -Answer: 根据Observation总结本次工具调用返回的结果,如果结果中出现url,请使用如下格式展示出来:![图片](url) - -""" - PROMPT_TEMPLATE_ZH = """ # 指令 @@ -52,24 +39,6 @@ """ -TOOL_TEMPLATE_EN = """ -# Tools - -## You have the following tools: - -{tool_descs} - -## When you need to call a tool, please intersperse the following tool command in your reply. %s - -Tool Invocation -Action: The name of the tool, must be one of [{tool_names}] -Action Input: Tool input -Observation: Tool returns result -Answer: Summarize the results of this tool call based on Observation. If the result contains url, %s - -""" % ('You can call zero or more times according to your needs:', - 'please display it in the following format:![Image](URL)') - PROMPT_TEMPLATE_EN = """ #Instructions @@ -80,11 +49,6 @@ KNOWLEDGE_TEMPLATE = {'zh': KNOWLEDGE_TEMPLATE_ZH, 'en': KNOWLEDGE_TEMPLATE_EN} -TOOL_TEMPLATE = { - 'zh': TOOL_TEMPLATE_ZH, - 'en': TOOL_TEMPLATE_EN, -} - PROMPT_TEMPLATE = { 'zh': PROMPT_TEMPLATE_ZH, 'en': PROMPT_TEMPLATE_EN, @@ -105,16 +69,6 @@ 'en': 'You are playing as {role_name}', } -SPECIAL_PREFIX_TEMPLATE_TOOL = { - 'zh': '。你可以使用工具:[{tool_names}]', - 'en': '. you can use tools: [{tool_names}]', -} - -SPECIAL_PREFIX_TEMPLATE_TOOL_FOR_CHAT = { - 'zh': '。你必须使用工具中的一个或多个:[{tool_names}]', - 'en': '. you must use one or more tools: [{tool_names}]', -} - SPECIAL_PREFIX_TEMPLATE_KNOWLEDGE = { 'zh': '。请查看前面的知识库', 'en': '. Please read the knowledge base at the beginning', @@ -125,13 +79,6 @@ 'en': '[Upload file "{file_names}"]', } -DEFAULT_EXEC_TEMPLATE = """\nObservation: {exec_result}\nAnswer:""" - -ACTION_TOKEN = 'Action:' -ARGS_TOKEN = 'Action Input:' -OBSERVATION_TOKEN = 'Observation:' -ANSWER_TOKEN = 'Answer:' - class RolePlay(Agent, AgentEnvMixin): @@ -147,6 +94,33 @@ def __init__(self, description, instruction, **kwargs) AgentEnvMixin.__init__(self, **kwargs) + def _prepare_tool_system(self, + tools: Optional[List] = None, + parallel_tool_calls: bool = False, + lang='zh'): + # prepare the tool description and tool names with parallel function calling + tool_desc_template = TOOL_TEMPLATE[ + lang + ('_parallel' if parallel_tool_calls else '')] + + if tools is not None: + tool_descs = BaseTool.parser_function(tools) + tool_name_list = [] + for tool in tools: + func_info = tool.get('function', {}) + if func_info == {}: + continue + if 'name' in func_info: + tool_name_list.append(func_info['name']) + tool_names = ','.join(tool_name_list) + else: + tool_descs = '\n\n'.join(tool.function_plain_text + for tool in self.function_map.values()) + tool_names = ','.join(tool.name + for tool in self.function_map.values()) + tool_system = tool_desc_template.format( + tool_descs=tool_descs, tool_names=tool_names) + return tool_names, tool_system + def _run(self, user_request, history: Optional[List[Dict]] = None, @@ -158,24 +132,11 @@ def _run(self, chat_mode = kwargs.pop('chat_mode', False) tools = kwargs.get('tools', None) tool_choice = kwargs.get('tool_choice', 'auto') + parallel_tool_calls = kwargs.get('parallel_tool_calls', + True if chat_mode else False) - if tools is not None: - self.tool_descs = BaseTool.parser_function(tools) - tool_name_list = [] - for tool in tools: - func_info = tool.get('function', {}) - if func_info == {}: - continue - if 'name' in func_info: - tool_name_list.append(func_info['name']) - self.tool_names = ','.join(tool_name_list) - else: - self.tool_descs = '\n\n'.join( - tool.function_plain_text - for tool in self.function_map.values()) - self.tool_names = ','.join(tool.name - for tool in self.function_map.values()) - + tool_names, tool_system = self._prepare_tool_system( + tools, parallel_tool_calls, lang) self.system_prompt = '' self.query_prefix = '' self.query_prefix_dict = {'role': '', 'tool': '', 'knowledge': ''} @@ -197,11 +158,10 @@ def _run(self, 'knowledge'] = SPECIAL_PREFIX_TEMPLATE_KNOWLEDGE[lang] # concat tools information - if self.tool_descs and not self.llm.support_function_calling(): - self.system_prompt += TOOL_TEMPLATE[lang].format( - tool_descs=self.tool_descs, tool_names=self.tool_names) + if tool_system and not self.llm.support_function_calling(): + self.system_prompt += tool_system self.query_prefix_dict['tool'] = SPECIAL_PREFIX_TEMPLATE_TOOL[ - lang].format(tool_names=self.tool_names) + lang].format(tool_names=tool_names) # concat instruction if isinstance(self.instruction, dict): @@ -242,7 +202,7 @@ def _run(self, # concat the new messages if chat_mode and tool_choice == 'required': required_prefix = SPECIAL_PREFIX_TEMPLATE_TOOL_FOR_CHAT[ - lang].format(tool_names=self.tool_names) + lang].format(tool_names=tool_names) messages.append({ 'role': 'user', 'content': required_prefix + user_request @@ -295,25 +255,25 @@ def _run(self, llm_result += s yield s + use_tool = False + tool_list = [] if isinstance(llm_result, str): - use_tool, action, action_input, output = self._detect_tool( - llm_result) + use_tool, tool_list, output = detect_multi_tool(llm_result) elif isinstance(llm_result, dict): - use_tool, action, action_input, output = super()._detect_tool( - llm_result) + use_tool, tool_list, output = super()._detect_tool(llm_result) else: assert 'llm_result must be an instance of dict or str' if chat_mode: if use_tool and tool_choice != 'none': - return f'Action: {action}\nAction Input: {action_input}\nResult: {output}' + return convert_tools_to_prompt(tool_list) else: return f'Result: {output}' # yield output if use_tool: if self.llm.support_function_calling(): - yield f'Action: {action}\nAction Input: {action_input}' + yield convert_tools_to_prompt(tool_list) if self.use_tool_api: # convert all files with base64, for the tool instance usage in case. @@ -321,7 +281,8 @@ def _run(self, kwargs['base64_files'] = encoded_files kwargs['use_tool_api'] = True - observation = self._call_tool(action, action_input, **kwargs) + # currently only one observation execute, parallel + observation = self._call_tool(tool_list, **kwargs) format_observation = DEFAULT_EXEC_TEMPLATE.format( exec_result=observation) yield format_observation @@ -360,26 +321,6 @@ def _limit_observation_length(self, observation): limited_observation = str(observation)[:int(reasonable_length)] return DEFAULT_EXEC_TEMPLATE.format(exec_result=limited_observation) - def _detect_tool(self, message: Union[str, - dict]) -> Tuple[bool, str, str, str]: - assert isinstance(message, str) - text = message - func_name, func_args = None, None - i = text.rfind(ACTION_TOKEN) - j = text.rfind(ARGS_TOKEN) - k = text.rfind(OBSERVATION_TOKEN) - if 0 <= i < j: # If the text has `Action` and `Action input`, - if k < j: # but does not contain `Observation`, - # then it is likely that `Observation` is ommited by the LLM, - # because the output text may have discarded the stop word. - text = text.rstrip() + OBSERVATION_TOKEN # Add it back. - k = text.rfind(OBSERVATION_TOKEN) - func_name = text[i + len(ACTION_TOKEN):j].strip() - func_args = text[j + len(ARGS_TOKEN):k].strip() - text = text[:k] # Discard '\nObservation:'. - - return (func_name is not None), func_name, func_args, text - def _parse_role_config(self, config: dict, lang: str = 'zh') -> str: """ Parsing role config dict to str. diff --git a/modelscope_agent/llm/utils/function_call_with_raw_prompt.py b/modelscope_agent/llm/utils/function_call_with_raw_prompt.py new file mode 100644 index 000000000..d06c46b5d --- /dev/null +++ b/modelscope_agent/llm/utils/function_call_with_raw_prompt.py @@ -0,0 +1,153 @@ +import re +from typing import List, Tuple, Union + +DEFAULT_EXEC_TEMPLATE = """\nObservation: {exec_result}\nAnswer:""" + +ACTION_TOKEN = 'Action:' +ARGS_TOKEN = 'Action Input:' +OBSERVATION_TOKEN = 'Observation:' +ANSWER_TOKEN = 'Answer:' + +TOOL_TEMPLATE_ZH = """ +# 工具 + +## 你拥有如下工具: + +{tool_descs} + +## 当你需要调用工具时,请在你的回复中穿插如下的工具调用命令,可以根据需求调用零次或多次: + +工具调用 +Action: 工具的名称,必须是[{tool_names}]之一 +Action Input: 工具的输入 +Observation: 工具返回的结果 +Answer: 根据Observation总结本次工具调用返回的结果,如果结果中出现url,请使用如下格式展示出来:![图片](url) + +""" + +TOOL_TEMPLATE_ZH_PARALLEL = """ +# 工具 + +## 你拥有如下工具: + +{tool_descs} + +## 当你需要调用工具时,请在你的回复中穿插如下的工具调用命令,可以根据需求调用零次或多次: + +工具调用 +Action: 工具1的名称,必须是[{tool_names}]之一 +Action Input: 工具1的输入 +Action: 工具2的名称,必须是[{tool_names}]之一 +Action Input: 工具2的输入 +... +Action: 工具N的名称,必须是[{tool_names}]之一 +Action Input: 工具N的输入 +Observation: 工具1返回的结果 +Observation: 工具2返回的结果 +... +Observation: 工具N返回的结果 + +Answer: 根据Observation总结本次工具调用返回的结果,如果结果中出现url,请使用如下格式展示出来:![图片](url) + +""" + +TOOL_TEMPLATE_EN = """ +# Tools + +## You have the following tools: + +{tool_descs} + +## When you need to call a tool, please intersperse the following tool command in your reply. %s + +Tool Invocation +Action: The name of the tool, must be one of [{tool_names}] +Action Input: Tool input +Observation: Tool returns result +Answer: Summarize the results of this tool call based on Observation. If the result contains url, %s + +""" % ('You can call zero or more times according to your needs:', + 'please display it in the following format:![Image](URL)') + +TOOL_TEMPLATE_EN_PARALLEL = """ +# Tools + +## You have the following tools: + +{tool_descs} + +## When you need to call a tool, please intersperse the following tool command in your reply. %s + +Tool Invocation +Action: The name of the tool 1, must be one of [{tool_names}] +Action Input: Tool input ot tool 1 +Action: The name of the tool 2, must be one of [{tool_names}] +Action Input: Tool input ot tool 2 +... +Action: The name of the tool N, must be one of [{tool_names}] +Action Input: Tool input ot tool N +Observation: Tool 1 returns result +Observation: Tool 1 returns result +... +Observation: Tool N returns result +Answer: Summarize the results of this tool call based on Observation. If the result contains url, %s + +""" % ('You can call zero or more times according to your needs:', + 'please display it in the following format:![Image](URL)') + +TOOL_TEMPLATE = { + 'zh': TOOL_TEMPLATE_ZH, + 'en': TOOL_TEMPLATE_EN, + 'zh_parallel': TOOL_TEMPLATE_ZH_PARALLEL, + 'en_parallel': TOOL_TEMPLATE_EN_PARALLEL, +} + +SPECIAL_PREFIX_TEMPLATE_TOOL = { + 'zh': '。你可以使用工具:[{tool_names}]', + 'en': '. you can use tools: [{tool_names}]', +} + +SPECIAL_PREFIX_TEMPLATE_TOOL_FOR_CHAT = { + 'zh': '。你必须使用工具中的一个或多个:[{tool_names}]', + 'en': '. you must use one or more tools: [{tool_names}]', +} + + +def detect_multi_tool(message: Union[str, dict]) -> Tuple[bool, list, str]: + """ + parse 'Action: xxx Action input: yyy\n\nAction: ppp Action input: qqq' into + {'xxx': 'yyy', 'ppp': 'qqq'} + Args: + message: str message only for now + + Returns: + if contain tools, action and action input in a dict format, text string of the message + """ + + assert isinstance(message, str) + text = message + # find first Action + match_result = re.findall(r'Action: (.+)\nAction Input: (.+)', text) + + tools = [] + for item in match_result: + func_name, func_args = item + tool_info = {'name': func_name, 'arguments': func_args} + tools.append(tool_info) + + return (len(tools) > 0), tools, text + + +def convert_tools_to_prompt(tool_list: List) -> str: + """ + convert action_dict to 'Action: xxx\nAction Input: yyyy\n\nAction: ppp\nAction Input: qqq' + Args: + tool_list: list of tools + + Returns: + string of the tools + """ + return '\n\n'.join([ + f'Action: {tool["name"]}\nAction Input: {tool["arguments"]}' + for tool in tool_list + ]) diff --git a/modelscope_agent_servers/README.md b/modelscope_agent_servers/README.md index 4af4ebb92..536317aa0 100644 --- a/modelscope_agent_servers/README.md +++ b/modelscope_agent_servers/README.md @@ -145,6 +145,141 @@ With above examples, the output should be like this: "object":"chat.completion", "usage":{"prompt_tokens":267,"completion_tokens":15,"total_tokens":282}} ``` +We also support the `parallel_tool_calls` ability, the `parallel_tool_calls` are default on, for the parallel tool calling scenarios, the query and output would be + +```Shell +curl -X POST 'http://localhost:31512/v1/chat/completions' \ +-H 'Content-Type: application/json' \ +-H "Authorization: Bearer $DASHSCOPE_API_KEY" \ +-d '{ + "tools": [{ + "type": "function", + "function": { + "name": "amap_weather", + "description": "amap weather tool", + "parameters": [{ + "name": "location", + "type": "string", + "description": "城市/区具体名称,如`北京市海淀区`请描述为`海淀区`", + "required": true + }] + } + }], + "tool_choice": "auto", + "model": "qwen-max", + "messages": [ + {"content": "请同时调用工具查找北京和上海的天气", "role": "user"} + ] +}' + +``` + +With above examples, the output should be like this: +```Python +{ + "request_id": "chatcmpl_058fd645-4a7a-41ca-a1db-29c9330814d6", + "message": "", + "output": null, + "id": "chatcmpl_058fd645-4a7a-41ca-a1db-29c9330814d6", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Action: amap_weather\nAction Input: {\"location\": \"北京\"}\nAction: amap_weather\nAction Input: {\"location\": \"上海\"}\n\n", + "tool_calls": [ + { + "type": "function", + "function": { + "name": "amap_weather", + "arguments": "{\"location\": \"北京\"}" + } + }, + { + "type": "function", + "function": { + "name": "amap_weather", + "arguments": "{\"location\": \"上海\"}" + } + } + ] + }, + "finish_reason": "tool_calls" + } + ], + "created": 1721123488, + "model": "Qwen2-7B-Instruct", + "system_fingerprint": "chatcmpl_058fd645-4a7a-41ca-a1db-29c9330814d6", + "object": "chat.completion", + "usage": { + "prompt_tokens": 333, + "completion_tokens": 33, + "total_tokens": 366 + } +} +``` +meanwhile if you set `parallel_tool_calls` as false, then you get only tool. +```shell +curl -v -X POST 'http://localhost:31512/v1/chat/completions' -H 'Content-Type: application/json' -d '{ + "tools": [{ + "type": "function", + "function": { + "name": "amap_weather", + "description": "amap weather tool", + "parameters": [{ + "name": "location", + "type": "string", + "description": "城市/区具体名称,如`北京市海淀区`请描述为`海淀区`", + "required": true + }] + } + }], + "model": "Qwen2-7B-Instruct", + "messages": [ + {"content": "请同时调用工具查找北京和上海的天气", "role": "user"} + ], + "parallel_tool_calls": false +}' +``` +From the result, only one tool generated. + +```json +{ + "request_id": "chatcmpl_57255f07-3d86-4b64-82e8-2a99d5f763cb", + "message": "", + "output": null, + "id": "chatcmpl_57255f07-3d86-4b64-82e8-2a99d5f763cb", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Action: amap_weather\nAction Input: {\"location\": \"北京\"}\n", + "tool_calls": [ + { + "type": "function", + "function": { + "name": "amap_weather", + "arguments": "{\"location\": \"北京\"}" + } + } + ] + }, + "finish_reason": "tool_calls" + } + ], + "created": 1721127977, + "model": "Qwen2-7B-Instruct", + "system_fingerprint": "chatcmpl_57255f07-3d86-4b64-82e8-2a99d5f763cb", + "object": "chat.completion", + "usage": { + "prompt_tokens": 246, + "completion_tokens": 18, + "total_tokens": 264 + } +} +``` + #### Chat with vllm diff --git a/modelscope_agent_servers/assistant_server/api.py b/modelscope_agent_servers/assistant_server/api.py index 98aac08d2..6d4b04ec8 100644 --- a/modelscope_agent_servers/assistant_server/api.py +++ b/modelscope_agent_servers/assistant_server/api.py @@ -5,11 +5,13 @@ from fastapi import FastAPI, File, Form, Header, UploadFile from fastapi.responses import StreamingResponse from modelscope_agent.agents.role_play import RolePlay +from modelscope_agent.llm.utils.function_call_with_raw_prompt import \ + detect_multi_tool from modelscope_agent.rag.knowledge import BaseKnowledge from modelscope_agent_servers.assistant_server.models import ( AgentRequest, ChatCompletionRequest, ChatCompletionResponse, ToolResponse) from modelscope_agent_servers.assistant_server.utils import ( - choice_wrapper, parse_messages, parse_tool_result, stream_choice_wrapper) + choice_wrapper, parse_messages, stream_choice_wrapper) from modelscope_agent_servers.service_utils import (create_error_msg, create_success_msg) @@ -140,8 +142,10 @@ async def chat_completion(chat_request: ChatCompletionRequest, # tool related config tools = chat_request.tools tool_choice = None + parallel_tool_calls = True if tools: tool_choice = chat_request.tool_choice + parallel_tool_calls = chat_request.parallel_tool_calls # parse meesage query, history, image_url = parse_messages(chat_request.messages) @@ -155,6 +159,7 @@ async def chat_completion(chat_request: ChatCompletionRequest, tools=tools, tool_choice=tool_choice, chat_mode=True, + parallel_tool_calls=parallel_tool_calls, # **kwargs) ) @@ -173,8 +178,8 @@ async def chat_completion(chat_request: ChatCompletionRequest, del agent - action, action_input = parse_tool_result(llm_result) - choices = choice_wrapper(llm_result, action, action_input) + has_action, tool_list, _ = detect_multi_tool(llm_result) + choices = choice_wrapper(llm_result, tool_list) chat_response = ChatCompletionResponse( choices=choices, diff --git a/modelscope_agent_servers/assistant_server/models.py b/modelscope_agent_servers/assistant_server/models.py index 960a6659b..2da8bb719 100644 --- a/modelscope_agent_servers/assistant_server/models.py +++ b/modelscope_agent_servers/assistant_server/models.py @@ -85,6 +85,9 @@ class ChatCompletionRequest(BaseModel): messages: List[ChatMessage] tools: Optional[List[Dict]] = Field(None, title='Tools config') tool_choice: Optional[str] = Field('auto', title='tool usage choice') + parallel_tool_calls: Optional[bool] = Field( + True, + title='Whether to enable parallel function calling during tool use.') stream: Optional[bool] = Field(False, title='Stream output') user: str = Field('default_user', title='User name') diff --git a/modelscope_agent_servers/assistant_server/utils.py b/modelscope_agent_servers/assistant_server/utils.py index 54d6b99ef..9573ab711 100644 --- a/modelscope_agent_servers/assistant_server/utils.py +++ b/modelscope_agent_servers/assistant_server/utils.py @@ -1,3 +1,4 @@ +import re from typing import List import json @@ -6,25 +7,6 @@ ChatCompletionResponseStreamChoice, ChatMessage, DeltaMessage) -def parse_tool_result(llm_result: str): - """ - Args: - llm_result: the result from the model - - Returns: dict - - """ - try: - import re - import json - result = re.search(r'Action: (.+)\nAction Input: (.+)', llm_result) - action = result.group(1) - action_input = json.loads(result.group(2)) - return action, action_input - except Exception: - return None, None - - def parse_messages(messages: List[ChatMessage]): """ Args: @@ -75,9 +57,7 @@ def stream_choice_wrapper(response, model, request_id, llm): yield 'data: [DONE]\n\n' -def choice_wrapper(response: str, - tool_name: str = None, - tool_inputs: dict = None): +def choice_wrapper(response: str, tool_list: list = []): """ output should be in the format of openai choices "choices": [ @@ -102,7 +82,8 @@ def choice_wrapper(response: str, ], Args: - response: the chatresponse object + tool_list: the tool list from the output of llm + response: the chat response object Returns: dict @@ -115,14 +96,12 @@ def choice_wrapper(response: str, 'content': response, } } - if tool_name is not None: - choice['message']['tool_calls'] = [{ - 'type': 'function', - 'function': { - 'name': tool_name, - 'arguments': json.dumps(tool_inputs, ensure_ascii=False) - } - }] + if len(tool_list) > 0: + tool_calls = [] + for item in tool_list: + tool_dict = {'type': 'function', 'function': item} + tool_calls.append(tool_dict) + choice['message']['tool_calls'] = tool_calls choice['finish_reason'] = 'tool_calls' else: choice['finish_reason'] = 'stop' diff --git a/tests/test_agent.py b/tests/test_agent.py index aa71f6b97..f493ae1aa 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -59,8 +59,12 @@ def test_agent_run(tester_agent): def test_agent_call_tool(tester_agent): # Mocking a simple response from the tool for testing purposes - response = tester_agent._call_tool('mock_tool', 'tool response') - assert response == 'tool response' + tool_list = [{ + 'name': 'mock_tool', + 'arguments': '{\"test\": \"tool response\"}' + }] + response = tester_agent._call_tool(tool_list) + assert response == '{\"test\": \"tool response\"}' def test_agent_parse_image_url(tester_agent):