Skip to content

Commit

Permalink
fix code style (#289)
Browse files Browse the repository at this point in the history
  • Loading branch information
wenmengzhou authored Feb 1, 2024
1 parent 17d1171 commit 7a805e6
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 60 deletions.
4 changes: 2 additions & 2 deletions modelscope_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from .code_interpreter_jupyter import CodeInterpreterJupyter
from .hf_tool import HFTool
from .image_chat_tool import ImageChatTool
from .paraformer_asr_tool import ParaformerAsrTool
from .phantom_tool import Phantom
from .pipeline_tool import ModelscopePipelineTool
from .plugin_tool import LangchainTool
from .qwen_vl import QWenVL
from .sambert_tts_tool import SambertTtsTool
from .style_repaint import StyleRepaint
from .text_address_tool import TextAddressTool
from .text_ie_tool import TextInfoExtractTool
Expand All @@ -20,8 +22,6 @@
from .web_browser import WebBrowser
from .web_search import WebSearch
from .wordart_tool import WordArtTexture
from .paraformer_asr_tool import ParaformerAsrTool
from .sambert_tts_tool import SambertTtsTool

TOOL_INFO_LIST = {
'text-translation-zh2en': 'TranslationZh2EnTool',
Expand Down
25 changes: 17 additions & 8 deletions modelscope_agent/tools/paraformer_asr_tool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import subprocess
from http import HTTPStatus
from typing import List, Any
from typing import Any, List

from modelscope_agent.tools.tool import Tool, ToolSchema
from pydantic import ValidationError
Expand All @@ -10,9 +10,11 @@


def _preprocess(input_file, output_file):
ret = subprocess.call(['ffmpeg', '-y', '-i', input_file,
'-f', 's16le', '-acodec', 'pcm_s16le', '-ac', '1', '-ar', '16000',
'-loglevel', 'quiet', output_file])
ret = subprocess.call([
'ffmpeg', '-y', '-i', input_file, '-f', 's16le', '-acodec',
'pcm_s16le', '-ac', '1', '-ar', '16000', '-loglevel', 'quiet',
output_file
])
if ret != 0:
raise ValueError(f'Failed to preprocess audio file {input_file}')

Expand All @@ -29,7 +31,8 @@ class ParaformerAsrTool(Tool):
def __init__(self, cfg={}):
self.cfg = cfg.get(self.name, {})

self.api_key = self.cfg.get('dashscope_api_key', os.environ.get('DASHSCOPE_API_KEY'))
self.api_key = self.cfg.get('dashscope_api_key',
os.environ.get('DASHSCOPE_API_KEY'))
if self.api_key is None:
raise ValueError('Please set valid DASHSCOPE_API_KEY!')

Expand All @@ -55,8 +58,12 @@ def __call__(self, *args, **kwargs):
pcm_file = WORK_DIR + '/' + 'audio.pcm'
_preprocess(raw_audio_file, pcm_file)
if not os.path.exists(pcm_file):
raise ValueError(f'convert audio to pcm file failed')
recognition = Recognition(model='paraformer-realtime-v1', format='pcm', sample_rate=16000, callback=None)
raise ValueError(f'convert audio to pcm file {pcm_file} failed')
recognition = Recognition(
model='paraformer-realtime-v1',
format='pcm',
sample_rate=16000,
callback=None)
response = recognition.call(pcm_file)
result = ''
if response.status_code == HTTPStatus.OK:
Expand All @@ -65,5 +72,7 @@ def __call__(self, *args, **kwargs):
for sentence in sentences:
result += sentence['text']
else:
raise ValueError(f'call paraformer asr failed, request id: {response.get_request_id()}')
raise ValueError(
f'call paraformer asr failed, request id: {response.get_request_id()}'
)
return {'result': result}
77 changes: 37 additions & 40 deletions modelscope_agent/tools/phantom_tool.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
import os
import pandas as pd
import time

import json
import pandas as pd
import requests
from modelscope_agent.tools.localfile2url_utils.localfile2url import get_upload_url
from modelscope_agent.tools.localfile2url_utils.localfile2url import \
get_upload_url
from modelscope_agent.tools.tool import Tool, ToolSchema
from pydantic import ValidationError
from requests.exceptions import RequestException, Timeout
import time

MAX_RETRY_TIMES = 3
WORK_DIR = os.getenv('CODE_INTERPRETER_WORK_DIR', '/tmp/ci_workspace')

class Phantom(Tool):# 继承基础类Tool,新建一个继承类
description = '追影-放大镜'# 对这个tool的功能描述
name = 'phantom_image_enhancement'#tool name

class Phantom(Tool): # 继承基础类Tool,新建一个继承类
description = '追影-放大镜' # 对这个tool的功能描述
name = 'phantom_image_enhancement' # tool name
"""
parameters是需要传入api tool的参数,通过api详情获取需要哪些必要入参
其中每一个参数都是一个字典,包含name,description,required三个字段
Expand All @@ -23,20 +26,20 @@ class Phantom(Tool):# 继承基础类Tool,新建一个继承类
'name': 'input.image_path',
'description': '输入的待增强图片的本地相对路径',
'required': True
},
{
}, {
'name': 'parameters.upscale',
'description': '选择需要超分的倍率,可选择1、2、3、4',
'required': False
}]

def __init__(self, cfg={}):
self.cfg = cfg.get(self.name, {})# cfg注册见下一条说明,这里是通过name找到对应的cfg
self.cfg = cfg.get(self.name, {}) # cfg注册见下一条说明,这里是通过name找到对应的cfg
# api url
self.url = 'https://dashscope.aliyuncs.com/api/v1/services/enhance/image-enhancement/generation'
# api token,可以选择注册在下面的cfg里,也可以选择将'API_TOKEN'导入环境变量
self.token = self.cfg.get('token', os.environ.get('DASHSCOPE_API_KEY', ''))
assert self.token != '', 'dashscope api token must be acquired'
self.token = self.cfg.get('token',
os.environ.get('DASHSCOPE_API_KEY', ''))
assert self.token != '', 'dashscope api token must be acquired'
# 验证,转换参数格式,保持即可
try:
all_param = {
Expand All @@ -51,12 +54,12 @@ def __init__(self, cfg={}):
self._str = self.tool_schema.model_dump_json()
self._function = self.parse_pydantic_model_to_openai_function(
all_param)
# 调用api操作函数,kwargs里是llm根据上面的parameters说明得到的对应参数

# 调用api操作函数,kwargs里是llm根据上面的parameters说明得到的对应参数
def __call__(self, *args, **kwargs):
# 对入参格式调整和补充,比如解开嵌套的'.'连接的参数,还有导入你默认的一些参数,
# 比如model,参考下面的_remote_parse_input函数。

remote_parsed_input = json.dumps(
self._remote_parse_input(*args, **kwargs))
origin_result = None
Expand All @@ -76,14 +79,11 @@ def __call__(self, *args, **kwargs):
try:
# requests请求
response = requests.post(
url=self.url,
headers=headers,
data=remote_parsed_input)
url=self.url, headers=headers, data=remote_parsed_input)

if response.status_code != requests.codes.ok:
response.raise_for_status()
origin_result = json.loads(
response.content.decode('utf-8'))
origin_result = json.loads(response.content.decode('utf-8'))
# self._parse_output是基础类Tool对output结果的一个格式调整,你可 # 以在这里按需调整返回格式
self.final_result = self._parse_output(
origin_result, remote=True)
Expand All @@ -95,11 +95,11 @@ def __call__(self, *args, **kwargs):
raise ValueError(
f'Remote call failed with error code: {e.response.status_code},\
error message: {e.response.content.decode("utf-8")}')

raise ValueError(
'Remote call max retry times exceeded! Please try to use local call.'
)

def _remote_parse_input(self, *args, **kwargs):
restored_dict = {}
for key, value in kwargs.items():
Expand All @@ -114,9 +114,10 @@ def _remote_parse_input(self, *args, **kwargs):
# f the key does not contain ".", directly store the key-value pair into restored_dict
restored_dict[key] = value
kwargs = restored_dict

image_path = kwargs['input'].pop('image_path', None)
if image_path and image_path.endswith(('.jpeg', '.png', '.jpg', '.bmp')):
if image_path and \
image_path.endswith(('.jpeg', '.png', '.jpg', '.bmp')):
# 生成 image_url,然后设置到 kwargs['input'] 中
# 复用dashscope公共oss
image_path = f'file://{os.path.join(WORK_DIR, image_path)}'
Expand All @@ -127,47 +128,42 @@ def _remote_parse_input(self, *args, **kwargs):
kwargs['input']['image_url'] = image_url
else:
raise ValueError('请先上传一张正确格式的图片')

kwargs['model'] = 'wanx-image-enhancement-v1'
print('传给tool的参数:', kwargs)
return kwargs

def get_result(self):
result_data = json.loads(json.dumps(self.final_result['result']))
if 'task_id' in result_data['output']:
task_id = result_data['output']['task_id']
get_url = f'https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}'
get_header = {'Authorization': f'Bearer {self.token}'}

origin_result = None
retry_times = MAX_RETRY_TIMES
while retry_times:
retry_times -= 1
try:
response = requests.request(
'GET',
url=get_url,
headers=get_header
)
'GET', url=get_url, headers=get_header)
if response.status_code != requests.codes.ok:
response.raise_for_status()
origin_result = json.loads(
response.content.decode('utf-8'))
origin_result = json.loads(response.content.decode('utf-8'))

get_result = self._parse_output(
origin_result, remote=True)
get_result = self._parse_output(origin_result, remote=True)
return get_result
except Timeout:
continue
except RequestException as e:
raise ValueError(
f'Remote call failed with error code: {e.response.status_code},\
error message: {e.response.content.decode("utf-8")}')

raise ValueError(
'Remote call max retry times exceeded! Please try to use local call.'
)

def get_phantom_result(self):
try:
result = self.get_result()
Expand All @@ -186,16 +182,17 @@ def get_phantom_result(self):
# output_url = self._parse_output(result['result']['output']['result_url'])
output_url = {}
output_url['result'] = {}
output_url['result']['url'] = result['result']['output']['result_url']
output_url['result']['url'] = result['result']['output'][
'result_url']
# print(output_url)
print(output_url)
return output_url

elif task_status in ['FAILED', 'ERROR']:
raise("任务失败")
raise ('任务失败')

# 继续轮询,等待一段时间后再次调用
time.sleep(1) # 等待 1 秒钟
result = self.get_result()
except Exception as e:
print('get request Error:', str(e))
print('get request Error:', str(e))
12 changes: 8 additions & 4 deletions modelscope_agent/tools/sambert_tts_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class SambertTtsTool(Tool):
def __init__(self, cfg={}):
self.cfg = cfg.get(self.name, {})

self.api_key = self.cfg.get('dashscope_api_key', os.environ.get('DASHSCOPE_API_KEY'))
self.api_key = self.cfg.get('dashscope_api_key',
os.environ.get('DASHSCOPE_API_KEY'))
if self.api_key is None:
raise ValueError('Please set valid DASHSCOPE_API_KEY!')

Expand All @@ -41,13 +42,16 @@ def __call__(self, *args, **kwargs):
from dashscope.audio.tts import SpeechSynthesizer
tts_text = kwargs['text']
if tts_text is None or len(tts_text) == 0 or tts_text == '':
raise ValueError(f'tts input text is valid')
raise ValueError('tts input text is valid')
os.makedirs(WORK_DIR, exist_ok=True)
wav_file = WORK_DIR + '/sambert_tts_audio.wav'
response = SpeechSynthesizer.call(model='sambert-zhijia-v1', format='wav', text=tts_text)
response = SpeechSynthesizer.call(
model='sambert-zhijia-v1', format='wav', text=tts_text)
if response.get_audio_data() is not None:
with open(wav_file, 'wb') as f:
f.write(response.get_audio_data())
else:
raise ValueError(f'call sambert tts failed, request id: {response.get_response().request_id}')
raise ValueError(
f'call sambert tts failed, request id: {response.get_response().request_id}'
)
return {'result': AudioWrapper(wav_file)}
6 changes: 3 additions & 3 deletions tests/tools/test_dashscope_asr_tts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from modelscope_agent.agent import AgentExecutor
from modelscope_agent.tools import ParaformerAsrTool
from modelscope_agent.tools import SambertTtsTool
from modelscope_agent.tools import ParaformerAsrTool, SambertTtsTool
from tests.utils import MockLLM, MockOutParser, MockPromptGenerator


Expand All @@ -21,7 +20,8 @@ def test_sambert_tts():
def test_paraformer_asr_agent():
responses = [
"<|startofthink|>{\"api_name\": \"paraformer_asr_utils\", \"parameters\": "
"{\"audio_path\": \"16k-xwlb3_local_user.wav\"}}<|endofthink|>", 'summarize'
"{\"audio_path\": \"16k-xwlb3_local_user.wav\"}}<|endofthink|>",
'summarize'
]
llm = MockLLM(responses)

Expand Down
12 changes: 9 additions & 3 deletions tests/tools/test_phantom.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,23 @@

def test_phantom():
input = '2_local_user.png'
kwargs = {'input.image_path': input, 'parameters.upscale': 2, 'remote': False}
kwargs = {
'input.image_path': input,
'parameters.upscale': 2,
'remote': False
}
phantom = Phantom()
res = phantom(**kwargs)

print(res)
assert res['result']['url'].startswith('http')


def test_phantom_agent():
responses = [
"<|startofthink|>{\"api_name\": \"phantom_image_enhancement\", \"parameters\": "
"{\"input.image_path\": \"2_local_user.png\"}}<|endofthink|>", 'summarize'
"{\"input.image_path\": \"2_local_user.png\"}}<|endofthink|>",
'summarize'
]
llm = MockLLM(responses)

Expand All @@ -34,4 +40,4 @@ def test_phantom_agent():
res = agent.run('2倍超分')
print(res)

assert res[0]['result']['url'].startswith('http')
assert res[0]['result']['url'].startswith('http')

0 comments on commit 7a805e6

Please sign in to comment.