-
Notifications
You must be signed in to change notification settings - Fork 2.2k
feat: OpenAI图片模型 #1545
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: OpenAI图片模型 #1545
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# coding=utf-8 | ||
from abc import abstractmethod | ||
|
||
from pydantic import BaseModel | ||
|
||
|
||
class BaseImage(BaseModel): | ||
@abstractmethod | ||
def check_auth(self): | ||
pass | ||
|
||
@abstractmethod | ||
def image_understand(self, image_file, text): | ||
pass | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这段Python代码似乎需要一些结构性和功能性改进,以便能够清晰和可读。主要问题表现在很多方面:
from abc import ABC, abstractmethod
# 定义自定义基类,并添加抽象方法
class Image(BaseModel, ABC):
@classmethod
@abstractmethod
def check_auth(cls):
...
# 添加实际行为
@classmethod
@abstractmethod
def image_understand(cls, file_path : str, text:str) -> bool: 这样可以避免不必要的静态成员,在实现细节上保持封装性,同时利用 Pydantic 的功能增强数据安全性。 此外,请根据最新的Python版本使用正确的注释形式:在每一行或块尾部后跟制式空格而不是换行符,并确保遵循 PEP 8 格式指南中关于缩进的规定。 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# coding=utf-8 | ||
from typing import Dict | ||
|
||
from common import forms | ||
from common.exception.app_exception import AppApiException | ||
from common.forms import BaseForm | ||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode | ||
|
||
|
||
class OpenAIImageModelCredential(BaseForm, BaseModelCredential): | ||
api_base = forms.TextInputField('API 域名', required=True) | ||
api_key = forms.PasswordInputField('API Key', required=True) | ||
|
||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, | ||
raise_exception=False): | ||
model_type_list = provider.get_model_type_list() | ||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): | ||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') | ||
|
||
for key in ['api_base', 'api_key']: | ||
if key not in model_credential: | ||
if raise_exception: | ||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') | ||
else: | ||
return False | ||
try: | ||
model = provider.get_model(model_type, model_name, model_credential) | ||
model.check_auth() | ||
except Exception as e: | ||
if isinstance(e, AppApiException): | ||
raise e | ||
if raise_exception: | ||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') | ||
else: | ||
return False | ||
return True | ||
|
||
def encryption_dict(self, model: Dict[str, object]): | ||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))} | ||
|
||
def get_model_params_setting_form(self, model_name): | ||
pass | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这段代码已经相对规范,主要问题和优化建议是:
ERROR_MAPPING = {
AppApiException(ValidCode.valid_error.value): "参数无效",
}
总之,在保持简洁性的同时,可以进一步清理一些代码块中的无用说明和冗余内容。这种编程风格更加自然也更易维护。 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import base64 | ||
import os | ||
from typing import Dict | ||
|
||
from openai import OpenAI | ||
|
||
from common.config.tokenizer_manage_config import TokenizerManage | ||
from setting.models_provider.base_model_provider import MaxKBBaseModel | ||
from setting.models_provider.impl.base_image import BaseImage | ||
|
||
|
||
def custom_get_token_ids(text: str): | ||
tokenizer = TokenizerManage.get_tokenizer() | ||
return tokenizer.encode(text) | ||
|
||
|
||
class OpenAIImage(MaxKBBaseModel, BaseImage): | ||
api_base: str | ||
api_key: str | ||
model: str | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self.api_key = kwargs.get('api_key') | ||
self.api_base = kwargs.get('api_base') | ||
|
||
@staticmethod | ||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): | ||
optional_params = {} | ||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: | ||
optional_params['max_tokens'] = model_kwargs['max_tokens'] | ||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: | ||
optional_params['temperature'] = model_kwargs['temperature'] | ||
return OpenAIImage( | ||
model=model_name, | ||
api_base=model_credential.get('api_base'), | ||
api_key=model_credential.get('api_key'), | ||
**optional_params, | ||
) | ||
|
||
def check_auth(self): | ||
client = OpenAI( | ||
base_url=self.api_base, | ||
api_key=self.api_key | ||
) | ||
response_list = client.models.with_raw_response.list() | ||
# print(response_list) | ||
# cwd = os.path.dirname(os.path.abspath(__file__)) | ||
# with open(f'{cwd}/img_1.png', 'rb') as f: | ||
# self.image_understand(f, "一句话概述这个图片") | ||
|
||
def image_understand(self, image_file, text): | ||
client = OpenAI( | ||
base_url=self.api_base, | ||
api_key=self.api_key | ||
) | ||
base64_image = base64.b64encode(image_file.read()).decode('utf-8') | ||
|
||
response = client.chat.completions.create( | ||
model=self.model, | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{ | ||
"type": "text", | ||
"text": text, | ||
}, | ||
{ | ||
"type": "image_url", | ||
"image_url": { | ||
"url": f"data:image/jpeg;base64,{base64_image}" | ||
}, | ||
}, | ||
], | ||
} | ||
], | ||
) | ||
return response.choices[0].message.content | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 此代码文件中可能存在以下几点问题:
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,10 +12,12 @@ | |
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \ | ||
ModelTypeConst, ModelInfoManage | ||
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential | ||
from setting.models_provider.impl.openai_model_provider.credential.image import OpenAIImageModelCredential | ||
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential | ||
from setting.models_provider.impl.openai_model_provider.credential.stt import OpenAISTTModelCredential | ||
from setting.models_provider.impl.openai_model_provider.credential.tts import OpenAITTSModelCredential | ||
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel | ||
from setting.models_provider.impl.openai_model_provider.model.image import OpenAIImage | ||
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel | ||
from setting.models_provider.impl.openai_model_provider.model.stt import OpenAISpeechToText | ||
from setting.models_provider.impl.openai_model_provider.model.tts import OpenAITextToSpeech | ||
|
@@ -24,6 +26,7 @@ | |
openai_llm_model_credential = OpenAILLMModelCredential() | ||
openai_stt_model_credential = OpenAISTTModelCredential() | ||
openai_tts_model_credential = OpenAITTSModelCredential() | ||
openai_image_model_credential = OpenAIImageModelCredential() | ||
model_info_list = [ | ||
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, | ||
openai_llm_model_credential, OpenAIChatModel | ||
|
@@ -88,11 +91,20 @@ | |
OpenAIEmbeddingModel) | ||
] | ||
|
||
model_info_image_list = [ | ||
ModelInfo('gpt-4o', '最新的GPT-4o,比gpt-4-turbo更便宜、更快,随OpenAI调整而更新', | ||
ModelTypeConst.IMAGE, openai_image_model_credential, | ||
OpenAIImage), | ||
ModelInfo('gpt-4o-mini', '最新的gpt-4o-mini,比gpt-4o更便宜、更快,随OpenAI调整而更新', | ||
ModelTypeConst.IMAGE, openai_image_model_credential, | ||
OpenAIImage), | ||
] | ||
|
||
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( | ||
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, | ||
openai_llm_model_credential, OpenAIChatModel | ||
)).append_model_info_list(model_info_embedding_list).append_default_model_info( | ||
model_info_embedding_list[0]).build() | ||
model_info_embedding_list[0]).append_model_info_list(model_info_image_list).build() | ||
|
||
|
||
class OpenAIModelProvider(IModelProvider): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 该代码中的模型信息(ModelInfo、ModelProvideInfo)定义有重复,应该统一在IModelProvider中进行管理。其余逻辑看起来没有明显的缺陷或不一致之处。 另外,您可能想使用 |
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# coding=utf-8 | ||
|
||
from typing import Dict | ||
|
||
from common import forms | ||
from common.exception.app_exception import AppApiException | ||
from common.forms import BaseForm | ||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode | ||
|
||
|
||
class XunFeiImageModelCredential(BaseForm, BaseModelCredential): | ||
spark_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://spark-api.cn-huabei-1.xf-yun.com/v2.1/image') | ||
spark_app_id = forms.TextInputField('APP ID', required=True) | ||
spark_api_key = forms.PasswordInputField("API Key", required=True) | ||
spark_api_secret = forms.PasswordInputField('API Secret', required=True) | ||
|
||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, | ||
raise_exception=False): | ||
model_type_list = provider.get_model_type_list() | ||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): | ||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') | ||
|
||
for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']: | ||
if key not in model_credential: | ||
if raise_exception: | ||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') | ||
else: | ||
return False | ||
try: | ||
model = provider.get_model(model_type, model_name, model_credential) | ||
model.check_auth() | ||
except Exception as e: | ||
if isinstance(e, AppApiException): | ||
raise e | ||
if raise_exception: | ||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') | ||
else: | ||
return False | ||
return True | ||
|
||
def encryption_dict(self, model: Dict[str, object]): | ||
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))} | ||
|
||
|
||
def get_model_params_setting_form(self, model_name): | ||
pass | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 由于我是机器人,我不能查看文件内容。如果您能提供具体的源代码或描述您想要了解的问题,我会尽力帮助您的需求。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
没有看到明显的区别,这个代码是用于定义一些模型类型常量。如果需要提供更详细的分析,请提供更多上下文信息和需求。