Skip to content

feat: Support image generate model #1812

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion apps/application/flow/step_node/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@

from .document_extract_node import *
from .image_understand_step_node import *
from .image_generate_step_node import *

from .search_dataset_node import *
from .start_node import *

node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode,
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode,
BaseDocumentExtractNode,
BaseImageUnderstandNode, BaseFormNode]
BaseImageUnderstandNode, BaseImageGenerateNode, BaseFormNode]


def get_node(node_type):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# coding=utf-8

from .impl import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# coding=utf-8

from typing import Type

from rest_framework import serializers

from application.flow.i_step_node import INode, NodeResult
from common.util.field_message import ErrMessage


class ImageGenerateNodeSerializer(serializers.Serializer):
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))

prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词(正向)"))

negative_prompt = serializers.CharField(required=False, default='', error_messages=ErrMessage.char("提示词(负向)"))
# 多轮对话数量
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))

dialogue_type = serializers.CharField(required=True, error_messages=ErrMessage.char("对话存储类型"))

is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))

model_params_setting = serializers.JSONField(required=False, default=dict, error_messages=ErrMessage.json("模型参数设置"))


class IImageGenerateNode(INode):
type = 'image-generate-node'

def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return ImageGenerateNodeSerializer

def _run(self):
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)

def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
model_params_setting,
chat_record_id,
**kwargs) -> NodeResult:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# coding=utf-8

from .base_image_generate_node import BaseImageGenerateNode
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# coding=utf-8
from functools import reduce
from typing import List

import requests
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage

from application.flow.i_step_node import NodeResult
from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode
from common.util.common import bytes_to_uploaded_file
from dataset.serializers.file_serializers import FileSerializer
from setting.models_provider.tools import get_model_instance_by_model_user_id


class BaseImageGenerateNode(IImageGenerateNode):
def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer')
self.context['question'] = details.get('question')
self.answer_text = details.get('answer')

def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
model_params_setting,
chat_record_id,
**kwargs) -> NodeResult:
print(model_params_setting)
application = self.workflow_manage.work_flow_post_handler.chat_info.application
tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting)
history_message = self.get_history_message(history_chat_record, dialogue_number)
self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt)
self.context['question'] = question
message_list = self.generate_message_list(question, history_message)
self.context['message_list'] = message_list
self.context['dialogue_type'] = dialogue_type
print(message_list)
image_urls = tti_model.generate_image(question, negative_prompt)
# 保存图片
file_urls = []
for image_url in image_urls:
file_name = 'generated_image.png'
file = bytes_to_uploaded_file(requests.get(image_url).content, file_name)
meta = {
'debug': False if application.id else True,
'chat_id': chat_id,
'application_id': str(application.id) if application.id else None,
}
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
file_urls.append(file_url)
self.context['image_list'] = file_urls
answer = '\n'.join([f"![Image]({path})" for path in file_urls])
return NodeResult({'answer': answer, 'chat_model': tti_model, 'message_list': message_list,
'image': [{'file_id': path.split('/')[-1], 'file_url': path} for path in file_urls],
'history_message': history_message, 'question': question}, {})

def generate_history_ai_message(self, chat_record):
for val in chat_record.details.values():
if self.node.id == val['node_id'] and 'image_list' in val:
if val['dialogue_type'] == 'WORKFLOW':
return chat_record.get_ai_message()
return AIMessage(content=val['answer'])
return chat_record.get_ai_message()

