Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/tool/image_gen #247

Merged
merged 3 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 17 additions & 79 deletions modelscope_agent/tools/text_to_image_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,111 +5,49 @@
import dashscope
import json
from dashscope import ImageSynthesis
from modelscope_agent.output_wrapper import ImageWrapper
from modelscope_agent.tools.base import BaseTool, register_tool

from modelscope.utils.constant import Tasks
from .pipeline_tool import ModelscopePipelineTool


class TextToImageTool(ModelscopePipelineTool):
default_model = 'AI-ModelScope/stable-diffusion-xl-base-1.0'
@register_tool('image_gen')
class TextToImageTool(BaseTool):
description = 'AI绘画(图像生成)服务,输入文本描述和图像分辨率,返回根据文本信息绘制的图片URL。'
name = 'image_gen'
parameters: list = [{
'name': 'text',
'description': '详细描述了希望生成的图像具有什么内容,例如人物、环境、动作等细节描述',
'required': True,
'schema': {
'type': 'string'
}
'type': 'string'
}, {
'name': 'resolution',
'description':
'格式是 数字*数字,表示希望生成的图像的分辨率大小,选项有[1024*1024, 720*1280, 1280*720]',
'required': True,
'schema': {
'type': 'string'
}
'type': 'string'
}]
model_revision = 'v1.0.0'
task = Tasks.text_to_image_synthesis

# def _remote_parse_input(self, *args, **kwargs):
# params = {
# 'input': {
# 'text': kwargs['text'],
# 'resolution': kwargs['resolution']
# }
# }
# if kwargs.get('seed', None):
# params['input']['seed'] = kwargs['seed']
# return params

def _remote_call(self, *args, **kwargs):
def call(self, params: str, **kwargs) -> str:
params = self._verify_args(params)
if isinstance(params, str):
return 'Parameter Error'

if ('resolution' in kwargs) and (kwargs['resolution'] in [
'1024*1024', '720*1280', '1280*720'
]):
resolution = kwargs['resolution']
if params['resolution'] in ['1024*1024', '720*1280', '1280*720']:
resolution = params['resolution']
else:
resolution = '1280*720'

prompt = kwargs['text']
seed = kwargs.get('seed', None)
prompt = params['text']
if prompt is None:
return None
dashscope.api_key = os.getenv('DASHSCOPE_API_KEY')
seed = kwargs.get('seed', None)
model = kwargs.get('model', 'wanx-v1')
dashscope.api_key = os.getenv('DASHSCOPE_API_KEY')

response = ImageSynthesis.call(
model=model,
prompt=prompt,
n=1,
size=resolution,
steps=10,
seed=seed)
final_result = self._parse_output(response, remote=True)
return final_result

def _local_parse_input(self, *args, **kwargs):

text = kwargs.pop('text', '')

parsed_args = ({'text': text}, )

return parsed_args, {}

def _parse_output(self, origin_result, remote=True):
if not remote:
image = cv2.cvtColor(origin_result['output_imgs'][0],
cv2.COLOR_BGR2RGB)
else:
image = origin_result.output['results'][0]['url']

return {'result': ImageWrapper(image)}

def _handle_input_fallback(self, **kwargs):
"""
an alternative method is to parse image is that get item between { and }
for last try

:param fallback_text:
:return: language, cocde
"""

text = kwargs.get('text', None)
fallback = kwargs.get('fallback', None)

if text:
return text
elif fallback:
try:
text = fallback
json_block = re.search(r'\{([\s\S]+)\}', text) # noqa W^05
if json_block:
result = json_block.group(1)
result_json = json.loads('{' + result + '}')
return result_json['text']
except ValueError:
return text
else:
return text
image_url = response.output['results'][0]['url']
return image_url
39 changes: 39 additions & 0 deletions tests/tools/test_image_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from modelscope_agent.agent import Agent
from modelscope_agent.tools.text_to_image_tool import TextToImageTool

from modelscope_agent.prompts.role_play import RolePlay # NOQA


def test_image_gen():
params = """{'text': '画一只小猫', 'resolution': '1024*1024'}"""

t2i = TextToImageTool()
res = t2i.call(params)
assert (res.startswith('http'))


def test_image_gen_wrong_resolution():
params = """{'text': '画一只小猫', 'resolution': '1024'}"""

t2i = TextToImageTool()
res = t2i.call(params)
assert (res.startswith('http'))


def test_image_gen_role():
role_template = '你扮演一个画家,用尽可能丰富的描述调用工具绘制图像。'

llm_config = {'model': 'qwen-max', 'model_server': 'dashscope'}

# input tool args
function_list = [{'name': 'image_gen'}]

bot = RolePlay(
function_list=function_list, llm=llm_config, instruction=role_template)

response = bot.run('朝阳区天气怎样?')

text = ''
for chunk in response:
text += chunk
print(text)
30 changes: 3 additions & 27 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from modelscope_agent.action_parser import ActionParser
from modelscope_agent.llm import LLM
from modelscope_agent.prompts import PromptGenerator
from modelscope_agent.tools import Tool
from agent_scope.llm import BaseChatModel
from agent_scope.tools import Tool


class MockLLM(LLM):
class MockLLM(BaseChatModel):

def __init__(self, responses=['mock_llm_response']):
super().__init__({})
Expand All @@ -21,28 +19,6 @@ def stream_generate(self, prompt: str, function_list=[], **kwargs) -> str:
yield 'mock llm response'


class MockPromptGenerator(PromptGenerator):

def __init__(self):
super().__init__()


class MockOutParser(ActionParser):

def __init__(self, action, args, count=1):
super().__init__()
self.action = action
self.args = args
self.count = count

def parse_response(self, response: str):
if self.count > 0:
self.count -= 1
return self.action, self.args
else:
return None, None


class MockTool(Tool):

def __init__(self, name, func, description, parameters=[]):
Expand Down
Loading