Skip to content

Commit

Permalink
🔨 chore: add embedding ci file
Browse files Browse the repository at this point in the history
  • Loading branch information
cookieY committed Oct 15, 2024
1 parent ac13474 commit 42a231b
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 38 deletions.
3 changes: 3 additions & 0 deletions src/database/server/models/__tests__/file.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ vi.mock('@/config/db', async () => ({
DATABASE_DRIVER: 'node',
};
},
getServerDBConfig: vi.fn().mockReturnValue({
NEXT_PUBLIC_ENABLED_SERVER_SERVICE: true,
}),
}));

const userId = 'file-model-test-user-id';
Expand Down
61 changes: 50 additions & 11 deletions src/libs/agent-runtime/bedrock/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import { InvokeModelWithResponseStreamCommand } from '@aws-sdk/client-bedrock-runtime';
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

import { AgentRuntimeErrorType, ModelProvider } from '@/libs/agent-runtime';
import { AgentRuntimeError, AgentRuntimeErrorType, ModelProvider } from '@/libs/agent-runtime';

import * as debugStreamModule from '../utils/debugStream';
import { LobeBedrockAI } from './index';
Expand Down Expand Up @@ -393,18 +393,57 @@ describe('LobeBedrockAI', () => {
});

describe('embeddings', () => {
it('should call invokeEmbeddingsModel with correct payload', async () => {
process.env.DEBUG_BEDROCK_CHAT_COMPLETION = '1';
// @ts-ignore
const spy = vi.spyOn(instance, 'invokeEmbeddingModel');
await instance.embeddings({
dimensions: 1024,
input: 'Hello,World',
model: 'amazon.titan-embed-text-v2:0',
it('should return an array of EmbeddingItems', async () => {
// Arrange
const mockEmbeddingItem = { embedding: [0.1, 0.2], index: 0, object: 'embedding' };

const spy = vi
.spyOn(instance as any, 'invokeEmbeddingModel')
.mockResolvedValue(mockEmbeddingItem);

// Act
const result = await instance.embeddings({
input: ['Hello'],
dimensions: 128,
model: 'test-model',
});

// Assert
expect(spy).toHaveBeenCalled;
delete process.env.DEBUG_BEDROCK_CHAT_COMPLETION;
expect(spy).toHaveBeenCalled();
expect(result).toEqual([mockEmbeddingItem]);
});

it('should call instance.embeddings with correct parameters', async () => {
// Arrange
const payload = {
dimensions: 1024,
index: 0,
input: 'Hello',
modelId: 'test-model',
model: 'test-model', // Add the missing model property
};

const apiError = AgentRuntimeError.chat({
error: {
body: undefined,
message: 'Unexpected end of JSON input',
type: 'SyntaxError',
},
errorType: AgentRuntimeErrorType.ProviderBizError,
provider: ModelProvider.Bedrock,
region: 'us-west-2',
});

// 使用 vi.spyOn 来模拟 instance.embeddings 方法
const spy = vi.spyOn(instance as any, 'invokeEmbeddingModel').mockRejectedValue(apiError);

try {
// Act
await instance['invokeEmbeddingModel'](payload);
} catch (e) {
expect(e).toEqual(apiError);
}
expect(spy).toHaveBeenCalled();
});
});
});
21 changes: 9 additions & 12 deletions src/libs/agent-runtime/bedrock/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,6 @@ import {
createBedrockStream,
} from '../utils/streams';

interface BedRockEmbeddingsParams {
dimensions: number;
index: number;
input: string;
modelId: string;
}

export interface LobeBedrockAIParams {
accessKeyId?: string;
accessKeySecret?: string;
Expand Down Expand Up @@ -75,16 +68,16 @@ export class LobeBedrockAI implements LobeRuntimeAI {
dimensions: payload.dimensions,
index: index,
input: inputText,
modelId: payload.model,
} as BedRockEmbeddingsParams,
model: payload.model,
},
options,
),
);
return Promise.all(promises);
}

