Skip to content

feat: Support image generate model #1812

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 13, 2024
Merged

feat: Support image generate model #1812

merged 5 commits into from
Dec 13, 2024

Conversation

liuruibin
Copy link
Member

What this PR does / why we need it?

Summary of your change

Please indicate you've done the following:

  • Made sure tests are passing and test coverage is added if needed.
  • Made sure commit message follow the rule of Conventional Commits specification.
  • Considered the docs impact and opened a new docs issue or PR with docs changes if needed.

Copy link

f2c-ci-robot bot commented Dec 11, 2024

Adding the "do-not-merge/release-note-label-needed" label because no release-note block was detected, please follow our release note process to remove it.

Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository.

Copy link

f2c-ci-robot bot commented Dec 11, 2024

[APPROVALNOTIFIER] This PR is NOT APPROVED

This pull-request has been approved by:

The full list of commands accepted by this bot can be found here.

Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

@liuruibin liuruibin changed the title Feat tti model feat: Support image generate model Dec 12, 2024
return signV4Request(self.access_key, self.secret_key, service, formatted_query, formatted_body)

def is_cache_model(self):
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该代码有一些潜在的问题和改进空间:

  1. 缺少必要的错误处理:

    • requests.post() 操作抛出异常后没有正确的捕获或记录。
  2. 转义问题:

    • 在请求字符串中使用了字符 \u0026 替换原始的 &,这通常不是推荐的做法。可以考虑在构建 URL 之前进行适当的编码。
  3. 参数传递不清晰:

    • 参数传递过程中存在不必要的复杂性和重复,需要简化逻辑。
  4. 环境依赖性:

    • 引入了多个外部库(如 websockets, langchain_openai, etc.),但这些库的具体用途未说明。
  5. 部分函数缺少文档注释:

    • 函数签名和参数列表缺少详细的注释,难以理解其功能。

以下是改进建议:

改进点

  1. 确保异常被捕获

    try:
        r = requests.post(request_url, headers=headers, data=req_body)
    except Exception as err:
        print(f'Error occurred: {err}')
        raise
  2. 避免手动转义字符
    当然,如果你确实需要替换特殊字符,可以使用更安全的方式来进行编码,例如 urllib.parse.quote_plus() 来处理 URIs。

  3. 简化参数传递
    将参数处理逻辑整合到一个统一的位置,并确保所有模型实现类都遵循相似的参数约定。

