From 509b064f48d59f60ea1d151c43da33badc85b39d Mon Sep 17 00:00:00 2001 From: arvinxx Date: Wed, 31 Jan 2024 00:14:56 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A7=20wip:=20wip=20for=20model=20selec?= =?UTF-8?q?t?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.example | 6 +++ src/app/api/chat/[provider]/route.ts | 2 +- src/app/api/config/route.ts | 3 +- .../chat/(desktop)/features/ChatHeader.tsx | 34 +++++++++----- .../ChatHeader/ShareButton/Preview.tsx | 31 +++++++----- .../SessionListContent/List/Item/index.tsx | 47 ++++++++++++------- src/app/settings/llm/Zhipu/index.tsx | 15 ++---- src/components/ModelProviderIcons/index.ts | 5 -- src/components/ModelProviderIcons/index.tsx | 47 +++++++++++++++++++ src/components/ModelTag/index.tsx | 44 ++--------------- src/config/modelProviders.ts | 28 ++--------- src/config/server.ts | 2 + src/const/fetch.ts | 4 +- src/const/settings.ts | 5 +- src/database/models/message.ts | 4 +- src/database/schemas/message.ts | 1 + .../ChatInput/ActionBar/ModelSwitch.tsx | 30 ++++++++---- .../Conversation/Error/ApiKeyForm.tsx | 2 +- .../Conversation/Extras/Assistant.tsx | 2 +- src/libs/agent-runtime/utils/debugStream.ts | 18 +++++++ src/libs/agent-runtime/zhipu/chat.ts | 12 ++++- src/libs/agent-runtime/zhipu/createZhipu.ts | 6 +-- src/store/chat/slices/message/action.ts | 11 +++-- src/store/global/slices/common/action.ts | 10 +++- .../settings/selectors/modelProvider.ts | 33 +++++++++---- src/store/session/slices/agent/selectors.ts | 7 +++ src/types/agent/index.ts | 8 +++- src/types/message/index.ts | 1 + src/types/openai/chat.ts | 3 +- 29 files changed, 259 insertions(+), 162 deletions(-) delete mode 100644 src/components/ModelProviderIcons/index.ts create mode 100644 src/components/ModelProviderIcons/index.tsx create mode 100644 src/libs/agent-runtime/utils/debugStream.ts diff --git a/.env.example b/.env.example index f41f96004f66f..92be99cf7198d 100644 --- a/.env.example +++ b/.env.example @@ -33,6 +33,12 @@ OPENAI_API_KEY=sk-xxxxxxxxx # Azure's API version, follows the YYYY-MM-DD format # AZURE_API_VERSION=2023-08-01-preview +######################################## +############ ZhiPu AI Service ########## +######################################## + +# ZHIPU_API_KEY=xxxxxxxxxxxxxxxxxxx.xxxxxxxxxxxxx + ######################################## ############ Market Service ############ ######################################## diff --git a/src/app/api/chat/[provider]/route.ts b/src/app/api/chat/[provider]/route.ts index 9b41afc519272..2e5dc16545a0d 100644 --- a/src/app/api/chat/[provider]/route.ts +++ b/src/app/api/chat/[provider]/route.ts @@ -1,5 +1,5 @@ import { getPreferredRegion } from '@/app/api/config'; -import { createErrorResponse } from '@/app/api/openai/errorResponse'; +import { createErrorResponse } from '@/app/api/errorResponse'; import { CompletionError, LobeOpenAI, diff --git a/src/app/api/config/route.ts b/src/app/api/config/route.ts index c19a9c9686670..6a3c1466d3f46 100644 --- a/src/app/api/config/route.ts +++ b/src/app/api/config/route.ts @@ -7,10 +7,11 @@ export const runtime = 'edge'; * get Server config to client */ export const GET = async () => { - const { CUSTOM_MODELS } = getServerConfig(); + const { CUSTOM_MODELS, ENABLED_ZHIPU } = getServerConfig(); const config: GlobalServerConfig = { customModelName: CUSTOM_MODELS, + languageModel: { zhipu: { enabled: ENABLED_ZHIPU } }, }; return new Response(JSON.stringify(config)); diff --git a/src/app/chat/(desktop)/features/ChatHeader.tsx b/src/app/chat/(desktop)/features/ChatHeader.tsx index 1d78e08c8838d..3c68a0b113181 100644 --- a/src/app/chat/(desktop)/features/ChatHeader.tsx +++ b/src/app/chat/(desktop)/features/ChatHeader.tsx @@ -22,17 +22,27 @@ const Left = memo(() => { const router = useRouter(); - const [init, isInbox, title, description, avatar, backgroundColor, model, plugins] = - useSessionStore((s) => [ - sessionSelectors.isSomeSessionActive(s), - sessionSelectors.isInboxSession(s), - agentSelectors.currentAgentTitle(s), - agentSelectors.currentAgentDescription(s), - agentSelectors.currentAgentAvatar(s), - agentSelectors.currentAgentBackgroundColor(s), - agentSelectors.currentAgentModel(s), - agentSelectors.currentAgentPlugins(s), - ]); + const [ + init, + isInbox, + title, + description, + avatar, + backgroundColor, + model, + modelProvider, + plugins, + ] = useSessionStore((s) => [ + sessionSelectors.isSomeSessionActive(s), + sessionSelectors.isInboxSession(s), + agentSelectors.currentAgentTitle(s), + agentSelectors.currentAgentDescription(s), + agentSelectors.currentAgentAvatar(s), + agentSelectors.currentAgentBackgroundColor(s), + agentSelectors.currentAgentModel(s), + agentSelectors.currentAgentModelProvider(s), + agentSelectors.currentAgentPlugins(s), + ]); const displayTitle = isInbox ? t('inbox.title') : title; const displayDesc = isInbox ? t('inbox.desc') : description; @@ -63,7 +73,7 @@ const Left = memo(() => { desc={displayDesc} tag={ <> - + {plugins?.length > 0 && } } diff --git a/src/app/chat/features/ChatHeader/ShareButton/Preview.tsx b/src/app/chat/features/ChatHeader/ShareButton/Preview.tsx index 27076237adc17..8d623c9f2767d 100644 --- a/src/app/chat/features/ChatHeader/ShareButton/Preview.tsx +++ b/src/app/chat/features/ChatHeader/ShareButton/Preview.tsx @@ -15,16 +15,25 @@ import { FieldType } from './type'; const Preview = memo( ({ title, withSystemRole, withBackground, withFooter }) => { - const [isInbox, description, avatar, backgroundColor, model, plugins, systemRole] = - useSessionStore((s) => [ - sessionSelectors.isInboxSession(s), - agentSelectors.currentAgentDescription(s), - agentSelectors.currentAgentAvatar(s), - agentSelectors.currentAgentBackgroundColor(s), - agentSelectors.currentAgentModel(s), - agentSelectors.currentAgentPlugins(s), - agentSelectors.currentAgentSystemRole(s), - ]); + const [ + isInbox, + description, + avatar, + backgroundColor, + model, + modelProvider, + plugins, + systemRole, + ] = useSessionStore((s) => [ + sessionSelectors.isInboxSession(s), + agentSelectors.currentAgentDescription(s), + agentSelectors.currentAgentAvatar(s), + agentSelectors.currentAgentBackgroundColor(s), + agentSelectors.currentAgentModel(s), + agentSelectors.currentAgentModelProvider(s), + agentSelectors.currentAgentPlugins(s), + agentSelectors.currentAgentSystemRole(s), + ]); const { t } = useTranslation('chat'); const { styles } = useStyles(withBackground); @@ -42,7 +51,7 @@ const Preview = memo( desc={displayDesc} tag={ <> - + {plugins?.length > 0 && } } diff --git a/src/app/chat/features/SessionListContent/List/Item/index.tsx b/src/app/chat/features/SessionListContent/List/Item/index.tsx index 9151f288f8726..bf2e3e61b9b75 100644 --- a/src/app/chat/features/SessionListContent/List/Item/index.tsx +++ b/src/app/chat/features/SessionListContent/List/Item/index.tsx @@ -26,24 +26,35 @@ const SessionItem = memo(({ id }) => { const [active] = useSessionStore((s) => [s.activeId === id]); const [loading] = useChatStore((s) => [!!s.chatLoadingId && id === s.activeId]); - const [pin, title, description, systemRole, avatar, avatarBackground, updateAt, model, group] = - useSessionStore((s) => { - const session = sessionSelectors.getSessionById(id)(s); - const meta = session.meta; - const systemRole = session.config.systemRole; + const [ + pin, + title, + description, + systemRole, + avatar, + avatarBackground, + updateAt, + model, + provider, + group, + ] = useSessionStore((s) => { + const session = sessionSelectors.getSessionById(id)(s); + const meta = session.meta; + const systemRole = session.config.systemRole; - return [ - sessionHelpers.getSessionPinned(session), - agentSelectors.getTitle(meta), - agentSelectors.getDescription(meta), - systemRole, - agentSelectors.getAvatar(meta), - meta.backgroundColor, - session?.updatedAt, - session.config.model, - session?.group, - ]; - }); + return [ + sessionHelpers.getSessionPinned(session), + agentSelectors.getTitle(meta), + agentSelectors.getDescription(meta), + systemRole, + agentSelectors.getAvatar(meta), + meta.backgroundColor, + session?.updatedAt, + session.config.model, + session.config.provider, + session?.group, + ]; + }); const showModel = model !== defaultModel; @@ -63,7 +74,7 @@ const SessionItem = memo(({ id }) => { () => !showModel ? undefined : ( - {showModel && } + {showModel && } ), [showModel, model], diff --git a/src/app/settings/llm/Zhipu/index.tsx b/src/app/settings/llm/Zhipu/index.tsx index 286d07b768282..9abc673ade679 100644 --- a/src/app/settings/llm/Zhipu/index.tsx +++ b/src/app/settings/llm/Zhipu/index.tsx @@ -18,7 +18,7 @@ const LLM = memo(() => { const [form] = AntForm.useForm(); const [enabledZhipu, setSettings] = useGlobalStore((s) => [ - modelProviderSelectors.enabledZhipu(s), + modelProviderSelectors.enableZhipu(s), s.setSettings, ]); @@ -28,7 +28,7 @@ const LLM = memo(() => { form.setFieldsValue(settings); }, []); - const openAI: ItemGroup = { + const model: ItemGroup = { children: [ { children: ( @@ -51,8 +51,8 @@ const LLM = memo(() => { defaultActive: enabledZhipu, extra: ( { - console.log(e); + onChange={(enabled) => { + setSettings({ languageModel: { zhipu: { enabled } } }); }} value={enabledZhipu} /> @@ -62,12 +62,7 @@ const LLM = memo(() => { }; return ( -
+ ); }); diff --git a/src/components/ModelProviderIcons/index.ts b/src/components/ModelProviderIcons/index.ts deleted file mode 100644 index 9e074b2d0625c..0000000000000 --- a/src/components/ModelProviderIcons/index.ts +++ /dev/null @@ -1,5 +0,0 @@ -export { Anthropic } from './Anthropic'; -export { ChatGLM } from './ChatGLM'; -export { GoogleDeepMind } from './GoogleDeepMind'; -export { Mistral } from './Mistral'; -export { Tongyi } from './Tongyi'; diff --git a/src/components/ModelProviderIcons/index.tsx b/src/components/ModelProviderIcons/index.tsx new file mode 100644 index 0000000000000..1c1c09621375c --- /dev/null +++ b/src/components/ModelProviderIcons/index.tsx @@ -0,0 +1,47 @@ +import { SiAmazonaws, SiOpenai } from '@icons-pack/react-simple-icons'; +import { memo } from 'react'; + +import { ModelProvider } from '@/libs/agent-runtime'; + +import { Anthropic } from './Anthropic'; +import { ChatGLM } from './ChatGLM'; +import { GoogleDeepMind } from './GoogleDeepMind'; +import { Mistral } from './Mistral'; +import { Tongyi } from './Tongyi'; + +interface ModelProviderIconProps { + provider?: string; +} + +const ModelProviderIcon = memo(({ provider }) => { + switch (provider) { + case ModelProvider.Anthropic: { + return ; + } + case ModelProvider.Tongyi: { + return ; + } + case ModelProvider.Mistral: { + return ; + } + case ModelProvider.Bedrock: { + return ; + } + + case 'zhipu': + case ModelProvider.ChatGLM: { + return ; + } + + case ModelProvider.Google: { + return ; + } + + default: + case ModelProvider.OpenAI: { + return ; + } + } +}); + +export default ModelProviderIcon; diff --git a/src/components/ModelTag/index.tsx b/src/components/ModelTag/index.tsx index 04ba72365c969..6e6baa9f478f8 100644 --- a/src/components/ModelTag/index.tsx +++ b/src/components/ModelTag/index.tsx @@ -1,50 +1,14 @@ -import { SiAmazonaws, SiOpenai } from '@icons-pack/react-simple-icons'; import { Tag } from '@lobehub/ui'; -import { memo, useMemo } from 'react'; +import { memo } from 'react'; -import { - Anthropic, - ChatGLM, - GoogleDeepMind, - Mistral, - Tongyi, -} from '@/components/ModelProviderIcons'; -import { ModelProvider } from '@/libs/agent-runtime'; +import ModelProviderIcon from '@/components/ModelProviderIcons'; interface ModelTagProps { name: string; - provider?: ModelProvider; + provider?: string; } const ModelTag = memo(({ provider, name }) => { - const icon = useMemo(() => { - switch (provider) { - case ModelProvider.Anthropic: { - return ; - } - case ModelProvider.Tongyi: { - return ; - } - case ModelProvider.Mistral: { - return ; - } - case ModelProvider.Bedrock: { - return ; - } - case ModelProvider.ChatGLM: { - return ; - } - case ModelProvider.Google: { - return ; - } - - default: - case ModelProvider.OpenAI: { - return ; - } - } - }, [provider]); - - return {name}; + return }>{name}; }); export default ModelTag; diff --git a/src/config/modelProviders.ts b/src/config/modelProviders.ts index cfbbe326b7fa5..c63d9851dd787 100644 --- a/src/config/modelProviders.ts +++ b/src/config/modelProviders.ts @@ -1,24 +1,6 @@ -interface ChatModelCard { - description?: string; - displayName?: string; - /** - * 是否支持 Function Call - */ - functionCall?: boolean; - id: string; - tokens?: number; - /** - * 是否支持视觉识别 - */ - vision?: boolean; -} +import { ModelProviderCard } from '@/types/llm'; -interface ModelProvider { - chatModels: ChatModelCard[]; - name: string; -} - -export const ZhiPuModelCard: ModelProvider = { +export const ZhiPuModelCard: ModelProviderCard = { chatModels: [ { description: '最新的 GLM-4 、最大支持 128k 上下文、支持 Function Call 、Retreival', @@ -41,10 +23,10 @@ export const ZhiPuModelCard: ModelProvider = { tokens: 128_000, }, ], - name: 'zhipu', + id: 'zhipu', }; -export const OpenAIModelCard: ModelProvider = { +export const OpenAIModelCard: ModelProviderCard = { chatModels: [ { description: 'GPT 3.5 Turbo,适用于各种文本生成和理解任务', @@ -86,5 +68,5 @@ export const OpenAIModelCard: ModelProvider = { vision: true, // 支持视觉任务 }, ], - name: 'openai', + id: 'openai', }; diff --git a/src/config/server.ts b/src/config/server.ts index e56a8f12f50b7..13e18ad36c9f9 100644 --- a/src/config/server.ts +++ b/src/config/server.ts @@ -89,5 +89,7 @@ export const getServerConfig = () => { : 'https://chat-plugins.lobehub.com', PLUGIN_SETTINGS: process.env.PLUGIN_SETTINGS, + + DEBUG_CHAT_COMPLETION: process.env.DEBUG_CHAT_COMPLETION === '1', }; }; diff --git a/src/const/fetch.ts b/src/const/fetch.ts index cc40ad0f62c88..2964b97d880f6 100644 --- a/src/const/fetch.ts +++ b/src/const/fetch.ts @@ -8,7 +8,6 @@ export const AZURE_OPENAI_API_VERSION = 'X-azure-openai-api-version'; export const LOBE_CHAT_ACCESS_CODE = 'X-lobe-chat-access-code'; export const ZHIPU_API_KEY_HEADER_KEY = 'X-zhipu-api-key'; -export const ZHIPU_PROXY_URL_HEADER_KEY = 'X-zhipu-proxy-url'; export const getOpenAIAuthFromRequest = (req: Request) => { const apiKey = req.headers.get(OPENAI_API_KEY_HEADER_KEY); @@ -17,9 +16,8 @@ export const getOpenAIAuthFromRequest = (req: Request) => { const useAzureStr = req.headers.get(USE_AZURE_OPENAI); const apiVersion = req.headers.get(AZURE_OPENAI_API_VERSION); const zhipuApiKey = req.headers.get(ZHIPU_API_KEY_HEADER_KEY); - const zhipuProxyUrl = req.headers.get(ZHIPU_PROXY_URL_HEADER_KEY); const useAzure = !!useAzureStr; - return { accessCode, apiKey, apiVersion, endpoint, useAzure, zhipuApiKey, zhipuProxyUrl }; + return { accessCode, apiKey, apiVersion, endpoint, useAzure, zhipuApiKey }; }; diff --git a/src/const/settings.ts b/src/const/settings.ts index 45f00222f44b2..1c4ae753af8b8 100644 --- a/src/const/settings.ts +++ b/src/const/settings.ts @@ -1,7 +1,7 @@ import { DEFAULT_OPENAI_MODEL_LIST } from '@/const/llm'; import { DEFAULT_AGENT_META } from '@/const/meta'; +import { ModelProvider } from '@/libs/agent-runtime'; import { LobeAgentConfig, LobeAgentTTSConfig } from '@/types/agent'; -import { LanguageModel } from '@/types/llm'; import { GlobalBaseSettings, GlobalDefaultAgent, @@ -33,7 +33,7 @@ export const DEFAULT_AGENT_CONFIG: LobeAgentConfig = { displayMode: 'chat', enableAutoCreateTopic: true, historyCount: 1, - model: LanguageModel.GPT3_5, + model: 'gpt-3.5-turbo', params: { frequency_penalty: 0, presence_penalty: 0, @@ -41,6 +41,7 @@ export const DEFAULT_AGENT_CONFIG: LobeAgentConfig = { top_p: 1, }, plugins: [], + provider: ModelProvider.OpenAI, systemRole: '', tts: DEFAUTT_AGENT_TTS_CONFIG, }; diff --git a/src/database/models/message.ts b/src/database/models/message.ts index e88121e1d5556..a7ec2391bd8de 100644 --- a/src/database/models/message.ts +++ b/src/database/models/message.ts @@ -10,6 +10,7 @@ export interface CreateMessageParams extends Partial>, Pick { fromModel?: string; + fromProvider?: string; sessionId: string; } @@ -215,13 +216,14 @@ class _MessageModel extends BaseModel { private mapToChatMessage = ({ fromModel, + fromProvider, translate, tts, ...item }: DBModel): ChatMessage => { return { ...item, - extra: { fromModel, translate, tts }, + extra: { fromModel, fromProvider, translate, tts }, meta: {}, topicId: item.topicId ?? undefined, }; diff --git a/src/database/schemas/message.ts b/src/database/schemas/message.ts index 13b0fddd76460..5f00ec512a14d 100644 --- a/src/database/schemas/message.ts +++ b/src/database/schemas/message.ts @@ -24,6 +24,7 @@ export const DB_MessageSchema = z.object({ plugin: PluginSchema.optional(), pluginState: z.any().optional(), fromModel: z.string().optional(), + fromProvider: z.string().optional(), translate: TranslateSchema.optional().or(z.literal(false)), tts: z.any().optional(), diff --git a/src/features/ChatInput/ActionBar/ModelSwitch.tsx b/src/features/ChatInput/ActionBar/ModelSwitch.tsx index 275c87fbd207c..ddd8c489eff63 100644 --- a/src/features/ChatInput/ActionBar/ModelSwitch.tsx +++ b/src/features/ChatInput/ActionBar/ModelSwitch.tsx @@ -4,12 +4,13 @@ import isEqual from 'fast-deep-equal'; import { BrainCog } from 'lucide-react'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; +import { Flexbox } from 'react-layout-kit'; +import ModelProviderIcon from '@/components/ModelProviderIcons'; import { useGlobalStore } from '@/store/global'; import { modelProviderSelectors } from '@/store/global/selectors'; import { useSessionStore } from '@/store/session'; import { agentSelectors } from '@/store/session/selectors'; -import { LanguageModel } from '@/types/llm'; const ModelSwitch = memo(() => { const { t } = useTranslation('setting'); @@ -18,19 +19,32 @@ const ModelSwitch = memo(() => { return [agentSelectors.currentAgentModel(s), s.updateAgentConfig]; }); - const modelList = useGlobalStore(modelProviderSelectors.modelList, isEqual); + const select = useGlobalStore(modelProviderSelectors.modelSelectList, isEqual); return ( ({ key: name, label: displayName })), - onClick: (e) => { - updateAgentConfig({ model: e.key as LanguageModel }); - }, + items: select.map((provider) => ({ + children: provider.chatModels.map((model) => ({ + key: model.id, + label: model.id, + onClick: () => { + updateAgentConfig({ model: model.id, provider: provider?.id }); + }, + })), + key: provider.id, + label: ( + + + {provider.id} + + ), + type: 'group', + })), style: { - maxHeight: 400, - overflow: 'scroll', + maxHeight: 700, + overflowY: 'scroll', }, }} trigger={['click']} diff --git a/src/features/Conversation/Error/ApiKeyForm.tsx b/src/features/Conversation/Error/ApiKeyForm.tsx index 2c4aef6622420..101864d9de208 100644 --- a/src/features/Conversation/Error/ApiKeyForm.tsx +++ b/src/features/Conversation/Error/ApiKeyForm.tsx @@ -16,7 +16,7 @@ const APIKeyForm = memo<{ id: string }>(({ id }) => { const [showProxy, setShow] = useState(false); const [apiKey, proxyUrl, setConfig] = useGlobalStore((s) => [ - modelProviderSelectors.openAIAPI(s), + modelProviderSelectors.openAIAPIKey(s), modelProviderSelectors.openAIProxyUrl(s), s.setOpenAIConfig, ]); diff --git a/src/features/Conversation/Extras/Assistant.tsx b/src/features/Conversation/Extras/Assistant.tsx index fe5009fb66dd2..4f67c79953db8 100644 --- a/src/features/Conversation/Extras/Assistant.tsx +++ b/src/features/Conversation/Extras/Assistant.tsx @@ -29,7 +29,7 @@ export const AssistantMessageExtra: RenderMessageExtra = memo( {showModelTag && (
- +
)} <> diff --git a/src/libs/agent-runtime/utils/debugStream.ts b/src/libs/agent-runtime/utils/debugStream.ts new file mode 100644 index 0000000000000..43e9fbe6b2c44 --- /dev/null +++ b/src/libs/agent-runtime/utils/debugStream.ts @@ -0,0 +1,18 @@ +export const debugStream = async (stream: ReadableStream) => { + let done = false; + let chunk = 0; + const decoder = new TextDecoder(); + + const reader = stream.getReader(); + while (!done) { + const { value, done: _done } = await reader.read(); + const chunkValue = decoder.decode(value, { stream: true }); + if (!_done) { + console.log(`chunk ${chunk}:`); + console.log(chunkValue); + } + + done = _done; + chunk++; + } +}; diff --git a/src/libs/agent-runtime/zhipu/chat.ts b/src/libs/agent-runtime/zhipu/chat.ts index 585ec0a58a188..a7dc6873176f2 100644 --- a/src/libs/agent-runtime/zhipu/chat.ts +++ b/src/libs/agent-runtime/zhipu/chat.ts @@ -1,17 +1,20 @@ import { OpenAIStream, StreamingTextResponse } from 'ai'; +import { consola } from 'consola'; import OpenAI from 'openai'; +import { getServerConfig } from '@/config/server'; import { ChatErrorType } from '@/types/fetch'; import { CreateChatCompletionOptions, ModelProvider } from '../type'; +import { debugStream } from '../utils/debugStream'; +const { DEBUG_CHAT_COMPLETION } = getServerConfig(); export const createChatCompletion = async ({ payload, chatModel }: CreateChatCompletionOptions) => { // ============ 1. preprocess messages ============ // const { messages, top_p, ...params } = payload; // ============ 2. send api ============ // - console.log(top_p) try { const response = await chatModel.chat.completions.create( { @@ -23,8 +26,13 @@ export const createChatCompletion = async ({ payload, chatModel }: CreateChatCom } as unknown as OpenAI.ChatCompletionCreateParamsStreaming, { headers: { Accept: '*/*' } }, ); + const [debugResponseClone, returnResponse] = response.tee(); - const stream = OpenAIStream(response); + if (DEBUG_CHAT_COMPLETION) { + debugStream(debugResponseClone.toReadableStream()).catch(consola.error); + } + + const stream = OpenAIStream(returnResponse); return new StreamingTextResponse(stream); } catch (error) { let errorType: any = ChatErrorType.OpenAIBizError; diff --git a/src/libs/agent-runtime/zhipu/createZhipu.ts b/src/libs/agent-runtime/zhipu/createZhipu.ts index 134a71169711e..ddf219a9148dc 100644 --- a/src/libs/agent-runtime/zhipu/createZhipu.ts +++ b/src/libs/agent-runtime/zhipu/createZhipu.ts @@ -12,7 +12,7 @@ import { generateApiToken } from './authToken'; * if auth not pass ,just throw an error of {type: } */ export const createZhipu = async (req: Request): Promise => { - const { accessCode, zhipuApiKey, zhipuProxyUrl } = getOpenAIAuthFromRequest(req); + const { accessCode, zhipuApiKey } = getOpenAIAuthFromRequest(req); const result = checkAuth({ accessCode, apiKey: zhipuApiKey }); @@ -20,9 +20,9 @@ export const createZhipu = async (req: Request): Promise => { throw new TypeError(JSON.stringify({ type: result.error })); } - const { ZHIPU_API_KEY, ZHIPU_PROXY_URL } = getServerConfig(); + const { ZHIPU_API_KEY } = getServerConfig(); - const baseURL = zhipuProxyUrl || ZHIPU_PROXY_URL || 'https://open.bigmodel.cn/api/paas/v4'; + const baseURL = 'https://open.bigmodel.cn/api/paas/v4'; const apiKey = !zhipuApiKey ? ZHIPU_API_KEY : zhipuApiKey; diff --git a/src/store/chat/slices/message/action.ts b/src/store/chat/slices/message/action.ts index fb70891196018..75a25c25055c7 100644 --- a/src/store/chat/slices/message/action.ts +++ b/src/store/chat/slices/message/action.ts @@ -255,13 +255,14 @@ export const chatMessage: StateCreator< coreProcessMessage: async (messages, userMessageId) => { const { fetchAIChatMessage, triggerFunctionCall, refreshMessages, activeTopicId } = get(); - const { model } = getAgentConfig(); + const { model, provider } = getAgentConfig(); // 1. Add an empty message to place the AI response const assistantMessage: CreateMessageParams = { role: 'assistant', content: LOADING_FLAT, fromModel: model, + fromProvider: provider, parentId: userMessageId, sessionId: get().activeId, @@ -288,13 +289,14 @@ export const chatMessage: StateCreator< const functionMessage: CreateMessageParams = { role: 'function', content: functionCallContent, - extra: { - fromModel: model, - }, + fromModel: model, + fromProvider: provider, + parentId: userMessageId, sessionId: get().activeId, topicId: activeTopicId, }; + functionId = await messageService.create(functionMessage); } @@ -376,6 +378,7 @@ export const chatMessage: StateCreator< { messages: preprocessMsgs, model: config.model, + provider: config.provider, ...config.params, plugins: config.plugins, }, diff --git a/src/store/global/slices/common/action.ts b/src/store/global/slices/common/action.ts index e5d7d58e983e3..0e50e4c807b92 100644 --- a/src/store/global/slices/common/action.ts +++ b/src/store/global/slices/common/action.ts @@ -1,5 +1,6 @@ import { gt } from 'semver'; import useSWR, { SWRResponse, mutate } from 'swr'; +import { DeepPartial } from 'utility-types'; import type { StateCreator } from 'zustand/vanilla'; import { INBOX_SESSION_ID } from '@/const/session'; @@ -8,7 +9,7 @@ import { CURRENT_VERSION } from '@/const/version'; import { globalService } from '@/services/global'; import { UserConfig, userService } from '@/services/user'; import type { GlobalStore } from '@/store/global'; -import type { GlobalServerConfig } from '@/types/settings'; +import type { GlobalServerConfig, GlobalSettings } from '@/types/settings'; import { merge } from '@/utils/merge'; import { setNamespace } from '@/utils/storeDebug'; @@ -63,7 +64,12 @@ export const createCommonSlice: StateCreator< useSWR('fetchGlobalConfig', globalService.getGlobalConfig, { onSuccess: (data) => { if (data) { - const defaultSettings = merge(get().defaultSettings, { defaultAgent: data.defaultAgent }); + const serverSettings: DeepPartial = { + defaultAgent: data.defaultAgent, + languageModel: data.languageModel, + }; + + const defaultSettings = merge(get().defaultSettings, serverSettings); set({ defaultSettings, serverConfig: data }, false, n('initGlobalConfig')); } }, diff --git a/src/store/global/slices/settings/selectors/modelProvider.ts b/src/store/global/slices/settings/selectors/modelProvider.ts index 6db844761b966..1e5cd49923337 100644 --- a/src/store/global/slices/settings/selectors/modelProvider.ts +++ b/src/store/global/slices/settings/selectors/modelProvider.ts @@ -1,4 +1,6 @@ +import { OpenAIModelCard, ZhiPuModelCard } from '@/config/modelProviders'; import { DEFAULT_OPENAI_MODEL_LIST } from '@/const/llm'; +import { ModelProviderCard } from '@/types/llm'; import { CustomModels } from '@/types/settings'; import { GlobalStore } from '../../../store'; @@ -6,15 +8,14 @@ import { currentSettings } from './settings'; const openAIConfig = (s: GlobalStore) => currentSettings(s).languageModel.openAI; -const openAIAPIKeySelectors = (s: GlobalStore) => - currentSettings(s).languageModel.openAI.OPENAI_API_KEY; +const openAIAPIKey = (s: GlobalStore) => openAIConfig(s).OPENAI_API_KEY; +const enableAzure = (s: GlobalStore) => openAIConfig(s).useAzure; +const openAIProxyUrl = (s: GlobalStore) => openAIConfig(s).endpoint; +const zhipuAPIKey = (s: GlobalStore) => currentSettings(s).languageModel.zhipu.ZHIPU_API_KEY; -const enableAzure = (s: GlobalStore) => currentSettings(s).languageModel.openAI.useAzure; +const enableZhipu = (s: GlobalStore) => currentSettings(s).languageModel.zhipu.enabled; -const openAIProxyUrlSelectors = (s: GlobalStore) => - currentSettings(s).languageModel.openAI.endpoint; - -const modelListSelectors = (s: GlobalStore) => { +const customModelList = (s: GlobalStore) => { let models: CustomModels = []; const removedModels: string[] = []; @@ -55,10 +56,22 @@ const modelListSelectors = (s: GlobalStore) => { return models.filter((m) => !removedModels.includes(m.name)); }; +const modelSelectList = (s: GlobalStore): ModelProviderCard[] => { + customModelList(s); + + return [OpenAIModelCard, ZhiPuModelCard]; +}; + +/* eslint-disable sort-keys-fix/sort-keys-fix, */ export const modelProviderSelectors = { + modelList: customModelList, + modelSelectList, + // OpenAI enableAzure, - modelList: modelListSelectors, - openAIAPI: openAIAPIKeySelectors, + openAIAPIKey, openAIConfig, - openAIProxyUrl: openAIProxyUrlSelectors, + openAIProxyUrl, + // Zhipu + enableZhipu, + zhipuAPIKey, }; diff --git a/src/store/session/slices/agent/selectors.ts b/src/store/session/slices/agent/selectors.ts index 811033447d772..6b457de6513be 100644 --- a/src/store/session/slices/agent/selectors.ts +++ b/src/store/session/slices/agent/selectors.ts @@ -36,6 +36,12 @@ const currentAgentModel = (s: SessionStore): LanguageModel | string => { return config?.model || LanguageModel.GPT3_5; }; +const currentAgentModelProvider = (s: SessionStore) => { + const config = currentAgentConfig(s); + + return config?.provider; +}; + const currentAgentPlugins = (s: SessionStore) => { const config = currentAgentConfig(s); @@ -121,6 +127,7 @@ export const agentSelectors = { currentAgentDescription, currentAgentMeta, currentAgentModel, + currentAgentModelProvider, currentAgentPlugins, currentAgentSystemRole, currentAgentTTS, diff --git a/src/types/agent/index.ts b/src/types/agent/index.ts index 6cb46fddff41a..7a20b2ce7f63b 100644 --- a/src/types/agent/index.ts +++ b/src/types/agent/index.ts @@ -1,4 +1,4 @@ -import { FewShots, LLMParams, LanguageModel } from '@/types/llm'; +import { FewShots, LLMParams } from '@/types/llm'; export type TTSServer = 'openai' | 'edge' | 'microsoft'; @@ -37,7 +37,7 @@ export interface LobeAgentConfig { * 角色所使用的语言模型 * @default gpt-3.5-turbo */ - model: LanguageModel | string; + model: string; /** * 语言模型参数 */ @@ -46,6 +46,10 @@ export interface LobeAgentConfig { * 启用的插件 */ plugins?: string[]; + /** + * 模型供应商 + */ + provider?: string; /** * 系统角色 */ diff --git a/src/types/message/index.ts b/src/types/message/index.ts index f69ea6aead4dd..c8f1d7da03a64 100644 --- a/src/types/message/index.ts +++ b/src/types/message/index.ts @@ -32,6 +32,7 @@ export interface ChatMessage extends BaseDataModel { // 扩展字段 extra?: { fromModel?: string; + fromProvider?: string; // 翻译 translate?: ChatTranslate | false; // TTS diff --git a/src/types/openai/chat.ts b/src/types/openai/chat.ts index b8e34adeffbd3..e56d9ad48c0d6 100644 --- a/src/types/openai/chat.ts +++ b/src/types/openai/chat.ts @@ -1,4 +1,3 @@ -import { ModelProvider } from '@/libs/agent-runtime'; import { LLMRoleType } from '@/types/llm'; import { OpenAIFunctionCall } from './functionCall'; @@ -70,7 +69,7 @@ export interface ChatStreamPayload { /** * @default openai */ - provider?: ModelProvider; + provider?: string; /** * @title 是否开启流式请求 * @default true