diff --git a/src/libs/agent-runtime/AgentRuntime.ts b/src/libs/agent-runtime/AgentRuntime.ts index 68ea3ade7cd7..d6efbd985615 100644 --- a/src/libs/agent-runtime/AgentRuntime.ts +++ b/src/libs/agent-runtime/AgentRuntime.ts @@ -174,7 +174,7 @@ class AgentRuntime { } case ModelProvider.ZhiPu: { - runtimeModel = await LobeZhipuAI.fromAPIKey(params.zhipu); + runtimeModel = new LobeZhipuAI(params.zhipu); break; } diff --git a/src/libs/agent-runtime/google/index.ts b/src/libs/agent-runtime/google/index.ts index ecc80c9d2855..e4602240b76b 100644 --- a/src/libs/agent-runtime/google/index.ts +++ b/src/libs/agent-runtime/google/index.ts @@ -27,7 +27,7 @@ import { ModelProvider } from '../types/type'; import { AgentRuntimeError } from '../utils/createError'; import { debugStream } from '../utils/debugStream'; import { StreamingResponse } from '../utils/response'; -import { GoogleGenerativeAIStream, googleGenAIResultToStream } from '../utils/streams'; +import { GoogleGenerativeAIStream, convertIterableToStream } from '../utils/streams'; import { parseDataUri } from '../utils/uriParser'; enum HarmCategory { @@ -97,7 +97,7 @@ export class LobeGoogleAI implements LobeRuntimeAI { tools: this.buildGoogleTools(payload.tools), }); - const googleStream = googleGenAIResultToStream(geminiStreamResult); + const googleStream = convertIterableToStream(geminiStreamResult.stream); const [prod, useForDebug] = googleStream.tee(); if (process.env.DEBUG_GOOGLE_CHAT_COMPLETION === '1') { diff --git a/src/libs/agent-runtime/utils/streams/anthropic.ts b/src/libs/agent-runtime/utils/streams/anthropic.ts index 0f667e2753a1..ebdeea557269 100644 --- a/src/libs/agent-runtime/utils/streams/anthropic.ts +++ b/src/libs/agent-runtime/utils/streams/anthropic.ts @@ -1,6 +1,5 @@ import Anthropic from '@anthropic-ai/sdk'; import type { Stream } from '@anthropic-ai/sdk/streaming'; -import { readableFromAsyncIterable } from 'ai'; import { ChatStreamCallbacks } from '../../types'; import { @@ -8,6 +7,7 @@ import { StreamProtocolToolCallChunk, StreamStack, StreamToolCallChunkData, + convertIterableToStream, createCallbacksTransformer, createSSEProtocolTransformer, } from './protocol'; @@ -96,12 +96,6 @@ export const transformAnthropicStream = ( } }; -const chatStreamable = async function* (stream: AsyncIterable) { - for await (const response of stream) { - yield response; - } -}; - export const AnthropicStream = ( stream: Stream | ReadableStream, callbacks?: ChatStreamCallbacks, @@ -109,7 +103,7 @@ export const AnthropicStream = ( const streamStack: StreamStack = { id: '' }; const readableStream = - stream instanceof ReadableStream ? stream : readableFromAsyncIterable(chatStreamable(stream)); + stream instanceof ReadableStream ? stream : convertIterableToStream(stream); return readableStream .pipeThrough(createSSEProtocolTransformer(transformAnthropicStream, streamStack)) diff --git a/src/libs/agent-runtime/utils/streams/azureOpenai.ts b/src/libs/agent-runtime/utils/streams/azureOpenai.ts index a6f91f8e13b2..54d993a2c4d8 100644 --- a/src/libs/agent-runtime/utils/streams/azureOpenai.ts +++ b/src/libs/agent-runtime/utils/streams/azureOpenai.ts @@ -1,5 +1,4 @@ import { ChatCompletions, ChatCompletionsFunctionToolCall } from '@azure/openai'; -import { readableFromAsyncIterable } from 'ai'; import OpenAI from 'openai'; import type { Stream } from 'openai/streaming'; @@ -9,6 +8,7 @@ import { StreamProtocolToolCallChunk, StreamStack, StreamToolCallChunkData, + convertIterableToStream, createCallbacksTransformer, createSSEProtocolTransformer, } from './protocol'; @@ -69,19 +69,13 @@ const transformOpenAIStream = (chunk: ChatCompletions, stack: StreamStack): Stre }; }; -const chatStreamable = async function* (stream: AsyncIterable) { - for await (const response of stream) { - yield response; - } -}; - export const AzureOpenAIStream = ( stream: Stream | ReadableStream, callbacks?: ChatStreamCallbacks, ) => { const stack: StreamStack = { id: '' }; const readableStream = - stream instanceof ReadableStream ? stream : readableFromAsyncIterable(chatStreamable(stream)); + stream instanceof ReadableStream ? stream : convertIterableToStream(stream); return readableStream .pipeThrough(createSSEProtocolTransformer(transformOpenAIStream, stack)) diff --git a/src/libs/agent-runtime/utils/streams/google-ai.ts b/src/libs/agent-runtime/utils/streams/google-ai.ts index ff457c52be73..e669404e9223 100644 --- a/src/libs/agent-runtime/utils/streams/google-ai.ts +++ b/src/libs/agent-runtime/utils/streams/google-ai.ts @@ -1,8 +1,4 @@ -import { - EnhancedGenerateContentResponse, - GenerateContentStreamResult, -} from '@google/generative-ai'; -import { readableFromAsyncIterable } from 'ai'; +import { EnhancedGenerateContentResponse } from '@google/generative-ai'; import { nanoid } from '@/utils/uuid'; @@ -11,7 +7,6 @@ import { StreamProtocolChunk, StreamStack, StreamToolCallChunkData, - chatStreamable, createCallbacksTransformer, createSSEProtocolTransformer, generateToolCallId, @@ -50,12 +45,6 @@ const transformGoogleGenerativeAIStream = ( }; }; -// only use for debug -export const googleGenAIResultToStream = (stream: GenerateContentStreamResult) => { - // make the response to the streamable format - return readableFromAsyncIterable(chatStreamable(stream.stream)); -}; - export const GoogleGenerativeAIStream = ( rawStream: ReadableStream, callbacks?: ChatStreamCallbacks, diff --git a/src/libs/agent-runtime/utils/streams/ollama.ts b/src/libs/agent-runtime/utils/streams/ollama.ts index 32d4c5197d8b..baf92e06d605 100644 --- a/src/libs/agent-runtime/utils/streams/ollama.ts +++ b/src/libs/agent-runtime/utils/streams/ollama.ts @@ -1,4 +1,3 @@ -import { readableFromAsyncIterable } from 'ai'; import { ChatResponse } from 'ollama/browser'; import { ChatStreamCallbacks } from '@/libs/agent-runtime'; @@ -7,6 +6,7 @@ import { nanoid } from '@/utils/uuid'; import { StreamProtocolChunk, StreamStack, + convertIterableToStream, createCallbacksTransformer, createSSEProtocolTransformer, } from './protocol'; @@ -20,19 +20,13 @@ const transformOllamaStream = (chunk: ChatResponse, stack: StreamStack): StreamP return { data: chunk.message.content, id: stack.id, type: 'text' }; }; -const chatStreamable = async function* (stream: AsyncIterable) { - for await (const response of stream) { - yield response; - } -}; - export const OllamaStream = ( res: AsyncIterable, cb?: ChatStreamCallbacks, ): ReadableStream => { const streamStack: StreamStack = { id: 'chat_' + nanoid() }; - return readableFromAsyncIterable(chatStreamable(res)) + return convertIterableToStream(res) .pipeThrough(createSSEProtocolTransformer(transformOllamaStream, streamStack)) .pipeThrough(createCallbacksTransformer(cb)); }; diff --git a/src/libs/agent-runtime/utils/streams/openai.ts b/src/libs/agent-runtime/utils/streams/openai.ts index 14660c7ca332..0f01f07e26b4 100644 --- a/src/libs/agent-runtime/utils/streams/openai.ts +++ b/src/libs/agent-runtime/utils/streams/openai.ts @@ -1,4 +1,3 @@ -import { readableFromAsyncIterable } from 'ai'; import OpenAI from 'openai'; import type { Stream } from 'openai/streaming'; @@ -10,6 +9,7 @@ import { StreamProtocolToolCallChunk, StreamStack, StreamToolCallChunkData, + convertIterableToStream, createCallbacksTransformer, createSSEProtocolTransformer, generateToolCallId, @@ -105,12 +105,6 @@ export const transformOpenAIStream = ( } }; -const chatStreamable = async function* (stream: AsyncIterable) { - for await (const response of stream) { - yield response; - } -}; - export const OpenAIStream = ( stream: Stream | ReadableStream, callbacks?: ChatStreamCallbacks, @@ -118,7 +112,7 @@ export const OpenAIStream = ( const streamStack: StreamStack = { id: '' }; const readableStream = - stream instanceof ReadableStream ? stream : readableFromAsyncIterable(chatStreamable(stream)); + stream instanceof ReadableStream ? stream : convertIterableToStream(stream); return readableStream .pipeThrough(createSSEProtocolTransformer(transformOpenAIStream, streamStack)) diff --git a/src/libs/agent-runtime/utils/streams/protocol.ts b/src/libs/agent-runtime/utils/streams/protocol.ts index d9cee1a08b9c..ecbe9a93c2ae 100644 --- a/src/libs/agent-runtime/utils/streams/protocol.ts +++ b/src/libs/agent-runtime/utils/streams/protocol.ts @@ -1,3 +1,5 @@ +import { readableFromAsyncIterable } from 'ai'; + import { ChatStreamCallbacks } from '@/libs/agent-runtime'; export interface StreamStack { @@ -42,6 +44,11 @@ export const chatStreamable = async function* (stream: AsyncIterable) { } }; +// make the response to the streamable format +export const convertIterableToStream = (stream: AsyncIterable) => { + return readableFromAsyncIterable(chatStreamable(stream)); +}; + export const createSSEProtocolTransformer = ( transformer: (chunk: any, stack: StreamStack) => StreamProtocolChunk, streamStack?: StreamStack, diff --git a/src/libs/agent-runtime/utils/streams/qwen.ts b/src/libs/agent-runtime/utils/streams/qwen.ts index 349ae824c041..f0cc613b25e2 100644 --- a/src/libs/agent-runtime/utils/streams/qwen.ts +++ b/src/libs/agent-runtime/utils/streams/qwen.ts @@ -1,4 +1,3 @@ -import { readableFromAsyncIterable } from 'ai'; import { ChatCompletionContentPartText } from 'ai/prompts'; import OpenAI from 'openai'; import { ChatCompletionContentPart } from 'openai/resources/index.mjs'; @@ -9,7 +8,7 @@ import { StreamProtocolChunk, StreamProtocolToolCallChunk, StreamToolCallChunkData, - chatStreamable, + convertIterableToStream, createCallbacksTransformer, createSSEProtocolTransformer, generateToolCallId, @@ -86,7 +85,7 @@ export const QwenAIStream = ( callbacks?: ChatStreamCallbacks, ) => { const readableStream = - stream instanceof ReadableStream ? stream : readableFromAsyncIterable(chatStreamable(stream)); + stream instanceof ReadableStream ? stream : convertIterableToStream(stream); return readableStream .pipeThrough(createSSEProtocolTransformer(transformQwenStream)) diff --git a/src/libs/agent-runtime/utils/streams/wenxin.test.ts b/src/libs/agent-runtime/utils/streams/wenxin.test.ts index 6f1b30d9b449..da0072db8441 100644 --- a/src/libs/agent-runtime/utils/streams/wenxin.test.ts +++ b/src/libs/agent-runtime/utils/streams/wenxin.test.ts @@ -2,8 +2,9 @@ import { describe, expect, it, vi } from 'vitest'; import * as uuidModule from '@/utils/uuid'; +import { convertIterableToStream } from '../../utils/streams/protocol'; import { ChatResp } from '../../wenxin/type'; -import { WenxinResultToStream, WenxinStream } from './wenxin'; +import { WenxinStream } from './wenxin'; const dataStream = [ { @@ -95,7 +96,7 @@ describe('WenxinStream', () => { }, }; - const stream = WenxinResultToStream(mockWenxinStream); + const stream = convertIterableToStream(mockWenxinStream); const onStartMock = vi.fn(); const onTextMock = vi.fn(); @@ -142,7 +143,10 @@ describe('WenxinStream', () => { expect(onStartMock).toHaveBeenCalledTimes(1); expect(onTextMock).toHaveBeenNthCalledWith(1, '"当然可以,"'); - expect(onTextMock).toHaveBeenNthCalledWith(2, '"以下是一些建议的自驾游路线,它们涵盖了各种不同的风景和文化体验:\\n\\n1. **西安-敦煌历史文化之旅**:\\n\\n\\n\\t* 路线:西安"'); + expect(onTextMock).toHaveBeenNthCalledWith( + 2, + '"以下是一些建议的自驾游路线,它们涵盖了各种不同的风景和文化体验:\\n\\n1. **西安-敦煌历史文化之旅**:\\n\\n\\n\\t* 路线:西安"', + ); expect(onTokenMock).toHaveBeenCalledTimes(6); expect(onCompletionMock).toHaveBeenCalledTimes(1); }); diff --git a/src/libs/agent-runtime/utils/streams/wenxin.ts b/src/libs/agent-runtime/utils/streams/wenxin.ts index 23edd48c624b..418ede9e7f26 100644 --- a/src/libs/agent-runtime/utils/streams/wenxin.ts +++ b/src/libs/agent-runtime/utils/streams/wenxin.ts @@ -1,5 +1,3 @@ -import { readableFromAsyncIterable } from 'ai'; - import { ChatStreamCallbacks } from '@/libs/agent-runtime'; import { nanoid } from '@/utils/uuid'; @@ -7,7 +5,6 @@ import { ChatResp } from '../../wenxin/type'; import { StreamProtocolChunk, StreamStack, - chatStreamable, createCallbacksTransformer, createSSEProtocolTransformer, } from './protocol'; @@ -29,11 +26,6 @@ const transformERNIEBotStream = (chunk: ChatResp): StreamProtocolChunk => { }; }; -export const WenxinResultToStream = (stream: AsyncIterable) => { - // make the response to the streamable format - return readableFromAsyncIterable(chatStreamable(stream)); -}; - export const WenxinStream = ( rawStream: ReadableStream, callbacks?: ChatStreamCallbacks, diff --git a/src/libs/agent-runtime/wenxin/index.ts b/src/libs/agent-runtime/wenxin/index.ts index f5ae769d5af3..f597fcf92d92 100644 --- a/src/libs/agent-runtime/wenxin/index.ts +++ b/src/libs/agent-runtime/wenxin/index.ts @@ -10,7 +10,8 @@ import { ChatCompetitionOptions, ChatStreamPayload } from '../types'; import { AgentRuntimeError } from '../utils/createError'; import { debugStream } from '../utils/debugStream'; import { StreamingResponse } from '../utils/response'; -import { WenxinResultToStream, WenxinStream } from '../utils/streams/wenxin'; +import { convertIterableToStream } from '../utils/streams'; +import { WenxinStream } from '../utils/streams/wenxin'; import { ChatResp } from './type'; interface ChatErrorCode { @@ -46,7 +47,7 @@ export class LobeWenxinAI implements LobeRuntimeAI { payload.model, ); - const wenxinStream = WenxinResultToStream(result as AsyncIterable); + const wenxinStream = convertIterableToStream(result as AsyncIterable); const [prod, useForDebug] = wenxinStream.tee(); diff --git a/src/libs/agent-runtime/zhipu/index.test.ts b/src/libs/agent-runtime/zhipu/index.test.ts index 433ae236e24c..e2f0ebf6f6c0 100644 --- a/src/libs/agent-runtime/zhipu/index.test.ts +++ b/src/libs/agent-runtime/zhipu/index.test.ts @@ -2,7 +2,7 @@ import { OpenAI } from 'openai'; import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; -import { ChatStreamCallbacks, LobeOpenAI } from '@/libs/agent-runtime'; +import { ChatStreamCallbacks, LobeOpenAI, LobeOpenAICompatibleRuntime } from '@/libs/agent-runtime'; import * as debugStreamModule from '@/libs/agent-runtime/utils/debugStream'; import * as authTokenModule from './authToken'; @@ -24,28 +24,11 @@ describe('LobeZhipuAI', () => { vi.restoreAllMocks(); }); - describe('fromAPIKey', () => { - it('should correctly initialize with an API key', async () => { - const lobeZhipuAI = await LobeZhipuAI.fromAPIKey({ apiKey: 'test_api_key' }); - expect(lobeZhipuAI).toBeInstanceOf(LobeZhipuAI); - expect(lobeZhipuAI.baseURL).toEqual('https://open.bigmodel.cn/api/paas/v4'); - }); - - it('should throw an error if API key is invalid', async () => { - vi.spyOn(authTokenModule, 'generateApiToken').mockRejectedValue(new Error('Invalid API Key')); - try { - await LobeZhipuAI.fromAPIKey({ apiKey: 'asd' }); - } catch (e) { - expect(e).toEqual({ errorType: invalidErrorType }); - } - }); - }); - describe('chat', () => { - let instance: LobeZhipuAI; + let instance: LobeOpenAICompatibleRuntime; beforeEach(async () => { - instance = await LobeZhipuAI.fromAPIKey({ + instance = new LobeZhipuAI({ apiKey: 'test_api_key', }); @@ -131,9 +114,9 @@ describe('LobeZhipuAI', () => { const calledWithParams = spyOn.mock.calls[0][0]; expect(calledWithParams.messages[1].content).toEqual([{ type: 'text', text: 'Hello again' }]); - expect(calledWithParams.temperature).toBeUndefined(); // temperature 0 should be undefined + expect(calledWithParams.temperature).toBe(0); // temperature 0 should be undefined expect((calledWithParams as any).do_sample).toBeTruthy(); // temperature 0 should be undefined - expect(calledWithParams.top_p).toEqual(0.99); // top_p should be transformed correctly + expect(calledWithParams.top_p).toEqual(1); // top_p should be transformed correctly }); describe('Error', () => { @@ -175,7 +158,7 @@ describe('LobeZhipuAI', () => { it('should throw AgentRuntimeError with NoOpenAIAPIKey if no apiKey is provided', async () => { try { - await LobeZhipuAI.fromAPIKey({ apiKey: '' }); + new LobeZhipuAI({ apiKey: '' }); } catch (e) { expect(e).toEqual({ errorType: invalidErrorType }); } @@ -221,7 +204,7 @@ describe('LobeZhipuAI', () => { }; const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {}); - instance = await LobeZhipuAI.fromAPIKey({ + instance = new LobeZhipuAI({ apiKey: 'test', baseURL: 'https://abc.com/v2', diff --git a/src/libs/agent-runtime/zhipu/index.ts b/src/libs/agent-runtime/zhipu/index.ts index ff821a6245f2..5f92756e234c 100644 --- a/src/libs/agent-runtime/zhipu/index.ts +++ b/src/libs/agent-runtime/zhipu/index.ts @@ -1,99 +1,21 @@ -import OpenAI, { ClientOptions } from 'openai'; - -import { LobeRuntimeAI } from '../BaseAI'; -import { AgentRuntimeErrorType } from '../error'; -import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types'; -import { AgentRuntimeError } from '../utils/createError'; -import { debugStream } from '../utils/debugStream'; -import { desensitizeUrl } from '../utils/desensitizeUrl'; -import { handleOpenAIError } from '../utils/handleOpenAIError'; -import { convertOpenAIMessages } from '../utils/openaiHelpers'; -import { StreamingResponse } from '../utils/response'; -import { OpenAIStream } from '../utils/streams'; -import { generateApiToken } from './authToken'; - -const DEFAULT_BASE_URL = 'https://open.bigmodel.cn/api/paas/v4'; - -export class LobeZhipuAI implements LobeRuntimeAI { - private client: OpenAI; - - baseURL: string; - - constructor(oai: OpenAI) { - this.client = oai; - this.baseURL = this.client.baseURL; - } - - static async fromAPIKey({ apiKey, baseURL = DEFAULT_BASE_URL, ...res }: ClientOptions = {}) { - const invalidZhipuAPIKey = AgentRuntimeError.createError( - AgentRuntimeErrorType.InvalidProviderAPIKey, - ); - - if (!apiKey) throw invalidZhipuAPIKey; - - let token: string; - - try { - token = await generateApiToken(apiKey); - } catch { - throw invalidZhipuAPIKey; - } - - const header = { Authorization: `Bearer ${token}` }; - const llm = new OpenAI({ apiKey, baseURL, defaultHeaders: header, ...res }); - - return new LobeZhipuAI(llm); - } - - async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) { - try { - const params = await this.buildCompletionsParams(payload); - - const response = await this.client.chat.completions.create( - params as unknown as OpenAI.ChatCompletionCreateParamsStreaming, - ); - - const [prod, debug] = response.tee(); - - if (process.env.DEBUG_ZHIPU_CHAT_COMPLETION === '1') { - debugStream(debug.toReadableStream()).catch(console.error); - } - - return StreamingResponse(OpenAIStream(prod, options?.callback), { - headers: options?.headers, - }); - } catch (error) { - const { errorResult, RuntimeError } = handleOpenAIError(error); - - const errorType = RuntimeError || AgentRuntimeErrorType.ProviderBizError; - let desensitizedEndpoint = this.baseURL; - - if (this.baseURL !== DEFAULT_BASE_URL) { - desensitizedEndpoint = desensitizeUrl(this.baseURL); - } - throw AgentRuntimeError.chat({ - endpoint: desensitizedEndpoint, - error: errorResult, - errorType, - provider: ModelProvider.ZhiPu, - }); - } - } - - private async buildCompletionsParams(payload: ChatStreamPayload) { - const { messages, temperature, top_p, ...params } = payload; - - return { - messages: await convertOpenAIMessages(messages as any), - ...params, - do_sample: temperature === 0, - stream: true, - // 当前的模型侧不支持 top_p=1 和 temperature 为 0 - // refs: https://zhipu-ai.feishu.cn/wiki/TUo0w2LT7iswnckmfSEcqTD0ncd - temperature: temperature === 0 ? undefined : temperature, - top_p: top_p === 1 ? 0.99 : top_p, - }; - } -} - -export default LobeZhipuAI; +import OpenAI from 'openai'; + +import { ChatStreamPayload, ModelProvider } from '../types'; +import { LobeOpenAICompatibleFactory } from '../utils/openaiCompatibleFactory'; + +export const LobeZhipuAI = LobeOpenAICompatibleFactory({ + baseURL: 'https://open.bigmodel.cn/api/paas/v4', + chatCompletion: { + handlePayload: ({ temperature, ...payload }: ChatStreamPayload) => + ({ + ...payload, + do_sample: temperature === 0, + stream: true, + temperature, + }) as OpenAI.ChatCompletionCreateParamsStreaming, + }, + debug: { + chatCompletion: () => process.env.DEBUG_ZHIPU_CHAT_COMPLETION === '1', + }, + provider: ModelProvider.ZhiPu, +});