class VolcanicEngineTextToImage(MaxKBBaseModel, BaseTextToImage):
    access_key: str
    secret_key: str
    model_version: str
    params: dict

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.initialize(**kwargs)

    @staticmethod
    def initialize(access_key=None, secret_key=None, model_version=None, params={}):
        instance = VolcanicEngineTextToImage()
        instance.access_key = access_key or instance.access_key
        instance.secret_key = secret_key or instance.secret_key
        instance.model_version = model_version or instance.model_version
        instance.params = params.copy()  # Use copy to ensure the original dictionary is not modified
        return instance

    @staticmethod
    def create_from_config(config_data):
        defaults = {
            'access_key': 'your_access_key_here',
            'secret_key': 'your_secret_key_here',
            'model_version': 'general_v1.4',  # Default version to use
        }
        
        merged_config = defaults | config_data
        
        return VolcanicEngineTextToImage.initialize(
            **merged_config
        )

    def check_auth(self):
        res = self.generate_image('生成一张小猫图片')
        print(res)

    def generate_image(self, prompt: str, negative_prompt: str = None):
        if not all([prompt]):
            raise ValueError("Prompt must be provided.")
        
        query_params = {}
        body_params = {}

        for param in vars(self).values():
            if isinstance(param, (dict, list)):
                body_params.update(param)
            elif callable(getattr(self.__class__, param) and func:=getattr(self.__class__, param)) == self.create_from_config:
                continue  # Skip this since it's a class itself
            else:
                query_params[param] = getattr(self.__class__, param)()

        for attr, value in self.config.items():  # Assuming a configuration attribute exists
            query_params[attr] = getattr(value, attr, value)

        final_query_params = self.format_query(query_params)
        formatted_body = self.build_payload(body_params)
        
        return self.sign_v4_request(final_query_params, formatted_body)

    def format_query(self, params: dict) -> dict:
        ordered_query_params = sorted(params.items())
        formated_qery_str = '&'.join(f'{k}={v}' for k,v in ordered_query_params)
        return f'?{formated_qery_str}'

    def build_payload(self, params: Dict[str, object]) -> str:
        return json.dumps(params)

    def sign_v4_request(self, canonic_query: str, payload: str):
        access_key = self.access_key
        secret_key = self.secret_key
        
        t = datetime.datetime.utcnow()
        current_date = t.strftime('%Y%m%dT%H%M%SZ')
        datestamp = t.strftime('%Y%m%d')

        canonical_uri = '/cv/2022-08-31'
        signed_headers = 'content-type;host;x-amz-content-sha256;x-amz-date'
        algorithm = "AWS4-HMAC-SHA256"

        string_to_stringified = (algorithm +
                                "\n" +
                                current_date +
                                "\n" +
                                datestamp + "/" + REGION_NAME + "/service_name/" +
                                AWS_REQUEST_SUFIX_SIGNATURE +
                                calculate_canonical_request(canonical_uri,
                                                          canonic_query,
                                                          canonical_headers,
                                                          signed_headers))

        signing_key = aws_signing_algorithm(secret_key, 
                                             datestamp, 
                                             REGION_NAME, 
                                             SERVICE_NAME).sign(algorithm=algorithm)


        request_signature = hmac.new(signing_key.encode(),
                                  message=(string_with_sings).encode(),
                                  digestmod=hashlib.sha256)

        auth_header_value = '{Algorithm} Credential={AccessKey}/{Scope}, SignedHeaders={SignedHeaders}, Signature={Signature}'.format(
                                    Algorithm = algorithm.upper(), 
                                    AccessKey = access_key, 
                                    Scope = f"{datestamp}/{REGION_NAME}/service_name/{AWS_REQUEST_SUFIX_SIGNATURE}",
                                    SignedHeaders = signed_headers,
                                    Signature = encoded_signature.hex())

        headers = {"Authorization": auth_header_value}

        url = base_url + canonic_query

上述修改包括了一些关键点:规范化参数读取流程,减少代码冗余,并增加了对异常处理的支持。同时,为了提高可读性和安全性,在某些地方还做了进一步优化。当然根据具体的业务需求与技术栈,可能还需要做更多的调整。

@@ -377,6 +536,8 @@
ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding))
.append_model_info_list(rerank_list)
.append_model_info_list(image_model_info)
.append_model_info_list(tti_model_info)
.append_default_model_info(rerank_list[0])
.build())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段代码在结构上和功能上都有改进,但有一些优化点:

1. 简化导入语句

可以将所有模型相关文件夹下的 credentialmodel 包的依赖合并到一个顶级导入中。

from setting.models_provider.impl.xinference_model_provider.credential import \
    XinferenceLLMModelCredential, XinferenceSTTModelCredential, ...

并使用通配符导入:

from setting.models_provider.impl.xinference_model_provider.credential.* import *

这会使代码更加简洁明了。

2. 避免重复导入相同类型不同名称的模型类

例如:

  • 重复导出 XInferenceLLMModelCredential, XInferenceSTTModelCredential, ..., 可以直接在一行内完成导入。

例如:

xinference_model_credentials = {
    "LLM": XinferenceLLMModelCredential(),
    "STT": XinferenceSTTModelCredential,
}

这样可以减少冗余,并且更易于管理多个相同的凭证类实例。

3. 增加注释说明

为每一对模型提供详细的注释,以便其他开发者更容易理解这些信息。

4. 利用枚举类型

考虑到模型类型可能需要进一步扩展 (如 TEXT_TO_IMAGE),推荐使用枚举类型来替代固定字符串。

class ModelType(enum.Enum):
    LLM = auto()
    STT = auto()
    TTS = auto()
    RERANKER = auto()
    IMAGE = auto()  # 新添加的模型类别
    TTI = auto()   # 新添加的模型类别

# 更改构造方法参数:
ModelInfo(
    name="...",
    provider_type=provider_type,
    type=models.TypeEnum.LLM, if models else None,
    credential=credentials_map.get(type.value),
    class_=models.ModelBase.__subclass_with_name__(str(name), f"{type.name.capitalize()}Model") or BaseClass
)

