Skip to content

Commit

Permalink
feat(js/plugins/ollama): integrate ollama embeddings into plugin proper
Browse files Browse the repository at this point in the history
  • Loading branch information
cabljac committed Aug 30, 2024
1 parent 8055cb4 commit c196700
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 132 deletions.
40 changes: 32 additions & 8 deletions js/plugins/ollama/src/embeddings.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@
import { defineEmbedder, EmbedderReference } from '@genkit-ai/ai/embedder';
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { defineEmbedder } from '@genkit-ai/ai/embedder';
import { logger } from '@genkit-ai/core/logging';
import z from 'zod';
import { OllamaPluginParams } from './index.js';
Expand All @@ -19,18 +35,26 @@ interface OllamaEmbeddingPrediction {
embedding: number[];
}

export function defineOllamaEmbedder(
name: string,
modelName: string,
dimensions: number,
options: OllamaPluginParams
): EmbedderReference<typeof OllamaEmbeddingConfigSchema> {
interface DefineOllamaEmbeddingParams {
name: string;
modelName: string;
dimensions: number;
options: OllamaPluginParams;
}

export function defineOllamaEmbedder({
name,
modelName,
dimensions,
options,
}: DefineOllamaEmbeddingParams) {
return defineEmbedder(
{
name,
configSchema: OllamaEmbeddingConfigSchema, // Use the Zod schema directly here
info: {
label: 'Embedding using Ollama',
// TODO: do we want users to be able to specify the label when they call this method directly?
label: 'Embedding using Ollama - ' + modelName,
dimensions,
supports: {
// TODO: do any ollama models support other modalities?
Expand Down
13 changes: 13 additions & 0 deletions js/plugins/ollama/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
} from '@genkit-ai/ai/model';
import { genkitPlugin, Plugin } from '@genkit-ai/core';
import { logger } from '@genkit-ai/core/logging';
import { defineOllamaEmbedder } from './embeddings';

type ApiType = 'chat' | 'generate';

Expand All @@ -37,8 +38,11 @@ type RequestHeaders =

type ModelDefinition = { name: string; type?: ApiType };

type EmbeddingModelDefinition = { name: string; dimensions: number };

export interface OllamaPluginParams {
models: ModelDefinition[];
embeddingModels?: EmbeddingModelDefinition[];
/**
* ollama server address.
*/
Expand All @@ -51,10 +55,19 @@ export const ollama: Plugin<[OllamaPluginParams]> = genkitPlugin(
'ollama',
async (params: OllamaPluginParams) => {
const serverAddress = params?.serverAddress;

return {
models: params.models.map((model) =>
ollamaModel(model, serverAddress, params.requestHeaders)
),
embedders: params.embeddingModels?.map((model) =>
defineOllamaEmbedder({
name: `${ollama}/model.name`,
modelName: model.name,
dimensions: model.dimensions,
options: params,
})
),
};
}
);
Expand Down
28 changes: 22 additions & 6 deletions js/plugins/ollama/tests/embeddings_live_test.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { embed } from '@genkit-ai/ai/embedder';
import assert from 'node:assert';
import { describe, it } from 'node:test';
Expand Down Expand Up @@ -26,12 +42,12 @@ describe('defineOllamaEmbedder - Live Tests', () => {
};

it('should successfully return embeddings', async () => {
const embedder = defineOllamaEmbedder(
'live-test-embedder',
'nomic-embed-text',
768,
options
);
const embedder = defineOllamaEmbedder({
name: 'live-test-embedder',
modelName: 'nomic-embed-text',
dimensions: 768,
options,
});

const result = await embed({
embedder,
Expand Down
253 changes: 135 additions & 118 deletions js/plugins/ollama/tests/embeddings_test.ts
Original file line number Diff line number Diff line change
@@ -1,118 +1,135 @@
// import { embed } from '@genkit-ai/ai/embedder';
// import assert from 'node:assert';
// import { describe, it } from 'node:test';
// import {
// defineOllamaEmbedder,
// OllamaEmbeddingConfigSchema,
// } from '../src/embeddings.js'; // Adjust the import path as necessary
// import { OllamaPluginParams } from '../src/index.js'; // Adjust the import path as necessary

// // Mock fetch to simulate API responses
// global.fetch = async (input: RequestInfo | URL, options?: RequestInit) => {
// const url = typeof input === 'string' ? input : input.toString();

// if (url.includes('/api/embedding')) {
// if (options?.body && JSON.stringify(options.body).includes('fail')) {
// return {
// ok: false,
// statusText: 'Internal Server Error',
// json: async () => ({}),
// } as Response;
// }
// return {
// ok: true,
// json: async () => ({
// embeddings: {
// values: [0.1, 0.2, 0.3], // Example embedding values
// },
// }),
// } as Response;
// }

// throw new Error('Unknown API endpoint');
// };

// describe('defineOllamaEmbedder', () => {
// const options: OllamaPluginParams = {
// models: [{ name: 'test-model' }],
// serverAddress: 'http://localhost:3000',
// };

// it('should successfully return embeddings', async () => {
// const embedder = defineOllamaEmbedder(
// 'test-embedder',
// 'test-model',
// options
// );

// const result = await embed({
// embedder,
// content: 'Hello, world!',
// });
// assert.deepStrictEqual(result, [0.1, 0.2, 0.3]);
// });

// it('should handle API errors correctly', async () => {
// const embedder = defineOllamaEmbedder(
// 'test-embedder',
// 'test-model',
// options
// );

// await assert.rejects(
// async () => {
// await embed({
// embedder,
// content: 'fail',
// });
// },
// (error) => {
// // Check if error is an instance of Error
// assert(error instanceof Error);

// assert.strictEqual(
// error.message,
// 'Error fetching embedding from Ollama: Internal Server Error'
// );
// return true;
// }
// );
// });

// it('should validate the embedding configuration schema', async () => {
// const validConfig = {
// modelName: 'test-model',
// serverAddress: 'http://localhost:3000',
// };

// const invalidConfig = {
// modelName: 123, // Invalid type
// serverAddress: 'http://localhost:3000',
// };

// // Valid configuration should pass
// assert.doesNotThrow(() => {
// OllamaEmbeddingConfigSchema.parse(validConfig);
// });

// // Invalid configuration should throw
// assert.throws(() => {
// OllamaEmbeddingConfigSchema.parse(invalidConfig);
// });
// });

// it('should throw an error if the fetch response is not ok', async () => {
// const embedder = defineOllamaEmbedder(
// 'test-embedder',
// 'test-model',
// options
// );

// await assert.rejects(async () => {
// await embed({
// embedder,
// content: 'fail',
// });
// }, new Error('Error fetching embedding from Ollama: Internal Server Error'));
// });
// });
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { embed } from '@genkit-ai/ai/embedder';
import assert from 'node:assert';
import { describe, it } from 'node:test';
import {
OllamaEmbeddingConfigSchema,
defineOllamaEmbedder,
} from '../src/embeddings.js'; // Adjust the import path as necessary
import { OllamaPluginParams } from '../src/index.js'; // Adjust the import path as necessary

// Mock fetch to simulate API responses
global.fetch = async (input: RequestInfo | URL, options?: RequestInit) => {
const url = typeof input === 'string' ? input : input.toString();

if (url.includes('/api/embedding')) {
if (options?.body && JSON.stringify(options.body).includes('fail')) {
return {
ok: false,
statusText: 'Internal Server Error',
json: async () => ({}),
} as Response;
}
return {
ok: true,
json: async () => ({
embedding: [0.1, 0.2, 0.3], // Example embedding values
}),
} as Response;
}

throw new Error('Unknown API endpoint');
};

describe('defineOllamaEmbedder', () => {
const options: OllamaPluginParams = {
models: [{ name: 'test-model' }],
serverAddress: 'http://localhost:3000',
};

it('should successfully return embeddings', async () => {
const embedder = defineOllamaEmbedder({
name: 'test-embedder',
modelName: 'test-model',
dimensions: 123,
options,
});

const result = await embed({
embedder,
content: 'Hello, world!',
});
assert.deepStrictEqual(result, [0.1, 0.2, 0.3]);
});

it('should handle API errors correctly', async () => {
const embedder = defineOllamaEmbedder({
name: 'test-embedder',
modelName: 'test-model',
dimensions: 123,
options,
});

await assert.rejects(
async () => {
await embed({
embedder,
content: 'fail',
});
},
(error) => {
// Check if error is an instance of Error
assert(error instanceof Error);

assert.strictEqual(
error.message,
'Error fetching embedding from Ollama: Internal Server Error'
);
return true;
}
);
});

it('should validate the embedding configuration schema', async () => {
const validConfig = {
modelName: 'test-model',
serverAddress: 'http://localhost:3000',
};

const invalidConfig = {
modelName: 123, // Invalid type
serverAddress: 'http://localhost:3000',
};

// Valid configuration should pass
assert.doesNotThrow(() => {
OllamaEmbeddingConfigSchema.parse(validConfig);
});

// Invalid configuration should throw
assert.throws(() => {
OllamaEmbeddingConfigSchema.parse(invalidConfig);
});
});

it('should throw an error if the fetch response is not ok', async () => {
const embedder = defineOllamaEmbedder({
name: 'test-embedder',
modelName: 'test-model',
dimensions: 123,
options,
});

await assert.rejects(async () => {
await embed({
embedder,
content: 'fail',
});
}, new Error('Error fetching embedding from Ollama: Internal Server Error'));
});
});

0 comments on commit c196700

Please sign in to comment.