Skip to content

Commit b5f2a0b

Browse files
authored
change pipeline plugin tool (#248)
* change pipeline plugin tool * Standardized input and output * Standardized input and output * Standardized input and output
1 parent 053ad15 commit b5f2a0b

File tree

5 files changed

+74
-74
lines changed

5 files changed

+74
-74
lines changed
Lines changed: 25 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,30 @@
1-
from modelscope.pipelines import pipeline
2-
from .tool import Tool
1+
from modelscope_agent.tools.base import BaseTool, register_tool
2+
import json
3+
import requests
4+
import os
35

4-
5-
class ModelscopePipelineTool(Tool):
6-
7-
default_model: str = ''
8-
task: str = ''
9-
model_revision = None
6+
@register_tool('pipeline')
7+
class ModelscopePipelineTool(BaseTool):
8+
API_URL = ""
9+
API_KEY = ""
1010

1111
def __init__(self, cfg):
12-
12+
"""
13+
初始化一个ModelscopePipelineTool类
14+
Initialize a ModelscopePipelineTool class.
15+
参数:
16+
cfg (Dict[str, object]): 配置字典,包含了初始化对象所需要的参数
17+
"""
1318
super().__init__(cfg)
14-
self.model = self.cfg.get('model', None) or self.default_model
15-
self.model_revision = self.cfg.get('model_revision',
16-
None) or self.model_revision
17-
18-
self.pipeline_params = self.cfg.get('pipeline_params', {})
19-
self.pipeline = None
20-
self.is_initialized = False
21-
22-
def setup(self):
23-
24-
# only initialize when this tool is really called to save memory
25-
if not self.is_initialized:
26-
self.pipeline = pipeline(
27-
task=self.task,
28-
model=self.model,
29-
model_revision=self.model_revision,
30-
**self.pipeline_params)
31-
self.is_initialized = True
32-
33-
def _local_call(self, *args, **kwargs):
34-
35-
self.setup()
19+
self.API_URL = self.cfg.get(self.name, {}).get('url',None) or self.API_URL
20+
self.API_KEY = os.getenv('MODELSCOPE_API_KEY', None) or self.API_KEY
21+
22+
23+
def call(self, params: str, **kwargs) -> str:
24+
params = self._verify_args(params)
25+
data = json.dumps(params)
26+
headers = {"Authorization": f"Bearer {self.API_KEY}"}
27+
response = requests.request("POST", self.API_URL, headers=headers,data=data)
28+
result = json.loads(response.content.decode("utf-8"))
29+
return result
3630

37-
parsed_args, parsed_kwargs = self._local_parse_input(*args, **kwargs)
38-
origin_result = self.pipeline(*parsed_args, **parsed_kwargs)
39-
final_result = self._parse_output(origin_result, remote=False)
40-
return final_result

modelscope_agent/tools/plugin_tool.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
from copy import deepcopy
2+
from modelscope_agent.tools.base import BaseTool, register_tool
23

3-
from .tool import Tool
4-
5-
6-
class LangchainTool(Tool):
4+
@register_tool('plugin')
5+
class LangchainTool(BaseTool):
6+
description = '通过调用langchain插件来支持对语言模型的输入输出格式进行处理,输入文本字符,输出经过格式处理的结果'
7+
name = 'plugin'
8+
parameters: list = [{
9+
'name': 'commands',
10+
'description': '需要进行格式处理的文本字符列表',
11+
'required': True,
12+
'type': "string"
13+
}]
714

815
def __init__(self, langchain_tool):
916
from langchain.tools import BaseTool
@@ -23,8 +30,11 @@ def parse_langchain_schema(self):
2330
tool_arg = deepcopy(arg)
2431
tool_arg['name'] = name
2532
tool_arg['required'] = True
33+
tool_arg['type'] = arg['anyOf'][0].get("type","string")
2634
tool_arg.pop('title')
2735
self.parameters.append(tool_arg)
2836

29-
def _local_call(self, *args, **kwargs):
30-
return {'result': self.langchain_tool.run(kwargs)}
37+
def call(self, params: str, **kwargs):
38+
params = self._verify_args(params)
39+
res = self.langchain_tool.run(params)
40+
return res
Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
from modelscope_agent.output_wrapper import AudioWrapper
2-
3-
from modelscope.utils.constant import Tasks
41
from .pipeline_tool import ModelscopePipelineTool
5-
2+
from modelscope_agent.output_wrapper import AudioWrapper
63

74
class TexttoSpeechTool(ModelscopePipelineTool):
85
default_model = 'damo/speech_sambert-hifigan_tts_zh-cn_16k'
@@ -11,35 +8,16 @@ class TexttoSpeechTool(ModelscopePipelineTool):
118
parameters: list = [{
129
'name': 'input',
1310
'description': '要转成语音的文本',
14-
'required': True
11+
'required': True,
12+
'type': 'string'
1513
}, {
1614
'name': 'gender',
1715
'description': '用户身份',
18-
'required': True
16+
'required': True,
17+
'type': 'string'
1918
}]
20-
task = Tasks.text_to_speech
21-
22-
def _local_parse_input(self, *args, **kwargs):
23-
if 'gender' not in kwargs:
24-
kwargs['gender'] = 'man'
25-
voice = 'zhizhe_emo' if kwargs['gender'] == 'man' else 'zhiyan_emo'
26-
kwargs['voice'] = voice
27-
if 'text' in kwargs and 'input' not in kwargs:
28-
kwargs['input'] = kwargs['text']
29-
kwargs.pop('text')
30-
kwargs.pop('gender')
31-
return args, kwargs
32-
33-
def _remote_parse_input(self, *args, **kwargs):
34-
if 'gender' not in kwargs:
35-
kwargs['gender'] = 'man'
36-
voice = 'zhizhe_emo' if kwargs['gender'] == 'man' or kwargs[
37-
'gender'] == 'male' else 'zhiyan_emo'
38-
kwargs['parameters'] = {'voice': voice}
39-
kwargs.pop('gender')
40-
return kwargs
41-
42-
def _parse_output(self, origin_result, remote=True):
4319

44-
audio = origin_result['output_wav']
45-
return {'result': AudioWrapper(audio)}
20+
def call(self, params: str, **kwargs) -> str:
21+
result = super().call(params, **kwargs)
22+
audio = result['Data']['output_wav']
23+
return AudioWrapper(audio)

tests/tools/test_langchain_tool.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,9 @@ def test_is_langchain_tool():
1414
def test_run_langchin_tool():
1515
# test run langchain tool
1616
shell_tool = LangchainTool(ShellTool())
17-
res = shell_tool(commands=["echo 'Hello World!'"])
18-
assert res['result'] == 'Hello World!\n'
17+
input = """{'commands': ["echo 'Hello World!'"]}"""
18+
res = shell_tool.call(input)
19+
print(res)
20+
assert res == 'Hello World!\n'
21+
22+
test_run_langchin_tool()

tests/tools/test_pipeline_tool.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from modelscope_agent.tools.pipeline_tool import ModelscopePipelineTool
2+
from modelscope.utils.config import Config
3+
import os
4+
5+
cfg = Config.from_file('config/cfg_tool_template.json')
6+
# 请用自己的SDK令牌替换{YOUR_MODELSCOPE_SDK_TOKEN}(包括大括号)
7+
os.environ['MODELSCOPE_API_KEY'] = f"{YOUR_MODELSCOPE_SDK_TOKEN}"
8+
9+
def test_modelscope_speech_generation():
10+
from modelscope_agent.tools.text_to_speech_tool import TexttoSpeechTool
11+
kwargs = """{'input': '北京今天天气怎样?', 'gender': 'man'}"""
12+
txt2speech = TexttoSpeechTool(cfg)
13+
res = txt2speech.call(kwargs)
14+
print(res)
15+
16+
17+
test_modelscope_speech_generation()
18+

0 commit comments

Comments
 (0)