上述修改使代码更为模块化、更具可维护性。

'err_message': self.err_message,
'image_list': self.context.get('image_list'),
'dialogue_type': self.context.get('dialogue_type')
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该代码存在以下不足和改进点:

  1. 编码格式:文件使用了UTF-8编码,但缺少BOM(Byte Order Mark)。可以在头部添加 # -*- coding: utf-8 -*-

  2. 函数命名:一些方法名如 generate_prompt_question, get_details 可以考虑是否与现有API或库中的函数冲突。

  3. 错误处理:没有对所有可能的异常进行适当的捕获和处理,可能会导致程序崩溃或数据丢失。

  4. 日志记录:在某些地方打印了调试信息,但在生产环境中应该移除这些打印语句或替换为更安全的日志机制。

  5. 依赖项管理:导入的部分模块(如 reduce)来自内置函数和标准库,无需额外安装。可以将这部分移除。

  6. 变量命名:部分变量取名为 detailscontext,可能导致混淆。推荐使用更具描述性的名称,例如 query_detailscurrent_context 等。

  7. 文档字符串:虽然有一些注释解释了每一步的作用,但整体上的文档编写还有待提高,特别是在类层次结构、参数说明等方面。

  8. 可读性:代码逻辑清晰,但对于大型项目来说,良好的分隔和缩进仍然很重要。

以下是修改后的一些示例:

修改后的版本

def generate_prompt(self, query):
    """根据查询生成提示问题"""
    return self.workflow_manage.generate_prompt(query)

@staticmethod
def process_generated_images(image_urls, app_id=None):
    """处理生成的图像并上传到服务器"""
    result = []
    
    for url in image_urls:
        response = requests.get(url)
        image_data = response.content
        
        file_name = f"generated_image_{len(result)}.png"
        uploaded_file = bytes_to_uploaded_file(image_data, file_name)
        
        metadata = {
            'debug': bool(app_id),
            'app_id': str(app_id) if app_id else None,
        }
        upload_response = FileSerializer(data={'file': uploaded_file, 'meta': metadata}).upload()
        result.append(upload_response.url)
    
    return result

def save_result(self, task_detail, run_time):
    """保存任务结果"""
    return NodeResult({
        'answer': task_detail.answer,
        'chat_model': task_detail.chat_model,
        'message_list': task_detail.message_list,
        'image': [{'file_id': path.rsplit('/', 1)[-1], 'file_url': path} for path in task_detail.image_urls],
        'history_message': task_detail.history_message,
        'question': task_detail.question
    }, '')

def handle_workflow_messages(task_detail):
    """处理流程消息"""
    for val in task_detail.details.values():
        if val['node_id'] == self.node.id and 'image_list' in val:
            if val['dialogue_type'] == 'WORKFLOW':
                return task_detail.ai_message
            elif not (val['dialogue_type'] == 'IMAGE_GENERATION'):
                return AIMessage(content=val['answer'].strip())

def build_history_messages(chat_records, count):
    """构建历史消息列表"""
    history_messages = []
    for record in chat_records[-count:]:
        history_messages.extend([
            HumanMessage(content=record.problem_text),
            self.handle_workflow_messages(record)
        ])
    return history_messages

