Skip to content

Commit

Permalink
Feat/parallel tool calls (#532)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzhangpurdue authored Jul 17, 2024
1 parent 4cdcd26 commit e883333
Show file tree
Hide file tree
Showing 9 changed files with 380 additions and 160 deletions.
7 changes: 2 additions & 5 deletions apps/agentfabric/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
dashscope
faiss-cpu
gradio==4.36.1
langchain
markdown-cjk-spacing
mdx_truly_sane_lists
modelscope-agent==0.6.4
modelscope_studio
modelscope-agent>=0.6.4
modelscope_studio>=0.4.0
pymdown-extensions
python-slugify
unstructured
29 changes: 16 additions & 13 deletions modelscope_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,14 @@ def _call_llm(self,
stream=self.stream,
**kwargs)

def _call_tool(self, tool_name: str, tool_args: str, **kwargs):
def _call_tool(self, tool_list: list, **kwargs):
"""
Use when calling tools in bot()
"""
# version < 0.6.6 only one tool is in the tool_list
tool_name = tool_list[0]['name']
tool_args = tool_list[0]['arguments']
self.callback_manager.on_tool_start(tool_name, tool_args)
try:
result = self.function_map[tool_name].call(tool_args, **kwargs)
Expand Down Expand Up @@ -213,7 +216,7 @@ def _register_tool(self,
tool_class_with_tenant[tenant_id] = self.function_map[tool_name]

def _detect_tool(self, message: Union[str,
dict]) -> Tuple[bool, str, str, str]:
dict]) -> Tuple[bool, list, str]:
"""
A built-in tool call detection for func_call format
Expand All @@ -225,26 +228,26 @@ def _detect_tool(self, message: Union[str,
Returns:
- bool: need to call tool or not
- str: tool name
- str: tool args
- list: tool list
- str: text replies except for tool calls
"""
func_name = None
func_args = None

func_calls = []
assert isinstance(message, dict)
# deprecating
if 'function_call' in message and message['function_call']:
func_call = message['function_call']
func_name = func_call.get('name', '')
func_args = func_call.get('arguments', '')
# Compatible with OpenAI API
func_calls.append(func_call)

# Follow OpenAI API, allow multi func_calls
if 'tool_calls' in message and message['tool_calls']:
func_call = message['tool_calls'][0]['function']
func_name = func_call.get('name', '')
func_args = func_call.get('arguments', '')
for item in message['tool_calls']:
func_call = item['function']
func_calls.append(func_call)

text = message.get('content', '')

return (func_name is not None), func_name, func_args, text
return (len(func_calls) > 0), func_calls, text

def _parse_image_url(self, image_url: List[Union[str, Dict]],
messages: List[Dict]) -> List[Dict]:
Expand Down
153 changes: 47 additions & 106 deletions modelscope_agent/agents/role_play.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from modelscope_agent import Agent
from modelscope_agent.agent_env_util import AgentEnvMixin
from modelscope_agent.llm.base import BaseChatModel
from modelscope_agent.llm.utils.function_call_with_raw_prompt import (
DEFAULT_EXEC_TEMPLATE, SPECIAL_PREFIX_TEMPLATE_TOOL,
SPECIAL_PREFIX_TEMPLATE_TOOL_FOR_CHAT, TOOL_TEMPLATE,
convert_tools_to_prompt, detect_multi_tool)
from modelscope_agent.tools.base import BaseTool
from modelscope_agent.utils.base64_utils import encode_files_to_base64
from modelscope_agent.utils.logger import agent_logger as logger
Expand All @@ -19,23 +23,6 @@
"""

TOOL_TEMPLATE_ZH = """
# 工具
## 你拥有如下工具:
{tool_descs}
## 当你需要调用工具时,请在你的回复中穿插如下的工具调用命令,可以根据需求调用零次或多次:
工具调用
Action: 工具的名称,必须是[{tool_names}]之一
Action Input: 工具的输入
Observation: <result>工具返回的结果</result>
Answer: 根据Observation总结本次工具调用返回的结果,如果结果中出现url,请使用如下格式展示出来:![图片](url)
"""

PROMPT_TEMPLATE_ZH = """
# 指令
Expand All @@ -52,24 +39,6 @@
"""

TOOL_TEMPLATE_EN = """
# Tools
## You have the following tools:
{tool_descs}
## When you need to call a tool, please intersperse the following tool command in your reply. %s
Tool Invocation
Action: The name of the tool, must be one of [{tool_names}]
Action Input: Tool input
Observation: <result>Tool returns result</result>
Answer: Summarize the results of this tool call based on Observation. If the result contains url, %s
""" % ('You can call zero or more times according to your needs:',
'please display it in the following format:![Image](URL)')

PROMPT_TEMPLATE_EN = """
#Instructions
Expand All @@ -80,11 +49,6 @@

