diff --git a/js/plugins/ollama/src/embeddings.ts b/js/plugins/ollama/src/embeddings.ts index 9907257f2..4fe9e1946 100644 --- a/js/plugins/ollama/src/embeddings.ts +++ b/js/plugins/ollama/src/embeddings.ts @@ -1,4 +1,20 @@ -import { defineEmbedder, EmbedderReference } from '@genkit-ai/ai/embedder'; +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { defineEmbedder } from '@genkit-ai/ai/embedder'; import { logger } from '@genkit-ai/core/logging'; import z from 'zod'; import { OllamaPluginParams } from './index.js'; @@ -19,18 +35,26 @@ interface OllamaEmbeddingPrediction { embedding: number[]; } -export function defineOllamaEmbedder( - name: string, - modelName: string, - dimensions: number, - options: OllamaPluginParams -): EmbedderReference { +interface DefineOllamaEmbeddingParams { + name: string; + modelName: string; + dimensions: number; + options: OllamaPluginParams; +} + +export function defineOllamaEmbedder({ + name, + modelName, + dimensions, + options, +}: DefineOllamaEmbeddingParams) { return defineEmbedder( { name, configSchema: OllamaEmbeddingConfigSchema, // Use the Zod schema directly here info: { - label: 'Embedding using Ollama', + // TODO: do we want users to be able to specify the label when they call this method directly? + label: 'Embedding using Ollama - ' + modelName, dimensions, supports: { // TODO: do any ollama models support other modalities? diff --git a/js/plugins/ollama/src/index.ts b/js/plugins/ollama/src/index.ts index 7a7bbb636..87296ee28 100644 --- a/js/plugins/ollama/src/index.ts +++ b/js/plugins/ollama/src/index.ts @@ -25,6 +25,7 @@ import { } from '@genkit-ai/ai/model'; import { genkitPlugin, Plugin } from '@genkit-ai/core'; import { logger } from '@genkit-ai/core/logging'; +import { defineOllamaEmbedder } from './embeddings'; type ApiType = 'chat' | 'generate'; @@ -37,8 +38,11 @@ type RequestHeaders = type ModelDefinition = { name: string; type?: ApiType }; +type EmbeddingModelDefinition = { name: string; dimensions: number }; + export interface OllamaPluginParams { models: ModelDefinition[]; + embeddingModels?: EmbeddingModelDefinition[]; /** * ollama server address. */ @@ -51,10 +55,19 @@ export const ollama: Plugin<[OllamaPluginParams]> = genkitPlugin( 'ollama', async (params: OllamaPluginParams) => { const serverAddress = params?.serverAddress; + return { models: params.models.map((model) => ollamaModel(model, serverAddress, params.requestHeaders) ), + embedders: params.embeddingModels?.map((model) => + defineOllamaEmbedder({ + name: `${ollama}/model.name`, + modelName: model.name, + dimensions: model.dimensions, + options: params, + }) + ), }; } ); diff --git a/js/plugins/ollama/tests/embeddings_live_test.ts b/js/plugins/ollama/tests/embeddings_live_test.ts index 70d144ab9..8cbab30ef 100644 --- a/js/plugins/ollama/tests/embeddings_live_test.ts +++ b/js/plugins/ollama/tests/embeddings_live_test.ts @@ -1,3 +1,19 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + import { embed } from '@genkit-ai/ai/embedder'; import assert from 'node:assert'; import { describe, it } from 'node:test'; @@ -26,12 +42,12 @@ describe('defineOllamaEmbedder - Live Tests', () => { }; it('should successfully return embeddings', async () => { - const embedder = defineOllamaEmbedder( - 'live-test-embedder', - 'nomic-embed-text', - 768, - options - ); + const embedder = defineOllamaEmbedder({ + name: 'live-test-embedder', + modelName: 'nomic-embed-text', + dimensions: 768, + options, + }); const result = await embed({ embedder, diff --git a/js/plugins/ollama/tests/embeddings_test.ts b/js/plugins/ollama/tests/embeddings_test.ts index ca9327976..10d2407e7 100644 --- a/js/plugins/ollama/tests/embeddings_test.ts +++ b/js/plugins/ollama/tests/embeddings_test.ts @@ -1,118 +1,135 @@ -// import { embed } from '@genkit-ai/ai/embedder'; -// import assert from 'node:assert'; -// import { describe, it } from 'node:test'; -// import { -// defineOllamaEmbedder, -// OllamaEmbeddingConfigSchema, -// } from '../src/embeddings.js'; // Adjust the import path as necessary -// import { OllamaPluginParams } from '../src/index.js'; // Adjust the import path as necessary - -// // Mock fetch to simulate API responses -// global.fetch = async (input: RequestInfo | URL, options?: RequestInit) => { -// const url = typeof input === 'string' ? input : input.toString(); - -// if (url.includes('/api/embedding')) { -// if (options?.body && JSON.stringify(options.body).includes('fail')) { -// return { -// ok: false, -// statusText: 'Internal Server Error', -// json: async () => ({}), -// } as Response; -// } -// return { -// ok: true, -// json: async () => ({ -// embeddings: { -// values: [0.1, 0.2, 0.3], // Example embedding values -// }, -// }), -// } as Response; -// } - -// throw new Error('Unknown API endpoint'); -// }; - -// describe('defineOllamaEmbedder', () => { -// const options: OllamaPluginParams = { -// models: [{ name: 'test-model' }], -// serverAddress: 'http://localhost:3000', -// }; - -// it('should successfully return embeddings', async () => { -// const embedder = defineOllamaEmbedder( -// 'test-embedder', -// 'test-model', -// options -// ); - -// const result = await embed({ -// embedder, -// content: 'Hello, world!', -// }); -// assert.deepStrictEqual(result, [0.1, 0.2, 0.3]); -// }); - -// it('should handle API errors correctly', async () => { -// const embedder = defineOllamaEmbedder( -// 'test-embedder', -// 'test-model', -// options -// ); - -// await assert.rejects( -// async () => { -// await embed({ -// embedder, -// content: 'fail', -// }); -// }, -// (error) => { -// // Check if error is an instance of Error -// assert(error instanceof Error); - -// assert.strictEqual( -// error.message, -// 'Error fetching embedding from Ollama: Internal Server Error' -// ); -// return true; -// } -// ); -// }); - -// it('should validate the embedding configuration schema', async () => { -// const validConfig = { -// modelName: 'test-model', -// serverAddress: 'http://localhost:3000', -// }; - -// const invalidConfig = { -// modelName: 123, // Invalid type -// serverAddress: 'http://localhost:3000', -// }; - -// // Valid configuration should pass -// assert.doesNotThrow(() => { -// OllamaEmbeddingConfigSchema.parse(validConfig); -// }); - -// // Invalid configuration should throw -// assert.throws(() => { -// OllamaEmbeddingConfigSchema.parse(invalidConfig); -// }); -// }); - -// it('should throw an error if the fetch response is not ok', async () => { -// const embedder = defineOllamaEmbedder( -// 'test-embedder', -// 'test-model', -// options -// ); - -// await assert.rejects(async () => { -// await embed({ -// embedder, -// content: 'fail', -// }); -// }, new Error('Error fetching embedding from Ollama: Internal Server Error')); -// }); -// }); +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { embed } from '@genkit-ai/ai/embedder'; +import assert from 'node:assert'; +import { describe, it } from 'node:test'; +import { + OllamaEmbeddingConfigSchema, + defineOllamaEmbedder, +} from '../src/embeddings.js'; // Adjust the import path as necessary +import { OllamaPluginParams } from '../src/index.js'; // Adjust the import path as necessary + +// Mock fetch to simulate API responses +global.fetch = async (input: RequestInfo | URL, options?: RequestInit) => { + const url = typeof input === 'string' ? input : input.toString(); + + if (url.includes('/api/embedding')) { + if (options?.body && JSON.stringify(options.body).includes('fail')) { + return { + ok: false, + statusText: 'Internal Server Error', + json: async () => ({}), + } as Response; + } + return { + ok: true, + json: async () => ({ + embedding: [0.1, 0.2, 0.3], // Example embedding values + }), + } as Response; + } + + throw new Error('Unknown API endpoint'); +}; + +describe('defineOllamaEmbedder', () => { + const options: OllamaPluginParams = { + models: [{ name: 'test-model' }], + serverAddress: 'http://localhost:3000', + }; + + it('should successfully return embeddings', async () => { + const embedder = defineOllamaEmbedder({ + name: 'test-embedder', + modelName: 'test-model', + dimensions: 123, + options, + }); + + const result = await embed({ + embedder, + content: 'Hello, world!', + }); + assert.deepStrictEqual(result, [0.1, 0.2, 0.3]); + }); + + it('should handle API errors correctly', async () => { + const embedder = defineOllamaEmbedder({ + name: 'test-embedder', + modelName: 'test-model', + dimensions: 123, + options, + }); + + await assert.rejects( + async () => { + await embed({ + embedder, + content: 'fail', + }); + }, + (error) => { + // Check if error is an instance of Error + assert(error instanceof Error); + + assert.strictEqual( + error.message, + 'Error fetching embedding from Ollama: Internal Server Error' + ); + return true; + } + ); + }); + + it('should validate the embedding configuration schema', async () => { + const validConfig = { + modelName: 'test-model', + serverAddress: 'http://localhost:3000', + }; + + const invalidConfig = { + modelName: 123, // Invalid type + serverAddress: 'http://localhost:3000', + }; + + // Valid configuration should pass + assert.doesNotThrow(() => { + OllamaEmbeddingConfigSchema.parse(validConfig); + }); + + // Invalid configuration should throw + assert.throws(() => { + OllamaEmbeddingConfigSchema.parse(invalidConfig); + }); + }); + + it('should throw an error if the fetch response is not ok', async () => { + const embedder = defineOllamaEmbedder({ + name: 'test-embedder', + modelName: 'test-model', + dimensions: 123, + options, + }); + + await assert.rejects(async () => { + await embed({ + embedder, + content: 'fail', + }); + }, new Error('Error fetching embedding from Ollama: Internal Server Error')); + }); +});