-
Notifications
You must be signed in to change notification settings - Fork 101
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(js/plugins/ollama): add initial embedding support
- Loading branch information
Showing
5 changed files
with
259 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import { defineEmbedder, EmbedderReference } from '@genkit-ai/ai/embedder'; | ||
import { logger } from '@genkit-ai/core/logging'; | ||
import z from 'zod'; | ||
import { OllamaPluginParams } from './index.js'; | ||
|
||
// Define the schema for Ollama embedding configuration | ||
export const OllamaEmbeddingConfigSchema = z.object({ | ||
modelName: z.string(), | ||
serverAddress: z.string(), | ||
}); | ||
export type OllamaEmbeddingConfig = z.infer<typeof OllamaEmbeddingConfigSchema>; | ||
|
||
// Define the structure of the request and response for embedding | ||
interface OllamaEmbeddingInstance { | ||
content: string; | ||
} | ||
|
||
interface OllamaEmbeddingPrediction { | ||
embedding: number[]; | ||
} | ||
|
||
export function defineOllamaEmbedder( | ||
name: string, | ||
modelName: string, | ||
dimensions: number, | ||
options: OllamaPluginParams | ||
): EmbedderReference<typeof OllamaEmbeddingConfigSchema> { | ||
return defineEmbedder( | ||
{ | ||
name, | ||
configSchema: OllamaEmbeddingConfigSchema, // Use the Zod schema directly here | ||
info: { | ||
label: 'Embedding using Ollama', | ||
dimensions, | ||
supports: { | ||
// TODO: do any ollama models support other modalities? | ||
input: ['text'], | ||
}, | ||
}, | ||
}, | ||
async (input, _config) => { | ||
const serverAddress = options.serverAddress; | ||
|
||
const responses = await Promise.all( | ||
input.map(async (i) => { | ||
const requestPayload = { | ||
model: modelName, | ||
prompt: i.text(), | ||
}; | ||
let res: Response; | ||
try { | ||
console.log('MODEL NAME: ', modelName); | ||
res = await fetch(`${serverAddress}/api/embeddings`, { | ||
method: 'POST', | ||
headers: { | ||
'Content-Type': 'application/json', | ||
}, | ||
body: JSON.stringify(requestPayload), | ||
}); | ||
} catch (e) { | ||
logger.error('Failed to fetch Ollama embedding'); | ||
throw new Error(`Error fetching embedding from Ollama: ${e}`); | ||
} | ||
|
||
if (!res.ok) { | ||
logger.error('Failed to fetch Ollama embedding'); | ||
throw new Error( | ||
`Error fetching embedding from Ollama: ${res.statusText}` | ||
); | ||
} | ||
|
||
const responseData = (await res.json()) as OllamaEmbeddingPrediction; | ||
return responseData; | ||
}) | ||
); | ||
|
||
return { | ||
embeddings: responses, | ||
}; | ||
} | ||
); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import { embed } from '@genkit-ai/ai/embedder'; | ||
import assert from 'node:assert'; | ||
import { describe, it } from 'node:test'; | ||
|
||
import { defineOllamaEmbedder } from '../src/embeddings.js'; // Adjust the import path as necessary | ||
import { OllamaPluginParams } from '../src/index.js'; // Adjust the import path as necessary | ||
|
||
// Utility function to parse command-line arguments | ||
function parseArgs() { | ||
const args = process.argv.slice(2); | ||
const serverAddress = | ||
args.find((arg) => arg.startsWith('--server-address='))?.split('=')[1] || | ||
'http://localhost:11434'; | ||
const modelName = | ||
args.find((arg) => arg.startsWith('--model-name='))?.split('=')[1] || | ||
'nomic-embed-text'; | ||
return { serverAddress, modelName }; | ||
} | ||
|
||
const { serverAddress, modelName } = parseArgs(); | ||
|
||
describe('defineOllamaEmbedder - Live Tests', () => { | ||
const options: OllamaPluginParams = { | ||
models: [{ name: modelName }], | ||
serverAddress, | ||
}; | ||
|
||
it('should successfully return embeddings', async () => { | ||
const embedder = defineOllamaEmbedder( | ||
'live-test-embedder', | ||
'nomic-embed-text', | ||
768, | ||
options | ||
); | ||
|
||
const result = await embed({ | ||
embedder, | ||
content: 'Hello, world!', | ||
}); | ||
|
||
assert.strictEqual(result.length, 768); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
// import { embed } from '@genkit-ai/ai/embedder'; | ||
// import assert from 'node:assert'; | ||
// import { describe, it } from 'node:test'; | ||
// import { | ||
// defineOllamaEmbedder, | ||
// OllamaEmbeddingConfigSchema, | ||
// } from '../src/embeddings.js'; // Adjust the import path as necessary | ||
// import { OllamaPluginParams } from '../src/index.js'; // Adjust the import path as necessary | ||
|
||
// // Mock fetch to simulate API responses | ||
// global.fetch = async (input: RequestInfo | URL, options?: RequestInit) => { | ||
// const url = typeof input === 'string' ? input : input.toString(); | ||
|
||
// if (url.includes('/api/embedding')) { | ||
// if (options?.body && JSON.stringify(options.body).includes('fail')) { | ||
// return { | ||
// ok: false, | ||
// statusText: 'Internal Server Error', | ||
// json: async () => ({}), | ||
// } as Response; | ||
// } | ||
// return { | ||
// ok: true, | ||
// json: async () => ({ | ||
// embeddings: { | ||
// values: [0.1, 0.2, 0.3], // Example embedding values | ||
// }, | ||
// }), | ||
// } as Response; | ||
// } | ||
|
||
// throw new Error('Unknown API endpoint'); | ||
// }; | ||
|
||
// describe('defineOllamaEmbedder', () => { | ||
// const options: OllamaPluginParams = { | ||
// models: [{ name: 'test-model' }], | ||
// serverAddress: 'http://localhost:3000', | ||
// }; | ||
|
||
// it('should successfully return embeddings', async () => { | ||
// const embedder = defineOllamaEmbedder( | ||
// 'test-embedder', | ||
// 'test-model', | ||
// options | ||
// ); | ||
|
||
// const result = await embed({ | ||
// embedder, | ||
// content: 'Hello, world!', | ||
// }); | ||
// assert.deepStrictEqual(result, [0.1, 0.2, 0.3]); | ||
// }); | ||
|
||
// it('should handle API errors correctly', async () => { | ||
// const embedder = defineOllamaEmbedder( | ||
// 'test-embedder', | ||
// 'test-model', | ||
// options | ||
// ); | ||
|
||
// await assert.rejects( | ||
// async () => { | ||
// await embed({ | ||
// embedder, | ||
// content: 'fail', | ||
// }); | ||
// }, | ||
// (error) => { | ||
// // Check if error is an instance of Error | ||
// assert(error instanceof Error); | ||
|
||
// assert.strictEqual( | ||
// error.message, | ||
// 'Error fetching embedding from Ollama: Internal Server Error' | ||
// ); | ||
// return true; | ||
// } | ||
// ); | ||
// }); | ||
|
||
// it('should validate the embedding configuration schema', async () => { | ||
// const validConfig = { | ||
// modelName: 'test-model', | ||
// serverAddress: 'http://localhost:3000', | ||
// }; | ||
|
||
// const invalidConfig = { | ||
// modelName: 123, // Invalid type | ||
// serverAddress: 'http://localhost:3000', | ||
// }; | ||
|
||
// // Valid configuration should pass | ||
// assert.doesNotThrow(() => { | ||
// OllamaEmbeddingConfigSchema.parse(validConfig); | ||
// }); | ||
|
||
// // Invalid configuration should throw | ||
// assert.throws(() => { | ||
// OllamaEmbeddingConfigSchema.parse(invalidConfig); | ||
// }); | ||
// }); | ||
|
||
// it('should throw an error if the fetch response is not ok', async () => { | ||
// const embedder = defineOllamaEmbedder( | ||
// 'test-embedder', | ||
// 'test-model', | ||
// options | ||
// ); | ||
|
||
// await assert.rejects(async () => { | ||
// await embed({ | ||
// embedder, | ||
// content: 'fail', | ||
// }); | ||
// }, new Error('Error fetching embedding from Ollama: Internal Server Error')); | ||
// }); | ||
// }); |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.