KNOWLEDGE_TEMPLATE = {'zh': KNOWLEDGE_TEMPLATE_ZH, 'en': KNOWLEDGE_TEMPLATE_EN}

TOOL_TEMPLATE = {
'zh': TOOL_TEMPLATE_ZH,
'en': TOOL_TEMPLATE_EN,
}

PROMPT_TEMPLATE = {
'zh': PROMPT_TEMPLATE_ZH,
'en': PROMPT_TEMPLATE_EN,
Expand All @@ -105,16 +69,6 @@
'en': 'You are playing as {role_name}',
}

SPECIAL_PREFIX_TEMPLATE_TOOL = {
'zh': '。你可以使用工具:[{tool_names}]',
'en': '. you can use tools: [{tool_names}]',
}

SPECIAL_PREFIX_TEMPLATE_TOOL_FOR_CHAT = {
'zh': '。你必须使用工具中的一个或多个:[{tool_names}]',
'en': '. you must use one or more tools: [{tool_names}]',
}

SPECIAL_PREFIX_TEMPLATE_KNOWLEDGE = {
'zh': '。请查看前面的知识库',
'en': '. Please read the knowledge base at the beginning',
Expand All @@ -125,13 +79,6 @@
'en': '[Upload file "{file_names}"]',
}

DEFAULT_EXEC_TEMPLATE = """\nObservation: <result>{exec_result}</result>\nAnswer:"""

ACTION_TOKEN = 'Action:'
ARGS_TOKEN = 'Action Input:'
OBSERVATION_TOKEN = 'Observation:'
ANSWER_TOKEN = 'Answer:'


class RolePlay(Agent, AgentEnvMixin):

Expand All @@ -147,6 +94,33 @@ def __init__(self,
description, instruction, **kwargs)
AgentEnvMixin.__init__(self, **kwargs)

def _prepare_tool_system(self,
tools: Optional[List] = None,
parallel_tool_calls: bool = False,
lang='zh'):
# prepare the tool description and tool names with parallel function calling
tool_desc_template = TOOL_TEMPLATE[
lang + ('_parallel' if parallel_tool_calls else '')]

if tools is not None:
tool_descs = BaseTool.parser_function(tools)
tool_name_list = []
for tool in tools:
func_info = tool.get('function', {})
if func_info == {}:
continue
if 'name' in func_info:
tool_name_list.append(func_info['name'])
tool_names = ','.join(tool_name_list)
else:
tool_descs = '\n\n'.join(tool.function_plain_text
for tool in self.function_map.values())
tool_names = ','.join(tool.name
for tool in self.function_map.values())
tool_system = tool_desc_template.format(
tool_descs=tool_descs, tool_names=tool_names)
return tool_names, tool_system

