Skip to content

feat: OpenAI图片模型 #1545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions apps/setting/models_provider/base_model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class ModelTypeConst(Enum):
EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'}
STT = {'code': 'STT', 'message': '语音识别'}
TTS = {'code': 'TTS', 'message': '语音合成'}
IMAGE = {'code': 'IMAGE', 'message': '图片理解'}
RERANKER = {'code': 'RERANKER', 'message': '重排模型'}


Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有看到明显的区别,这个代码是用于定义一些模型类型常量。如果需要提供更详细的分析,请提供更多上下文信息和需求。

Expand Down
14 changes: 14 additions & 0 deletions apps/setting/models_provider/impl/base_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# coding=utf-8
from abc import abstractmethod

from pydantic import BaseModel


class BaseImage(BaseModel):
@abstractmethod
def check_auth(self):
pass

@abstractmethod
def image_understand(self, image_file, text):
pass
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段Python代码似乎需要一些结构性和功能性改进,以便能够清晰和可读。主要问题表现在很多方面:

  • 使用的是 @staticmethod 而不是 @property 来设置 check_auth(self)image_understand(self, ...), 这会带来潜在的性能问题。
    • 将函数从抽象类继承为属性(@abstractmethod),并定义一个实例方法来调用它们以提高效率。
from abc import ABC, abstractmethod

# 定义自定义基类,并添加抽象方法
 class Image(BaseModel, ABC):
    @classmethod
    @abstractmethod
    def check_auth(cls):
        ...

    # 添加实际行为
    @classmethod
    @abstractmethod
    def image_understand(cls, file_path : str, text:str) -> bool:

这样可以避免不必要的静态成员,在实现细节上保持封装性,同时利用 Pydantic 的功能增强数据安全性。

此外,请根据最新的Python版本使用正确的注释形式:在每一行或块尾部后跟制式空格而不是换行符,并确保遵循 PEP 8 格式指南中关于缩进的规定。

Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# coding=utf-8
from typing import Dict

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode


class OpenAIImageModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')

for key in ['api_base', 'api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.check_auth()
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True

def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}

def get_model_params_setting_form(self, model_name):
pass
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段代码已经相对规范,主要问题和优化建议是:

  1. OpenAIImageModelCredential 类中加入一个 _error_mapping_ 属性:将 API 错误信息转换为错误消息。这可以帮助开发者更容易地根据这些错误进行调试:
ERROR_MAPPING = {
    AppApiException(ValidCode.valid_error.value): "参数无效",
}
  1. 删除了不必要的注释:有些注释可能过时或多余,并且不符合代码规范。
  2. 尽量使用上下文管理器 (with 关键字) 替代直接导入模块(除非确实需要)以避免导入所有成员。

总之,在保持简洁性的同时,可以进一步清理一些代码块中的无用说明和冗余内容。这种编程风格更加自然也更易维护。

Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import base64
import os
from typing import Dict

from openai import OpenAI

from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_image import BaseImage


def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)


class OpenAIImage(MaxKBBaseModel, BaseImage):
api_base: str
api_key: str
model: str

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base')

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return OpenAIImage(
model=model_name,
api_base=model_credential.get('api_base'),
api_key=model_credential.get('api_key'),
**optional_params,
)

def check_auth(self):
client = OpenAI(
base_url=self.api_base,
api_key=self.api_key
)
response_list = client.models.with_raw_response.list()
# print(response_list)
# cwd = os.path.dirname(os.path.abspath(__file__))
# with open(f'{cwd}/img_1.png', 'rb') as f:
# self.image_understand(f, "一句话概述这个图片")

def image_understand(self, image_file, text):
client = OpenAI(
base_url=self.api_base,
api_key=self.api_key
)
base64_image = base64.b64encode(image_file.read()).decode('utf-8')

response = client.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": text,
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
},
},
],
}
],
)
return response.choices[0].message.content
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此代码文件中可能存在以下几点问题:

  1. 该代码没有明确说明其使用的版本号,可能导致与其他模型或环境交互时出现兼容性问题。
  2. 文件头缺少版权信息、模块注释和类型声明等基本元素要求。

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
ModelTypeConst, ModelInfoManage
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
from setting.models_provider.impl.openai_model_provider.credential.image import OpenAIImageModelCredential
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
from setting.models_provider.impl.openai_model_provider.credential.stt import OpenAISTTModelCredential
from setting.models_provider.impl.openai_model_provider.credential.tts import OpenAITTSModelCredential
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
from setting.models_provider.impl.openai_model_provider.model.image import OpenAIImage
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
from setting.models_provider.impl.openai_model_provider.model.stt import OpenAISpeechToText
from setting.models_provider.impl.openai_model_provider.model.tts import OpenAITextToSpeech
Expand All @@ -24,6 +26,7 @@
openai_llm_model_credential = OpenAILLMModelCredential()
openai_stt_model_credential = OpenAISTTModelCredential()
openai_tts_model_credential = OpenAITTSModelCredential()
openai_image_model_credential = OpenAIImageModelCredential()
model_info_list = [
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
openai_llm_model_credential, OpenAIChatModel
Expand Down Expand Up @@ -88,11 +91,20 @@
OpenAIEmbeddingModel)
]

model_info_image_list = [
ModelInfo('gpt-4o', '最新的GPT-4o,比gpt-4-turbo更便宜、更快,随OpenAI调整而更新',
ModelTypeConst.IMAGE, openai_image_model_credential,
OpenAIImage),
ModelInfo('gpt-4o-mini', '最新的gpt-4o-mini,比gpt-4o更便宜、更快,随OpenAI调整而更新',
ModelTypeConst.IMAGE, openai_image_model_credential,
OpenAIImage),
]

model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
openai_llm_model_credential, OpenAIChatModel
)).append_model_info_list(model_info_embedding_list).append_default_model_info(
model_info_embedding_list[0]).build()
model_info_embedding_list[0]).append_model_info_list(model_info_image_list).build()


class OpenAIModelProvider(IModelProvider):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该代码中的模型信息(ModelInfo、ModelProvideInfo)定义有重复,应该统一在IModelProvider中进行管理。其余逻辑看起来没有明显的缺陷或不一致之处。

另外,您可能想使用@dataclass(auto_attribs=True)等特性提高编程效率和维护性。然而,这些更改与特定的时间段相关,因为它们可能会受到当前版本的OpenAI API的变化影响。如果有必要的话,请检查API文档以了解未来的开发变更策略。

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# coding=utf-8

from typing import Dict

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode


class XunFeiImageModelCredential(BaseForm, BaseModelCredential):
spark_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://spark-api.cn-huabei-1.xf-yun.com/v2.1/image')
spark_app_id = forms.TextInputField('APP ID', required=True)
spark_api_key = forms.PasswordInputField("API Key", required=True)
spark_api_secret = forms.PasswordInputField('API Secret', required=True)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')

for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.check_auth()
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True

def encryption_dict(self, model: Dict[str, object]):
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}


def get_model_params_setting_form(self, model_name):
pass
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

由于我是机器人,我不能查看文件内容。如果您能提供具体的源代码或描述您想要了解的问题,我会尽力帮助您的需求。

Loading
Loading