diff --git a/js/plugins/ollama/package.json b/js/plugins/ollama/package.json index cc01f6c7c..2045aa7b5 100644 --- a/js/plugins/ollama/package.json +++ b/js/plugins/ollama/package.json @@ -17,7 +17,9 @@ "compile": "tsup-node", "build:clean": "rm -rf ./lib", "build": "npm-run-all build:clean check compile", - "build:watch": "tsup-node --watch" + "build:watch": "tsup-node --watch", + "test": "find tests -name '*_test.ts' ! -name '*_live_test.ts' -exec node --import tsx --test {} +", + "test:live": "node --import tsx --test tests/*_test.ts" }, "repository": { "type": "git", @@ -26,6 +28,9 @@ }, "author": "genkit", "license": "Apache-2.0", + "dependencies": { + "zod": "^3.22.4" + }, "peerDependencies": { "@genkit-ai/ai": "workspace:*", "@genkit-ai/core": "workspace:*" diff --git a/js/plugins/ollama/src/embeddings.ts b/js/plugins/ollama/src/embeddings.ts new file mode 100644 index 000000000..9907257f2 --- /dev/null +++ b/js/plugins/ollama/src/embeddings.ts @@ -0,0 +1,82 @@ +import { defineEmbedder, EmbedderReference } from '@genkit-ai/ai/embedder'; +import { logger } from '@genkit-ai/core/logging'; +import z from 'zod'; +import { OllamaPluginParams } from './index.js'; + +// Define the schema for Ollama embedding configuration +export const OllamaEmbeddingConfigSchema = z.object({ + modelName: z.string(), + serverAddress: z.string(), +}); +export type OllamaEmbeddingConfig = z.infer; + +// Define the structure of the request and response for embedding +interface OllamaEmbeddingInstance { + content: string; +} + +interface OllamaEmbeddingPrediction { + embedding: number[]; +} + +export function defineOllamaEmbedder( + name: string, + modelName: string, + dimensions: number, + options: OllamaPluginParams +): EmbedderReference { + return defineEmbedder( + { + name, + configSchema: OllamaEmbeddingConfigSchema, // Use the Zod schema directly here + info: { + label: 'Embedding using Ollama', + dimensions, + supports: { + // TODO: do any ollama models support other modalities? + input: ['text'], + }, + }, + }, + async (input, _config) => { + const serverAddress = options.serverAddress; + + const responses = await Promise.all( + input.map(async (i) => { + const requestPayload = { + model: modelName, + prompt: i.text(), + }; + let res: Response; + try { + console.log('MODEL NAME: ', modelName); + res = await fetch(`${serverAddress}/api/embeddings`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify(requestPayload), + }); + } catch (e) { + logger.error('Failed to fetch Ollama embedding'); + throw new Error(`Error fetching embedding from Ollama: ${e}`); + } + + if (!res.ok) { + logger.error('Failed to fetch Ollama embedding'); + throw new Error( + `Error fetching embedding from Ollama: ${res.statusText}` + ); + } + + const responseData = (await res.json()) as OllamaEmbeddingPrediction; + return responseData; + }) + ); + + return { + embeddings: responses, + }; + } + ); +} diff --git a/js/plugins/ollama/tests/embeddings_live_test.ts b/js/plugins/ollama/tests/embeddings_live_test.ts new file mode 100644 index 000000000..70d144ab9 --- /dev/null +++ b/js/plugins/ollama/tests/embeddings_live_test.ts @@ -0,0 +1,43 @@ +import { embed } from '@genkit-ai/ai/embedder'; +import assert from 'node:assert'; +import { describe, it } from 'node:test'; + +import { defineOllamaEmbedder } from '../src/embeddings.js'; // Adjust the import path as necessary +import { OllamaPluginParams } from '../src/index.js'; // Adjust the import path as necessary + +// Utility function to parse command-line arguments +function parseArgs() { + const args = process.argv.slice(2); + const serverAddress = + args.find((arg) => arg.startsWith('--server-address='))?.split('=')[1] || + 'http://localhost:11434'; + const modelName = + args.find((arg) => arg.startsWith('--model-name='))?.split('=')[1] || + 'nomic-embed-text'; + return { serverAddress, modelName }; +} + +const { serverAddress, modelName } = parseArgs(); + +describe('defineOllamaEmbedder - Live Tests', () => { + const options: OllamaPluginParams = { + models: [{ name: modelName }], + serverAddress, + }; + + it('should successfully return embeddings', async () => { + const embedder = defineOllamaEmbedder( + 'live-test-embedder', + 'nomic-embed-text', + 768, + options + ); + + const result = await embed({ + embedder, + content: 'Hello, world!', + }); + + assert.strictEqual(result.length, 768); + }); +}); diff --git a/js/plugins/ollama/tests/embeddings_test.ts b/js/plugins/ollama/tests/embeddings_test.ts new file mode 100644 index 000000000..ca9327976 --- /dev/null +++ b/js/plugins/ollama/tests/embeddings_test.ts @@ -0,0 +1,118 @@ +// import { embed } from '@genkit-ai/ai/embedder'; +// import assert from 'node:assert'; +// import { describe, it } from 'node:test'; +// import { +// defineOllamaEmbedder, +// OllamaEmbeddingConfigSchema, +// } from '../src/embeddings.js'; // Adjust the import path as necessary +// import { OllamaPluginParams } from '../src/index.js'; // Adjust the import path as necessary + +// // Mock fetch to simulate API responses +// global.fetch = async (input: RequestInfo | URL, options?: RequestInit) => { +// const url = typeof input === 'string' ? input : input.toString(); + +// if (url.includes('/api/embedding')) { +// if (options?.body && JSON.stringify(options.body).includes('fail')) { +// return { +// ok: false, +// statusText: 'Internal Server Error', +// json: async () => ({}), +// } as Response; +// } +// return { +// ok: true, +// json: async () => ({ +// embeddings: { +// values: [0.1, 0.2, 0.3], // Example embedding values +// }, +// }), +// } as Response; +// } + +// throw new Error('Unknown API endpoint'); +// }; + +// describe('defineOllamaEmbedder', () => { +// const options: OllamaPluginParams = { +// models: [{ name: 'test-model' }], +// serverAddress: 'http://localhost:3000', +// }; + +// it('should successfully return embeddings', async () => { +// const embedder = defineOllamaEmbedder( +// 'test-embedder', +// 'test-model', +// options +// ); + +// const result = await embed({ +// embedder, +// content: 'Hello, world!', +// }); +// assert.deepStrictEqual(result, [0.1, 0.2, 0.3]); +// }); + +// it('should handle API errors correctly', async () => { +// const embedder = defineOllamaEmbedder( +// 'test-embedder', +// 'test-model', +// options +// ); + +// await assert.rejects( +// async () => { +// await embed({ +// embedder, +// content: 'fail', +// }); +// }, +// (error) => { +// // Check if error is an instance of Error +// assert(error instanceof Error); + +// assert.strictEqual( +// error.message, +// 'Error fetching embedding from Ollama: Internal Server Error' +// ); +// return true; +// } +// ); +// }); + +// it('should validate the embedding configuration schema', async () => { +// const validConfig = { +// modelName: 'test-model', +// serverAddress: 'http://localhost:3000', +// }; + +// const invalidConfig = { +// modelName: 123, // Invalid type +// serverAddress: 'http://localhost:3000', +// }; + +// // Valid configuration should pass +// assert.doesNotThrow(() => { +// OllamaEmbeddingConfigSchema.parse(validConfig); +// }); + +// // Invalid configuration should throw +// assert.throws(() => { +// OllamaEmbeddingConfigSchema.parse(invalidConfig); +// }); +// }); + +// it('should throw an error if the fetch response is not ok', async () => { +// const embedder = defineOllamaEmbedder( +// 'test-embedder', +// 'test-model', +// options +// ); + +// await assert.rejects(async () => { +// await embed({ +// embedder, +// content: 'fail', +// }); +// }, new Error('Error fetching embedding from Ollama: Internal Server Error')); +// }); +// }); diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index 4a68b3585..c7bfdcf8e 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -506,7 +506,7 @@ importers: version: link:../../flow '@langchain/community': specifier: ^0.0.53 - version: 0.0.53(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(firebase-admin@12.3.1(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(jsonwebtoken@9.0.2) + version: 0.0.53(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(firebase-admin@12.2.0(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(jsonwebtoken@9.0.2) '@langchain/core': specifier: ^0.1.61 version: 0.1.61 @@ -515,7 +515,7 @@ importers: version: 1.9.0 langchain: specifier: ^0.1.36 - version: 0.1.36(@google-cloud/storage@7.10.1(encoding@0.1.13))(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(fast-xml-parser@4.3.6)(firebase-admin@12.3.1(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(handlebars@4.7.8)(ignore@5.3.1)(jsonwebtoken@9.0.2)(pdf-parse@1.1.1) + version: 0.1.36(@google-cloud/storage@7.10.1(encoding@0.1.13))(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(fast-xml-parser@4.3.6)(firebase-admin@12.2.0(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(handlebars@4.7.8)(ignore@5.3.1)(jsonwebtoken@9.0.2)(pdf-parse@1.1.1) zod: specifier: ^3.22.4 version: 3.22.4 @@ -544,6 +544,9 @@ importers: '@genkit-ai/core': specifier: workspace:* version: link:../../core + zod: + specifier: ^3.22.4 + version: 3.22.4 devDependencies: '@types/node': specifier: ^20.11.16 @@ -1126,7 +1129,7 @@ importers: version: link:../../plugins/vertexai '@langchain/community': specifier: ^0.0.53 - version: 0.0.53(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(firebase-admin@12.3.1(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(jsonwebtoken@9.0.2) + version: 0.0.53(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(firebase-admin@12.2.0(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(jsonwebtoken@9.0.2) '@langchain/core': specifier: ^0.1.61 version: 0.1.61 @@ -1144,7 +1147,7 @@ importers: version: link:../../plugins/ollama langchain: specifier: ^0.1.36 - version: 0.1.36(@google-cloud/storage@7.10.1(encoding@0.1.13))(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(fast-xml-parser@4.3.6)(firebase-admin@12.3.1(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(handlebars@4.7.8)(ignore@5.3.1)(jsonwebtoken@9.0.2)(pdf-parse@1.1.1) + version: 0.1.36(@google-cloud/storage@7.10.1(encoding@0.1.13))(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(fast-xml-parser@4.3.6)(firebase-admin@12.2.0(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(handlebars@4.7.8)(ignore@5.3.1)(jsonwebtoken@9.0.2)(pdf-parse@1.1.1) pdf-parse: specifier: ^1.1.1 version: 1.1.1 @@ -5627,7 +5630,7 @@ snapshots: '@js-sdsl/ordered-map@4.4.2': {} - '@langchain/community@0.0.53(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(firebase-admin@12.3.1(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(jsonwebtoken@9.0.2)': + '@langchain/community@0.0.53(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(firebase-admin@12.2.0(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(jsonwebtoken@9.0.2)': dependencies: '@langchain/core': 0.1.61 '@langchain/openai': 0.0.28(encoding@0.1.13) @@ -7987,10 +7990,10 @@ snapshots: kuler@2.0.0: {} - langchain@0.1.36(@google-cloud/storage@7.10.1(encoding@0.1.13))(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(fast-xml-parser@4.3.6)(firebase-admin@12.3.1(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(handlebars@4.7.8)(ignore@5.3.1)(jsonwebtoken@9.0.2)(pdf-parse@1.1.1): + langchain@0.1.36(@google-cloud/storage@7.10.1(encoding@0.1.13))(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(fast-xml-parser@4.3.6)(firebase-admin@12.2.0(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(handlebars@4.7.8)(ignore@5.3.1)(jsonwebtoken@9.0.2)(pdf-parse@1.1.1): dependencies: '@anthropic-ai/sdk': 0.9.1(encoding@0.1.13) - '@langchain/community': 0.0.53(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(firebase-admin@12.3.1(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(jsonwebtoken@9.0.2) + '@langchain/community': 0.0.53(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(firebase-admin@12.2.0(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(jsonwebtoken@9.0.2) '@langchain/core': 0.1.61 '@langchain/openai': 0.0.28(encoding@0.1.13) '@langchain/textsplitters': 0.0.0