diff --git a/apps/setting/models_provider/constants/model_provider_constants.py b/apps/setting/models_provider/constants/model_provider_constants.py
index 0a46cbfa41f..1e587bba246 100644
--- a/apps/setting/models_provider/constants/model_provider_constants.py
+++ b/apps/setting/models_provider/constants/model_provider_constants.py
@@ -16,6 +16,7 @@
from setting.models_provider.impl.kimi_model_provider.kimi_model_provider import KimiModelProvider
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
+from setting.models_provider.impl.deepseek_model_provider.deepseek_model_provider import DeepSeekModelProvider
class ModelProvideConstants(Enum):
@@ -27,3 +28,4 @@ class ModelProvideConstants(Enum):
model_qwen_provider = QwenModelProvider()
model_zhipu_provider = ZhiPuModelProvider()
model_xf_provider = XunFeiModelProvider()
+ model_deepseek_provider = DeepSeekModelProvider()
diff --git a/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py b/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py
index 1f04a268a42..5731d7e38d1 100644
--- a/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py
+++ b/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py
@@ -10,7 +10,6 @@
from typing import Dict
from langchain.schema import HumanMessage
-from langchain_community.chat_models.azure_openai import AzureChatOpenAI
from common import forms
from common.exception.app_exception import AppApiException
@@ -22,7 +21,7 @@
from setting.models_provider.impl.azure_model_provider.model.azure_chat_model import AzureChatModel
from smartdoc.conf import PROJECT_DIR
-
+"""
class AzureLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
@@ -52,11 +51,12 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
- api_base = forms.TextInputField('API 域名', required=True)
+ api_base = forms.TextInputField('API 版本 (api_version)', required=True)
- api_key = forms.PasswordInputField("API Key", required=True)
+ api_key = forms.PasswordInputField("API Key(API 密钥)", required=True)
- deployment_name = forms.TextInputField("部署名", required=True)
+ deployment_name = forms.TextInputField("部署名(deployment_name)", required=True)
+"""
class DefaultAzureLLMModelCredential(BaseForm, BaseModelCredential):
@@ -88,28 +88,23 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
- api_version = forms.TextInputField("api_version", required=True)
+ api_version = forms.TextInputField("API 版本 (api_version)", required=True)
- api_base = forms.TextInputField('API 域名', required=True)
+ api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
- api_key = forms.PasswordInputField("API Key", required=True)
+ api_key = forms.PasswordInputField("API Key (api_key)", required=True)
- deployment_name = forms.TextInputField("部署名", required=True)
+ deployment_name = forms.TextInputField("部署名 (deployment_name)", required=True)
-azure_llm_model_credential = AzureLLMModelCredential()
+# azure_llm_model_credential: AzureLLMModelCredential = AzureLLMModelCredential()
base_azure_llm_model_credential = DefaultAzureLLMModelCredential()
model_dict = {
- 'gpt-3.5-turbo-0613': ModelInfo('gpt-3.5-turbo-0613', '', ModelTypeConst.LLM, azure_llm_model_credential,
- api_version='2023-07-01-preview'),
- 'gpt-3.5-turbo-0301': ModelInfo('gpt-3.5-turbo-0301', '', ModelTypeConst.LLM, azure_llm_model_credential,
- api_version='2023-07-01-preview'),
- 'gpt-3.5-turbo-16k-0613': ModelInfo('gpt-3.5-turbo-16k-0613', '', ModelTypeConst.LLM, azure_llm_model_credential,
- api_version='2023-07-01-preview'),
- 'gpt-4-0613': ModelInfo('gpt-4-0613', '', ModelTypeConst.LLM, azure_llm_model_credential,
- api_version='2023-07-01-preview'),
+ 'deployment_name': ModelInfo('Azure OpenAI', '具体的基础模型由部署名决定', ModelTypeConst.LLM,
+ base_azure_llm_model_credential, api_version='2024-02-15-preview'
+ )
}
@@ -118,12 +113,11 @@ class AzureModelProvider(IModelProvider):
def get_dialogue_number(self):
return 3
- def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> AzureChatOpenAI:
+ def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> AzureChatModel:
model_info: ModelInfo = model_dict.get(model_name)
azure_chat_open_ai = AzureChatModel(
azure_endpoint=model_credential.get('api_base'),
- openai_api_version=model_info.api_version if model_name in model_dict else model_credential.get(
- 'api_version'),
+ openai_api_version=model_credential.get('api_version', '2024-02-15-preview'),
deployment_name=model_credential.get('deployment_name'),
openai_api_key=model_credential.get('api_key'),
openai_api_type="azure"
diff --git a/apps/setting/models_provider/impl/deepseek_model_provider/__init__.py b/apps/setting/models_provider/impl/deepseek_model_provider/__init__.py
new file mode 100644
index 00000000000..ee456da1ffe
--- /dev/null
+++ b/apps/setting/models_provider/impl/deepseek_model_provider/__init__.py
@@ -0,0 +1,8 @@
+#!/usr/bin/env python
+# -*- coding: UTF-8 -*-
+"""
+@Project :MaxKB
+@File :__init__.py.py
+@Author :Brian Yang
+@Date :5/12/24 7:38 AM
+"""
diff --git a/apps/setting/models_provider/impl/deepseek_model_provider/deepseek_model_provider.py b/apps/setting/models_provider/impl/deepseek_model_provider/deepseek_model_provider.py
new file mode 100644
index 00000000000..3baa5f04ad7
--- /dev/null
+++ b/apps/setting/models_provider/impl/deepseek_model_provider/deepseek_model_provider.py
@@ -0,0 +1,97 @@
+#!/usr/bin/env python
+# -*- coding: UTF-8 -*-
+"""
+@Project :MaxKB
+@File :deepseek_model_provider.py
+@Author :Brian Yang
+@Date :5/12/24 7:40 AM
+"""
+import os
+from typing import Dict
+
+from langchain.schema import HumanMessage
+
+from common import forms
+from common.exception.app_exception import AppApiException
+from common.forms import BaseForm
+from common.util.file_util import get_file_content
+from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
+ ModelInfo, ModelTypeConst, ValidCode
+from setting.models_provider.impl.deepseek_model_provider.model.deepseek_chat_model import DeepSeekChatModel
+from smartdoc.conf import PROJECT_DIR
+
+
+class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential):
+
+ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
+ model_type_list = DeepSeekModelProvider().get_model_type_list()
+ if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
+ raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
+
+ for key in ['api_key']:
+ if key not in model_credential:
+ if raise_exception:
+ raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
+ else:
+ return False
+ try:
+ model = DeepSeekModelProvider().get_model(model_type, model_name, model_credential)
+ model.invoke([HumanMessage(content='你好')])
+ except Exception as e:
+ if isinstance(e, AppApiException):
+ raise e
+ if raise_exception:
+ raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
+ else:
+ return False
+ return True
+
+ def encryption_dict(self, model: Dict[str, object]):
+ return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
+
+ api_key = forms.PasswordInputField('API Key', required=True)
+
+
+deepseek_llm_model_credential = DeepSeekLLMModelCredential()
+
+model_dict = {
+ 'deepseek-chat': ModelInfo('deepseek-chat', '擅长通用对话任务,支持 32K 上下文', ModelTypeConst.LLM,
+ deepseek_llm_model_credential,
+ ),
+ 'deepseek-coder': ModelInfo('deepseek-coder', '擅长处理编程任务,支持 16K 上下文', ModelTypeConst.LLM,
+ deepseek_llm_model_credential,
+ ),
+}
+
+
+class DeepSeekModelProvider(IModelProvider):
+
+ def get_dialogue_number(self):
+ return 3
+
+ def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> DeepSeekChatModel:
+ deepseek_chat_open_ai = DeepSeekChatModel(
+ model=model_name,
+ openai_api_base='https://api.deepseek.com',
+ openai_api_key=model_credential.get('api_key')
+ )
+ return deepseek_chat_open_ai
+
+ def get_model_credential(self, model_type, model_name):
+ if model_name in model_dict:
+ return model_dict.get(model_name).model_credential
+ return deepseek_llm_model_credential
+
+ def get_model_provide_info(self):
+ return ModelProvideInfo(provider='model_deepseek_provider', name='DeepSeek', icon=get_file_content(
+ os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'deepseek_model_provider', 'icon',
+ 'deepseek_icon_svg')))
+
+ def get_model_list(self, model_type: str):
+ if model_type is None:
+ raise AppApiException(500, '模型类型不能为空')
+ return [model_dict.get(key).to_dict() for key in
+ list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
+
+ def get_model_type_list(self):
+ return [{'key': "大语言模型", 'value': "LLM"}]
diff --git a/apps/setting/models_provider/impl/deepseek_model_provider/icon/deepseek_icon_svg b/apps/setting/models_provider/impl/deepseek_model_provider/icon/deepseek_icon_svg
new file mode 100644
index 00000000000..6ace8911a62
--- /dev/null
+++ b/apps/setting/models_provider/impl/deepseek_model_provider/icon/deepseek_icon_svg
@@ -0,0 +1,6 @@
+
\ No newline at end of file
diff --git a/apps/setting/models_provider/impl/deepseek_model_provider/model/deepseek_chat_model.py b/apps/setting/models_provider/impl/deepseek_model_provider/model/deepseek_chat_model.py
new file mode 100644
index 00000000000..b7a54b302d9
--- /dev/null
+++ b/apps/setting/models_provider/impl/deepseek_model_provider/model/deepseek_chat_model.py
@@ -0,0 +1,30 @@
+#!/usr/bin/env python
+# -*- coding: UTF-8 -*-
+"""
+@Project :MaxKB
+@File :deepseek_chat_model.py
+@Author :Brian Yang
+@Date :5/12/24 7:44 AM
+"""
+from typing import List
+
+from langchain_core.messages import BaseMessage, get_buffer_string
+from langchain_openai import ChatOpenAI
+
+from common.config.tokenizer_manage_config import TokenizerManage
+
+
+class DeepSeekChatModel(ChatOpenAI):
+ def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
+ try:
+ return super().get_num_tokens_from_messages(messages)
+ except Exception as e:
+ tokenizer = TokenizerManage.get_tokenizer()
+ return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
+
+ def get_num_tokens(self, text: str) -> int:
+ try:
+ return super().get_num_tokens(text)
+ except Exception as e:
+ tokenizer = TokenizerManage.get_tokenizer()
+ return len(tokenizer.encode(text))
diff --git a/ui/src/views/template/index.vue b/ui/src/views/template/index.vue
index eb8f6c51300..00ca3d8ffd5 100644
--- a/ui/src/views/template/index.vue
+++ b/ui/src/views/template/index.vue
@@ -70,6 +70,7 @@
ref="createModelRef"
@submit="list_model"
@change="openCreateModel($event)"
+ :key="dialogState.createModelDialogKey"
>
import { ElMessage } from 'element-plus'
-import { onMounted, ref, computed, watch } from 'vue'
+import { onMounted, ref, computed, reactive } from 'vue'
import ModelApi from '@/api/model'
import type { Provider, Model } from '@/api/type/model'
import AppIcon from '@/components/icons/AppIcon.vue'
@@ -128,6 +129,7 @@ const openCreateModel = (provider?: Provider) => {
createModelRef.value?.open(provider)
} else {
selectProviderRef.value?.open()
+ refreshCreateModelDialogKey() // 更新key
}
}
@@ -138,6 +140,16 @@ const list_model = () => {
})
}
+// 添加一个响应式的state来存储dialog的key
+const dialogState = reactive({
+ createModelDialogKey: Date.now() // 初始值为当前的时间戳
+})
+
+// 更新dialogState.createModelDialogKey的函数
+const refreshCreateModelDialogKey = () => {
+ dialogState.createModelDialogKey = Date.now() // 更新为新的时间戳
+}
+
onMounted(() => {
ModelApi.getProvider(loading).then((ok) => {
active_provider.value = allObj
@@ -154,6 +166,7 @@ onMounted(() => {
width: var(--setting-left-width);
min-width: var(--setting-left-width);
}
+
.model-list-height {
height: calc(var(--create-dataset-height) - 70px);
}