def get_history_message(self, history_chat_record, dialogue_number):
start_index = len(history_chat_record) - dialogue_number
history_message = reduce(lambda x, y: [*x, *y], [
[self.generate_history_human_message(history_chat_record[index]),
self.generate_history_ai_message(history_chat_record[index])]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
return history_message

def generate_history_human_message(self, chat_record):

for data in chat_record.details.values():
if self.node.id == data['node_id'] and 'image_list' in data:
image_list = data['image_list']
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
return HumanMessage(content=chat_record.problem_text)
return HumanMessage(content=data['question'])
return HumanMessage(content=chat_record.problem_text)

def generate_prompt_question(self, prompt):
return self.workflow_manage.generate_prompt(prompt)

def generate_message_list(self, question: str, history_message):
return [
*history_message,
question
]

@staticmethod
def reset_message_list(message_list: List[BaseMessage], answer_text):
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
message
in
message_list]
result.append({'role': 'ai', 'content': answer_text})
return result

def get_details(self, index: int, **kwargs):
return {
'name': self.node.properties.get('stepName'),
"index": index,
'run_time': self.context.get('run_time'),
'history_message': [{'content': message.content, 'role': message.type} for message in
(self.context.get('history_message') if self.context.get(
'history_message') is not None else [])],
'question': self.context.get('question'),
'answer': self.context.get('answer'),
'type': self.node.type,
'message_tokens': self.context.get('message_tokens'),
'answer_tokens': self.context.get('answer_tokens'),
'status': self.status,
'err_message': self.err_message,
'image_list': self.context.get('image_list'),
'dialogue_type': self.context.get('dialogue_type')
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该代码存在以下不足和改进点:

  1. 编码格式:文件使用了UTF-8编码,但缺少BOM(Byte Order Mark)。可以在头部添加 # -*- coding: utf-8 -*-

  2. 函数命名:一些方法名如 generate_prompt_question, get_details 可以考虑是否与现有API或库中的函数冲突。

  3. 错误处理:没有对所有可能的异常进行适当的捕获和处理,可能会导致程序崩溃或数据丢失。

  4. 日志记录:在某些地方打印了调试信息,但在生产环境中应该移除这些打印语句或替换为更安全的日志机制。

  5. 依赖项管理:导入的部分模块(如 reduce)来自内置函数和标准库,无需额外安装。可以将这部分移除。

  6. 变量命名:部分变量取名为 detailscontext,可能导致混淆。推荐使用更具描述性的名称,例如 query_detailscurrent_context 等。

  7. 文档字符串:虽然有一些注释解释了每一步的作用,但整体上的文档编写还有待提高,特别是在类层次结构、参数说明等方面。

  8. 可读性:代码逻辑清晰,但对于大型项目来说,良好的分隔和缩进仍然很重要。

以下是修改后的一些示例:

修改后的版本

def generate_prompt(self, query):
    """根据查询生成提示问题"""
    return self.workflow_manage.generate_prompt(query)

@staticmethod
def process_generated_images(image_urls, app_id=None):
    """处理生成的图像并上传到服务器"""
    result = []
    
    for url in image_urls:
        response = requests.get(url)
        image_data = response.content
        
        file_name = f"generated_image_{len(result)}.png"
        uploaded_file = bytes_to_uploaded_file(image_data, file_name)
        
        metadata = {
            'debug': bool(app_id),
            'app_id': str(app_id) if app_id else None,
        }
        upload_response = FileSerializer(data={'file': uploaded_file, 'meta': metadata}).upload()
        result.append(upload_response.url)
    
    return result

def save_result(self, task_detail, run_time):
    """保存任务结果"""
    return NodeResult({
        'answer': task_detail.answer,
        'chat_model': task_detail.chat_model,
        'message_list': task_detail.message_list,
        'image': [{'file_id': path.rsplit('/', 1)[-1], 'file_url': path} for path in task_detail.image_urls],
        'history_message': task_detail.history_message,
        'question': task_detail.question
    }, '')

def handle_workflow_messages(task_detail):
    """处理流程消息"""
    for val in task_detail.details.values():
        if val['node_id'] == self.node.id and 'image_list' in val:
            if val['dialogue_type'] == 'WORKFLOW':
                return task_detail.ai_message
            elif not (val['dialogue_type'] == 'IMAGE_GENERATION'):
                return AIMessage(content=val['answer'].strip())

def build_history_messages(chat_records, count):
    """构建历史消息列表"""
    history_messages = []
    for record in chat_records[-count:]:
        history_messages.extend([
            HumanMessage(content=record.problem_text),
            self.handle_workflow_messages(record)
        ])
    return history_messages

class ImageGenerationWorkflow(BaseStepNode, iImageGenerateNode):
    """图像生成流程节点"""

    def execute(
            self,
            model_id,
            prompt,
            negative_prompt='',
            dialogue_number=1,
            dialect='DEFAULT',
            history=None,
            dialog_type='CHAT',
            flow_params_serializer=None,
            **kwargs
    ) -> TaskDetailResponseModel_v3_0:

        try:
            task_details = {}
            task_results = {}

            # 获取模型实例,并获取上下文。
            ... (原实现保持不变generated_images = task_details.get('generated_images')

            if generated_images:
                image_upload_urls = ImageGenerationWorkflow.process_generated_images(generated_images, app_id)

                node_result = self.save_result(
                    {'id': task_details.get('task_id', ''), 'question': task_details.get('prompt', ''),
                     'generated_answers': task_details.get('response', ''), 'image_list': image_upload_urls},
                    datetime.datetime.now()
                )

            return node_result

        except Exception as e:
            self.log_error(f"Exception during execution: {e}")
            raise

通过上述改进建议,可以使代码更加健壮、易于维护和扩展。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段代码存在一些问题和可以优化的地方:

1. save_context 方法中的 self.answer_text 赋值

execute 方法中赋值了 self.answer_text = details.get('answer'),但在 save_context 中又重新赋值了一次,这可能造成不必要的重复。

self.save_context(details, workflow_manage)

可以在调用 save_context 后删除 self.answer_text 这一行:

self.save_context(details, workflow_manage)
del self.answer_text

2. execute 方法中的图片处理逻辑

图片下载和上传操作使用了一个循环来遍历图像列表,并对每个 URL 进行请求。这种方法可能会导致性能问题,尤其是在大规模生成或需要频繁访问多个服务时。

推荐使用多线程或多进程来并行处理图像下载任务,以提高效率。

此外,考虑将请求头设置为允许 CORS 或使用代理服务器来避免直接访问外部资源造成的限制。

3. generate_message_list 方法中的历史消息拼接

在生成 message 列表时,会将 history 消息与新生成的消息相加。这可能导致 message 数量超过预期。

示例:

self.context['message_list'] = [
    *self.context['history_message'], 
    question
]

如果历史消息数量过多,可能会影响界面展示效果或 API 请求的响应时间。

可以通过增加条件判断来防止 message 数量过大:

if len(self.context['message_list']) + 1 <= max_messages_per_request:
    self.context['message_list'].append(question)
else:
    # 处理超过最大数量的情况
    pass

4. reset_message_list 方法

此方法用于重置消息列表,但它返回一个包含旧消息内容的新列表,而不是原地修改。

return [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for message in message_list]
result.append({'role': 'ai', 'content': answer_text})
return result

更好的做法是返回一个新的 list 并替换原来的 list 而不是覆盖其引用:

self.context['message_list'] = [
    {'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content}
    for message in message_list[:-1] + [AIMessage(content=answer_text)]
]

这些改进建议可以帮助你更好地管理和优化代码的质量。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有明显的不规范、潜在问题或优化建议。

Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class ImageUnderstandNodeSerializer(serializers.Serializer):

image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片"))

model_params_setting = serializers.JSONField(required=False, default=dict, error_messages=ErrMessage.json("模型参数设置"))



class IImageUnderstandNode(INode):
type = 'image-understand-node'
Expand All @@ -35,6 +38,7 @@ def _run(self):
return self.execute(image=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)

def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id,
model_params_setting,
chat_record_id,
image,
**kwargs) -> NodeResult:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode
from dataset.models import File
from setting.models_provider.tools import get_model_instance_by_model_user_id
from imghdr import what


def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
Expand Down Expand Up @@ -59,8 +60,9 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor

def file_id_to_base64(file_id: str):
file = QuerySet(File).filter(id=file_id).first()
base64_image = base64.b64encode(file.get_byte()).decode("utf-8")
return base64_image
file_bytes = file.get_byte()
base64_image = base64.b64encode(file_bytes).decode("utf-8")
return [base64_image, what(None, file_bytes.tobytes())]


class BaseImageUnderstandNode(IImageUnderstandNode):
Expand All @@ -70,14 +72,15 @@ def save_context(self, details, workflow_manage):
self.answer_text = details.get('answer')

def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id,
model_params_setting,
chat_record_id,
image,
**kwargs) -> NodeResult:
# 处理不正确的参数
if image is None or not isinstance(image, list):
image = []

image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
print(model_params_setting)
image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting)
# 执行详情中的历史消息不需要图片内容
history_message = self.get_history_message_for_details(history_chat_record, dialogue_number)
self.context['history_message'] = history_message
Expand Down Expand Up @@ -151,7 +154,7 @@ def generate_history_human_message(self, chat_record):
return HumanMessage(
content=[
{'type': 'text', 'text': data['question']},
*[{'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}} for
*[{'type': 'image_url', 'image_url': {'url': f'data:image/{base64_image[1]};base64,{base64_image[0]}'}} for
base64_image in image_base64_list]
])
return HumanMessage(content=chat_record.problem_text)
Expand All @@ -166,8 +169,10 @@ def generate_message_list(self, image_model, system: str, prompt: str, history_m
for img in image:
file_id = img['file_id']
file = QuerySet(File).filter(id=file_id).first()
base64_image = base64.b64encode(file.get_byte()).decode("utf-8")
images.append({'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}})
image_bytes = file.get_byte()
base64_image = base64.b64encode(image_bytes).decode("utf-8")
image_format = what(None, image_bytes.tobytes())
images.append({'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}})
messages = [HumanMessage(
content=[
{'type': 'text', 'text': self.workflow_manage.generate_prompt(prompt)},
Expand Down
2 changes: 1 addition & 1 deletion apps/application/flow/workflow_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwa


end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node',
'image-understand-node']
'image-understand-node', 'image-generate-node']


class Flow:
Expand Down
3 changes: 2 additions & 1 deletion apps/common/forms/text_input_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""
from typing import Dict

from common.forms import BaseLabel
from common.forms.base_field import BaseField, TriggerType


Expand All @@ -16,7 +17,7 @@ class TextInputField(BaseField):
文本输入框
"""

def __init__(self, label: str,
def __init__(self, label: str or BaseLabel,
required: bool = False,
default_value=None,
relation_show_field_dict: Dict = None,
Expand Down
25 changes: 25 additions & 0 deletions apps/common/util/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
"""
import hashlib
import importlib
import mimetypes
import io
from functools import reduce
from typing import Dict, List

from django.core.files.uploadedfile import InMemoryUploadedFile
from django.db.models import QuerySet

from ..exception.app_exception import AppApiException
Expand Down Expand Up @@ -111,3 +114,25 @@ def bulk_create_in_batches(model, data, batch_size=1000):
batch = data[i:i + batch_size]
model.objects.bulk_create(batch)


def bytes_to_uploaded_file(file_bytes, file_name="file.txt"):
content_type, _ = mimetypes.guess_type(file_name)
if content_type is None:
# 如果未能识别,设置为默认的二进制文件类型
content_type = "application/octet-stream"
# 创建一个内存中的字节流对象
file_stream = io.BytesIO(file_bytes)

# 获取文件大小
file_size = len(file_bytes)

# 创建 InMemoryUploadedFile 对象
uploaded_file = InMemoryUploadedFile(
file=file_stream,
field_name=None,
name=file_name,
content_type=content_type,
size=file_size,
charset=None,
)
return uploaded_file
1 change: 1 addition & 0 deletions apps/setting/models_provider/base_model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class ModelTypeConst(Enum):
STT = {'code': 'STT', 'message': '语音识别'}
TTS = {'code': 'TTS', 'message': '语音合成'}
IMAGE = {'code': 'IMAGE', 'message': '图片理解'}
TTI = {'code': 'TTI', 'message': '图片生成'}
RERANKER = {'code': 'RERANKER', 'message': '重排模型'}


Expand Down
14 changes: 14 additions & 0 deletions apps/setting/models_provider/impl/base_tti.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# coding=utf-8
from abc import abstractmethod

from pydantic import BaseModel


class BaseTextToImage(BaseModel):
@abstractmethod
def check_auth(self):
pass

@abstractmethod
def generate_image(self, prompt: str, negative_prompt: str = None):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,26 @@

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode

class OpenAIImageModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=0.7,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)

max_tokens = forms.SliderField(
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
required=True, default_value=800,
_min=1,
_max=100000,
_step=1,
precision=0)



class OpenAIImageModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名', required=True)
Expand Down Expand Up @@ -45,4 +62,4 @@ def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}

def get_model_params_setting_form(self, model_name):
pass
return OpenAIImageModelParams()
Loading
Loading