From 977e68f86b641b60773c166945e3558fb4ccab7d Mon Sep 17 00:00:00 2001 From: CaptainB Date: Mon, 9 Dec 2024 11:00:51 +0800 Subject: [PATCH 1/5] feat: Support image generate model --- apps/application/flow/step_node/__init__.py | 3 +- .../image_generate_step_node/__init__.py | 3 + .../i_image_generate_node.py | 37 ++ .../image_generate_step_node/impl/__init__.py | 3 + .../impl/base_image_generate_node.py | 101 ++++++ apps/application/flow/workflow_manage.py | 2 +- apps/common/util/common.py | 25 ++ .../models_provider/base_model_provider.py | 1 + apps/setting/models_provider/impl/base_tti.py | 14 + .../openai_model_provider/credential/tti.py | 47 +++ .../impl/openai_model_provider/model/tti.py | 67 ++++ .../openai_model_provider.py | 33 +- .../qwen_model_provider/credential/tti.py | 70 ++++ .../impl/qwen_model_provider/model/tti.py | 64 ++++ .../qwen_model_provider.py | 23 +- .../tencent_model_provider/credential/tti.py | 108 ++++++ .../impl/tencent_model_provider/model/tti.py | 98 ++++++ .../tencent_model_provider.py | 10 + .../zhipu_model_provider/credential/tti.py | 44 +++ .../impl/zhipu_model_provider/model/tti.py | 73 ++++ .../zhipu_model_provider.py | 13 + ui/src/api/application.ts | 8 + .../ai-chat/ExecutionDetailDialog.vue | 60 ++++ ui/src/enums/model.ts | 1 + ui/src/enums/workflow.ts | 1 + .../component/SelectProviderDialog.vue | 1 + ui/src/views/template/index.vue | 1 + ui/src/workflow/common/data.ts | 26 +- ui/src/workflow/common/validate.ts | 1 + .../icons/image-generate-node-icon.vue | 6 + ui/src/workflow/nodes/image-generate/index.ts | 14 + .../workflow/nodes/image-generate/index.vue | 323 ++++++++++++++++++ 32 files changed, 1265 insertions(+), 16 deletions(-) create mode 100644 apps/application/flow/step_node/image_generate_step_node/__init__.py create mode 100644 apps/application/flow/step_node/image_generate_step_node/i_image_generate_node.py create mode 100644 apps/application/flow/step_node/image_generate_step_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py create mode 100644 apps/setting/models_provider/impl/base_tti.py create mode 100644 apps/setting/models_provider/impl/openai_model_provider/credential/tti.py create mode 100644 apps/setting/models_provider/impl/openai_model_provider/model/tti.py create mode 100644 apps/setting/models_provider/impl/qwen_model_provider/credential/tti.py create mode 100644 apps/setting/models_provider/impl/qwen_model_provider/model/tti.py create mode 100644 apps/setting/models_provider/impl/tencent_model_provider/credential/tti.py create mode 100644 apps/setting/models_provider/impl/tencent_model_provider/model/tti.py create mode 100644 apps/setting/models_provider/impl/zhipu_model_provider/credential/tti.py create mode 100644 apps/setting/models_provider/impl/zhipu_model_provider/model/tti.py create mode 100644 ui/src/workflow/icons/image-generate-node-icon.vue create mode 100644 ui/src/workflow/nodes/image-generate/index.ts create mode 100644 ui/src/workflow/nodes/image-generate/index.vue diff --git a/apps/application/flow/step_node/__init__.py b/apps/application/flow/step_node/__init__.py index cd8b08a974a..535560b5fcd 100644 --- a/apps/application/flow/step_node/__init__.py +++ b/apps/application/flow/step_node/__init__.py @@ -18,6 +18,7 @@ from .document_extract_node import * from .image_understand_step_node import * +from .image_generate_step_node import * from .search_dataset_node import * from .start_node import * @@ -25,7 +26,7 @@ node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode, BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode, BaseDocumentExtractNode, - BaseImageUnderstandNode, BaseFormNode] + BaseImageUnderstandNode, BaseImageGenerateNode, BaseFormNode] def get_node(node_type): diff --git a/apps/application/flow/step_node/image_generate_step_node/__init__.py b/apps/application/flow/step_node/image_generate_step_node/__init__.py new file mode 100644 index 00000000000..f3feecc9ce2 --- /dev/null +++ b/apps/application/flow/step_node/image_generate_step_node/__init__.py @@ -0,0 +1,3 @@ +# coding=utf-8 + +from .impl import * diff --git a/apps/application/flow/step_node/image_generate_step_node/i_image_generate_node.py b/apps/application/flow/step_node/image_generate_step_node/i_image_generate_node.py new file mode 100644 index 00000000000..48a0840e7cb --- /dev/null +++ b/apps/application/flow/step_node/image_generate_step_node/i_image_generate_node.py @@ -0,0 +1,37 @@ +# coding=utf-8 + +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage + + +class ImageGenerateNodeSerializer(serializers.Serializer): + model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id")) + + prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词(正向)")) + + negative_prompt = serializers.CharField(required=False, default='', error_messages=ErrMessage.char("提示词(负向)")) + # 多轮对话数量 + dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) + + dialogue_type = serializers.CharField(required=True, error_messages=ErrMessage.char("对话存储类型")) + + is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) + + +class IImageGenerateNode(INode): + type = 'image-generate-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return ImageGenerateNodeSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id, + chat_record_id, + **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/image_generate_step_node/impl/__init__.py b/apps/application/flow/step_node/image_generate_step_node/impl/__init__.py new file mode 100644 index 00000000000..14a21a9159c --- /dev/null +++ b/apps/application/flow/step_node/image_generate_step_node/impl/__init__.py @@ -0,0 +1,3 @@ +# coding=utf-8 + +from .base_image_generate_node import BaseImageGenerateNode diff --git a/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py b/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py new file mode 100644 index 00000000000..2933c46ec8c --- /dev/null +++ b/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py @@ -0,0 +1,101 @@ +# coding=utf-8 +from functools import reduce +from typing import List + +from langchain_core.messages import BaseMessage, HumanMessage, AIMessage + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode +from setting.models_provider.tools import get_model_instance_by_model_user_id + + +class BaseImageGenerateNode(IImageGenerateNode): + def save_context(self, details, workflow_manage): + self.context['answer'] = details.get('answer') + self.context['question'] = details.get('question') + self.answer_text = details.get('answer') + + def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id, + chat_record_id, + **kwargs) -> NodeResult: + + tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id')) + history_message = self.get_history_message(history_chat_record, dialogue_number) + self.context['history_message'] = history_message + question = self.generate_prompt_question(prompt) + self.context['question'] = question + message_list = self.generate_message_list(question, history_message) + self.context['message_list'] = message_list + self.context['dialogue_type'] = dialogue_type + print(message_list) + print(negative_prompt) + image_urls = tti_model.generate_image(question, negative_prompt) + self.context['image_list'] = image_urls + answer = '\n'.join([f"![Image]({path})" for path in image_urls]) + return NodeResult({'answer': answer, 'chat_model': tti_model, 'message_list': message_list, + 'image': [{'file_id': path.split('/')[-1], 'file_url': path} for path in file_urls], + 'history_message': history_message, 'question': question}, {}) + + def generate_history_ai_message(self, chat_record): + for val in chat_record.details.values(): + if self.node.id == val['node_id'] and 'image_list' in val: + if val['dialogue_type'] == 'WORKFLOW': + return chat_record.get_ai_message() + return AIMessage(content=val['answer']) + return chat_record.get_ai_message() + + def get_history_message(self, history_chat_record, dialogue_number): + start_index = len(history_chat_record) - dialogue_number + history_message = reduce(lambda x, y: [*x, *y], [ + [self.generate_history_human_message(history_chat_record[index]), + self.generate_history_ai_message(history_chat_record[index])] + for index in + range(start_index if start_index > 0 else 0, len(history_chat_record))], []) + return history_message + + def generate_history_human_message(self, chat_record): + + for data in chat_record.details.values(): + if self.node.id == data['node_id'] and 'image_list' in data: + image_list = data['image_list'] + if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW': + return HumanMessage(content=chat_record.problem_text) + return HumanMessage(content=data['question']) + return HumanMessage(content=chat_record.problem_text) + + def generate_prompt_question(self, prompt): + return self.workflow_manage.generate_prompt(prompt) + + def generate_message_list(self, question: str, history_message): + return [ + *history_message, + question + ] + + @staticmethod + def reset_message_list(message_list: List[BaseMessage], answer_text): + result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for + message + in + message_list] + result.append({'role': 'ai', 'content': answer_text}) + return result + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'history_message': [{'content': message.content, 'role': message.type} for message in + (self.context.get('history_message') if self.context.get( + 'history_message') is not None else [])], + 'question': self.context.get('question'), + 'answer': self.context.get('answer'), + 'type': self.node.type, + 'message_tokens': self.context.get('message_tokens'), + 'answer_tokens': self.context.get('answer_tokens'), + 'status': self.status, + 'err_message': self.err_message, + 'image_list': self.context.get('image_list'), + 'dialogue_type': self.context.get('dialogue_type') + } diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index c62719b6e3f..4a8e0b92279 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -54,7 +54,7 @@ def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwa end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node', - 'image-understand-node'] + 'image-understand-node', 'image-generate-node'] class Flow: diff --git a/apps/common/util/common.py b/apps/common/util/common.py index 8571c91e33c..230727622a7 100644 --- a/apps/common/util/common.py +++ b/apps/common/util/common.py @@ -8,9 +8,12 @@ """ import hashlib import importlib +import mimetypes +import io from functools import reduce from typing import Dict, List +from django.core.files.uploadedfile import InMemoryUploadedFile from django.db.models import QuerySet from ..exception.app_exception import AppApiException @@ -111,3 +114,25 @@ def bulk_create_in_batches(model, data, batch_size=1000): batch = data[i:i + batch_size] model.objects.bulk_create(batch) + +def bytes_to_uploaded_file(file_bytes, file_name="file.txt"): + content_type, _ = mimetypes.guess_type(file_name) + if content_type is None: + # 如果未能识别,设置为默认的二进制文件类型 + content_type = "application/octet-stream" + # 创建一个内存中的字节流对象 + file_stream = io.BytesIO(file_bytes) + + # 获取文件大小 + file_size = len(file_bytes) + + # 创建 InMemoryUploadedFile 对象 + uploaded_file = InMemoryUploadedFile( + file=file_stream, + field_name=None, + name=file_name, + content_type=content_type, + size=file_size, + charset=None, + ) + return uploaded_file diff --git a/apps/setting/models_provider/base_model_provider.py b/apps/setting/models_provider/base_model_provider.py index c86068c68a8..39a759a6548 100644 --- a/apps/setting/models_provider/base_model_provider.py +++ b/apps/setting/models_provider/base_model_provider.py @@ -150,6 +150,7 @@ class ModelTypeConst(Enum): STT = {'code': 'STT', 'message': '语音识别'} TTS = {'code': 'TTS', 'message': '语音合成'} IMAGE = {'code': 'IMAGE', 'message': '图片理解'} + TTI = {'code': 'TTI', 'message': '图片生成'} RERANKER = {'code': 'RERANKER', 'message': '重排模型'} diff --git a/apps/setting/models_provider/impl/base_tti.py b/apps/setting/models_provider/impl/base_tti.py new file mode 100644 index 00000000000..5e34d12cd11 --- /dev/null +++ b/apps/setting/models_provider/impl/base_tti.py @@ -0,0 +1,14 @@ +# coding=utf-8 +from abc import abstractmethod + +from pydantic import BaseModel + + +class BaseTextToImage(BaseModel): + @abstractmethod + def check_auth(self): + pass + + @abstractmethod + def generate_image(self, prompt: str, negative_prompt: str = None): + pass diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/tti.py b/apps/setting/models_provider/impl/openai_model_provider/credential/tti.py new file mode 100644 index 00000000000..480c20a8c02 --- /dev/null +++ b/apps/setting/models_provider/impl/openai_model_provider/credential/tti.py @@ -0,0 +1,47 @@ +# coding=utf-8 +import base64 +import os +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class OpenAITextToImageModelCredential(BaseForm, BaseModelCredential): + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + res = model.check_auth() + print(res) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + pass diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/tti.py b/apps/setting/models_provider/impl/openai_model_provider/model/tti.py new file mode 100644 index 00000000000..08ceede0235 --- /dev/null +++ b/apps/setting/models_provider/impl/openai_model_provider/model/tti.py @@ -0,0 +1,67 @@ +from typing import Dict + +import requests +from langchain_core.messages import HumanMessage +from langchain_openai import ChatOpenAI +from openai import OpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from common.util.common import bytes_to_uploaded_file +from dataset.serializers.file_serializers import FileSerializer +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_tti import BaseTextToImage + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class OpenAITextToImage(MaxKBBaseModel, BaseTextToImage): + api_base: str + api_key: str + model: str + params: dict + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get('api_key') + self.api_base = kwargs.get('api_base') + self.model = kwargs.get('model') + self.params = kwargs.get('params') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {'params': {}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value + return OpenAITextToImage( + model=model_name, + api_base=model_credential.get('api_base'), + api_key=model_credential.get('api_key'), + **optional_params, + ) + + def check_auth(self): + chat = OpenAI(api_key=self.api_key, base_url=self.api_base) + response_list = chat.models.with_raw_response.list() + + # self.generate_image('生成一个小猫图片') + + def generate_image(self, prompt: str, negative_prompt: str = None): + + chat = OpenAI(api_key=self.api_key, base_url=self.api_base) + res = chat.images.generate(model=self.model, prompt=prompt, **self.params) + + file_urls = [] + for content in res.data: + url = content.url + print(url) + file_name = 'generated_image.png' + file = bytes_to_uploaded_file(requests.get(url).content, file_name) + meta = {'debug': True} + file_url = FileSerializer(data={'file': file, 'meta': meta}).upload() + file_urls.append(file_url) + + return file_urls diff --git a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py index 974599cc89c..be659291efc 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py +++ b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py @@ -15,11 +15,13 @@ from setting.models_provider.impl.openai_model_provider.credential.image import OpenAIImageModelCredential from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential from setting.models_provider.impl.openai_model_provider.credential.stt import OpenAISTTModelCredential +from setting.models_provider.impl.openai_model_provider.credential.tti import OpenAITextToImageModelCredential from setting.models_provider.impl.openai_model_provider.credential.tts import OpenAITTSModelCredential from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel from setting.models_provider.impl.openai_model_provider.model.image import OpenAIImage from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel from setting.models_provider.impl.openai_model_provider.model.stt import OpenAISpeechToText +from setting.models_provider.impl.openai_model_provider.model.tti import OpenAITextToImage from setting.models_provider.impl.openai_model_provider.model.tts import OpenAITextToSpeech from smartdoc.conf import PROJECT_DIR @@ -27,6 +29,7 @@ openai_stt_model_credential = OpenAISTTModelCredential() openai_tts_model_credential = OpenAITTSModelCredential() openai_image_model_credential = OpenAIImageModelCredential() +openai_tti_model_credential = OpenAITextToImageModelCredential() model_info_list = [ ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, openai_llm_model_credential, OpenAIChatModel @@ -37,8 +40,8 @@ ModelTypeConst.LLM, openai_llm_model_credential, OpenAIChatModel), ModelInfo('gpt-4o-mini', '最新的gpt-4o-mini,比gpt-4o更便宜、更快,随OpenAI调整而更新', - ModelTypeConst.LLM, openai_llm_model_credential, - OpenAIChatModel), + ModelTypeConst.LLM, openai_llm_model_credential, + OpenAIChatModel), ModelInfo('gpt-4-turbo', '最新的gpt-4-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, openai_llm_model_credential, OpenAIChatModel), @@ -100,11 +103,27 @@ OpenAIImage), ] -model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( - ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, - openai_llm_model_credential, OpenAIChatModel - )).append_model_info_list(model_info_embedding_list).append_default_model_info( - model_info_embedding_list[0]).append_model_info_list(model_info_image_list).build() +model_info_tti_list = [ + ModelInfo('dall-e-2', '', + ModelTypeConst.TTI, openai_tti_model_credential, + OpenAITextToImage), + ModelInfo('dall-e-3', '', + ModelTypeConst.TTI, openai_tti_model_credential, + OpenAITextToImage), +] + +model_info_manage = ( + ModelInfoManage.builder() + .append_model_info_list(model_info_list) + .append_default_model_info(ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, + openai_llm_model_credential, OpenAIChatModel + )) + .append_model_info_list(model_info_embedding_list) + .append_default_model_info(model_info_embedding_list[0]) + .append_model_info_list(model_info_image_list) + .append_model_info_list(model_info_tti_list) + .build() +) class OpenAIModelProvider(IModelProvider): diff --git a/apps/setting/models_provider/impl/qwen_model_provider/credential/tti.py b/apps/setting/models_provider/impl/qwen_model_provider/credential/tti.py new file mode 100644 index 00000000000..dc4779f6b05 --- /dev/null +++ b/apps/setting/models_provider/impl/qwen_model_provider/credential/tti.py @@ -0,0 +1,70 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/11 18:41 + @desc: +""" +import base64 +import os +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class QwenModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=1.0, + _min=0.1, + _max=1.9, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=800, + _min=1, + _max=100000, + _step=1, + precision=0) + + +class QwenTextToImageModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + res = model.check_auth() + print(res) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_key = forms.PasswordInputField('API Key', required=True) + + def get_model_params_setting_form(self, model_name): + return QwenModelParams() diff --git a/apps/setting/models_provider/impl/qwen_model_provider/model/tti.py b/apps/setting/models_provider/impl/qwen_model_provider/model/tti.py new file mode 100644 index 00000000000..593171dfe59 --- /dev/null +++ b/apps/setting/models_provider/impl/qwen_model_provider/model/tti.py @@ -0,0 +1,64 @@ +# coding=utf-8 +from http import HTTPStatus +from pathlib import PurePosixPath +from typing import Dict +from urllib.parse import unquote, urlparse + +import requests +from dashscope import ImageSynthesis +from langchain_community.chat_models import ChatTongyi +from langchain_core.messages import HumanMessage + +from common.util.common import bytes_to_uploaded_file +from dataset.serializers.file_serializers import FileSerializer +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_tti import BaseTextToImage + + +class QwenTextToImageModel(MaxKBBaseModel, BaseTextToImage): + api_key: str + model_name: str + params: dict + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get('api_key') + self.model_name = kwargs.get('model_name') + self.params = kwargs.get('params') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {'params': {}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value + chat_tong_yi = QwenTextToImageModel( + model_name=model_name, + api_key=model_credential.get('api_key'), + **optional_params, + ) + return chat_tong_yi + + def check_auth(self): + chat = ChatTongyi(api_key=self.api_key, model_name='qwen-max') + chat.invoke([HumanMessage([{"type": "text", "text": "你好"}])]) + + def generate_image(self, prompt: str, negative_prompt: str = None): + # api_base='https://dashscope.aliyuncs.com/compatible-mode/v1', + rsp = ImageSynthesis.call(api_key=self.api_key, + model=self.model_name, + prompt=prompt, + negative_prompt=negative_prompt, + **self.params) + file_urls = [] + if rsp.status_code == HTTPStatus.OK: + for result in rsp.output.results: + file_name = PurePosixPath(unquote(urlparse(result.url).path)).parts[-1] + file = bytes_to_uploaded_file(requests.get(result.url).content, file_name) + meta = {'debug': True} + file_url = FileSerializer(data={'file': file, 'meta': meta}).upload() + file_urls.append(file_url) + else: + print('sync_call Failed, status_code: %s, code: %s, message: %s' % + (rsp.status_code, rsp.code, rsp.message)) + return file_urls diff --git a/apps/setting/models_provider/impl/qwen_model_provider/qwen_model_provider.py b/apps/setting/models_provider/impl/qwen_model_provider/qwen_model_provider.py index 0a24ca35ce8..fc2506a9c59 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/qwen_model_provider.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/qwen_model_provider.py @@ -13,13 +13,16 @@ ModelInfoManage from setting.models_provider.impl.qwen_model_provider.credential.image import QwenVLModelCredential from setting.models_provider.impl.qwen_model_provider.credential.llm import OpenAILLMModelCredential +from setting.models_provider.impl.qwen_model_provider.credential.tti import QwenTextToImageModelCredential from setting.models_provider.impl.qwen_model_provider.model.image import QwenVLChatModel from setting.models_provider.impl.qwen_model_provider.model.llm import QwenChatModel +from setting.models_provider.impl.qwen_model_provider.model.tti import QwenTextToImageModel from smartdoc.conf import PROJECT_DIR qwen_model_credential = OpenAILLMModelCredential() qwenvl_model_credential = QwenVLModelCredential() +qwentti_model_credential = QwenTextToImageModelCredential() module_info_list = [ ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel), @@ -31,13 +34,21 @@ ModelInfo('qwen-vl-max-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel), ModelInfo('qwen-vl-plus-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel), ] +module_info_tti_list = [ + ModelInfo('wanx-v1', + '通义万相-文本生成图像大模型,支持中英文双语输入,支持输入参考图片进行参考内容或者参考风格迁移,重点风格包括但不限于水彩、油画、中国画、素描、扁平插画、二次元、3D卡通。', + ModelTypeConst.TTI, qwentti_model_credential, QwenTextToImageModel), +] -model_info_manage = (ModelInfoManage.builder() - .append_model_info_list(module_info_list) - .append_default_model_info( - ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel)) - .append_model_info_list(module_info_vl_list) - .build()) +model_info_manage = ( + ModelInfoManage.builder() + .append_model_info_list(module_info_list) + .append_default_model_info( + ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel)) + .append_model_info_list(module_info_vl_list) + .append_model_info_list(module_info_tti_list) + .build() +) class QwenModelProvider(IModelProvider): diff --git a/apps/setting/models_provider/impl/tencent_model_provider/credential/tti.py b/apps/setting/models_provider/impl/tencent_model_provider/credential/tti.py new file mode 100644 index 00000000000..1b6183e8a11 --- /dev/null +++ b/apps/setting/models_provider/impl/tencent_model_provider/credential/tti.py @@ -0,0 +1,108 @@ +# coding=utf-8 +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class TencentTTIModelParams(BaseForm): + Style = forms.SingleSelect( + TooltipLabel('绘画风格', '不传默认使用201(日系动漫风格)'), + required=True, + default_value='201', + option_list=[ + {'value': '000', 'label': '不限定风格'}, + {'value': '101', 'label': '水墨画'}, + {'value': '102', 'label': '概念艺术'}, + {'value': '103', 'label': '油画1'}, + {'value': '118', 'label': '油画2(梵高)'}, + {'value': '104', 'label': '水彩画'}, + {'value': '105', 'label': '像素画'}, + {'value': '106', 'label': '厚涂风格'}, + {'value': '107', 'label': '插图'}, + {'value': '108', 'label': '剪纸风格'}, + {'value': '109', 'label': '印象派1(莫奈)'}, + {'value': '119', 'label': '印象派2'}, + {'value': '110', 'label': '2.5D'}, + {'value': '111', 'label': '古典肖像画'}, + {'value': '112', 'label': '黑白素描画'}, + {'value': '113', 'label': '赛博朋克'}, + {'value': '114', 'label': '科幻风格'}, + {'value': '115', 'label': '暗黑风格'}, + {'value': '116', 'label': '3D'}, + {'value': '117', 'label': '蒸汽波'}, + {'value': '201', 'label': '日系动漫'}, + {'value': '202', 'label': '怪兽风格'}, + {'value': '203', 'label': '唯美古风'}, + {'value': '204', 'label': '复古动漫'}, + {'value': '301', 'label': '游戏卡通手绘'}, + {'value': '401', 'label': '通用写实风格'}, + ], + value_field='value', + text_field='label' + ) + + Resolution = forms.SingleSelect( + TooltipLabel('生成图分辨率', '不传默认使用768:768。'), + required=True, + default_value='768:768', + option_list=[ + {'value': '768:768', 'label': '768:768(1:1)'}, + {'value': '768:1024', 'label': '768:1024(3:4)'}, + {'value': '1024:768', 'label': '1024:768(4:3)'}, + {'value': '1024:1024', 'label': '1024:1024(1:1)'}, + {'value': '720:1280', 'label': '720:1280(9:16)'}, + {'value': '1280:720', 'label': '1280:720(16:9)'}, + {'value': '768:1280', 'label': '768:1280(3:5)'}, + {'value': '1280:768', 'label': '1280:768(5:3)'}, + {'value': '1080:1920', 'label': '1080:1920(9:16)'}, + {'value': '1920:1080', 'label': '1920:1080(16:9)'}, + ], + value_field='value', + text_field='label' + ) + + +class TencentTTIModelCredential(BaseForm, BaseModelCredential): + REQUIRED_FIELDS = ['hunyuan_secret_id', 'hunyuan_secret_key'] + + @classmethod + def _validate_model_type(cls, model_type, provider, raise_exception=False): + if not any(mt['value'] == model_type for mt in provider.get_model_type_list()): + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + return False + return True + + @classmethod + def _validate_credential_fields(cls, model_credential, raise_exception=False): + missing_keys = [key for key in cls.REQUIRED_FIELDS if key not in model_credential] + if missing_keys: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{", ".join(missing_keys)} 字段为必填字段') + return False + return True + + def is_valid(self, model_type, model_name, model_credential, provider, raise_exception=False): + if not (self._validate_model_type(model_type, provider, raise_exception) and + self._validate_credential_fields(model_credential, raise_exception)): + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + return False + return True + + def encryption_dict(self, model): + return {**model, 'hunyuan_secret_key': super().encryption(model.get('hunyuan_secret_key', ''))} + + hunyuan_secret_id = forms.PasswordInputField('SecretId', required=True) + hunyuan_secret_key = forms.PasswordInputField('SecretKey', required=True) + + def get_model_params_setting_form(self, model_name): + return TencentTTIModelParams() diff --git a/apps/setting/models_provider/impl/tencent_model_provider/model/tti.py b/apps/setting/models_provider/impl/tencent_model_provider/model/tti.py new file mode 100644 index 00000000000..e2a9976f012 --- /dev/null +++ b/apps/setting/models_provider/impl/tencent_model_provider/model/tti.py @@ -0,0 +1,98 @@ +# coding=utf-8 + +import json +from typing import Dict + +import requests +from tencentcloud.common import credential +from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException +from tencentcloud.common.profile.client_profile import ClientProfile +from tencentcloud.common.profile.http_profile import HttpProfile +from tencentcloud.hunyuan.v20230901 import hunyuan_client, models + +from common.util.common import bytes_to_uploaded_file +from dataset.serializers.file_serializers import FileSerializer +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_tti import BaseTextToImage +from setting.models_provider.impl.tencent_model_provider.model.hunyuan import ChatHunyuan + + +class TencentTextToImageModel(MaxKBBaseModel, BaseTextToImage): + hunyuan_secret_id: str + hunyuan_secret_key: str + model: str + params: dict + + @staticmethod + def is_cache_model(): + return False + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.hunyuan_secret_id = kwargs.get('hunyuan_secret_id') + self.hunyuan_secret_key = kwargs.get('hunyuan_secret_key') + self.model = kwargs.get('model_name') + self.params = kwargs.get('params') + + @staticmethod + def new_instance(model_type: str, model_name: str, model_credential: Dict[str, object], + **model_kwargs) -> 'TencentTextToImageModel': + optional_params = {'params': {'Style': '201', 'Resolution': '768:768'}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value + return TencentTextToImageModel( + model=model_name, + hunyuan_secret_id=model_credential.get('hunyuan_secret_id'), + hunyuan_secret_key=model_credential.get('hunyuan_secret_key'), + **optional_params + ) + + def check_auth(self): + chat = ChatHunyuan(hunyuan_app_id='111111', + hunyuan_secret_id=self.hunyuan_secret_id, + hunyuan_secret_key=self.hunyuan_secret_key, + model="hunyuan-standard") + res = chat.invoke('你好') + # print(res) + + def generate_image(self, prompt: str, negative_prompt: str = None): + try: + # 实例化一个认证对象,入参需要传入腾讯云账户 SecretId 和 SecretKey,此处还需注意密钥对的保密 + # 代码泄露可能会导致 SecretId 和 SecretKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考,建议采用更安全的方式来使用密钥,请参见:https://cloud.tencent.com/document/product/1278/85305 + # 密钥可前往官网控制台 https://console.cloud.tencent.com/cam/capi 进行获取 + cred = credential.Credential(self.hunyuan_secret_id, self.hunyuan_secret_key) + # 实例化一个http选项,可选的,没有特殊需求可以跳过 + httpProfile = HttpProfile() + httpProfile.endpoint = "hunyuan.tencentcloudapi.com" + + # 实例化一个client选项,可选的,没有特殊需求可以跳过 + clientProfile = ClientProfile() + clientProfile.httpProfile = httpProfile + # 实例化要请求产品的client对象,clientProfile是可选的 + client = hunyuan_client.HunyuanClient(cred, "ap-guangzhou", clientProfile) + + # 实例化一个请求对象,每个接口都会对应一个request对象 + req = models.TextToImageLiteRequest() + params = { + "Prompt": prompt, + "NegativePrompt": negative_prompt, + "RspImgType": "url", + **self.params + } + req.from_json_string(json.dumps(params)) + + # 返回的resp是一个TextToImageLiteResponse的实例,与请求对象对应 + resp = client.TextToImageLite(req) + # 输出json格式的字符串回包 + print(resp.to_json_string()) + file_urls = [] + file_name = 'generated_image.png' + file = bytes_to_uploaded_file(requests.get(resp.ResultImage).content, file_name) + meta = {'debug': True} + file_url = FileSerializer(data={'file': file, 'meta': meta}).upload() + file_urls.append(file_url) + return file_urls + except TencentCloudSDKException as err: + print(err) + diff --git a/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py b/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py index b37809eb582..553d5e3834b 100644 --- a/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py +++ b/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py @@ -9,9 +9,11 @@ from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential from setting.models_provider.impl.tencent_model_provider.credential.image import TencentVisionModelCredential from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential +from setting.models_provider.impl.tencent_model_provider.credential.tti import TencentTTIModelCredential from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel from setting.models_provider.impl.tencent_model_provider.model.image import TencentVision from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel +from setting.models_provider.impl.tencent_model_provider.model.tti import TencentTextToImageModel from smartdoc.conf import PROJECT_DIR @@ -87,11 +89,19 @@ def _initialize_model_info(): TencentVisionModelCredential, TencentVision)] + model_info_tti_list = [_create_model_info( + 'hunyuan-dit', + '混元生图模型', + ModelTypeConst.TTI, + TencentTTIModelCredential, + TencentTextToImageModel)] + model_info_manage = ModelInfoManage.builder() \ .append_model_info_list(model_info_list) \ .append_model_info_list(model_info_embedding_list) \ .append_model_info_list(model_info_vision_list) \ + .append_model_info_list(model_info_tti_list) \ .append_default_model_info(model_info_list[0]) \ .build() diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/credential/tti.py b/apps/setting/models_provider/impl/zhipu_model_provider/credential/tti.py new file mode 100644 index 00000000000..4ac28c1b529 --- /dev/null +++ b/apps/setting/models_provider/impl/zhipu_model_provider/credential/tti.py @@ -0,0 +1,44 @@ +# coding=utf-8 +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class ZhiPuTextToImageModelCredential(BaseForm, BaseModelCredential): + api_key = forms.PasswordInputField('API Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + res = model.check_auth() + print(res) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + pass diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/model/tti.py b/apps/setting/models_provider/impl/zhipu_model_provider/model/tti.py new file mode 100644 index 00000000000..c537669ce50 --- /dev/null +++ b/apps/setting/models_provider/impl/zhipu_model_provider/model/tti.py @@ -0,0 +1,73 @@ +from typing import Dict + +import requests +from langchain_community.chat_models import ChatZhipuAI +from langchain_core.messages import HumanMessage +from zhipuai import ZhipuAI + +from common.config.tokenizer_manage_config import TokenizerManage +from common.util.common import bytes_to_uploaded_file +from dataset.serializers.file_serializers import FileSerializer +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_tti import BaseTextToImage + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class ZhiPuTextToImage(MaxKBBaseModel, BaseTextToImage): + api_key: str + model: str + params: dict + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get('api_key') + self.model = kwargs.get('model') + self.params = kwargs.get('params') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {'params': {}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value + return ZhiPuTextToImage( + model=model_name, + api_key=model_credential.get('api_key'), + **optional_params, + ) + + def check_auth(self): + chat = ChatZhipuAI( + zhipuai_api_key=self.api_key, + model_name=self.model, + ) + chat.invoke([HumanMessage([{"type": "text", "text": "你好"}])]) + + # self.generate_image('生成一个小猫图片') + + def generate_image(self, prompt: str, negative_prompt: str = None): + # chat = ChatZhipuAI( + # zhipuai_api_key=self.api_key, + # model_name=self.model, + # ) + chat = ZhipuAI(api_key=self.api_key) + response = chat.images.generations( + model=self.model, # 填写需要调用的模型编码 + prompt=prompt, # 填写需要生成图片的文本 + **self.params # 填写额外参数 + ) + file_urls = [] + for content in response.data: + url = content['url'] + print(url) + file_name = url.split('/')[-1] + file = bytes_to_uploaded_file(requests.get(url).content, file_name) + meta = {'debug': True} + file_url = FileSerializer(data={'file': file, 'meta': meta}).upload() + file_urls.append(file_url) + + return file_urls diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py b/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py index b24c8dd0d86..0fd0b3f2524 100644 --- a/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py +++ b/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py @@ -13,12 +13,15 @@ ModelInfoManage from setting.models_provider.impl.zhipu_model_provider.credential.image import ZhiPuImageModelCredential from setting.models_provider.impl.zhipu_model_provider.credential.llm import ZhiPuLLMModelCredential +from setting.models_provider.impl.zhipu_model_provider.credential.tti import ZhiPuTextToImageModelCredential from setting.models_provider.impl.zhipu_model_provider.model.image import ZhiPuImage from setting.models_provider.impl.zhipu_model_provider.model.llm import ZhipuChatModel +from setting.models_provider.impl.zhipu_model_provider.model.tti import ZhiPuTextToImage from smartdoc.conf import PROJECT_DIR qwen_model_credential = ZhiPuLLMModelCredential() zhipu_image_model_credential = ZhiPuImageModelCredential() +zhipu_tti_model_credential = ZhiPuTextToImageModelCredential() model_info_list = [ ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel), @@ -38,11 +41,21 @@ ZhiPuImage), ] +model_info_tti_list = [ + ModelInfo('cogview-3', '根据用户文字描述快速、精准生成图像。分辨率支持1024x1024', + ModelTypeConst.TTI, zhipu_tti_model_credential, + ZhiPuTextToImage), + ModelInfo('cogview-3-plus', '根据用户文字描述生成高质量图像,支持多图片尺寸', + ModelTypeConst.TTI, zhipu_tti_model_credential, + ZhiPuTextToImage), +] + model_info_manage = ( ModelInfoManage.builder() .append_model_info_list(model_info_list) .append_default_model_info(ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel)) .append_model_info_list(model_info_image_list) + .append_model_info_list(model_info_tti_list) .build() ) diff --git a/ui/src/api/application.ts b/ui/src/api/application.ts index bf384ad49ff..404303aebd9 100644 --- a/ui/src/api/application.ts +++ b/ui/src/api/application.ts @@ -293,6 +293,13 @@ const getApplicationImageModel: ( return get(`${prefix}/${application_id}/model`, { model_type: 'IMAGE' }, loading) } +const getApplicationTTIModel: ( + application_id: string, + loading?: Ref +) => Promise>> = (application_id, loading) => { + return get(`${prefix}/${application_id}/model`, { model_type: 'TTI' }, loading) +} + /** * 发布应用 @@ -523,6 +530,7 @@ export default { getApplicationSTTModel, getApplicationTTSModel, getApplicationImageModel, + getApplicationTTIModel, postSpeechToText, postTextToSpeech, getPlatformStatus, diff --git a/ui/src/components/ai-chat/ExecutionDetailDialog.vue b/ui/src/components/ai-chat/ExecutionDetailDialog.vue index 5c55f1bf379..39a4f2617cc 100644 --- a/ui/src/components/ai-chat/ExecutionDetailDialog.vue +++ b/ui/src/components/ai-chat/ExecutionDetailDialog.vue @@ -32,6 +32,7 @@ item.type === WorkflowType.Question || item.type === WorkflowType.AiChat || item.type === WorkflowType.ImageUnderstandNode || + item.type === WorkflowType.ImageGenerateNode || item.type === WorkflowType.Application " >{{ item?.message_tokens + item?.answer_tokens }} tokens + +