From 057788f25d2424a9fc1867d2042ae06650428a98 Mon Sep 17 00:00:00 2001 From: Jacob Cable Date: Tue, 8 Oct 2024 09:22:56 +0100 Subject: [PATCH] feat(js/plugins/ollama): add ollama embeddings back in (after rollback) for main --- js/plugins/ollama/package.json | 3 + js/plugins/ollama/src/embeddings.ts | 96 ++++++++++++++ js/plugins/ollama/src/index.ts | 12 +- .../ollama/tests/embeddings_live_test.ts | 51 ++++++++ js/plugins/ollama/tests/embeddings_test.ts | 119 ++++++++++++++++++ js/pnpm-lock.yaml | 3 + 6 files changed, 283 insertions(+), 1 deletion(-) create mode 100644 js/plugins/ollama/src/embeddings.ts create mode 100644 js/plugins/ollama/tests/embeddings_live_test.ts create mode 100644 js/plugins/ollama/tests/embeddings_test.ts diff --git a/js/plugins/ollama/package.json b/js/plugins/ollama/package.json index 30f9190a23..29cb72c119 100644 --- a/js/plugins/ollama/package.json +++ b/js/plugins/ollama/package.json @@ -26,6 +26,9 @@ }, "author": "genkit", "license": "Apache-2.0", + "dependencies": { + "zod": "^3.22.4" + }, "peerDependencies": { "@genkit-ai/ai": "workspace:*", "@genkit-ai/core": "workspace:*" diff --git a/js/plugins/ollama/src/embeddings.ts b/js/plugins/ollama/src/embeddings.ts new file mode 100644 index 0000000000..0b991672fb --- /dev/null +++ b/js/plugins/ollama/src/embeddings.ts @@ -0,0 +1,96 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { defineEmbedder } from '@genkit-ai/ai/embedder'; +import { logger } from '@genkit-ai/core/logging'; +import z from 'zod'; +import { OllamaPluginParams } from './index.js'; +// Define the schema for Ollama embedding configuration +export const OllamaEmbeddingConfigSchema = z.object({ + modelName: z.string(), + serverAddress: z.string(), +}); +export type OllamaEmbeddingConfig = z.infer; +// Define the structure of the request and response for embedding +interface OllamaEmbeddingInstance { + content: string; +} +interface OllamaEmbeddingPrediction { + embedding: number[]; +} +interface DefineOllamaEmbeddingParams { + name: string; + modelName: string; + dimensions: number; + options: OllamaPluginParams; +} +export function defineOllamaEmbedder({ + name, + modelName, + dimensions, + options, +}: DefineOllamaEmbeddingParams) { + return defineEmbedder( + { + name, + configSchema: OllamaEmbeddingConfigSchema, // Use the Zod schema directly here + info: { + // TODO: do we want users to be able to specify the label when they call this method directly? + label: 'Ollama Embedding - ' + modelName, + dimensions, + supports: { + // TODO: do any ollama models support other modalities? + input: ['text'], + }, + }, + }, + async (input, _config) => { + const serverAddress = options.serverAddress; + const responses = await Promise.all( + input.map(async (i) => { + const requestPayload = { + model: modelName, + prompt: i.text(), + }; + let res: Response; + try { + console.log('MODEL NAME: ', modelName); + res = await fetch(`${serverAddress}/api/embeddings`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify(requestPayload), + }); + } catch (e) { + logger.error('Failed to fetch Ollama embedding'); + throw new Error(`Error fetching embedding from Ollama: ${e}`); + } + if (!res.ok) { + logger.error('Failed to fetch Ollama embedding'); + throw new Error( + `Error fetching embedding from Ollama: ${res.statusText}` + ); + } + const responseData = (await res.json()) as OllamaEmbeddingPrediction; + return responseData; + }) + ); + return { + embeddings: responses, + }; + } + ); +} diff --git a/js/plugins/ollama/src/index.ts b/js/plugins/ollama/src/index.ts index 7a7bbb636c..8e5b7ceef6 100644 --- a/js/plugins/ollama/src/index.ts +++ b/js/plugins/ollama/src/index.ts @@ -25,6 +25,7 @@ import { } from '@genkit-ai/ai/model'; import { genkitPlugin, Plugin } from '@genkit-ai/core'; import { logger } from '@genkit-ai/core/logging'; +import { defineOllamaEmbedder } from './embeddings'; type ApiType = 'chat' | 'generate'; @@ -36,9 +37,10 @@ type RequestHeaders = ) => Promise | void>); type ModelDefinition = { name: string; type?: ApiType }; - +type EmbeddingModelDefinition = { name: string; dimensions: number }; export interface OllamaPluginParams { models: ModelDefinition[]; + embeddingModels?: EmbeddingModelDefinition[]; /** * ollama server address. */ @@ -55,6 +57,14 @@ export const ollama: Plugin<[OllamaPluginParams]> = genkitPlugin( models: params.models.map((model) => ollamaModel(model, serverAddress, params.requestHeaders) ), + embedders: params.embeddingModels?.map((model) => + defineOllamaEmbedder({ + name: `${ollama}/model.name`, + modelName: model.name, + dimensions: model.dimensions, + options: params, + }) + ), }; } ); diff --git a/js/plugins/ollama/tests/embeddings_live_test.ts b/js/plugins/ollama/tests/embeddings_live_test.ts new file mode 100644 index 0000000000..f89500df1c --- /dev/null +++ b/js/plugins/ollama/tests/embeddings_live_test.ts @@ -0,0 +1,51 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { embed } from '@genkit-ai/ai/embedder'; +import assert from 'node:assert'; +import { describe, it } from 'node:test'; +import { defineOllamaEmbedder } from '../src/embeddings.js'; // Adjust the import path as necessary +import { OllamaPluginParams } from '../src/index.js'; // Adjust the import path as necessary +// Utility function to parse command-line arguments +function parseArgs() { + const args = process.argv.slice(2); + const serverAddress = + args.find((arg) => arg.startsWith('--server-address='))?.split('=')[1] || + 'http://localhost:11434'; + const modelName = + args.find((arg) => arg.startsWith('--model-name='))?.split('=')[1] || + 'nomic-embed-text'; + return { serverAddress, modelName }; +} +const { serverAddress, modelName } = parseArgs(); +describe('defineOllamaEmbedder - Live Tests', () => { + const options: OllamaPluginParams = { + models: [{ name: modelName }], + serverAddress, + }; + it('should successfully return embeddings', async () => { + const embedder = defineOllamaEmbedder({ + name: 'live-test-embedder', + modelName: 'nomic-embed-text', + dimensions: 768, + options, + }); + const result = await embed({ + embedder, + content: 'Hello, world!', + }); + assert.strictEqual(result.length, 768); + }); +}); diff --git a/js/plugins/ollama/tests/embeddings_test.ts b/js/plugins/ollama/tests/embeddings_test.ts new file mode 100644 index 0000000000..61255028a4 --- /dev/null +++ b/js/plugins/ollama/tests/embeddings_test.ts @@ -0,0 +1,119 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { embed } from '@genkit-ai/ai/embedder'; +import assert from 'node:assert'; +import { describe, it } from 'node:test'; +import { + OllamaEmbeddingConfigSchema, + defineOllamaEmbedder, +} from '../src/embeddings.js'; // Adjust the import path as necessary +import { OllamaPluginParams } from '../src/index.js'; // Adjust the import path as necessary +// Mock fetch to simulate API responses +global.fetch = async (input: RequestInfo | URL, options?: RequestInit) => { + const url = typeof input === 'string' ? input : input.toString(); + if (url.includes('/api/embedding')) { + if (options?.body && JSON.stringify(options.body).includes('fail')) { + return { + ok: false, + statusText: 'Internal Server Error', + json: async () => ({}), + } as Response; + } + return { + ok: true, + json: async () => ({ + embedding: [0.1, 0.2, 0.3], // Example embedding values + }), + } as Response; + } + throw new Error('Unknown API endpoint'); +}; +describe('defineOllamaEmbedder', () => { + const options: OllamaPluginParams = { + models: [{ name: 'test-model' }], + serverAddress: 'http://localhost:3000', + }; + it('should successfully return embeddings', async () => { + const embedder = defineOllamaEmbedder({ + name: 'test-embedder', + modelName: 'test-model', + dimensions: 123, + options, + }); + const result = await embed({ + embedder, + content: 'Hello, world!', + }); + assert.deepStrictEqual(result, [0.1, 0.2, 0.3]); + }); + it('should handle API errors correctly', async () => { + const embedder = defineOllamaEmbedder({ + name: 'test-embedder', + modelName: 'test-model', + dimensions: 123, + options, + }); + await assert.rejects( + async () => { + await embed({ + embedder, + content: 'fail', + }); + }, + (error) => { + // Check if error is an instance of Error + assert(error instanceof Error); + assert.strictEqual( + error.message, + 'Error fetching embedding from Ollama: Internal Server Error' + ); + return true; + } + ); + }); + it('should validate the embedding configuration schema', async () => { + const validConfig = { + modelName: 'test-model', + serverAddress: 'http://localhost:3000', + }; + const invalidConfig = { + modelName: 123, // Invalid type + serverAddress: 'http://localhost:3000', + }; + // Valid configuration should pass + assert.doesNotThrow(() => { + OllamaEmbeddingConfigSchema.parse(validConfig); + }); + // Invalid configuration should throw + assert.throws(() => { + OllamaEmbeddingConfigSchema.parse(invalidConfig); + }); + }); + it('should throw an error if the fetch response is not ok', async () => { + const embedder = defineOllamaEmbedder({ + name: 'test-embedder', + modelName: 'test-model', + dimensions: 123, + options, + }); + await assert.rejects(async () => { + await embed({ + embedder, + content: 'fail', + }); + }, new Error('Error fetching embedding from Ollama: Internal Server Error')); + }); +}); diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index c6b897d9d0..b85032a1b6 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -550,6 +550,9 @@ importers: '@genkit-ai/core': specifier: workspace:* version: link:../../core + zod: + specifier: ^3.22.4 + version: 3.22.4 devDependencies: '@types/node': specifier: ^20.11.16