diff --git a/apps/setting/models_provider/base_model_provider.py b/apps/setting/models_provider/base_model_provider.py index c4722c9f59d..c86068c68a8 100644 --- a/apps/setting/models_provider/base_model_provider.py +++ b/apps/setting/models_provider/base_model_provider.py @@ -149,6 +149,7 @@ class ModelTypeConst(Enum): EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'} STT = {'code': 'STT', 'message': '语音识别'} TTS = {'code': 'TTS', 'message': '语音合成'} + IMAGE = {'code': 'IMAGE', 'message': '图片理解'} RERANKER = {'code': 'RERANKER', 'message': '重排模型'} diff --git a/apps/setting/models_provider/impl/base_image.py b/apps/setting/models_provider/impl/base_image.py new file mode 100644 index 00000000000..70bc99595d1 --- /dev/null +++ b/apps/setting/models_provider/impl/base_image.py @@ -0,0 +1,14 @@ +# coding=utf-8 +from abc import abstractmethod + +from pydantic import BaseModel + + +class BaseImage(BaseModel): + @abstractmethod + def check_auth(self): + pass + + @abstractmethod + def image_understand(self, image_file, text): + pass diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/image.py b/apps/setting/models_provider/impl/openai_model_provider/credential/image.py new file mode 100644 index 00000000000..30404022251 --- /dev/null +++ b/apps/setting/models_provider/impl/openai_model_provider/credential/image.py @@ -0,0 +1,42 @@ +# coding=utf-8 +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class OpenAIImageModelCredential(BaseForm, BaseModelCredential): + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + pass diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/image.py b/apps/setting/models_provider/impl/openai_model_provider/model/image.py new file mode 100644 index 00000000000..556b4a0ee81 --- /dev/null +++ b/apps/setting/models_provider/impl/openai_model_provider/model/image.py @@ -0,0 +1,79 @@ +import base64 +import os +from typing import Dict + +from openai import OpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_image import BaseImage + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class OpenAIImage(MaxKBBaseModel, BaseImage): + api_base: str + api_key: str + model: str + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get('api_key') + self.api_base = kwargs.get('api_base') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + return OpenAIImage( + model=model_name, + api_base=model_credential.get('api_base'), + api_key=model_credential.get('api_key'), + **optional_params, + ) + + def check_auth(self): + client = OpenAI( + base_url=self.api_base, + api_key=self.api_key + ) + response_list = client.models.with_raw_response.list() + # print(response_list) + # cwd = os.path.dirname(os.path.abspath(__file__)) + # with open(f'{cwd}/img_1.png', 'rb') as f: + # self.image_understand(f, "一句话概述这个图片") + + def image_understand(self, image_file, text): + client = OpenAI( + base_url=self.api_base, + api_key=self.api_key + ) + base64_image = base64.b64encode(image_file.read()).decode('utf-8') + + response = client.chat.completions.create( + model=self.model, + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": text, + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}" + }, + }, + ], + } + ], + ) + return response.choices[0].message.content diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/img_1.png b/apps/setting/models_provider/impl/openai_model_provider/model/img_1.png new file mode 100644 index 00000000000..ccb9d3b2035 Binary files /dev/null and b/apps/setting/models_provider/impl/openai_model_provider/model/img_1.png differ diff --git a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py index f9221388f65..974599cc89c 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py +++ b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py @@ -12,10 +12,12 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \ ModelTypeConst, ModelInfoManage from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential +from setting.models_provider.impl.openai_model_provider.credential.image import OpenAIImageModelCredential from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential from setting.models_provider.impl.openai_model_provider.credential.stt import OpenAISTTModelCredential from setting.models_provider.impl.openai_model_provider.credential.tts import OpenAITTSModelCredential from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel +from setting.models_provider.impl.openai_model_provider.model.image import OpenAIImage from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel from setting.models_provider.impl.openai_model_provider.model.stt import OpenAISpeechToText from setting.models_provider.impl.openai_model_provider.model.tts import OpenAITextToSpeech @@ -24,6 +26,7 @@ openai_llm_model_credential = OpenAILLMModelCredential() openai_stt_model_credential = OpenAISTTModelCredential() openai_tts_model_credential = OpenAITTSModelCredential() +openai_image_model_credential = OpenAIImageModelCredential() model_info_list = [ ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, openai_llm_model_credential, OpenAIChatModel @@ -88,11 +91,20 @@ OpenAIEmbeddingModel) ] +model_info_image_list = [ + ModelInfo('gpt-4o', '最新的GPT-4o,比gpt-4-turbo更便宜、更快,随OpenAI调整而更新', + ModelTypeConst.IMAGE, openai_image_model_credential, + OpenAIImage), + ModelInfo('gpt-4o-mini', '最新的gpt-4o-mini,比gpt-4o更便宜、更快,随OpenAI调整而更新', + ModelTypeConst.IMAGE, openai_image_model_credential, + OpenAIImage), +] + model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, openai_llm_model_credential, OpenAIChatModel )).append_model_info_list(model_info_embedding_list).append_default_model_info( - model_info_embedding_list[0]).build() + model_info_embedding_list[0]).append_model_info_list(model_info_image_list).build() class OpenAIModelProvider(IModelProvider): diff --git a/apps/setting/models_provider/impl/xf_model_provider/credential/image.py b/apps/setting/models_provider/impl/xf_model_provider/credential/image.py new file mode 100644 index 00000000000..88345a545ac --- /dev/null +++ b/apps/setting/models_provider/impl/xf_model_provider/credential/image.py @@ -0,0 +1,46 @@ +# coding=utf-8 + +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class XunFeiImageModelCredential(BaseForm, BaseModelCredential): + spark_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://spark-api.cn-huabei-1.xf-yun.com/v2.1/image') + spark_app_id = forms.TextInputField('APP ID', required=True) + spark_api_key = forms.PasswordInputField("API Key", required=True) + spark_api_secret = forms.PasswordInputField('API Secret', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))} + + + def get_model_params_setting_form(self, model_name): + pass diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/image.py b/apps/setting/models_provider/impl/xf_model_provider/model/image.py new file mode 100644 index 00000000000..91e8fa0cc57 --- /dev/null +++ b/apps/setting/models_provider/impl/xf_model_provider/model/image.py @@ -0,0 +1,170 @@ +# coding=utf-8 + +import asyncio +import base64 +import datetime +import hashlib +import hmac +import json +import os +import ssl +from datetime import datetime, UTC +from typing import Dict +from urllib.parse import urlencode +from urllib.parse import urlparse + +import websockets + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_image import BaseImage + +ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +ssl_context.check_hostname = False +ssl_context.verify_mode = ssl.CERT_NONE + + +class XFSparkImage(MaxKBBaseModel, BaseImage): + spark_app_id: str + spark_api_key: str + spark_api_secret: str + spark_api_url: str + params: dict + + # 初始化 + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.spark_api_url = kwargs.get('spark_api_url') + self.spark_app_id = kwargs.get('spark_app_id') + self.spark_api_key = kwargs.get('spark_api_key') + self.spark_api_secret = kwargs.get('spark_api_secret') + self.params = kwargs.get('params') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {'params': {}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value + return XFSparkImage( + spark_app_id=model_credential.get('spark_app_id'), + spark_api_key=model_credential.get('spark_api_key'), + spark_api_secret=model_credential.get('spark_api_secret'), + spark_api_url=model_credential.get('spark_api_url'), + **optional_params + ) + + def create_url(self): + url = self.spark_api_url + host = urlparse(url).hostname + # 生成RFC1123格式的时间戳 + gmt_format = '%a, %d %b %Y %H:%M:%S GMT' + date = datetime.now(UTC).strftime(gmt_format) + + # 拼接字符串 + signature_origin = "host: " + host + "\n" + signature_origin += "date: " + date + "\n" + signature_origin += "GET " + "/v2.1/image " + "HTTP/1.1" + # 进行hmac-sha256进行加密 + signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'), + digestmod=hashlib.sha256).digest() + signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8') + + authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % ( + self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha) + authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') + # 将请求的鉴权参数组合为字典 + v = { + "authorization": authorization, + "date": date, + "host": host + } + # 拼接鉴权参数,生成url + url = url + '?' + urlencode(v) + # print("date: ",date) + # print("v: ",v) + # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 + # print('websocket url :', url) + return url + + def check_auth(self): + cwd = os.path.dirname(os.path.abspath(__file__)) + with open(f'{cwd}/img_1.png', 'rb') as f: + self.image_understand(f,"一句话概述这个图片") + + def image_understand(self, image_file, question): + async def handle(): + async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws: + # 发送 full client request + await self.send(ws, image_file, question) + return await self.handle_message(ws) + + return asyncio.run(handle()) + + # 收到websocket消息的处理 + @staticmethod + async def handle_message(ws): + # print(message) + answer = '' + while True: + res = await ws.recv() + data = json.loads(res) + code = data['header']['code'] + if code != 0: + return f'请求错误: {code}, {data}' + else: + choices = data["payload"]["choices"] + status = choices["status"] + content = choices["text"][0]["content"] + # print(content, end="") + answer += content + # print(1) + if status == 2: + break + return answer + + async def send(self, ws, image_file, question): + text = [ + {"role": "user", "content": str(base64.b64encode(image_file.read()), 'utf-8'), "content_type": "image"}, + {"role": "user", "content": question} + ] + + data = { + "header": { + "app_id": self.spark_app_id + }, + "parameter": { + "chat": { + "domain": "image", + "temperature": 0.5, + "top_k": 4, + "max_tokens": 2028, + "auditing": "default" + } + }, + "payload": { + "message": { + "text": text + } + } + } + + d = json.dumps(data) + await ws.send(d) + + def is_cache_model(self): + return False + + @staticmethod + def get_len(text): + length = 0 + for content in text: + temp = content["content"] + leng = len(temp) + length += leng + return length + + def check_len(self, text): + print("text-content-tokens:", self.get_len(text[1:])) + while (self.get_len(text[1:]) > 8000): + del text[1] + return text diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/img_1.png b/apps/setting/models_provider/impl/xf_model_provider/model/img_1.png new file mode 100644 index 00000000000..ccb9d3b2035 Binary files /dev/null and b/apps/setting/models_provider/impl/xf_model_provider/model/img_1.png differ diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/stt.py b/apps/setting/models_provider/impl/xf_model_provider/model/stt.py index f400473ed84..d36bcdb9feb 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/model/stt.py +++ b/apps/setting/models_provider/impl/xf_model_provider/model/stt.py @@ -10,7 +10,7 @@ import json import logging import os -from datetime import datetime +from datetime import datetime, UTC from typing import Dict from urllib.parse import urlencode, urlparse import ssl @@ -63,7 +63,7 @@ def create_url(self): host = urlparse(url).hostname # 生成RFC1123格式的时间戳 gmt_format = '%a, %d %b %Y %H:%M:%S GMT' - date = datetime.utcnow().strftime(gmt_format) + date = datetime.now(UTC).strftime(gmt_format) # 拼接字符串 signature_origin = "host: " + host + "\n" diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/tts.py b/apps/setting/models_provider/impl/xf_model_provider/model/tts.py index 3a575ed28b2..1d677b7c32a 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/model/tts.py +++ b/apps/setting/models_provider/impl/xf_model_provider/model/tts.py @@ -12,7 +12,7 @@ import json import logging import os -from datetime import datetime +from datetime import datetime, UTC from typing import Dict from urllib.parse import urlencode, urlparse import ssl @@ -67,7 +67,7 @@ def create_url(self): host = urlparse(url).hostname # 生成RFC1123格式的时间戳 gmt_format = '%a, %d %b %Y %H:%M:%S GMT' - date = datetime.utcnow().strftime(gmt_format) + date = datetime.now(UTC).strftime(gmt_format) # 拼接字符串 signature_origin = "host: " + host + "\n" diff --git a/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py b/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py index 04fd2d439d4..37bd3e43594 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py +++ b/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py @@ -13,10 +13,12 @@ from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \ ModelInfoManage from setting.models_provider.impl.xf_model_provider.credential.embedding import XFEmbeddingCredential +from setting.models_provider.impl.xf_model_provider.credential.image import XunFeiImageModelCredential from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential from setting.models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential from setting.models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential from setting.models_provider.impl.xf_model_provider.model.embedding import XFEmbedding +from setting.models_provider.impl.xf_model_provider.model.image import XFSparkImage from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM from setting.models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText from setting.models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech @@ -26,6 +28,7 @@ qwen_model_credential = XunFeiLLMModelCredential() stt_model_credential = XunFeiSTTModelCredential() +image_model_credential = XunFeiImageModelCredential() tts_model_credential = XunFeiTTSModelCredential() embedding_model_credential = XFEmbeddingCredential() model_info_list = [ @@ -34,6 +37,7 @@ ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText), ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech), + ModelInfo('image', '', ModelTypeConst.IMAGE, image_model_credential, XFSparkImage), ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding) ] diff --git a/ui/src/views/template/index.vue b/ui/src/views/template/index.vue index 0a33c8ea9b8..20c9a62e542 100644 --- a/ui/src/views/template/index.vue +++ b/ui/src/views/template/index.vue @@ -132,6 +132,7 @@ +