Skip to content

Commit

Permalink
Merge pull request #247 from suluyana/refactor/tool/image_gen
Browse files Browse the repository at this point in the history
Refactor/tool/image_gen
  • Loading branch information
tuhahaha authored Jan 5, 2024
2 parents d4b68fc + 4689856 commit b4bb3b8
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 106 deletions.
96 changes: 17 additions & 79 deletions modelscope_agent/tools/text_to_image_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,111 +5,49 @@
import dashscope
import json
from dashscope import ImageSynthesis
from modelscope_agent.output_wrapper import ImageWrapper
from modelscope_agent.tools.base import BaseTool, register_tool

from modelscope.utils.constant import Tasks
from .pipeline_tool import ModelscopePipelineTool


class TextToImageTool(ModelscopePipelineTool):
default_model = 'AI-ModelScope/stable-diffusion-xl-base-1.0'
@register_tool('image_gen')
class TextToImageTool(BaseTool):
description = 'AI绘画(图像生成)服务,输入文本描述和图像分辨率,返回根据文本信息绘制的图片URL。'
name = 'image_gen'
parameters: list = [{
'name': 'text',
'description': '详细描述了希望生成的图像具有什么内容,例如人物、环境、动作等细节描述',
'required': True,
'schema': {
'type': 'string'
}
'type': 'string'
}, {
'name': 'resolution',
'description':
'格式是 数字*数字,表示希望生成的图像的分辨率大小,选项有[1024*1024, 720*1280, 1280*720]',
'required': True,
'schema': {
'type': 'string'
}
'type': 'string'
}]
model_revision = 'v1.0.0'
task = Tasks.text_to_image_synthesis

# def _remote_parse_input(self, *args, **kwargs):
# params = {
# 'input': {
# 'text': kwargs['text'],
# 'resolution': kwargs['resolution']
# }
# }
# if kwargs.get('seed', None):
# params['input']['seed'] = kwargs['seed']
# return params

def _remote_call(self, *args, **kwargs):
def call(self, params: str, **kwargs) -> str:
params = self._verify_args(params)
if isinstance(params, str):
return 'Parameter Error'

if ('resolution' in kwargs) and (kwargs['resolution'] in [
'1024*1024', '720*1280', '1280*720'
]):
resolution = kwargs['resolution']
if params['resolution'] in ['1024*1024', '720*1280', '1280*720']:
resolution = params['resolution']
else:
resolution = '1280*720'

prompt = kwargs['text']
seed = kwargs.get('seed', None)
prompt = params['text']
if prompt is None:
return None
dashscope.api_key = os.getenv('DASHSCOPE_API_KEY')
seed = kwargs.get('seed', None)
model = kwargs.get('model', 'wanx-v1')
dashscope.api_key = os.getenv('DASHSCOPE_API_KEY')

response = ImageSynthesis.call(
model=model,
prompt=prompt,
n=1,
size=resolution,
steps=10,
seed=seed)
final_result = self._parse_output(response, remote=True)
return final_result

def _local_parse_input(self, *args, **kwargs):

text = kwargs.pop('text', '')

parsed_args = ({'text': text}, )

return parsed_args, {}

def _parse_output(self, origin_result, remote=True):
if not remote:
image = cv2.cvtColor(origin_result['output_imgs'][0],
cv2.COLOR_BGR2RGB)
else:
image = origin_result.output['results'][0]['url']

return {'result': ImageWrapper(image)}

def _handle_input_fallback(self, **kwargs):
"""
an alternative method is to parse image is that get item between { and }
for last try
:param fallback_text:
:return: language, cocde
"""

text = kwargs.get('text', None)
fallback = kwargs.get('fallback', None)

if text:
return text
elif fallback:
try:
text = fallback
json_block = re.search(r'\{([\s\S]+)\}', text) # noqa W^05
if json_block:
result = json_block.group(1)
result_json = json.loads('{' + result + '}')
return result_json['text']
except ValueError:
return text
else:
return text
image_url = response.output['results'][0]['url']
return image_url
39 changes: 39 additions & 0 deletions tests/tools/test_image_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from modelscope_agent.agent import Agent
from modelscope_agent.tools.text_to_image_tool import TextToImageTool

from modelscope_agent.prompts.role_play import RolePlay # NOQA


def test_image_gen():
params = """{'text': '画一只小猫', 'resolution': '1024*1024'}"""

t2i = TextToImageTool()
res = t2i.call(params)
assert (res.startswith('http'))


def test_image_gen_wrong_resolution():
params = """{'text': '画一只小猫', 'resolution': '1024'}"""

t2i = TextToImageTool()
res = t2i.call(params)
assert (res.startswith('http'))


def test_image_gen_role():
role_template = '你扮演一个画家,用尽可能丰富的描述调用工具绘制图像。'

llm_config = {'model': 'qwen-max', 'model_server': 'dashscope'}

# input tool args
function_list = [{'name': 'image_gen'}]

bot = RolePlay(
function_list=function_list, llm=llm_config, instruction=role_template)

response = bot.run('朝阳区天气怎样?')

text = ''
for chunk in response:
text += chunk
print(text)
30 changes: 3 additions & 27 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from modelscope_agent.action_parser import ActionParser
from modelscope_agent.llm import LLM
from modelscope_agent.prompts import PromptGenerator
from modelscope_agent.tools import Tool
from agent_scope.llm import BaseChatModel
from agent_scope.tools import Tool


class MockLLM(LLM):
class MockLLM(BaseChatModel):

def __init__(self, responses=['mock_llm_response']):
super().__init__({})
Expand All @@ -21,28 +19,6 @@ def stream_generate(self, prompt: str, function_list=[], **kwargs) -> str:
yield 'mock llm response'


class MockPromptGenerator(PromptGenerator):

def __init__(self):
super().__init__()


class MockOutParser(ActionParser):

def __init__(self, action, args, count=1):
super().__init__()
self.action = action
self.args = args
self.count = count

def parse_response(self, response: str):
if self.count > 0:
self.count -= 1
return self.action, self.args
else:
return None, None


class MockTool(Tool):

def __init__(self, name, func, description, parameters=[]):
Expand Down

0 comments on commit b4bb3b8

Please sign in to comment.