diff --git a/apps/setting/models_provider/base_model_provider.py b/apps/setting/models_provider/base_model_provider.py
index c4722c9f59d..c86068c68a8 100644
--- a/apps/setting/models_provider/base_model_provider.py
+++ b/apps/setting/models_provider/base_model_provider.py
@@ -149,6 +149,7 @@ class ModelTypeConst(Enum):
EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'}
STT = {'code': 'STT', 'message': '语音识别'}
TTS = {'code': 'TTS', 'message': '语音合成'}
+ IMAGE = {'code': 'IMAGE', 'message': '图片理解'}
RERANKER = {'code': 'RERANKER', 'message': '重排模型'}
diff --git a/apps/setting/models_provider/impl/base_image.py b/apps/setting/models_provider/impl/base_image.py
new file mode 100644
index 00000000000..70bc99595d1
--- /dev/null
+++ b/apps/setting/models_provider/impl/base_image.py
@@ -0,0 +1,14 @@
+# coding=utf-8
+from abc import abstractmethod
+
+from pydantic import BaseModel
+
+
+class BaseImage(BaseModel):
+ @abstractmethod
+ def check_auth(self):
+ pass
+
+ @abstractmethod
+ def image_understand(self, image_file, text):
+ pass
diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/image.py b/apps/setting/models_provider/impl/openai_model_provider/credential/image.py
new file mode 100644
index 00000000000..30404022251
--- /dev/null
+++ b/apps/setting/models_provider/impl/openai_model_provider/credential/image.py
@@ -0,0 +1,42 @@
+# coding=utf-8
+from typing import Dict
+
+from common import forms
+from common.exception.app_exception import AppApiException
+from common.forms import BaseForm
+from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
+
+
+class OpenAIImageModelCredential(BaseForm, BaseModelCredential):
+ api_base = forms.TextInputField('API 域名', required=True)
+ api_key = forms.PasswordInputField('API Key', required=True)
+
+ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
+ raise_exception=False):
+ model_type_list = provider.get_model_type_list()
+ if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
+ raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
+
+ for key in ['api_base', 'api_key']:
+ if key not in model_credential:
+ if raise_exception:
+ raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
+ else:
+ return False
+ try:
+ model = provider.get_model(model_type, model_name, model_credential)
+ model.check_auth()
+ except Exception as e:
+ if isinstance(e, AppApiException):
+ raise e
+ if raise_exception:
+ raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
+ else:
+ return False
+ return True
+
+ def encryption_dict(self, model: Dict[str, object]):
+ return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
+
+ def get_model_params_setting_form(self, model_name):
+ pass
diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/image.py b/apps/setting/models_provider/impl/openai_model_provider/model/image.py
new file mode 100644
index 00000000000..556b4a0ee81
--- /dev/null
+++ b/apps/setting/models_provider/impl/openai_model_provider/model/image.py
@@ -0,0 +1,79 @@
+import base64
+import os
+from typing import Dict
+
+from openai import OpenAI
+
+from common.config.tokenizer_manage_config import TokenizerManage
+from setting.models_provider.base_model_provider import MaxKBBaseModel
+from setting.models_provider.impl.base_image import BaseImage
+
+
+def custom_get_token_ids(text: str):
+ tokenizer = TokenizerManage.get_tokenizer()
+ return tokenizer.encode(text)
+
+
+class OpenAIImage(MaxKBBaseModel, BaseImage):
+ api_base: str
+ api_key: str
+ model: str
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.api_key = kwargs.get('api_key')
+ self.api_base = kwargs.get('api_base')
+
+ @staticmethod
+ def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
+ optional_params = {}
+ if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
+ optional_params['max_tokens'] = model_kwargs['max_tokens']
+ if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
+ optional_params['temperature'] = model_kwargs['temperature']
+ return OpenAIImage(
+ model=model_name,
+ api_base=model_credential.get('api_base'),
+ api_key=model_credential.get('api_key'),
+ **optional_params,
+ )
+
+ def check_auth(self):
+ client = OpenAI(
+ base_url=self.api_base,
+ api_key=self.api_key
+ )
+ response_list = client.models.with_raw_response.list()
+ # print(response_list)
+ # cwd = os.path.dirname(os.path.abspath(__file__))
+ # with open(f'{cwd}/img_1.png', 'rb') as f:
+ # self.image_understand(f, "一句话概述这个图片")
+
+ def image_understand(self, image_file, text):
+ client = OpenAI(
+ base_url=self.api_base,
+ api_key=self.api_key
+ )
+ base64_image = base64.b64encode(image_file.read()).decode('utf-8')
+
+ response = client.chat.completions.create(
+ model=self.model,
+ messages=[
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": text,
+ },
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": f"data:image/jpeg;base64,{base64_image}"
+ },
+ },
+ ],
+ }
+ ],
+ )
+ return response.choices[0].message.content
diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/img_1.png b/apps/setting/models_provider/impl/openai_model_provider/model/img_1.png
new file mode 100644
index 00000000000..ccb9d3b2035
Binary files /dev/null and b/apps/setting/models_provider/impl/openai_model_provider/model/img_1.png differ
diff --git a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py
index f9221388f65..974599cc89c 100644
--- a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py
+++ b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py
@@ -12,10 +12,12 @@
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
ModelTypeConst, ModelInfoManage
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
+from setting.models_provider.impl.openai_model_provider.credential.image import OpenAIImageModelCredential
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
from setting.models_provider.impl.openai_model_provider.credential.stt import OpenAISTTModelCredential
from setting.models_provider.impl.openai_model_provider.credential.tts import OpenAITTSModelCredential
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
+from setting.models_provider.impl.openai_model_provider.model.image import OpenAIImage
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
from setting.models_provider.impl.openai_model_provider.model.stt import OpenAISpeechToText
from setting.models_provider.impl.openai_model_provider.model.tts import OpenAITextToSpeech
@@ -24,6 +26,7 @@
openai_llm_model_credential = OpenAILLMModelCredential()
openai_stt_model_credential = OpenAISTTModelCredential()
openai_tts_model_credential = OpenAITTSModelCredential()
+openai_image_model_credential = OpenAIImageModelCredential()
model_info_list = [
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
openai_llm_model_credential, OpenAIChatModel
@@ -88,11 +91,20 @@
OpenAIEmbeddingModel)
]
+model_info_image_list = [
+ ModelInfo('gpt-4o', '最新的GPT-4o,比gpt-4-turbo更便宜、更快,随OpenAI调整而更新',
+ ModelTypeConst.IMAGE, openai_image_model_credential,
+ OpenAIImage),
+ ModelInfo('gpt-4o-mini', '最新的gpt-4o-mini,比gpt-4o更便宜、更快,随OpenAI调整而更新',
+ ModelTypeConst.IMAGE, openai_image_model_credential,
+ OpenAIImage),
+]
+
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
openai_llm_model_credential, OpenAIChatModel
)).append_model_info_list(model_info_embedding_list).append_default_model_info(
- model_info_embedding_list[0]).build()
+ model_info_embedding_list[0]).append_model_info_list(model_info_image_list).build()
class OpenAIModelProvider(IModelProvider):
diff --git a/apps/setting/models_provider/impl/xf_model_provider/credential/image.py b/apps/setting/models_provider/impl/xf_model_provider/credential/image.py
new file mode 100644
index 00000000000..88345a545ac
--- /dev/null
+++ b/apps/setting/models_provider/impl/xf_model_provider/credential/image.py
@@ -0,0 +1,46 @@
+# coding=utf-8
+
+from typing import Dict
+
+from common import forms
+from common.exception.app_exception import AppApiException
+from common.forms import BaseForm
+from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
+
+
+class XunFeiImageModelCredential(BaseForm, BaseModelCredential):
+ spark_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://spark-api.cn-huabei-1.xf-yun.com/v2.1/image')
+ spark_app_id = forms.TextInputField('APP ID', required=True)
+ spark_api_key = forms.PasswordInputField("API Key", required=True)
+ spark_api_secret = forms.PasswordInputField('API Secret', required=True)
+
+ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
+ raise_exception=False):
+ model_type_list = provider.get_model_type_list()
+ if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
+ raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
+
+ for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']:
+ if key not in model_credential:
+ if raise_exception:
+ raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
+ else:
+ return False
+ try:
+ model = provider.get_model(model_type, model_name, model_credential)
+ model.check_auth()
+ except Exception as e:
+ if isinstance(e, AppApiException):
+ raise e
+ if raise_exception:
+ raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
+ else:
+ return False
+ return True
+
+ def encryption_dict(self, model: Dict[str, object]):
+ return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
+
+
+ def get_model_params_setting_form(self, model_name):
+ pass
diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/image.py b/apps/setting/models_provider/impl/xf_model_provider/model/image.py
new file mode 100644
index 00000000000..91e8fa0cc57
--- /dev/null
+++ b/apps/setting/models_provider/impl/xf_model_provider/model/image.py
@@ -0,0 +1,170 @@
+# coding=utf-8
+
+import asyncio
+import base64
+import datetime
+import hashlib
+import hmac
+import json
+import os
+import ssl
+from datetime import datetime, UTC
+from typing import Dict
+from urllib.parse import urlencode
+from urllib.parse import urlparse
+
+import websockets
+
+from setting.models_provider.base_model_provider import MaxKBBaseModel
+from setting.models_provider.impl.base_image import BaseImage
+
+ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ssl_context.check_hostname = False
+ssl_context.verify_mode = ssl.CERT_NONE
+
+
+class XFSparkImage(MaxKBBaseModel, BaseImage):
+ spark_app_id: str
+ spark_api_key: str
+ spark_api_secret: str
+ spark_api_url: str
+ params: dict
+
+ # 初始化
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.spark_api_url = kwargs.get('spark_api_url')
+ self.spark_app_id = kwargs.get('spark_app_id')
+ self.spark_api_key = kwargs.get('spark_api_key')
+ self.spark_api_secret = kwargs.get('spark_api_secret')
+ self.params = kwargs.get('params')
+
+ @staticmethod
+ def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
+ optional_params = {'params': {}}
+ for key, value in model_kwargs.items():
+ if key not in ['model_id', 'use_local', 'streaming']:
+ optional_params['params'][key] = value
+ return XFSparkImage(
+ spark_app_id=model_credential.get('spark_app_id'),
+ spark_api_key=model_credential.get('spark_api_key'),
+ spark_api_secret=model_credential.get('spark_api_secret'),
+ spark_api_url=model_credential.get('spark_api_url'),
+ **optional_params
+ )
+
+ def create_url(self):
+ url = self.spark_api_url
+ host = urlparse(url).hostname
+ # 生成RFC1123格式的时间戳
+ gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
+ date = datetime.now(UTC).strftime(gmt_format)
+
+ # 拼接字符串
+ signature_origin = "host: " + host + "\n"
+ signature_origin += "date: " + date + "\n"
+ signature_origin += "GET " + "/v2.1/image " + "HTTP/1.1"
+ # 进行hmac-sha256进行加密
+ signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
+ digestmod=hashlib.sha256).digest()
+ signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
+
+ authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
+ self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha)
+ authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
+ # 将请求的鉴权参数组合为字典
+ v = {
+ "authorization": authorization,
+ "date": date,
+ "host": host
+ }
+ # 拼接鉴权参数,生成url
+ url = url + '?' + urlencode(v)
+ # print("date: ",date)
+ # print("v: ",v)
+ # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
+ # print('websocket url :', url)
+ return url
+
+ def check_auth(self):
+ cwd = os.path.dirname(os.path.abspath(__file__))
+ with open(f'{cwd}/img_1.png', 'rb') as f:
+ self.image_understand(f,"一句话概述这个图片")
+
+ def image_understand(self, image_file, question):
+ async def handle():
+ async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
+ # 发送 full client request
+ await self.send(ws, image_file, question)
+ return await self.handle_message(ws)
+
+ return asyncio.run(handle())
+
+ # 收到websocket消息的处理
+ @staticmethod
+ async def handle_message(ws):
+ # print(message)
+ answer = ''
+ while True:
+ res = await ws.recv()
+ data = json.loads(res)
+ code = data['header']['code']
+ if code != 0:
+ return f'请求错误: {code}, {data}'
+ else:
+ choices = data["payload"]["choices"]
+ status = choices["status"]
+ content = choices["text"][0]["content"]
+ # print(content, end="")
+ answer += content
+ # print(1)
+ if status == 2:
+ break
+ return answer
+
+ async def send(self, ws, image_file, question):
+ text = [
+ {"role": "user", "content": str(base64.b64encode(image_file.read()), 'utf-8'), "content_type": "image"},
+ {"role": "user", "content": question}
+ ]
+
+ data = {
+ "header": {
+ "app_id": self.spark_app_id
+ },
+ "parameter": {
+ "chat": {
+ "domain": "image",
+ "temperature": 0.5,
+ "top_k": 4,
+ "max_tokens": 2028,
+ "auditing": "default"
+ }
+ },
+ "payload": {
+ "message": {
+ "text": text
+ }
+ }
+ }
+
+ d = json.dumps(data)
+ await ws.send(d)
+
+ def is_cache_model(self):
+ return False
+
+ @staticmethod
+ def get_len(text):
+ length = 0
+ for content in text:
+ temp = content["content"]
+ leng = len(temp)
+ length += leng
+ return length
+
+ def check_len(self, text):
+ print("text-content-tokens:", self.get_len(text[1:]))
+ while (self.get_len(text[1:]) > 8000):
+ del text[1]
+ return text
diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/img_1.png b/apps/setting/models_provider/impl/xf_model_provider/model/img_1.png
new file mode 100644
index 00000000000..ccb9d3b2035
Binary files /dev/null and b/apps/setting/models_provider/impl/xf_model_provider/model/img_1.png differ
diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/stt.py b/apps/setting/models_provider/impl/xf_model_provider/model/stt.py
index f400473ed84..d36bcdb9feb 100644
--- a/apps/setting/models_provider/impl/xf_model_provider/model/stt.py
+++ b/apps/setting/models_provider/impl/xf_model_provider/model/stt.py
@@ -10,7 +10,7 @@
import json
import logging
import os
-from datetime import datetime
+from datetime import datetime, UTC
from typing import Dict
from urllib.parse import urlencode, urlparse
import ssl
@@ -63,7 +63,7 @@ def create_url(self):
host = urlparse(url).hostname
# 生成RFC1123格式的时间戳
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
- date = datetime.utcnow().strftime(gmt_format)
+ date = datetime.now(UTC).strftime(gmt_format)
# 拼接字符串
signature_origin = "host: " + host + "\n"
diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/tts.py b/apps/setting/models_provider/impl/xf_model_provider/model/tts.py
index 3a575ed28b2..1d677b7c32a 100644
--- a/apps/setting/models_provider/impl/xf_model_provider/model/tts.py
+++ b/apps/setting/models_provider/impl/xf_model_provider/model/tts.py
@@ -12,7 +12,7 @@
import json
import logging
import os
-from datetime import datetime
+from datetime import datetime, UTC
from typing import Dict
from urllib.parse import urlencode, urlparse
import ssl
@@ -67,7 +67,7 @@ def create_url(self):
host = urlparse(url).hostname
# 生成RFC1123格式的时间戳
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
- date = datetime.utcnow().strftime(gmt_format)
+ date = datetime.now(UTC).strftime(gmt_format)
# 拼接字符串
signature_origin = "host: " + host + "\n"
diff --git a/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py b/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py
index 04fd2d439d4..37bd3e43594 100644
--- a/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py
+++ b/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py
@@ -13,10 +13,12 @@
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
ModelInfoManage
from setting.models_provider.impl.xf_model_provider.credential.embedding import XFEmbeddingCredential
+from setting.models_provider.impl.xf_model_provider.credential.image import XunFeiImageModelCredential
from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
from setting.models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential
from setting.models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential
from setting.models_provider.impl.xf_model_provider.model.embedding import XFEmbedding
+from setting.models_provider.impl.xf_model_provider.model.image import XFSparkImage
from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
from setting.models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText
from setting.models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech
@@ -26,6 +28,7 @@
qwen_model_credential = XunFeiLLMModelCredential()
stt_model_credential = XunFeiSTTModelCredential()
+image_model_credential = XunFeiImageModelCredential()
tts_model_credential = XunFeiTTSModelCredential()
embedding_model_credential = XFEmbeddingCredential()
model_info_list = [
@@ -34,6 +37,7 @@
ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
+ ModelInfo('image', '', ModelTypeConst.IMAGE, image_model_credential, XFSparkImage),
ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding)
]
diff --git a/ui/src/views/template/index.vue b/ui/src/views/template/index.vue
index 0a33c8ea9b8..20c9a62e542 100644
--- a/ui/src/views/template/index.vue
+++ b/ui/src/views/template/index.vue
@@ -132,6 +132,7 @@
+