Skip to content

Commit

Permalink
feat(js/plugins/ollama): add initial embedding support
Browse files Browse the repository at this point in the history
  • Loading branch information
cabljac committed Aug 30, 2024
1 parent 795cec2 commit 8055cb4
Show file tree
Hide file tree
Showing 5 changed files with 259 additions and 8 deletions.
7 changes: 6 additions & 1 deletion js/plugins/ollama/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
"compile": "tsup-node",
"build:clean": "rm -rf ./lib",
"build": "npm-run-all build:clean check compile",
"build:watch": "tsup-node --watch"
"build:watch": "tsup-node --watch",
"test": "find tests -name '*_test.ts' ! -name '*_live_test.ts' -exec node --import tsx --test {} +",
"test:live": "node --import tsx --test tests/*_test.ts"
},
"repository": {
"type": "git",
Expand All @@ -26,6 +28,9 @@
},
"author": "genkit",
"license": "Apache-2.0",
"dependencies": {
"zod": "^3.22.4"
},
"peerDependencies": {
"@genkit-ai/ai": "workspace:*",
"@genkit-ai/core": "workspace:*"
Expand Down
82 changes: 82 additions & 0 deletions js/plugins/ollama/src/embeddings.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import { defineEmbedder, EmbedderReference } from '@genkit-ai/ai/embedder';
import { logger } from '@genkit-ai/core/logging';
import z from 'zod';
import { OllamaPluginParams } from './index.js';

// Define the schema for Ollama embedding configuration
export const OllamaEmbeddingConfigSchema = z.object({
modelName: z.string(),
serverAddress: z.string(),
});
export type OllamaEmbeddingConfig = z.infer<typeof OllamaEmbeddingConfigSchema>;

// Define the structure of the request and response for embedding
interface OllamaEmbeddingInstance {
content: string;
}

interface OllamaEmbeddingPrediction {
embedding: number[];
}

export function defineOllamaEmbedder(
name: string,
modelName: string,
dimensions: number,
options: OllamaPluginParams
): EmbedderReference<typeof OllamaEmbeddingConfigSchema> {
return defineEmbedder(
{
name,
configSchema: OllamaEmbeddingConfigSchema, // Use the Zod schema directly here
info: {
label: 'Embedding using Ollama',
dimensions,
supports: {
// TODO: do any ollama models support other modalities?
input: ['text'],
},
},
},
async (input, _config) => {
const serverAddress = options.serverAddress;

const responses = await Promise.all(
input.map(async (i) => {
const requestPayload = {
model: modelName,
prompt: i.text(),
};
let res: Response;
try {
console.log('MODEL NAME: ', modelName);
res = await fetch(`${serverAddress}/api/embeddings`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(requestPayload),
});
} catch (e) {
logger.error('Failed to fetch Ollama embedding');
throw new Error(`Error fetching embedding from Ollama: ${e}`);
}

if (!res.ok) {
logger.error('Failed to fetch Ollama embedding');
throw new Error(
`Error fetching embedding from Ollama: ${res.statusText}`
);
}

const responseData = (await res.json()) as OllamaEmbeddingPrediction;
return responseData;
})
);

return {
embeddings: responses,
};
}
);
}
43 changes: 43 additions & 0 deletions js/plugins/ollama/tests/embeddings_live_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import { embed } from '@genkit-ai/ai/embedder';
import assert from 'node:assert';
import { describe, it } from 'node:test';

import { defineOllamaEmbedder } from '../src/embeddings.js'; // Adjust the import path as necessary
import { OllamaPluginParams } from '../src/index.js'; // Adjust the import path as necessary

// Utility function to parse command-line arguments
function parseArgs() {
const args = process.argv.slice(2);
const serverAddress =
args.find((arg) => arg.startsWith('--server-address='))?.split('=')[1] ||
'http://localhost:11434';
const modelName =
args.find((arg) => arg.startsWith('--model-name='))?.split('=')[1] ||
'nomic-embed-text';
return { serverAddress, modelName };
}

const { serverAddress, modelName } = parseArgs();

describe('defineOllamaEmbedder - Live Tests', () => {
const options: OllamaPluginParams = {
models: [{ name: modelName }],
serverAddress,
};

it('should successfully return embeddings', async () => {
const embedder = defineOllamaEmbedder(
'live-test-embedder',
'nomic-embed-text',
768,
options
);

const result = await embed({
embedder,
content: 'Hello, world!',
});

assert.strictEqual(result.length, 768);
});
});
118 changes: 118 additions & 0 deletions js/plugins/ollama/tests/embeddings_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// import { embed } from '@genkit-ai/ai/embedder';
// import assert from 'node:assert';
// import { describe, it } from 'node:test';
// import {
// defineOllamaEmbedder,
// OllamaEmbeddingConfigSchema,
// } from '../src/embeddings.js'; // Adjust the import path as necessary
// import { OllamaPluginParams } from '../src/index.js'; // Adjust the import path as necessary

// // Mock fetch to simulate API responses
// global.fetch = async (input: RequestInfo | URL, options?: RequestInit) => {
// const url = typeof input === 'string' ? input : input.toString();

// if (url.includes('/api/embedding')) {
// if (options?.body && JSON.stringify(options.body).includes('fail')) {
// return {
// ok: false,
// statusText: 'Internal Server Error',
// json: async () => ({}),
// } as Response;
// }
// return {
// ok: true,
// json: async () => ({
// embeddings: {
// values: [0.1, 0.2, 0.3], // Example embedding values
// },
// }),
// } as Response;
// }

// throw new Error('Unknown API endpoint');
// };

// describe('defineOllamaEmbedder', () => {
// const options: OllamaPluginParams = {
// models: [{ name: 'test-model' }],
// serverAddress: 'http://localhost:3000',
// };

// it('should successfully return embeddings', async () => {
// const embedder = defineOllamaEmbedder(
// 'test-embedder',
// 'test-model',
// options
// );

// const result = await embed({
// embedder,
// content: 'Hello, world!',
// });
// assert.deepStrictEqual(result, [0.1, 0.2, 0.3]);
// });

// it('should handle API errors correctly', async () => {
// const embedder = defineOllamaEmbedder(
// 'test-embedder',
// 'test-model',
// options
// );

// await assert.rejects(
// async () => {
// await embed({
// embedder,
// content: 'fail',
// });
// },
// (error) => {
// // Check if error is an instance of Error
// assert(error instanceof Error);

// assert.strictEqual(
// error.message,
// 'Error fetching embedding from Ollama: Internal Server Error'
// );
// return true;
// }
// );
// });

// it('should validate the embedding configuration schema', async () => {
// const validConfig = {
// modelName: 'test-model',
// serverAddress: 'http://localhost:3000',
// };

// const invalidConfig = {
// modelName: 123, // Invalid type
// serverAddress: 'http://localhost:3000',
// };

// // Valid configuration should pass
// assert.doesNotThrow(() => {
// OllamaEmbeddingConfigSchema.parse(validConfig);
// });

// // Invalid configuration should throw
// assert.throws(() => {
// OllamaEmbeddingConfigSchema.parse(invalidConfig);
// });
// });

// it('should throw an error if the fetch response is not ok', async () => {
// const embedder = defineOllamaEmbedder(
// 'test-embedder',
// 'test-model',
// options
// );

// await assert.rejects(async () => {
// await embed({
// embedder,
// content: 'fail',
// });
// }, new Error('Error fetching embedding from Ollama: Internal Server Error'));
// });
// });
17 changes: 10 additions & 7 deletions js/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 8055cb4

Please sign in to comment.