diff --git a/src/database/server/models/__tests__/file.test.ts b/src/database/server/models/__tests__/file.test.ts index b02a89bd4ebb..206b980a9033 100644 --- a/src/database/server/models/__tests__/file.test.ts +++ b/src/database/server/models/__tests__/file.test.ts @@ -35,6 +35,9 @@ vi.mock('@/config/db', async () => ({ DATABASE_DRIVER: 'node', }; }, + getServerDBConfig: vi.fn().mockReturnValue({ + NEXT_PUBLIC_ENABLED_SERVER_SERVICE: true, + }), })); const userId = 'file-model-test-user-id'; diff --git a/src/libs/agent-runtime/bedrock/index.test.ts b/src/libs/agent-runtime/bedrock/index.test.ts index 5949ab706c27..ab3fb5708ab7 100644 --- a/src/libs/agent-runtime/bedrock/index.test.ts +++ b/src/libs/agent-runtime/bedrock/index.test.ts @@ -2,7 +2,7 @@ import { InvokeModelWithResponseStreamCommand } from '@aws-sdk/client-bedrock-runtime'; import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; -import { AgentRuntimeErrorType, ModelProvider } from '@/libs/agent-runtime'; +import { AgentRuntimeError, AgentRuntimeErrorType, ModelProvider } from '@/libs/agent-runtime'; import * as debugStreamModule from '../utils/debugStream'; import { LobeBedrockAI } from './index'; @@ -393,18 +393,57 @@ describe('LobeBedrockAI', () => { }); describe('embeddings', () => { - it('should call invokeEmbeddingsModel with correct payload', async () => { - process.env.DEBUG_BEDROCK_CHAT_COMPLETION = '1'; - // @ts-ignore - const spy = vi.spyOn(instance, 'invokeEmbeddingModel'); - await instance.embeddings({ - dimensions: 1024, - input: 'Hello,World', - model: 'amazon.titan-embed-text-v2:0', + it('should return an array of EmbeddingItems', async () => { + // Arrange + const mockEmbeddingItem = { embedding: [0.1, 0.2], index: 0, object: 'embedding' }; + + const spy = vi + .spyOn(instance as any, 'invokeEmbeddingModel') + .mockResolvedValue(mockEmbeddingItem); + + // Act + const result = await instance.embeddings({ + input: ['Hello'], + dimensions: 128, + model: 'test-model', }); + // Assert - expect(spy).toHaveBeenCalled; - delete process.env.DEBUG_BEDROCK_CHAT_COMPLETION; + expect(spy).toHaveBeenCalled(); + expect(result).toEqual([mockEmbeddingItem]); + }); + + it('should call instance.embeddings with correct parameters', async () => { + // Arrange + const payload = { + dimensions: 1024, + index: 0, + input: 'Hello', + modelId: 'test-model', + model: 'test-model', // Add the missing model property + }; + + const apiError = AgentRuntimeError.chat({ + error: { + body: undefined, + message: 'Unexpected end of JSON input', + type: 'SyntaxError', + }, + errorType: AgentRuntimeErrorType.ProviderBizError, + provider: ModelProvider.Bedrock, + region: 'us-west-2', + }); + + // 使用 vi.spyOn 来模拟 instance.embeddings 方法 + const spy = vi.spyOn(instance as any, 'invokeEmbeddingModel').mockRejectedValue(apiError); + + try { + // Act + await instance['invokeEmbeddingModel'](payload); + } catch (e) { + expect(e).toEqual(apiError); + } + expect(spy).toHaveBeenCalled(); }); }); }); diff --git a/src/libs/agent-runtime/bedrock/index.ts b/src/libs/agent-runtime/bedrock/index.ts index f4b0f9599191..d55ddfcac35d 100644 --- a/src/libs/agent-runtime/bedrock/index.ts +++ b/src/libs/agent-runtime/bedrock/index.ts @@ -25,13 +25,6 @@ import { createBedrockStream, } from '../utils/streams'; -interface BedRockEmbeddingsParams { - dimensions: number; - index: number; - input: string; - modelId: string; -} - export interface LobeBedrockAIParams { accessKeyId?: string; accessKeySecret?: string; @@ -75,8 +68,8 @@ export class LobeBedrockAI implements LobeRuntimeAI { dimensions: payload.dimensions, index: index, input: inputText, - modelId: payload.model, - } as BedRockEmbeddingsParams, + model: payload.model, + }, options, ), ); @@ -84,7 +77,7 @@ export class LobeBedrockAI implements LobeRuntimeAI { } private invokeEmbeddingModel = async ( - payload: BedRockEmbeddingsParams, + payload: EmbeddingsPayload, options?: EmbeddingsOptions, ): Promise => { const command = new InvokeModelCommand({ @@ -95,12 +88,16 @@ export class LobeBedrockAI implements LobeRuntimeAI { normalize: true, }), contentType: 'application/json', - modelId: payload.modelId, + modelId: payload.model, }); try { const res = await this.client.send(command, { abortSignal: options?.signal }); const responseBody = JSON.parse(new TextDecoder().decode(res.body)); - return { embedding: responseBody.embedding, index: payload.index, object: 'embedding' }; + return { + embedding: responseBody.embedding, + index: payload.index as number, + object: 'embedding', + }; } catch (e) { const err = e as Error & { $metadata: any }; throw AgentRuntimeError.chat({ diff --git a/src/libs/agent-runtime/ollama/index.test.ts b/src/libs/agent-runtime/ollama/index.test.ts index d48cebfcc3de..1611637e9aac 100644 --- a/src/libs/agent-runtime/ollama/index.test.ts +++ b/src/libs/agent-runtime/ollama/index.test.ts @@ -132,6 +132,62 @@ describe('LobeOllamaAI', () => { }); }); + describe('embeddings', () => { + it('should return an array of EmbeddingItems', async () => { + // Arrange + const mockEmbeddingItem = { embedding: [0.1, 0.2], index: 0, object: 'embedding' }; + + const spy = vi + .spyOn(ollamaAI as any, 'invokeEmbeddingModel') + .mockResolvedValue(mockEmbeddingItem); + + // Act + const result = await ollamaAI.embeddings({ + input: ['Hello'], + dimensions: 128, + model: 'test-model', + index: 0, + }); + + // Assert + expect(spy).toHaveBeenCalled(); + expect(result).toEqual([mockEmbeddingItem]); + }); + + it('should call instance.embeddings with correct parameters', async () => { + // Arrange + const payload = { + dimensions: 1024, + index: 0, + input: 'Hello', + modelId: 'test-model', + model: 'test-model', // Add the missing model property + }; + + const apiError = AgentRuntimeError.chat({ + error: { + body: undefined, + message: 'Unexpected end of JSON input', + type: 'SyntaxError', + }, + errorType: AgentRuntimeErrorType.ProviderBizError, + provider: ModelProvider.Bedrock, + region: 'us-west-2', + }); + + // 使用 vi.spyOn 来模拟 instance.embeddings 方法 + const spy = vi.spyOn(ollamaAI as any, 'invokeEmbeddingModel').mockRejectedValue(apiError); + + try { + // Act + await ollamaAI['invokeEmbeddingModel'](payload); + } catch (e) { + expect(e).toEqual(apiError); + } + expect(spy).toHaveBeenCalled(); + }); + }); + describe('models', () => { it('should call Ollama client list method and return ChatModelCard array', async () => { const listMock = vi.fn().mockResolvedValue({ diff --git a/src/libs/agent-runtime/ollama/index.ts b/src/libs/agent-runtime/ollama/index.ts index 34042e49c056..79ea35b7e619 100644 --- a/src/libs/agent-runtime/ollama/index.ts +++ b/src/libs/agent-runtime/ollama/index.ts @@ -74,14 +74,12 @@ export class LobeOllamaAI implements LobeRuntimeAI { async embeddings(payload: EmbeddingsPayload): Promise { const input = Array.isArray(payload.input) ? payload.input : [payload.input]; const promises = input.map((inputText: string, index: number) => - this.invokeEmbeddingModel( - { - dimensions: payload.dimensions, - input: inputText, - model: payload.model, - }, - index, - ), + this.invokeEmbeddingModel({ + dimensions: payload.dimensions, + index: index, + input: inputText, + model: payload.model, + }), ); return await Promise.all(promises); } @@ -93,16 +91,17 @@ export class LobeOllamaAI implements LobeRuntimeAI { })); } - private invokeEmbeddingModel = async ( - payload: EmbeddingsPayload, - index: number, - ): Promise => { + private invokeEmbeddingModel = async (payload: EmbeddingsPayload): Promise => { try { const responseBody = await this.client.embeddings({ model: payload.model, prompt: payload.input as string, }); - return { embedding: responseBody.embedding, index: index, object: 'embedding' }; + return { + embedding: responseBody.embedding, + index: payload.index as number, + object: 'embedding', + }; } catch (error) { const e = error as { message: string; name: string; status_code: number }; diff --git a/src/libs/agent-runtime/types/embeddings.ts b/src/libs/agent-runtime/types/embeddings.ts index b13bc442874f..e3433a93da47 100644 --- a/src/libs/agent-runtime/types/embeddings.ts +++ b/src/libs/agent-runtime/types/embeddings.ts @@ -4,6 +4,7 @@ export interface EmbeddingsPayload { * supported in `text-embedding-3` and later models. */ dimensions?: number; + index?: number; /** * Input text to embed, encoded as a string or array of tokens. To embed multiple * inputs in a single request, pass an array of strings . diff --git a/src/utils/fetch/__tests__/fetchSSE.test.ts b/src/utils/fetch/__tests__/fetchSSE.test.ts index c1a8acbd9730..2b3469ca4395 100644 --- a/src/utils/fetch/__tests__/fetchSSE.test.ts +++ b/src/utils/fetch/__tests__/fetchSSE.test.ts @@ -437,7 +437,7 @@ describe('fetchSSE', () => { context: { chunk: 'abc', error: { - message: 'Unexpected token a in JSON at position 0', + message: `Unexpected token 'a', \"abc\" is not valid JSON`, name: 'SyntaxError', }, }, diff --git a/tests/setup-db.ts b/tests/setup-db.ts index ae39ca20ee84..71512da2b811 100644 --- a/tests/setup-db.ts +++ b/tests/setup-db.ts @@ -4,4 +4,7 @@ import * as dotenv from 'dotenv'; dotenv.config(); -global.crypto = new Crypto(); +Object.defineProperty(global, 'crypto', { + configurable: true, + value: new Crypto(), +});