private invokeEmbeddingModel = async (
payload: BedRockEmbeddingsParams,
payload: EmbeddingsPayload,
options?: EmbeddingsOptions,
): Promise<EmbeddingItem> => {
const command = new InvokeModelCommand({
Expand All @@ -95,12 +88,16 @@ export class LobeBedrockAI implements LobeRuntimeAI {
normalize: true,
}),
contentType: 'application/json',
modelId: payload.modelId,
modelId: payload.model,
});
try {
const res = await this.client.send(command, { abortSignal: options?.signal });
const responseBody = JSON.parse(new TextDecoder().decode(res.body));
return { embedding: responseBody.embedding, index: payload.index, object: 'embedding' };
return {
embedding: responseBody.embedding,
index: payload.index as number,
object: 'embedding',
};
} catch (e) {
const err = e as Error & { $metadata: any };
throw AgentRuntimeError.chat({
Expand Down
56 changes: 56 additions & 0 deletions src/libs/agent-runtime/ollama/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,62 @@ describe('LobeOllamaAI', () => {
});
});

describe('embeddings', () => {
it('should return an array of EmbeddingItems', async () => {
// Arrange
const mockEmbeddingItem = { embedding: [0.1, 0.2], index: 0, object: 'embedding' };

const spy = vi
.spyOn(ollamaAI as any, 'invokeEmbeddingModel')
.mockResolvedValue(mockEmbeddingItem);

// Act
const result = await ollamaAI.embeddings({
input: ['Hello'],
dimensions: 128,
model: 'test-model',
index: 0,
});

// Assert
expect(spy).toHaveBeenCalled();
expect(result).toEqual([mockEmbeddingItem]);
});

it('should call instance.embeddings with correct parameters', async () => {
// Arrange
const payload = {
dimensions: 1024,
index: 0,
input: 'Hello',
modelId: 'test-model',
model: 'test-model', // Add the missing model property
};

const apiError = AgentRuntimeError.chat({
error: {
body: undefined,
message: 'Unexpected end of JSON input',
type: 'SyntaxError',
},
errorType: AgentRuntimeErrorType.ProviderBizError,
provider: ModelProvider.Bedrock,
region: 'us-west-2',
});

// 使用 vi.spyOn 来模拟 instance.embeddings 方法
const spy = vi.spyOn(ollamaAI as any, 'invokeEmbeddingModel').mockRejectedValue(apiError);

try {
// Act
await ollamaAI['invokeEmbeddingModel'](payload);
} catch (e) {
expect(e).toEqual(apiError);
}
expect(spy).toHaveBeenCalled();
});
});

describe('models', () => {
it('should call Ollama client list method and return ChatModelCard array', async () => {
const listMock = vi.fn().mockResolvedValue({
Expand Down
25 changes: 12 additions & 13 deletions src/libs/agent-runtime/ollama/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,12 @@ export class LobeOllamaAI implements LobeRuntimeAI {
async embeddings(payload: EmbeddingsPayload): Promise<EmbeddingItem[]> {
const input = Array.isArray(payload.input) ? payload.input : [payload.input];
const promises = input.map((inputText: string, index: number) =>
this.invokeEmbeddingModel(
{
dimensions: payload.dimensions,
input: inputText,
model: payload.model,
},
index,
),
this.invokeEmbeddingModel({
dimensions: payload.dimensions,
index: index,
input: inputText,
model: payload.model,
}),
);
return await Promise.all(promises);
}
Expand All @@ -93,16 +91,17 @@ export class LobeOllamaAI implements LobeRuntimeAI {
}));
}

private invokeEmbeddingModel = async (
payload: EmbeddingsPayload,
index: number,
): Promise<EmbeddingItem> => {
private invokeEmbeddingModel = async (payload: EmbeddingsPayload): Promise<EmbeddingItem> => {
try {
const responseBody = await this.client.embeddings({
model: payload.model,
prompt: payload.input as string,
});
return { embedding: responseBody.embedding, index: index, object: 'embedding' };
return {
embedding: responseBody.embedding,
index: payload.index as number,
object: 'embedding',
};
} catch (error) {
const e = error as { message: string; name: string; status_code: number };

Expand Down
1 change: 1 addition & 0 deletions src/libs/agent-runtime/types/embeddings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export interface EmbeddingsPayload {
* supported in `text-embedding-3` and later models.
*/
dimensions?: number;
index?: number;
/**
* Input text to embed, encoded as a string or array of tokens. To embed multiple
* inputs in a single request, pass an array of strings .
Expand Down
2 changes: 1 addition & 1 deletion src/utils/fetch/__tests__/fetchSSE.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ describe('fetchSSE', () => {
context: {
chunk: 'abc',
error: {
message: 'Unexpected token a in JSON at position 0',
message: `Unexpected token 'a', \"abc\" is not valid JSON`,
name: 'SyntaxError',
},
},
Expand Down
5 changes: 4 additions & 1 deletion tests/setup-db.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@ import * as dotenv from 'dotenv';

dotenv.config();

global.crypto = new Crypto();
Object.defineProperty(global, 'crypto', {
configurable: true,
value: new Crypto(),
});

0 comments on commit 42a231b

Please sign in to comment.