diff --git a/modelscope_agent/tools/text_to_image_tool.py b/modelscope_agent/tools/text_to_image_tool.py index d65fb13fc..f17b3fc6a 100644 --- a/modelscope_agent/tools/text_to_image_tool.py +++ b/modelscope_agent/tools/text_to_image_tool.py @@ -5,61 +5,43 @@ import dashscope import json from dashscope import ImageSynthesis -from modelscope_agent.output_wrapper import ImageWrapper +from modelscope_agent.tools.base import BaseTool, register_tool -from modelscope.utils.constant import Tasks -from .pipeline_tool import ModelscopePipelineTool - -class TextToImageTool(ModelscopePipelineTool): - default_model = 'AI-ModelScope/stable-diffusion-xl-base-1.0' +@register_tool('image_gen') +class TextToImageTool(BaseTool): description = 'AI绘画(图像生成)服务,输入文本描述和图像分辨率,返回根据文本信息绘制的图片URL。' name = 'image_gen' parameters: list = [{ 'name': 'text', 'description': '详细描述了希望生成的图像具有什么内容,例如人物、环境、动作等细节描述', 'required': True, - 'schema': { - 'type': 'string' - } + 'type': 'string' }, { 'name': 'resolution', 'description': '格式是 数字*数字,表示希望生成的图像的分辨率大小,选项有[1024*1024, 720*1280, 1280*720]', 'required': True, - 'schema': { - 'type': 'string' - } + 'type': 'string' }] - model_revision = 'v1.0.0' - task = Tasks.text_to_image_synthesis - - # def _remote_parse_input(self, *args, **kwargs): - # params = { - # 'input': { - # 'text': kwargs['text'], - # 'resolution': kwargs['resolution'] - # } - # } - # if kwargs.get('seed', None): - # params['input']['seed'] = kwargs['seed'] - # return params - def _remote_call(self, *args, **kwargs): + def call(self, params: str, **kwargs) -> str: + params = self._verify_args(params) + if isinstance(params, str): + return 'Parameter Error' - if ('resolution' in kwargs) and (kwargs['resolution'] in [ - '1024*1024', '720*1280', '1280*720' - ]): - resolution = kwargs['resolution'] + if params['resolution'] in ['1024*1024', '720*1280', '1280*720']: + resolution = params['resolution'] else: resolution = '1280*720' - prompt = kwargs['text'] - seed = kwargs.get('seed', None) + prompt = params['text'] if prompt is None: return None - dashscope.api_key = os.getenv('DASHSCOPE_API_KEY') + seed = kwargs.get('seed', None) model = kwargs.get('model', 'wanx-v1') + dashscope.api_key = os.getenv('DASHSCOPE_API_KEY') + response = ImageSynthesis.call( model=model, prompt=prompt, @@ -67,49 +49,5 @@ def _remote_call(self, *args, **kwargs): size=resolution, steps=10, seed=seed) - final_result = self._parse_output(response, remote=True) - return final_result - - def _local_parse_input(self, *args, **kwargs): - - text = kwargs.pop('text', '') - - parsed_args = ({'text': text}, ) - - return parsed_args, {} - - def _parse_output(self, origin_result, remote=True): - if not remote: - image = cv2.cvtColor(origin_result['output_imgs'][0], - cv2.COLOR_BGR2RGB) - else: - image = origin_result.output['results'][0]['url'] - - return {'result': ImageWrapper(image)} - - def _handle_input_fallback(self, **kwargs): - """ - an alternative method is to parse image is that get item between { and } - for last try - - :param fallback_text: - :return: language, cocde - """ - - text = kwargs.get('text', None) - fallback = kwargs.get('fallback', None) - - if text: - return text - elif fallback: - try: - text = fallback - json_block = re.search(r'\{([\s\S]+)\}', text) # noqa W^05 - if json_block: - result = json_block.group(1) - result_json = json.loads('{' + result + '}') - return result_json['text'] - except ValueError: - return text - else: - return text + image_url = response.output['results'][0]['url'] + return image_url diff --git a/tests/tools/test_image_gen.py b/tests/tools/test_image_gen.py new file mode 100644 index 000000000..b76b401a8 --- /dev/null +++ b/tests/tools/test_image_gen.py @@ -0,0 +1,39 @@ +from modelscope_agent.agent import Agent +from modelscope_agent.tools.text_to_image_tool import TextToImageTool + +from modelscope_agent.prompts.role_play import RolePlay # NOQA + + +def test_image_gen(): + params = """{'text': '画一只小猫', 'resolution': '1024*1024'}""" + + t2i = TextToImageTool() + res = t2i.call(params) + assert (res.startswith('http')) + + +def test_image_gen_wrong_resolution(): + params = """{'text': '画一只小猫', 'resolution': '1024'}""" + + t2i = TextToImageTool() + res = t2i.call(params) + assert (res.startswith('http')) + + +def test_image_gen_role(): + role_template = '你扮演一个画家,用尽可能丰富的描述调用工具绘制图像。' + + llm_config = {'model': 'qwen-max', 'model_server': 'dashscope'} + + # input tool args + function_list = [{'name': 'image_gen'}] + + bot = RolePlay( + function_list=function_list, llm=llm_config, instruction=role_template) + + response = bot.run('朝阳区天气怎样?') + + text = '' + for chunk in response: + text += chunk + print(text) diff --git a/tests/utils.py b/tests/utils.py index 13c1ccf5a..b45d1b19a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,10 +1,8 @@ -from modelscope_agent.action_parser import ActionParser -from modelscope_agent.llm import LLM -from modelscope_agent.prompts import PromptGenerator -from modelscope_agent.tools import Tool +from agent_scope.llm import BaseChatModel +from agent_scope.tools import Tool -class MockLLM(LLM): +class MockLLM(BaseChatModel): def __init__(self, responses=['mock_llm_response']): super().__init__({}) @@ -21,28 +19,6 @@ def stream_generate(self, prompt: str, function_list=[], **kwargs) -> str: yield 'mock llm response' -class MockPromptGenerator(PromptGenerator): - - def __init__(self): - super().__init__() - - -class MockOutParser(ActionParser): - - def __init__(self, action, args, count=1): - super().__init__() - self.action = action - self.args = args - self.count = count - - def parse_response(self, response: str): - if self.count > 0: - self.count -= 1 - return self.action, self.args - else: - return None, None - - class MockTool(Tool): def __init__(self, name, func, description, parameters=[]):