class ImageGenerationWorkflow(BaseStepNode, iImageGenerateNode):
    """图像生成流程节点"""

    def execute(
            self,
            model_id,
            prompt,
            negative_prompt='',
            dialogue_number=1,
            dialect='DEFAULT',
            history=None,
            dialog_type='CHAT',
            flow_params_serializer=None,
            **kwargs
    ) -> TaskDetailResponseModel_v3_0:

        try:
            task_details = {}
            task_results = {}

            # 获取模型实例,并获取上下文。
            ... (原实现保持不变generated_images = task_details.get('generated_images')

            if generated_images:
                image_upload_urls = ImageGenerationWorkflow.process_generated_images(generated_images, app_id)

                node_result = self.save_result(
                    {'id': task_details.get('task_id', ''), 'question': task_details.get('prompt', ''),
                     'generated_answers': task_details.get('response', ''), 'image_list': image_upload_urls},
                    datetime.datetime.now()
                )

            return node_result

        except Exception as e:
            self.log_error(f"Exception during execution: {e}")
            raise

通过上述改进建议,可以使代码更加健壮、易于维护和扩展。

return signV4Request(self.access_key, self.secret_key, service, formatted_query, formatted_body)

def is_cache_model(self):
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该代码存在以下几个问题和建议:

  1. 缩进不一致:代码中的缩进需要保持一致性。

  2. 注释错误:注释部分拼写了一些额外的空格。

  3. 参数传递方式:在 generate_imagenew_instance 中,参数传递的方式比较混乱。使用 self.params 直接操作可能会导致意外的行为。应该尽量避免直接修改实例变量或类字段,而是通过方法处理。

  4. 没有测试覆盖率:代码缺少必要的单元测试来验证其功能是否正确。

  5. 字符串格式化:在某些情况下,字符串格式化(如 \u0026) 可能会导致输出乱码。可以考虑使用 json 库处理 JSON 数据以确保字节编码一致。

  6. 错误处理:虽然有一些基本的错误处理逻辑,但可以在请求过程中添加更多的异常捕获和日志记录来提高鲁棒性。

以下是一些具体的改进建议:

@@ -87,9 +87,10 @@ class VolcanicEngineTextToImage(MaxKBBaseModel, BaseTextToImage):
     def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
         optional_params = {'params': {}}
         for key, value in model_kwargs.items():
-            if key not in ['model_id', 'use_local', 'streaming']:
+            if key.lower() not in ['model_id', 'use_local', 'streaming']:  # 添加 .lower() 进行大小写转换
                 optional_params['params'][key] = value
-
+        return VolcanicEngineTextToImage(
+            access_key=kwargs.get('access_key'),                      # 移除重复的 self.access_key 和模型版本设置
+            secret_key=kwargs.get('secret_key'),
+            **optional_params
+    )

@@ -377,6 +536,8 @@
ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding))
.append_model_info_list(rerank_list)
.append_model_info_list(image_model_info)
.append_model_info_list(tti_model_info)
.append_default_model_info(rerank_list[0])
.build())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段代码有一些不规范的地方和优化建议:

  1. 重复导入:在多个地方都引入了相同的模块(如XInferenceLLMModelCredential, XInferenceSTTModelCredential, 等等),可以将其合并到一个列表中进行统一导入,然后逐一实例化。

    module_credentials = [
        XinferenceLLMModelCredential,
        XinferenceSTTModelCredential,
        # ... 其他credential类
    ]
    
    credentials_map = {cls.__name__: cls() for cls in module_credentials}
  2. 避免硬编码模型名称

    • 目前在image_model_infotti_model_info数组内直接使用模型名称作为字符串,应该考虑将它们存储在一个更安全的方式下,比如从配置文件或数据库读取。
  3. 冗余代码

    • 在构造多组model信息时,部分逻辑相似,可以在函数中封装起来以减少重复代码。
  4. 文档和注释

    • 对于每一个新增的ModuleInfosBuilder#append_model_info_list调用,添加一些解释性注释,说明每个参数的作用。
  5. 异常处理

    • 添加适当的错误处理机制,以防某些模块无法成功加载或实例化。

以下是改进后的版本示例:

from smartdoc.conf import PROJECT_DIR

module_credentials = [
    XinferenceLLMModelCredential,
    XinferenceSTTModelCredential,
    # other credential classes can be added here
]

credentials_map = {cls.__name__: cls for cls in module_credentials}

def build_model_infos(credentials):
    builder = ModuleInfosBuilder()

    def append_models(model_type_name, models_list):
        if not models_list:
            return
        model_group_name = f"{model_type_name.lower()}_chat"
        builder.append_default_model_info(models[model_group_name][0])  # using an assumption: all default models follow this pattern
        return builder.append_model_info_list(models[model_group_name])

    append_models("IMAGE", image_model_info)  # Assuming a structure where image related models are categorized under "IMAGE"_chat
    append_models("TTI", tti_model_info)

    return builder.build()


