-
Notifications
You must be signed in to change notification settings - Fork 2.1k
feat: BaiLian Image Model #1844
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# coding=utf-8 | ||
""" | ||
@project: MaxKB | ||
@Author:虎 | ||
@file: llm.py | ||
@date:2024/7/11 18:41 | ||
@desc: | ||
""" | ||
import base64 | ||
import os | ||
from typing import Dict | ||
|
||
from langchain_core.messages import HumanMessage | ||
|
||
from common import forms | ||
from common.exception.app_exception import AppApiException | ||
from common.forms import BaseForm, TooltipLabel | ||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode | ||
|
||
|
||
class QwenModelParams(BaseForm): | ||
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), | ||
required=True, default_value=1.0, | ||
_min=0.1, | ||
_max=1.9, | ||
_step=0.01, | ||
precision=2) | ||
|
||
max_tokens = forms.SliderField( | ||
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), | ||
required=True, default_value=800, | ||
_min=1, | ||
_max=100000, | ||
_step=1, | ||
precision=0) | ||
|
||
|
||
class QwenVLModelCredential(BaseForm, BaseModelCredential): | ||
|
||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, | ||
raise_exception=False): | ||
model_type_list = provider.get_model_type_list() | ||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): | ||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') | ||
for key in ['api_key']: | ||
if key not in model_credential: | ||
if raise_exception: | ||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') | ||
else: | ||
return False | ||
try: | ||
model = provider.get_model(model_type, model_name, model_credential) | ||
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])]) | ||
for chunk in res: | ||
print(chunk) | ||
except Exception as e: | ||
if isinstance(e, AppApiException): | ||
raise e | ||
if raise_exception: | ||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') | ||
else: | ||
return False | ||
return True | ||
|
||
def encryption_dict(self, model: Dict[str, object]): | ||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))} | ||
|
||
api_key = forms.PasswordInputField('API Key', required=True) | ||
|
||
def get_model_params_setting_form(self, model_name): | ||
return QwenModelParams() | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# coding=utf-8 | ||
""" | ||
@project: MaxKB | ||
@Author:虎 | ||
@file: llm.py | ||
@date:2024/7/11 18:41 | ||
@desc: | ||
""" | ||
import base64 | ||
import os | ||
from typing import Dict | ||
|
||
from langchain_core.messages import HumanMessage | ||
|
||
from common import forms | ||
from common.exception.app_exception import AppApiException | ||
from common.forms import BaseForm, TooltipLabel | ||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode | ||
|
||
|
||
class QwenModelParams(BaseForm): | ||
size = forms.SingleSelect( | ||
TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'), | ||
required=True, | ||
default_value='1024*1024', | ||
option_list=[ | ||
{'value': '1024*1024', 'label': '1024*1024'}, | ||
{'value': '720*1280', 'label': '720*1280'}, | ||
{'value': '768*1152', 'label': '768*1152'}, | ||
{'value': '1280*720', 'label': '1280*720'}, | ||
], | ||
text_field='label', | ||
value_field='value') | ||
n = forms.SliderField( | ||
TooltipLabel('图片数量', '指定生成图片的数量'), | ||
required=True, default_value=1, | ||
_min=1, | ||
_max=4, | ||
_step=1, | ||
precision=0) | ||
style = forms.SingleSelect( | ||
TooltipLabel('风格', '指定生成图片的风格'), | ||
required=True, | ||
default_value='<auto>', | ||
option_list=[ | ||
{'value': '<auto>', 'label': '默认值,由模型随机输出图像风格'}, | ||
{'value': '<photography>', 'label': '摄影'}, | ||
{'value': '<portrait>', 'label': '人像写真'}, | ||
{'value': '<3d cartoon>', 'label': '3D卡通'}, | ||
{'value': '<anime>', 'label': '动画'}, | ||
{'value': '<oil painting>', 'label': '油画'}, | ||
{'value': '<watercolor>', 'label': '水彩'}, | ||
{'value': '<sketch>', 'label': '素描'}, | ||
{'value': '<chinese painting>', 'label': '中国画'}, | ||
{'value': '<flat illustration>', 'label': '扁平插画'}, | ||
], | ||
text_field='label', | ||
value_field='value' | ||
) | ||
|
||
|
||
class QwenTextToImageModelCredential(BaseForm, BaseModelCredential): | ||
|
||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, | ||
raise_exception=False): | ||
model_type_list = provider.get_model_type_list() | ||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): | ||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') | ||
for key in ['api_key']: | ||
if key not in model_credential: | ||
if raise_exception: | ||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') | ||
else: | ||
return False | ||
try: | ||
model = provider.get_model(model_type, model_name, model_credential) | ||
res = model.check_auth() | ||
print(res) | ||
except Exception as e: | ||
if isinstance(e, AppApiException): | ||
raise e | ||
if raise_exception: | ||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') | ||
else: | ||
return False | ||
return True | ||
|
||
def encryption_dict(self, model: Dict[str, object]): | ||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))} | ||
|
||
api_key = forms.PasswordInputField('API Key', required=True) | ||
|
||
def get_model_params_setting_form(self, model_name): | ||
return QwenModelParams() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 代码中存在以下不规范和潜在问题:
优化建议:
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# coding=utf-8 | ||
|
||
from typing import Dict | ||
|
||
from langchain_community.chat_models import ChatOpenAI | ||
|
||
from setting.models_provider.base_model_provider import MaxKBBaseModel | ||
|
||
|
||
class QwenVLChatModel(MaxKBBaseModel, ChatOpenAI): | ||
|
||
@staticmethod | ||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): | ||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) | ||
chat_tong_yi = QwenVLChatModel( | ||
model_name=model_name, | ||
openai_api_key=model_credential.get('api_key'), | ||
openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1', | ||
# stream_options={"include_usage": True}, | ||
streaming=True, | ||
model_kwargs=optional_params, | ||
) | ||
return chat_tong_yi |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# coding=utf-8 | ||
from http import HTTPStatus | ||
from typing import Dict | ||
|
||
from dashscope import ImageSynthesis | ||
from langchain_community.chat_models import ChatTongyi | ||
from langchain_core.messages import HumanMessage | ||
|
||
from setting.models_provider.base_model_provider import MaxKBBaseModel | ||
from setting.models_provider.impl.base_tti import BaseTextToImage | ||
|
||
|
||
class QwenTextToImageModel(MaxKBBaseModel, BaseTextToImage): | ||
api_key: str | ||
model_name: str | ||
params: dict | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self.api_key = kwargs.get('api_key') | ||
self.model_name = kwargs.get('model_name') | ||
self.params = kwargs.get('params') | ||
|
||
@staticmethod | ||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): | ||
optional_params = {'params': {'size': '1024*1024', 'style': '<auto>', 'n': 1}} | ||
for key, value in model_kwargs.items(): | ||
if key not in ['model_id', 'use_local', 'streaming']: | ||
optional_params['params'][key] = value | ||
chat_tong_yi = QwenTextToImageModel( | ||
model_name=model_name, | ||
api_key=model_credential.get('api_key'), | ||
**optional_params, | ||
) | ||
return chat_tong_yi | ||
|
||
def is_cache_model(self): | ||
return False | ||
|
||
def check_auth(self): | ||
chat = ChatTongyi(api_key=self.api_key, model_name='qwen-max') | ||
chat.invoke([HumanMessage([{"type": "text", "text": "你好"}])]) | ||
|
||
def generate_image(self, prompt: str, negative_prompt: str = None): | ||
# api_base='https://dashscope.aliyuncs.com/compatible-mode/v1', | ||
rsp = ImageSynthesis.call(api_key=self.api_key, | ||
model=self.model_name, | ||
prompt=prompt, | ||
negative_prompt=negative_prompt, | ||
**self.params) | ||
file_urls = [] | ||
if rsp.status_code == HTTPStatus.OK: | ||
for result in rsp.output.results: | ||
file_urls.append(result.url) | ||
else: | ||
print('sync_call Failed, status_code: %s, code: %s, message: %s' % | ||
(rsp.status_code, rsp.code, rsp.message)) | ||
return file_urls | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 看起来这段代码是一个从阿里云迁移到DashScope的文本到图像模型实现。以下是存在的一些问题和改进建议:
@@ -0,0 +1,58 @@
+# coding=utf-8
+from http import HTTPStatus
+from typing import Dict
+
+from dashscope.dashscope_client import DashScopeClient # 修正缺失导入
+from dashscope import ImageSynthesis
+from langchain_community.chat_models import ChatTongyi
+from langchain.core.messages import HumanMessage
+
+from setting.models_provider.base_model_provider import MaxKBBaseModel
+from setting.models_provider.impl.base_tti import BaseTextToImage
+
+
+class QwenTextToImageModel(MaxKBBaseModel, BaseTextToImage):
+ api_key: str
+ model_name: str
+ params: dict
以上是主要的问题和改进建议,可以根据实际情况进一步调整其他部分。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
代码整体结构良好,但存在一些需要关注的问题和优化点:
导入顺序问题:
应该先导入
HumanMessage
类。API 错误处理:
在捕获到其他异常时,建议只打印日志或记录错误信息,而不是重新抛出异常,除非有明确的需求。
加密逻辑简化:
如果
BaseModelCredential
类已经实现了一个安全的加密方法,并且这个加密逻辑适用于QwenVLModelCredential
,可以直接调用而不需要重复编写加密。模型测试功能:
需要添加调试信息以确保模型流获取成功,或者减少打印过多的日志以便于后续维护。
类成员顺序:
推荐遵循驼峰命名法(PascalCase)来定义类成员名称,例如使用
apiKey
而不是api_key
。注释文档:
通过上述改进,可以提高代码质量,避免潜在的安全风险和性能问题。