def _run(self,
user_request,
history: Optional[List[Dict]] = None,
Expand All @@ -158,24 +132,11 @@ def _run(self,
chat_mode = kwargs.pop('chat_mode', False)
tools = kwargs.get('tools', None)
tool_choice = kwargs.get('tool_choice', 'auto')
parallel_tool_calls = kwargs.get('parallel_tool_calls',
True if chat_mode else False)

if tools is not None:
self.tool_descs = BaseTool.parser_function(tools)
tool_name_list = []
for tool in tools:
func_info = tool.get('function', {})
if func_info == {}:
continue
if 'name' in func_info:
tool_name_list.append(func_info['name'])
self.tool_names = ','.join(tool_name_list)
else:
self.tool_descs = '\n\n'.join(
tool.function_plain_text
for tool in self.function_map.values())
self.tool_names = ','.join(tool.name
for tool in self.function_map.values())

tool_names, tool_system = self._prepare_tool_system(
tools, parallel_tool_calls, lang)
self.system_prompt = ''
self.query_prefix = ''
self.query_prefix_dict = {'role': '', 'tool': '', 'knowledge': ''}
Expand All @@ -197,11 +158,10 @@ def _run(self,
'knowledge'] = SPECIAL_PREFIX_TEMPLATE_KNOWLEDGE[lang]

# concat tools information
if self.tool_descs and not self.llm.support_function_calling():
self.system_prompt += TOOL_TEMPLATE[lang].format(
tool_descs=self.tool_descs, tool_names=self.tool_names)
if tool_system and not self.llm.support_function_calling():
self.system_prompt += tool_system
self.query_prefix_dict['tool'] = SPECIAL_PREFIX_TEMPLATE_TOOL[
lang].format(tool_names=self.tool_names)
lang].format(tool_names=tool_names)

# concat instruction
if isinstance(self.instruction, dict):
Expand Down Expand Up @@ -242,7 +202,7 @@ def _run(self,
# concat the new messages
if chat_mode and tool_choice == 'required':
required_prefix = SPECIAL_PREFIX_TEMPLATE_TOOL_FOR_CHAT[
lang].format(tool_names=self.tool_names)
lang].format(tool_names=tool_names)
messages.append({
'role': 'user',
'content': required_prefix + user_request
Expand Down Expand Up @@ -295,33 +255,34 @@ def _run(self,
llm_result += s
yield s

use_tool = False
tool_list = []
if isinstance(llm_result, str):
use_tool, action, action_input, output = self._detect_tool(
llm_result)
use_tool, tool_list, output = detect_multi_tool(llm_result)
elif isinstance(llm_result, dict):
use_tool, action, action_input, output = super()._detect_tool(
llm_result)
use_tool, tool_list, output = super()._detect_tool(llm_result)
else:
assert 'llm_result must be an instance of dict or str'

if chat_mode:
if use_tool and tool_choice != 'none':
return f'Action: {action}\nAction Input: {action_input}\nResult: {output}'
return convert_tools_to_prompt(tool_list)
else:
return f'Result: {output}'

# yield output
if use_tool:
if self.llm.support_function_calling():
yield f'Action: {action}\nAction Input: {action_input}'
yield convert_tools_to_prompt(tool_list)

if self.use_tool_api:
# convert all files with base64, for the tool instance usage in case.
encoded_files = encode_files_to_base64(append_files)
kwargs['base64_files'] = encoded_files
kwargs['use_tool_api'] = True

observation = self._call_tool(action, action_input, **kwargs)
# currently only one observation execute, parallel
observation = self._call_tool(tool_list, **kwargs)
format_observation = DEFAULT_EXEC_TEMPLATE.format(
exec_result=observation)
yield format_observation
Expand Down Expand Up @@ -360,26 +321,6 @@ def _limit_observation_length(self, observation):
limited_observation = str(observation)[:int(reasonable_length)]
return DEFAULT_EXEC_TEMPLATE.format(exec_result=limited_observation)

def _detect_tool(self, message: Union[str,
dict]) -> Tuple[bool, str, str, str]:
assert isinstance(message, str)
text = message
func_name, func_args = None, None
i = text.rfind(ACTION_TOKEN)
j = text.rfind(ARGS_TOKEN)
k = text.rfind(OBSERVATION_TOKEN)
if 0 <= i < j: # If the text has `Action` and `Action input`,
if k < j: # but does not contain `Observation`,
# then it is likely that `Observation` is ommited by the LLM,
# because the output text may have discarded the stop word.
text = text.rstrip() + OBSERVATION_TOKEN # Add it back.
k = text.rfind(OBSERVATION_TOKEN)
func_name = text[i + len(ACTION_TOKEN):j].strip()
func_args = text[j + len(ARGS_TOKEN):k].strip()
text = text[:k] # Discard '\nObservation:'.

return (func_name is not None), func_name, func_args, text

def _parse_role_config(self, config: dict, lang: str = 'zh') -> str:
"""
Parsing role config dict to str.
Expand Down
Loading

0 comments on commit e883333

Please sign in to comment.