def main():
    xinference_embedding_model_credential = XinferenceEmbeddingModelCredential()
    
    global_module_info = GlobalModuleInfos(xinference_embedding_model_credential).append_default_model_info(None)
    sub_module_info = ModuleSubInfos(
        xinference_embedded_text_tokenizer="default",
        xinference_llm_chat=None,
        embeded_sentence_chunk_size=1,
        xinference_rerankers=[],  # assuming rerank strategies aren't loaded dynamically at this point
        xinference_stt=None,
        embedding_models=[],  # these will be populated after fetch_from_global_config is called
        llms=[],  # these will be populated after fetch_from_global_config is called
        tts=[]
    ).append_global(module_global_info)
    combined_info_builder = CombinedInfosBuilder().append_model_subinfos(sub_module_info)
    final_combined_info = combined_info_builder.build()

if __name__ == "__main__":
    main()

这个改进主要集中在代码组织方面,并尽量减少全局变量的使用,提高结构清晰度。同时增加了对模块动态扩展的支持,以及对于非默认模式的一些假设(如如何获取子模组的所有模型)。

'err_message': self.err_message,
'image_list': self.context.get('image_list'),
'dialogue_type': self.context.get('dialogue_type')
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段代码存在一些问题和可以优化的地方:

1. save_context 方法中的 self.answer_text 赋值

execute 方法中赋值了 self.answer_text = details.get('answer'),但在 save_context 中又重新赋值了一次,这可能造成不必要的重复。

self.save_context(details, workflow_manage)

可以在调用 save_context 后删除 self.answer_text 这一行:

self.save_context(details, workflow_manage)
del self.answer_text

2. execute 方法中的图片处理逻辑

图片下载和上传操作使用了一个循环来遍历图像列表,并对每个 URL 进行请求。这种方法可能会导致性能问题,尤其是在大规模生成或需要频繁访问多个服务时。

推荐使用多线程或多进程来并行处理图像下载任务,以提高效率。

此外,考虑将请求头设置为允许 CORS 或使用代理服务器来避免直接访问外部资源造成的限制。

3. generate_message_list 方法中的历史消息拼接

在生成 message 列表时,会将 history 消息与新生成的消息相加。这可能导致 message 数量超过预期。

示例:

self.context['message_list'] = [
    *self.context['history_message'], 
    question
]

如果历史消息数量过多,可能会影响界面展示效果或 API 请求的响应时间。

可以通过增加条件判断来防止 message 数量过大:

if len(self.context['message_list']) + 1 <= max_messages_per_request:
    self.context['message_list'].append(question)
else:
    # 处理超过最大数量的情况
    pass

4. reset_message_list 方法

此方法用于重置消息列表,但它返回一个包含旧消息内容的新列表,而不是原地修改。

return [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for message in message_list]
result.append({'role': 'ai', 'content': answer_text})
return result

更好的做法是返回一个新的 list 并替换原来的 list 而不是覆盖其引用:

self.context['message_list'] = [
    {'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content}
    for message in message_list[:-1] + [AIMessage(content=answer_text)]
]

这些改进建议可以帮助你更好地管理和优化代码的质量。

return signV4Request(self.access_key, self.secret_key, service, formatted_query, formatted_body)

def is_cache_model(self):
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

以下是针对你提供的代码的一些建议和优化提示:

  1. 空文件头
    • 文件开头部分包含了多个注释,但这些注释似乎并不是必要的。
    • 建议去掉不必要的空行或保留关键部分的注释。
# coding=utf-8
  1. 编码问题
    • 确保所有字符串都使用 UTF-8 编码,并且在处理时避免乱码情况。
method = 'POST'
host = 'visual.volcengineapi.com'
region = 'cn-north-1'
endpoint = 'https://visual.volcengineapi.com'
service = 'cv'

req_key_dict = {
    'general_v1.4': 'high_aes_general_v14',
    'general_v2.0': 'high_aes_general_v20',
    'general_v2.0_L': 'high_aes_general_v20_L',
    'anime_v1.3': 'high_aes',
    'anime_v1.3.1': 'high_aes',
}
  1. 日期格式化
    • current_date 变量可以直接使用当前时间戳来生成,而不是手动指定的。
    • 这样可以减少日期相关的错误和不确定性。
current_date = datetime.datetime.utcnow().strftime('%Y%m%dT%H%M%SZ')
  1. 签名函数中的键顺序
    • 在计算签名密钥时,确保 signKeyDict 中的键是按字典序排序的。虽然这里没有直接使用签名密钥列表,但在其他地方可能会这样做,因此仍然保持一致性比较好。
kDate = sign(key.encode('utf-8'), dateStamp)
kRegion = sign(kDate, regionName)
kService = sign(kRegion, serviceName)
kSigning = sign(kService, 'request')
return kSigning
  1. 参数拼接

    • 确保所有的查询参数都是先按字母顺序排序后连接起来的。
    • 避免出现遗漏任何参数的情况。
  2. 响应处理

    • 在处理请求失败时,应该捕获 HTTP 错误并提供更友好的用户反馈。
try:
    r = requests.post(request_url, headers=headers, data=req_body)
except Exception as err:
    print(f'error occurred: {err}')
    raise
else:
    ...
  1. 方法签名

    • generate_image 方法中,应明确其返回值类型(例如:-> str | list).
  2. 类属性初始化

    • __init__ 方法中,初始化字段时应该使用默认参数,并对缺失的关键字进行警告或默认设置。
def __init__(self, model_version=None, access_key=None, secret_key=None, params={}):
    super().__init__()
    self.access_key = access_key or None
    self.secret_key = secret_key or None
    self.model_version = model_version or ""
    self.params = params.copy()  # 复制以防止外部修改原输入
  1. 缓存支持

    • 如果需要实现缓存机制,可以在相应的方法中添加逻辑,并确保缓存策略符合业务需求。
  2. 异步改进

    • 引入 Python 的 asyncio 库可以提高程序并发性和响应速度,特别是在数据传输密集型任务中。
import asyncio

async def async_generate_image(self, prompt: str, negative_prompt: str = None):
    # 同同步版本的 generate_image 实现
    ...

async def run_async_methods():
    await asyncio.gather(
        self.async_generate_image("生成一张小猫图片"),
        self.async_check_auth()
    )

if __name__ == "__main__":
    loop = asyncio.get_event_loop()
    loop.run_until_complete(run_async_methods())

以上是一些可能有助于改善代码质量和可维护性的建议。你可以根据具体的需求进一步调整和完善代码。

@@ -377,6 +536,8 @@
ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding))
.append_model_info_list(rerank_list)
.append_model_info_list(image_model_info)
.append_model_info_list(tti_model_info)
.append_default_model_info(rerank_list[0])
.build())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段代码主要增加了xinference模型和模型管理的类以及相关的配置信息,以支持新的图像处理和文本转图片(Text to Image)等功能。以下是一些可能存在的问题及建议:

  1. 冗余导入:在setting/models_provider/impl/xinference_model_provider/model/tts.py中引入了不必要的文件import xx_informer.py
    建议:删除不必要的导入。

  2. 重复定义model_info_list, xinference_llm_model_credential, xinference_stt_model_credential, xinference_tts_model_credential, 和 xinqineering_image_model_credential
    可能已经存在相同的定义。建议根据实际情况重新命名或合并这些变量。

  3. 逻辑错误:确保xinference_embedding_model_credential被正确地创建后添加到embedding_model_info列表,以便后续使用该凭据进行嵌入操作。

  4. 拼写检查:一些注释中的拼写的错误,如“xinference_embedding”,应该修正为正确的名称。

  5. 参数缺失:在ModelInfo对象初始化时需要确保所有必填字段都已经提供,否则可能会产生异常。

  6. 性能考虑:在构建model_info_mgr过程中添加大量数据项会导致性能问题,因此可以尝试分段加载或缓存部分数据。

综上所述,以上问题是由于新增加的功能导致的代码不够简洁且可能存在逻辑上的问题和可读性不强的问题。

'err_message': self.err_message,
'image_list': self.context.get('image_list'),
'dialogue_type': self.context.get('dialogue_type')
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有明显的不规范、潜在问题或优化建议。

@liuruibin liuruibin merged commit 7bd791f into main Dec 13, 2024
3 of 4 checks passed
@liuruibin liuruibin deleted the feat_tti_model branch December 13, 2024 06:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants