-
Notifications
You must be signed in to change notification settings - Fork 2.1k
feat: Support image generate model #1812
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Adding the "do-not-merge/release-note-label-needed" label because no release-note block was detected, please follow our release note process to remove it. Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository. |
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here.
Needs approval from an approver in each of these files:
Approvers can indicate their approval by writing |
apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py
Show resolved
Hide resolved
e639e15
to
9160410
Compare
apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py
Show resolved
Hide resolved
9160410
to
592a9fd
Compare
apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py
Show resolved
Hide resolved
apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py
Show resolved
Hide resolved
cb04b7b
to
00bcd80
Compare
apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py
Show resolved
Hide resolved
00bcd80
to
7a17bb5
Compare
apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py
Show resolved
Hide resolved
7a17bb5
to
0ab3f7f
Compare
apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py
Show resolved
Hide resolved
apps/setting/models_provider/impl/openai_model_provider/credential/tti.py
Show resolved
Hide resolved
0ab3f7f
to
b29744b
Compare
apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tti.py
Show resolved
Hide resolved
apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py
Show resolved
Hide resolved
apps/setting/models_provider/impl/tencent_model_provider/credential/tti.py
Show resolved
Hide resolved
7d64ee3
to
12aecdf
Compare
apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tti.py
Show resolved
Hide resolved
apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py
Show resolved
Hide resolved
apps/setting/models_provider/impl/tencent_model_provider/credential/tti.py
Show resolved
Hide resolved
12aecdf
to
f4c98ea
Compare
apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tti.py
Show resolved
Hide resolved
apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py
Show resolved
Hide resolved
apps/setting/models_provider/impl/tencent_model_provider/credential/tti.py
Show resolved
Hide resolved
return signV4Request(self.access_key, self.secret_key, service, formatted_query, formatted_body) | ||
|
||
def is_cache_model(self): | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
该代码有一些潜在的问题和改进空间:
-
缺少必要的错误处理:
requests.post()
操作抛出异常后没有正确的捕获或记录。
-
转义问题:
- 在请求字符串中使用了字符
\u0026
替换原始的&
,这通常不是推荐的做法。可以考虑在构建 URL 之前进行适当的编码。
- 在请求字符串中使用了字符
-
参数传递不清晰:
- 参数传递过程中存在不必要的复杂性和重复,需要简化逻辑。
-
环境依赖性:
- 引入了多个外部库(如
websockets
,langchain_openai
, etc.),但这些库的具体用途未说明。
- 引入了多个外部库(如
-
部分函数缺少文档注释:
- 函数签名和参数列表缺少详细的注释,难以理解其功能。
以下是改进建议:
改进点
-
确保异常被捕获:
try: r = requests.post(request_url, headers=headers, data=req_body) except Exception as err: print(f'Error occurred: {err}') raise
-
避免手动转义字符:
当然,如果你确实需要替换特殊字符,可以使用更安全的方式来进行编码,例如urllib.parse.quote_plus()
来处理 URIs。 -
简化参数传递:
将参数处理逻辑整合到一个统一的位置,并确保所有模型实现类都遵循相似的参数约定。
class VolcanicEngineTextToImage(MaxKBBaseModel, BaseTextToImage):
access_key: str
secret_key: str
model_version: str
params: dict
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.initialize(**kwargs)
@staticmethod
def initialize(access_key=None, secret_key=None, model_version=None, params={}):
instance = VolcanicEngineTextToImage()
instance.access_key = access_key or instance.access_key
instance.secret_key = secret_key or instance.secret_key
instance.model_version = model_version or instance.model_version
instance.params = params.copy() # Use copy to ensure the original dictionary is not modified
return instance
@staticmethod
def create_from_config(config_data):
defaults = {
'access_key': 'your_access_key_here',
'secret_key': 'your_secret_key_here',
'model_version': 'general_v1.4', # Default version to use
}
merged_config = defaults | config_data
return VolcanicEngineTextToImage.initialize(
**merged_config
)
def check_auth(self):
res = self.generate_image('生成一张小猫图片')
print(res)
def generate_image(self, prompt: str, negative_prompt: str = None):
if not all([prompt]):
raise ValueError("Prompt must be provided.")
query_params = {}
body_params = {}
for param in vars(self).values():
if isinstance(param, (dict, list)):
body_params.update(param)
elif callable(getattr(self.__class__, param) and func:=getattr(self.__class__, param)) == self.create_from_config:
continue # Skip this since it's a class itself
else:
query_params[param] = getattr(self.__class__, param)()
for attr, value in self.config.items(): # Assuming a configuration attribute exists
query_params[attr] = getattr(value, attr, value)
final_query_params = self.format_query(query_params)
formatted_body = self.build_payload(body_params)
return self.sign_v4_request(final_query_params, formatted_body)
def format_query(self, params: dict) -> dict:
ordered_query_params = sorted(params.items())
formated_qery_str = '&'.join(f'{k}={v}' for k,v in ordered_query_params)
return f'?{formated_qery_str}'
def build_payload(self, params: Dict[str, object]) -> str:
return json.dumps(params)
def sign_v4_request(self, canonic_query: str, payload: str):
access_key = self.access_key
secret_key = self.secret_key
t = datetime.datetime.utcnow()
current_date = t.strftime('%Y%m%dT%H%M%SZ')
datestamp = t.strftime('%Y%m%d')
canonical_uri = '/cv/2022-08-31'
signed_headers = 'content-type;host;x-amz-content-sha256;x-amz-date'
algorithm = "AWS4-HMAC-SHA256"
string_to_stringified = (algorithm +
"\n" +
current_date +
"\n" +
datestamp + "/" + REGION_NAME + "/service_name/" +
AWS_REQUEST_SUFIX_SIGNATURE +
calculate_canonical_request(canonical_uri,
canonic_query,
canonical_headers,
signed_headers))
signing_key = aws_signing_algorithm(secret_key,
datestamp,
REGION_NAME,
SERVICE_NAME).sign(algorithm=algorithm)
request_signature = hmac.new(signing_key.encode(),
message=(string_with_sings).encode(),
digestmod=hashlib.sha256)
auth_header_value = '{Algorithm} Credential={AccessKey}/{Scope}, SignedHeaders={SignedHeaders}, Signature={Signature}'.format(
Algorithm = algorithm.upper(),
AccessKey = access_key,
Scope = f"{datestamp}/{REGION_NAME}/service_name/{AWS_REQUEST_SUFIX_SIGNATURE}",
SignedHeaders = signed_headers,
Signature = encoded_signature.hex())
headers = {"Authorization": auth_header_value}
url = base_url + canonic_query
上述修改包括了一些关键点:规范化参数读取流程,减少代码冗余,并增加了对异常处理的支持。同时,为了提高可读性和安全性,在某些地方还做了进一步优化。当然根据具体的业务需求与技术栈,可能还需要做更多的调整。
@@ -377,6 +536,8 @@ | |||
ModelTypeConst.EMBEDDING, | |||
xinference_embedding_model_credential, XinferenceEmbedding)) | |||
.append_model_info_list(rerank_list) | |||
.append_model_info_list(image_model_info) | |||
.append_model_info_list(tti_model_info) | |||
.append_default_model_info(rerank_list[0]) | |||
.build()) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这段代码在结构上和功能上都有改进,但有一些优化点:
1. 简化导入语句
可以将所有模型相关文件夹下的 credential
和 model
包的依赖合并到一个顶级导入中。
from setting.models_provider.impl.xinference_model_provider.credential import \
XinferenceLLMModelCredential, XinferenceSTTModelCredential, ...
并使用通配符导入:
from setting.models_provider.impl.xinference_model_provider.credential.* import *
这会使代码更加简洁明了。
2. 避免重复导入相同类型不同名称的模型类
例如:
- 重复导出
XInferenceLLMModelCredential
,XInferenceSTTModelCredential
, ..., 可以直接在一行内完成导入。
例如:
xinference_model_credentials = {
"LLM": XinferenceLLMModelCredential(),
"STT": XinferenceSTTModelCredential,
}
这样可以减少冗余,并且更易于管理多个相同的凭证类实例。
3. 增加注释说明
为每一对模型提供详细的注释,以便其他开发者更容易理解这些信息。
4. 利用枚举类型
考虑到模型类型可能需要进一步扩展 (如 TEXT_TO_IMAGE
),推荐使用枚举类型来替代固定字符串。
class ModelType(enum.Enum):
LLM = auto()
STT = auto()
TTS = auto()
RERANKER = auto()
IMAGE = auto() # 新添加的模型类别
TTI = auto() # 新添加的模型类别
# 更改构造方法参数:
ModelInfo(
name="...",
provider_type=provider_type,
type=models.TypeEnum.LLM, if models else None,
credential=credentials_map.get(type.value),
class_=models.ModelBase.__subclass_with_name__(str(name), f"{type.name.capitalize()}Model") or BaseClass
)
上述修改使代码更为模块化、更具可维护性。
'err_message': self.err_message, | ||
'image_list': self.context.get('image_list'), | ||
'dialogue_type': self.context.get('dialogue_type') | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
该代码存在以下不足和改进点:
-
编码格式:文件使用了UTF-8编码,但缺少BOM(Byte Order Mark)。可以在头部添加
# -*- coding: utf-8 -*-
。 -
函数命名:一些方法名如
generate_prompt_question
,get_details
可以考虑是否与现有API或库中的函数冲突。 -
错误处理:没有对所有可能的异常进行适当的捕获和处理,可能会导致程序崩溃或数据丢失。
-
日志记录:在某些地方打印了调试信息,但在生产环境中应该移除这些打印语句或替换为更安全的日志机制。
-
依赖项管理:导入的部分模块(如
reduce
)来自内置函数和标准库,无需额外安装。可以将这部分移除。 -
变量命名:部分变量取名为
details
和context
,可能导致混淆。推荐使用更具描述性的名称,例如query_details
、current_context
等。 -
文档字符串:虽然有一些注释解释了每一步的作用,但整体上的文档编写还有待提高,特别是在类层次结构、参数说明等方面。
-
可读性:代码逻辑清晰,但对于大型项目来说,良好的分隔和缩进仍然很重要。
以下是修改后的一些示例:
修改后的版本
def generate_prompt(self, query):
"""根据查询生成提示问题"""
return self.workflow_manage.generate_prompt(query)
@staticmethod
def process_generated_images(image_urls, app_id=None):
"""处理生成的图像并上传到服务器"""
result = []
for url in image_urls:
response = requests.get(url)
image_data = response.content
file_name = f"generated_image_{len(result)}.png"
uploaded_file = bytes_to_uploaded_file(image_data, file_name)
metadata = {
'debug': bool(app_id),
'app_id': str(app_id) if app_id else None,
}
upload_response = FileSerializer(data={'file': uploaded_file, 'meta': metadata}).upload()
result.append(upload_response.url)
return result
def save_result(self, task_detail, run_time):
"""保存任务结果"""
return NodeResult({
'answer': task_detail.answer,
'chat_model': task_detail.chat_model,
'message_list': task_detail.message_list,
'image': [{'file_id': path.rsplit('/', 1)[-1], 'file_url': path} for path in task_detail.image_urls],
'history_message': task_detail.history_message,
'question': task_detail.question
}, '')
def handle_workflow_messages(task_detail):
"""处理流程消息"""
for val in task_detail.details.values():
if val['node_id'] == self.node.id and 'image_list' in val:
if val['dialogue_type'] == 'WORKFLOW':
return task_detail.ai_message
elif not (val['dialogue_type'] == 'IMAGE_GENERATION'):
return AIMessage(content=val['answer'].strip())
def build_history_messages(chat_records, count):
"""构建历史消息列表"""
history_messages = []
for record in chat_records[-count:]:
history_messages.extend([
HumanMessage(content=record.problem_text),
self.handle_workflow_messages(record)
])
return history_messages
class ImageGenerationWorkflow(BaseStepNode, iImageGenerateNode):
"""图像生成流程节点"""
def execute(
self,
model_id,
prompt,
negative_prompt='',
dialogue_number=1,
dialect='DEFAULT',
history=None,
dialog_type='CHAT',
flow_params_serializer=None,
**kwargs
) -> TaskDetailResponseModel_v3_0:
try:
task_details = {}
task_results = {}
# 获取模型实例,并获取上下文。
... (原实现保持不变)
generated_images = task_details.get('generated_images')
if generated_images:
image_upload_urls = ImageGenerationWorkflow.process_generated_images(generated_images, app_id)
node_result = self.save_result(
{'id': task_details.get('task_id', ''), 'question': task_details.get('prompt', ''),
'generated_answers': task_details.get('response', ''), 'image_list': image_upload_urls},
datetime.datetime.now()
)
return node_result
except Exception as e:
self.log_error(f"Exception during execution: {e}")
raise
通过上述改进建议,可以使代码更加健壮、易于维护和扩展。
3fe8880
to
defbfd6
Compare
return signV4Request(self.access_key, self.secret_key, service, formatted_query, formatted_body) | ||
|
||
def is_cache_model(self): | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
该代码存在以下几个问题和建议:
-
缩进不一致:代码中的缩进需要保持一致性。
-
注释错误:注释部分拼写了一些额外的空格。
-
参数传递方式:在
generate_image
和new_instance
中,参数传递的方式比较混乱。使用self.params
直接操作可能会导致意外的行为。应该尽量避免直接修改实例变量或类字段,而是通过方法处理。 -
没有测试覆盖率:代码缺少必要的单元测试来验证其功能是否正确。
-
字符串格式化:在某些情况下,字符串格式化(如
\u0026
) 可能会导致输出乱码。可以考虑使用json
库处理 JSON 数据以确保字节编码一致。 -
错误处理:虽然有一些基本的错误处理逻辑,但可以在请求过程中添加更多的异常捕获和日志记录来提高鲁棒性。
以下是一些具体的改进建议:
@@ -87,9 +87,10 @@ class VolcanicEngineTextToImage(MaxKBBaseModel, BaseTextToImage):
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {'params': {}}
for key, value in model_kwargs.items():
- if key not in ['model_id', 'use_local', 'streaming']:
+ if key.lower() not in ['model_id', 'use_local', 'streaming']: # 添加 .lower() 进行大小写转换
optional_params['params'][key] = value
-
+ return VolcanicEngineTextToImage(
+ access_key=kwargs.get('access_key'), # 移除重复的 self.access_key 和模型版本设置
+ secret_key=kwargs.get('secret_key'),
+ **optional_params
+ )
@@ -377,6 +536,8 @@ | |||
ModelTypeConst.EMBEDDING, | |||
xinference_embedding_model_credential, XinferenceEmbedding)) | |||
.append_model_info_list(rerank_list) | |||
.append_model_info_list(image_model_info) | |||
.append_model_info_list(tti_model_info) | |||
.append_default_model_info(rerank_list[0]) | |||
.build()) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这段代码有一些不规范的地方和优化建议:
-
重复导入:在多个地方都引入了相同的模块(如
XInferenceLLMModelCredential
,XInferenceSTTModelCredential
, 等等),可以将其合并到一个列表中进行统一导入,然后逐一实例化。module_credentials = [ XinferenceLLMModelCredential, XinferenceSTTModelCredential, # ... 其他credential类 ] credentials_map = {cls.__name__: cls() for cls in module_credentials}
-
避免硬编码模型名称:
- 目前在
image_model_info
和tti_model_info
数组内直接使用模型名称作为字符串,应该考虑将它们存储在一个更安全的方式下,比如从配置文件或数据库读取。
- 目前在
-
冗余代码:
- 在构造多组model信息时,部分逻辑相似,可以在函数中封装起来以减少重复代码。
-
文档和注释:
- 对于每一个新增的
ModuleInfosBuilder#append_model_info_list
调用,添加一些解释性注释,说明每个参数的作用。
- 对于每一个新增的
-
异常处理:
- 添加适当的错误处理机制,以防某些模块无法成功加载或实例化。
以下是改进后的版本示例:
from smartdoc.conf import PROJECT_DIR
module_credentials = [
XinferenceLLMModelCredential,
XinferenceSTTModelCredential,
# other credential classes can be added here
]
credentials_map = {cls.__name__: cls for cls in module_credentials}
def build_model_infos(credentials):
builder = ModuleInfosBuilder()
def append_models(model_type_name, models_list):
if not models_list:
return
model_group_name = f"{model_type_name.lower()}_chat"
builder.append_default_model_info(models[model_group_name][0]) # using an assumption: all default models follow this pattern
return builder.append_model_info_list(models[model_group_name])
append_models("IMAGE", image_model_info) # Assuming a structure where image related models are categorized under "IMAGE"_chat
append_models("TTI", tti_model_info)
return builder.build()
def main():
xinference_embedding_model_credential = XinferenceEmbeddingModelCredential()
global_module_info = GlobalModuleInfos(xinference_embedding_model_credential).append_default_model_info(None)
sub_module_info = ModuleSubInfos(
xinference_embedded_text_tokenizer="default",
xinference_llm_chat=None,
embeded_sentence_chunk_size=1,
xinference_rerankers=[], # assuming rerank strategies aren't loaded dynamically at this point
xinference_stt=None,
embedding_models=[], # these will be populated after fetch_from_global_config is called
llms=[], # these will be populated after fetch_from_global_config is called
tts=[]
).append_global(module_global_info)
combined_info_builder = CombinedInfosBuilder().append_model_subinfos(sub_module_info)
final_combined_info = combined_info_builder.build()
if __name__ == "__main__":
main()
这个改进主要集中在代码组织方面,并尽量减少全局变量的使用,提高结构清晰度。同时增加了对模块动态扩展的支持,以及对于非默认模式的一些假设(如如何获取子模组的所有模型)。
'err_message': self.err_message, | ||
'image_list': self.context.get('image_list'), | ||
'dialogue_type': self.context.get('dialogue_type') | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这段代码存在一些问题和可以优化的地方:
1. save_context
方法中的 self.answer_text
赋值
在 execute
方法中赋值了 self.answer_text = details.get('answer')
,但在 save_context
中又重新赋值了一次,这可能造成不必要的重复。
self.save_context(details, workflow_manage)
可以在调用 save_context
后删除 self.answer_text
这一行:
self.save_context(details, workflow_manage)
del self.answer_text
2. execute
方法中的图片处理逻辑
图片下载和上传操作使用了一个循环来遍历图像列表,并对每个 URL 进行请求。这种方法可能会导致性能问题,尤其是在大规模生成或需要频繁访问多个服务时。
推荐使用多线程或多进程来并行处理图像下载任务,以提高效率。
此外,考虑将请求头设置为允许 CORS 或使用代理服务器来避免直接访问外部资源造成的限制。
3. generate_message_list
方法中的历史消息拼接
在生成 message 列表时,会将 history 消息与新生成的消息相加。这可能导致 message 数量超过预期。
示例:
self.context['message_list'] = [
*self.context['history_message'],
question
]
如果历史消息数量过多,可能会影响界面展示效果或 API 请求的响应时间。
可以通过增加条件判断来防止 message 数量过大:
if len(self.context['message_list']) + 1 <= max_messages_per_request:
self.context['message_list'].append(question)
else:
# 处理超过最大数量的情况
pass
4. reset_message_list
方法
此方法用于重置消息列表,但它返回一个包含旧消息内容的新列表,而不是原地修改。
return [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for message in message_list]
result.append({'role': 'ai', 'content': answer_text})
return result
更好的做法是返回一个新的 list 并替换原来的 list 而不是覆盖其引用:
self.context['message_list'] = [
{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content}
for message in message_list[:-1] + [AIMessage(content=answer_text)]
]
这些改进建议可以帮助你更好地管理和优化代码的质量。
defbfd6
to
15ab598
Compare
return signV4Request(self.access_key, self.secret_key, service, formatted_query, formatted_body) | ||
|
||
def is_cache_model(self): | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
以下是针对你提供的代码的一些建议和优化提示:
- 空文件头:
- 文件开头部分包含了多个注释,但这些注释似乎并不是必要的。
- 建议去掉不必要的空行或保留关键部分的注释。
# coding=utf-8
- 编码问题:
- 确保所有字符串都使用 UTF-8 编码,并且在处理时避免乱码情况。
method = 'POST'
host = 'visual.volcengineapi.com'
region = 'cn-north-1'
endpoint = 'https://visual.volcengineapi.com'
service = 'cv'
req_key_dict = {
'general_v1.4': 'high_aes_general_v14',
'general_v2.0': 'high_aes_general_v20',
'general_v2.0_L': 'high_aes_general_v20_L',
'anime_v1.3': 'high_aes',
'anime_v1.3.1': 'high_aes',
}
- 日期格式化:
current_date
变量可以直接使用当前时间戳来生成,而不是手动指定的。- 这样可以减少日期相关的错误和不确定性。
current_date = datetime.datetime.utcnow().strftime('%Y%m%dT%H%M%SZ')
- 签名函数中的键顺序:
- 在计算签名密钥时,确保
signKeyDict
中的键是按字典序排序的。虽然这里没有直接使用签名密钥列表,但在其他地方可能会这样做,因此仍然保持一致性比较好。
- 在计算签名密钥时,确保
kDate = sign(key.encode('utf-8'), dateStamp)
kRegion = sign(kDate, regionName)
kService = sign(kRegion, serviceName)
kSigning = sign(kService, 'request')
return kSigning
-
参数拼接:
- 确保所有的查询参数都是先按字母顺序排序后连接起来的。
- 避免出现遗漏任何参数的情况。
-
响应处理:
- 在处理请求失败时,应该捕获 HTTP 错误并提供更友好的用户反馈。
try:
r = requests.post(request_url, headers=headers, data=req_body)
except Exception as err:
print(f'error occurred: {err}')
raise
else:
...
-
方法签名:
generate_image
方法中,应明确其返回值类型(例如:-> str | list
).
-
类属性初始化:
__init__
方法中,初始化字段时应该使用默认参数,并对缺失的关键字进行警告或默认设置。
def __init__(self, model_version=None, access_key=None, secret_key=None, params={}):
super().__init__()
self.access_key = access_key or None
self.secret_key = secret_key or None
self.model_version = model_version or ""
self.params = params.copy() # 复制以防止外部修改原输入
-
缓存支持:
- 如果需要实现缓存机制,可以在相应的方法中添加逻辑,并确保缓存策略符合业务需求。
-
异步改进:
- 引入 Python 的
asyncio
库可以提高程序并发性和响应速度,特别是在数据传输密集型任务中。
- 引入 Python 的
import asyncio
async def async_generate_image(self, prompt: str, negative_prompt: str = None):
# 同同步版本的 generate_image 实现
...
async def run_async_methods():
await asyncio.gather(
self.async_generate_image("生成一张小猫图片"),
self.async_check_auth()
)
if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(run_async_methods())
以上是一些可能有助于改善代码质量和可维护性的建议。你可以根据具体的需求进一步调整和完善代码。
@@ -377,6 +536,8 @@ | |||
ModelTypeConst.EMBEDDING, | |||
xinference_embedding_model_credential, XinferenceEmbedding)) | |||
.append_model_info_list(rerank_list) | |||
.append_model_info_list(image_model_info) | |||
.append_model_info_list(tti_model_info) | |||
.append_default_model_info(rerank_list[0]) | |||
.build()) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这段代码主要增加了xinference模型和模型管理的类以及相关的配置信息,以支持新的图像处理和文本转图片(Text to Image)等功能。以下是一些可能存在的问题及建议:
-
冗余导入:在
setting/models_provider/impl/xinference_model_provider/model/tts.py
中引入了不必要的文件import xx_informer.py
。
建议:删除不必要的导入。 -
重复定义:
model_info_list
,xinference_llm_model_credential
,xinference_stt_model_credential
,xinference_tts_model_credential
, 和xinqineering_image_model_credential
可能已经存在相同的定义。建议根据实际情况重新命名或合并这些变量。 -
逻辑错误:确保
xinference_embedding_model_credential
被正确地创建后添加到embedding_model_info
列表,以便后续使用该凭据进行嵌入操作。 -
拼写检查:一些注释中的拼写的错误,如“xinference_embedding”,应该修正为正确的名称。
-
参数缺失:在
ModelInfo
对象初始化时需要确保所有必填字段都已经提供,否则可能会产生异常。 -
性能考虑:在构建
model_info_mgr
过程中添加大量数据项会导致性能问题,因此可以尝试分段加载或缓存部分数据。
综上所述,以上问题是由于新增加的功能导致的代码不够简洁且可能存在逻辑上的问题和可读性不强的问题。
'err_message': self.err_message, | ||
'image_list': self.context.get('image_list'), | ||
'dialogue_type': self.context.get('dialogue_type') | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
没有明显的不规范、潜在问题或优化建议。
What this PR does / why we need it?
Summary of your change
Please indicate you've done the following: