diff --git a/js/plugins/ollama/src/constants.ts b/js/plugins/ollama/src/constants.ts new file mode 100644 index 0000000000..8efd0d9da8 --- /dev/null +++ b/js/plugins/ollama/src/constants.ts @@ -0,0 +1,34 @@ +/** + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { ModelInfo } from 'genkit/model'; + +export const ANY_JSON_SCHEMA: Record = { + $schema: 'http://json-schema.org/draft-07/schema#', +}; + +export const GENERIC_MODEL_INFO = { + supports: { + multiturn: true, + media: true, + tools: true, + toolChoice: true, + systemRole: true, + constrained: 'all', + }, +} as ModelInfo; + +export const DEFAULT_OLLAMA_SERVER_ADDRESS = 'http://localhost:11434'; diff --git a/js/plugins/ollama/src/embeddings.ts b/js/plugins/ollama/src/embeddings.ts index 96a58b6164..5a9c7ae9bb 100644 --- a/js/plugins/ollama/src/embeddings.ts +++ b/js/plugins/ollama/src/embeddings.ts @@ -13,11 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -import type { Document, EmbedderAction, Genkit } from 'genkit'; +import type { Document, EmbedderAction } from 'genkit'; +import { embedder } from 'genkit/plugin'; import type { EmbedRequest, EmbedResponse } from 'ollama'; +import { DEFAULT_OLLAMA_SERVER_ADDRESS } from './constants.js'; import type { DefineOllamaEmbeddingParams, RequestHeaders } from './types.js'; +import { OllamaEmbedderConfigSchema } from './types.js'; -async function toOllamaEmbedRequest( +export async function toOllamaEmbedRequest( modelName: string, dimensions: number, documents: Document[], @@ -59,13 +62,18 @@ async function toOllamaEmbedRequest( }; } -export function defineOllamaEmbedder( - ai: Genkit, - { name, modelName, dimensions, options }: DefineOllamaEmbeddingParams -): EmbedderAction { - return ai.defineEmbedder( +export function defineOllamaEmbedder({ + name, + modelName, + dimensions, + options, +}: DefineOllamaEmbeddingParams): EmbedderAction< + typeof OllamaEmbedderConfigSchema +> { + return embedder( { name: `ollama/${name}`, + configSchema: OllamaEmbedderConfigSchema, info: { label: 'Ollama Embedding - ' + name, dimensions, @@ -75,9 +83,11 @@ export function defineOllamaEmbedder( }, }, }, - async (input, config) => { - const serverAddress = config?.serverAddress || options.serverAddress; - + async ({ input, options: requestOptions }, config) => { + const serverAddress = + requestOptions?.serverAddress || + options.serverAddress || + DEFAULT_OLLAMA_SERVER_ADDRESS; const { url, requestPayload, headers } = await toOllamaEmbedRequest( modelName, dimensions, diff --git a/js/plugins/ollama/src/index.ts b/js/plugins/ollama/src/index.ts index 9c823ab276..c9f64f1bbb 100644 --- a/js/plugins/ollama/src/index.ts +++ b/js/plugins/ollama/src/index.ts @@ -20,7 +20,6 @@ import { z, type ActionMetadata, type EmbedderReference, - type Genkit, type ModelReference, type ToolRequest, type ToolRequestPart, @@ -35,11 +34,19 @@ import { type GenerateRequest, type GenerateResponseData, type MessageData, - type ModelInfo, type ToolDefinition, } from 'genkit/model'; -import { genkitPlugin, type GenkitPlugin } from 'genkit/plugin'; -import type { ActionType } from 'genkit/registry'; +import { + genkitPluginV2, + model, + type GenkitPluginV2, + type ResolvableAction, +} from 'genkit/plugin'; +import { + ANY_JSON_SCHEMA, + DEFAULT_OLLAMA_SERVER_ADDRESS, + GENERIC_MODEL_INFO, +} from './constants.js'; import { defineOllamaEmbedder } from './embeddings.js'; import type { ApiType, @@ -51,12 +58,13 @@ import type { OllamaTool, OllamaToolCall, RequestHeaders, + ResolveActionOptions, } from './types.js'; export type { OllamaPluginParams }; export type OllamaPlugin = { - (params?: OllamaPluginParams): GenkitPlugin; + (params?: OllamaPluginParams): GenkitPluginV2; model( name: string, @@ -65,59 +73,50 @@ export type OllamaPlugin = { embedder(name: string, config?: Record): EmbedderReference; }; -const ANY_JSON_SCHEMA: Record = { - $schema: 'http://json-schema.org/draft-07/schema#', -}; +function initializer(serverAddress: string, params: OllamaPluginParams = {}) { + const actions: ResolvableAction[] = []; -const GENERIC_MODEL_INFO = { - supports: { - multiturn: true, - media: true, - tools: true, - toolChoice: true, - systemRole: true, - constrained: 'all', - }, -} as ModelInfo; - -const DEFAULT_OLLAMA_SERVER_ADDRESS = 'http://localhost:11434'; - -async function initializer( - ai: Genkit, - serverAddress: string, - params?: OllamaPluginParams -) { - params?.models?.map((model) => - defineOllamaModel(ai, model, serverAddress, params?.requestHeaders) - ); - params?.embedders?.map((model) => - defineOllamaEmbedder(ai, { - name: model.name, - modelName: model.name, - dimensions: model.dimensions, - options: params!, - }) - ); + if (params?.models) { + for (const model of params.models) { + actions.push( + defineOllamaModel(model, serverAddress, params.requestHeaders) + ); + } + } + + if (params?.embedders && params.serverAddress) { + for (const embedder of params.embedders) { + actions.push( + defineOllamaEmbedder({ + name: embedder.name, + modelName: embedder.name, + dimensions: embedder.dimensions, + options: params, + }) + ); + } + } + + return actions; } -function resolveAction( - ai: Genkit, - actionType: ActionType, - actionName: string, - serverAddress: string, - requestHeaders?: RequestHeaders -) { - // We can only dynamically resolve models, for embedders user must provide dimensions. - if (actionType === 'model') { - defineOllamaModel( - ai, - { - name: actionName, - }, - serverAddress, - requestHeaders - ); +function resolveAction({ + params, + actionType, + actionName, + serverAddress, +}: ResolveActionOptions) { + switch (actionType) { + case 'model': + return defineOllamaModel( + { + name: actionName, + }, + serverAddress, + params?.requestHeaders + ); } + return undefined; } async function listActions( @@ -138,30 +137,21 @@ async function listActions( ); } -function ollamaPlugin(params?: OllamaPluginParams): GenkitPlugin { - if (!params) { - params = {}; - } - if (!params.serverAddress) { - params.serverAddress = DEFAULT_OLLAMA_SERVER_ADDRESS; - } - const serverAddress = params.serverAddress; - return genkitPlugin( - 'ollama', - async (ai: Genkit) => { - await initializer(ai, serverAddress, params); +function ollamaPlugin(params: OllamaPluginParams = {}): GenkitPluginV2 { + const serverAddress = params.serverAddress || DEFAULT_OLLAMA_SERVER_ADDRESS; + + return genkitPluginV2({ + name: 'ollama', + init() { + return initializer(serverAddress, params); }, - async (ai, actionType, actionName) => { - resolveAction( - ai, - actionType, - actionName, - serverAddress, - params?.requestHeaders - ); + resolve(actionType, actionName) { + return resolveAction({ params, actionType, actionName, serverAddress }); }, - async () => await listActions(serverAddress, params?.requestHeaders) - ); + async list() { + return await listActions(serverAddress, params?.requestHeaders); + }, + }); } async function listLocalModels( @@ -218,26 +208,25 @@ export const OllamaConfigSchema = GenerationCommonConfigSchema.extend({ }); function defineOllamaModel( - ai: Genkit, - model: ModelDefinition, + modelDef: ModelDefinition, serverAddress: string, requestHeaders?: RequestHeaders ) { - return ai.defineModel( + return model( { - name: `ollama/${model.name}`, - label: `Ollama - ${model.name}`, + name: `ollama/${modelDef.name}`, + label: `Ollama - ${modelDef.name}`, configSchema: OllamaConfigSchema, supports: { - multiturn: !model.type || model.type === 'chat', + multiturn: !modelDef.type || modelDef.type === 'chat', systemRole: true, - tools: model.supports?.tools, + tools: modelDef.supports?.tools, }, }, - async (input, streamingCallback) => { + async (request, opts) => { const { topP, topK, stopSequences, maxOutputTokens, ...rest } = - input.config as any; - const options: Record = { ...rest }; + request.config || {}; + const options = { ...rest }; if (topP !== undefined) { options.top_p = topP; } @@ -250,21 +239,21 @@ function defineOllamaModel( if (maxOutputTokens !== undefined) { options.num_predict = maxOutputTokens; } - const type = model.type ?? 'chat'; - const request = toOllamaRequest( - model.name, - input, + const type = modelDef.type ?? 'chat'; + const ollamaRequest = toOllamaRequest( + modelDef.name, + request, options, type, - !!streamingCallback + opts?.streamingRequested ); - logger.debug(request, `ollama request (${type})`); + logger.debug(ollamaRequest, `ollama request (${type})`); const extraHeaders = await getHeaders( serverAddress, requestHeaders, - model, - input + modelDef, + request ); let res; try { @@ -272,7 +261,7 @@ function defineOllamaModel( serverAddress + (type === 'chat' ? '/api/chat' : '/api/generate'), { method: 'POST', - body: JSON.stringify(request), + body: JSON.stringify(ollamaRequest), headers: { 'Content-Type': 'application/json', ...extraHeaders, @@ -297,7 +286,7 @@ function defineOllamaModel( let message: MessageData; - if (streamingCallback) { + if (opts.streamingRequested) { const reader = res.body.getReader(); const textDecoder = new TextDecoder(); let textResponse = ''; @@ -305,7 +294,7 @@ function defineOllamaModel( const chunkText = textDecoder.decode(chunk); const json = JSON.parse(chunkText); const message = parseMessage(json, type); - streamingCallback({ + opts.sendChunk({ index: 0, content: message.content, }); @@ -329,7 +318,7 @@ function defineOllamaModel( return { message, - usage: getBasicUsageStats(input.messages, message), + usage: getBasicUsageStats(request.messages, message), finishReason: 'stop', } as GenerateResponseData; } @@ -500,7 +489,7 @@ function toGenkitToolRequest(tool_calls: OllamaToolCall[]): ToolRequestPart[] { })); } -function readChunks(reader) { +function readChunks(reader: ReadableStreamDefaultReader) { return { async *[Symbol.asyncIterator]() { let readResult = await reader.read(); @@ -536,6 +525,7 @@ function isValidOllamaTool(tool: ToolDefinition): boolean { } export const ollama = ollamaPlugin as OllamaPlugin; + ollama.model = ( name: string, config?: z.infer diff --git a/js/plugins/ollama/src/types.ts b/js/plugins/ollama/src/types.ts index b166b3ca3e..bcb688f103 100644 --- a/js/plugins/ollama/src/types.ts +++ b/js/plugins/ollama/src/types.ts @@ -34,14 +34,6 @@ export interface EmbeddingModelDefinition { dimensions: number; } -export const OllamaEmbeddingPredictionSchema = z.object({ - embedding: z.array(z.number()), -}); - -export type OllamaEmbeddingPrediction = z.infer< - typeof OllamaEmbeddingPredictionSchema ->; - export interface DefineOllamaEmbeddingParams { name: string; modelName: string; @@ -49,6 +41,12 @@ export interface DefineOllamaEmbeddingParams { options: OllamaPluginParams; } +export const OllamaEmbedderConfigSchema = z.object({ + serverAddress: z.string().optional(), +}); + +export type OllamaEmbedderConfig = z.infer; + /** * Parameters for the Ollama plugin configuration. */ @@ -178,3 +176,10 @@ export interface LocalModel { export interface ListLocalModelsResponse { models: LocalModel[]; } + +export interface ResolveActionOptions { + params: OllamaPluginParams; + actionType: string; + actionName: string; + serverAddress: string; +} diff --git a/js/plugins/ollama/tests/embedding_live_test.ts b/js/plugins/ollama/tests/embedding_live_test.ts index 0f19150377..c155c73020 100644 --- a/js/plugins/ollama/tests/embedding_live_test.ts +++ b/js/plugins/ollama/tests/embedding_live_test.ts @@ -14,11 +14,12 @@ * limitations under the License. */ import * as assert from 'assert'; -import { genkit } from 'genkit'; -import { describe, it } from 'node:test'; +import { Genkit, genkit } from 'genkit'; +import { beforeEach, describe, it } from 'node:test'; import { defineOllamaEmbedder } from '../src/embeddings.js'; // Adjust the import path as necessary import { ollama } from '../src/index.js'; import type { OllamaPluginParams } from '../src/types.js'; // Adjust the import path as necessary + // Utility function to parse command-line arguments function parseArgs() { const args = process.argv.slice(2); @@ -31,27 +32,287 @@ function parseArgs() { return { serverAddress, modelName }; } const { serverAddress, modelName } = parseArgs(); -describe('defineOllamaEmbedder - Live Tests', () => { +describe('defineOllamaEmbedder - Live Tests (without genkit)', () => { const options: OllamaPluginParams = { models: [{ name: modelName }], serverAddress, }; - it('should successfully return embeddings', async () => { - const ai = genkit({ + + it('should successfully return embeddings for single document', async () => { + const embedder = defineOllamaEmbedder({ + name: 'live-test-embedder', + modelName: modelName, + dimensions: 768, + options, + }); + + const result = await embedder({ + input: [{ content: [{ text: 'Hello, world!' }] }], + }); + + assert.strictEqual(result.embeddings.length, 1); + assert.strictEqual(result.embeddings[0].embedding.length, 768); + assert.ok(Array.isArray(result.embeddings[0].embedding)); + assert.ok( + result.embeddings[0].embedding.every((val) => typeof val === 'number') + ); + }); + + it('should successfully return embeddings for multiple documents', async () => { + const embedder = defineOllamaEmbedder({ + name: 'live-test-embedder-multi', + modelName: modelName, + dimensions: 768, + options, + }); + + const result = await embedder({ + input: [ + { content: [{ text: 'First document about machine learning' }] }, + { + content: [{ text: 'Second document about artificial intelligence' }], + }, + { content: [{ text: 'Third document about neural networks' }] }, + ], + }); + + assert.strictEqual(result.embeddings.length, 3); + result.embeddings.forEach((embedding, index) => { + assert.strictEqual( + embedding.embedding.length, + 768, + `Embedding ${index} should have 768 dimensions` + ); + assert.ok( + Array.isArray(embedding.embedding), + `Embedding ${index} should be an array` + ); + assert.ok( + embedding.embedding.every((val) => typeof val === 'number'), + `Embedding ${index} should contain only numbers` + ); + }); + }); + + it('should return different embeddings for different texts', async () => { + const embedder = defineOllamaEmbedder({ + name: 'live-test-embedder-different', + modelName: modelName, + dimensions: 768, + options, + }); + + const result1 = await embedder({ + input: [ + { content: [{ text: 'The quick brown fox jumps over the lazy dog' }] }, + ], + }); + + const result2 = await embedder({ + input: [ + { + content: [ + { text: 'Machine learning is a subset of artificial intelligence' }, + ], + }, + ], + }); + + assert.strictEqual(result1.embeddings.length, 1); + assert.strictEqual(result2.embeddings.length, 1); + + const embedding1 = result1.embeddings[0].embedding; + const embedding2 = result2.embeddings[0].embedding; + + assert.notDeepStrictEqual( + embedding1, + embedding2, + 'Different texts should produce different embeddings' + ); + + assert.strictEqual(embedding1.length, 768); + assert.strictEqual(embedding2.length, 768); + }); + + it('should handle empty text gracefully', async () => { + const embedder = defineOllamaEmbedder({ + name: 'live-test-embedder-empty', + modelName: modelName, + dimensions: 768, + options, + }); + + const result = await embedder({ + input: [{ content: [{ text: '' }] }], + }); + + assert.strictEqual(result.embeddings.length, 1); + assert.strictEqual(result.embeddings[0].embedding.length, 768); + assert.ok(Array.isArray(result.embeddings[0].embedding)); + }); + + it('should handle long text', async () => { + const embedder = defineOllamaEmbedder({ + name: 'live-test-embedder-long', + modelName: modelName, + dimensions: 768, + options, + }); + + const longText = + 'Lorem ipsum dolor sit amet, consectetur adipiscing elit. '.repeat(100); + + const result = await embedder({ + input: [{ content: [{ text: longText }] }], + }); + + assert.strictEqual(result.embeddings.length, 1); + assert.strictEqual(result.embeddings[0].embedding.length, 768); + assert.ok(Array.isArray(result.embeddings[0].embedding)); + assert.ok( + result.embeddings[0].embedding.every((val) => typeof val === 'number') + ); + }); +}); + +describe('defineOllamaEmbedder - Live Tests (with genkit)', () => { + let ai: Genkit; + const options: OllamaPluginParams = { + models: [{ name: modelName }], + serverAddress, + }; + + beforeEach(() => { + ai = genkit({ plugins: [ollama(options)], }); - const embedder = defineOllamaEmbedder(ai, { - name: 'live-test-embedder', - modelName: 'nomic-embed-text', + }); + + it('should successfully return embeddings through genkit', async () => { + const embedder = defineOllamaEmbedder({ + name: 'live-test-embedder-genkit', + modelName: modelName, dimensions: 768, options, }); - const result = ( - await ai.embed({ - embedder, - content: 'Hello, world!', - }) - )[0].embedding; - assert.strictEqual(result.length, 768); + + const result = await ai.embed({ + embedder, + content: 'Hello, world!', + }); + + assert.strictEqual(result.length, 1); + assert.strictEqual(result[0].embedding.length, 768); + assert.ok(Array.isArray(result[0].embedding)); + assert.ok(result[0].embedding.every((val) => typeof val === 'number')); + }); + + it('should handle multiple documents through genkit', async () => { + const embedder = defineOllamaEmbedder({ + name: 'live-test-embedder-genkit-multi', + modelName: modelName, + dimensions: 768, + options, + }); + + const result = await ai.embedMany({ + embedder, + content: [ + 'First document about machine learning', + 'Second document about artificial intelligence', + 'Third document about neural networks', + ], + }); + + assert.strictEqual(result.length, 3); + result.forEach((embedding, index) => { + assert.strictEqual( + embedding.embedding.length, + 768, + `Embedding ${index} should have 768 dimensions` + ); + assert.ok( + Array.isArray(embedding.embedding), + `Embedding ${index} should be an array` + ); + assert.ok( + embedding.embedding.every((val) => typeof val === 'number'), + `Embedding ${index} should contain only numbers` + ); + }); + }); + + it('should return different embeddings for different texts through genkit', async () => { + const embedder = defineOllamaEmbedder({ + name: 'live-test-embedder-genkit-different', + modelName: modelName, + dimensions: 768, + options, + }); + + const result1 = await ai.embed({ + embedder, + content: 'The quick brown fox jumps over the lazy dog', + }); + + const result2 = await ai.embed({ + embedder, + content: 'Machine learning is a subset of artificial intelligence', + }); + + assert.strictEqual(result1.length, 1); + assert.strictEqual(result2.length, 1); + + const embedding1 = result1[0].embedding; + const embedding2 = result2[0].embedding; + + assert.notDeepStrictEqual( + embedding1, + embedding2, + 'Different texts should produce different embeddings' + ); + + assert.strictEqual(embedding1.length, 768); + assert.strictEqual(embedding2.length, 768); + }); + + it('should handle empty text gracefully through genkit', async () => { + const embedder = defineOllamaEmbedder({ + name: 'live-test-embedder-genkit-empty', + modelName: modelName, + dimensions: 768, + options, + }); + + const result = await ai.embed({ + embedder, + content: '', + }); + + assert.strictEqual(result.length, 1); + assert.strictEqual(result[0].embedding.length, 768); + assert.ok(Array.isArray(result[0].embedding)); + }); + + it('should handle long text through genkit', async () => { + const embedder = defineOllamaEmbedder({ + name: 'live-test-embedder-genkit-long', + modelName: modelName, + dimensions: 768, + options, + }); + + const longText = + 'Lorem ipsum dolor sit amet, consectetur adipiscing elit. '.repeat(100); + + const result = await ai.embed({ + embedder, + content: longText, + }); + + assert.strictEqual(result.length, 1); + assert.strictEqual(result[0].embedding.length, 768); + assert.ok(Array.isArray(result[0].embedding)); + assert.ok(result[0].embedding.every((val) => typeof val === 'number')); }); }); diff --git a/js/plugins/ollama/tests/embeddings_test.ts b/js/plugins/ollama/tests/embeddings_test.ts index a2a5e7adca..375569dfdd 100644 --- a/js/plugins/ollama/tests/embeddings_test.ts +++ b/js/plugins/ollama/tests/embeddings_test.ts @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -import * as assert from 'assert'; -import { genkit, type Genkit } from 'genkit'; +import { Genkit, genkit } from 'genkit'; +import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; +import { ollama } from '../src'; import { defineOllamaEmbedder } from '../src/embeddings.js'; -import { ollama } from '../src/index.js'; import type { OllamaPluginParams } from '../src/types.js'; // Mock fetch to simulate API responses @@ -31,60 +31,135 @@ global.fetch = async (input: RequestInfo | URL, options?: RequestInit) => { json: async () => ({}), } as Response; } + + const body = options?.body ? JSON.parse(options.body as string) : {}; + const inputCount = body.input ? body.input.length : 1; + return { ok: true, json: async () => ({ - embeddings: [[0.1, 0.2, 0.3]], // Example embedding values + embeddings: Array(inputCount).fill([0.1, 0.2, 0.3]), // Return embedding for each input }), } as Response; } throw new Error('Unknown API endpoint'); }; -describe('defineOllamaEmbedder', () => { - const options: OllamaPluginParams = { - models: [{ name: 'test-model' }], - serverAddress: 'http://localhost:3000', - }; +const options: OllamaPluginParams = { + models: [{ name: 'test-model' }], + serverAddress: 'http://localhost:3000', +}; + +describe('defineOllamaEmbedder (without genkit initialization)', () => { + it('should successfully return embeddings when called directly', async () => { + const embedder = defineOllamaEmbedder({ + name: 'test-embedder', + modelName: 'test-model', + dimensions: 123, + options, + }); + + const result = await embedder({ + input: [{ content: [{ text: 'Hello, world!' }] }], + }); + + assert.deepStrictEqual(result, { + embeddings: [{ embedding: [0.1, 0.2, 0.3] }], + }); + }); + + it('should handle API errors correctly when called directly', async () => { + const embedder = defineOllamaEmbedder({ + name: 'test-embedder', + modelName: 'test-model', + dimensions: 123, + options, + }); + + await assert.rejects( + async () => { + await embedder({ + input: [{ content: [{ text: 'fail' }] }], + }); + }, + (error) => { + assert.ok(error instanceof Error); + assert.strictEqual( + error.message, + 'Error fetching embedding from Ollama: Internal Server Error. ' + ); + return true; + } + ); + }); + + it('should handle multiple documents', async () => { + const embedder = defineOllamaEmbedder({ + name: 'test-embedder', + modelName: 'test-model', + dimensions: 123, + options, + }); + + const result = await embedder({ + input: [ + { content: [{ text: 'First document' }] }, + { content: [{ text: 'Second document' }] }, + ], + }); + assert.deepStrictEqual(result, { + embeddings: [ + { embedding: [0.1, 0.2, 0.3] }, + { embedding: [0.1, 0.2, 0.3] }, + ], + }); + }); +}); + +describe('defineOllamaEmbedder (with genkit initialization)', () => { let ai: Genkit; + beforeEach(() => { ai = genkit({ - plugins: [ - ollama({ - serverAddress: 'http://localhost:3000', - }), - ], + plugins: [ollama(options)], }); }); it('should successfully return embeddings', async () => { - const embedder = defineOllamaEmbedder(ai, { + const embedder = defineOllamaEmbedder({ name: 'test-embedder', modelName: 'test-model', dimensions: 123, options, }); + const result = await ai.embed({ embedder, content: 'Hello, world!', }); + assert.deepStrictEqual(result, [{ embedding: [0.1, 0.2, 0.3] }]); }); it('should handle API errors correctly', async () => { - const embedder = defineOllamaEmbedder(ai, { + const embedder = defineOllamaEmbedder({ name: 'test-embedder', modelName: 'test-model', dimensions: 123, options, }); + await assert.rejects( async () => { await ai.embed({ embedder, content: 'fail', }); + + await embedder({ + input: [{ content: [{ text: 'fail' }] }], + }); }, (error) => { assert.ok(error instanceof Error); @@ -96,4 +171,45 @@ describe('defineOllamaEmbedder', () => { } ); }); + + it('should support per-call embedder serverAddress configuration', async () => { + const aiWithEmbedder = genkit({ + plugins: [ + ollama({ + serverAddress: 'http://localhost:3000', + embedders: [{ name: 'test-embedder', dimensions: 768 }], + }), + ], + }); + + // Mock fetch to verify custom serverAddress is used + global.fetch = async (input: RequestInfo | URL, options?: RequestInit) => { + const url = typeof input === 'string' ? input : input.toString(); + + if (url.includes('/api/embed')) { + // Verify the custom serverAddress was used + assert.ok(url.includes('http://custom-server:11434')); + return new Response( + JSON.stringify({ + embeddings: [[0.1, 0.2, 0.3]], + }), + { + headers: { 'Content-Type': 'application/json' }, + } + ); + } + + throw new Error(`Unknown API endpoint: ${url}`); + }; + + const result = await aiWithEmbedder.embed({ + embedder: 'ollama/test-embedder', + content: 'test document', + options: { serverAddress: 'http://custom-server:11434' }, + }); + + assert.ok(result); + assert.strictEqual(result.length, 1); + assert.strictEqual(result[0].embedding.length, 3); + }); }); diff --git a/js/plugins/ollama/tests/list_test.ts b/js/plugins/ollama/tests/list_test.ts new file mode 100644 index 0000000000..6ae1f371a2 --- /dev/null +++ b/js/plugins/ollama/tests/list_test.ts @@ -0,0 +1,142 @@ +/** + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import * as assert from 'assert'; +import { Genkit, genkit } from 'genkit'; +import { beforeEach, describe, it } from 'node:test'; +import { ollama } from '../src/index.js'; +import type { + ListLocalModelsResponse, + OllamaPluginParams, +} from '../src/types.js'; + +const MOCK_MODELS_RESPONSE: ListLocalModelsResponse = { + models: [ + { + name: 'llama3.2:latest', + model: 'llama3.2:latest', + modified_at: '2024-07-22T20:33:28.123648Z', + size: 1234567890, + digest: 'sha256:abcdef123456', + details: { + parent_model: '', + format: 'gguf', + family: 'llama', + families: ['llama'], + parameter_size: '8B', + quantization_level: 'Q4_0', + }, + }, + { + name: 'gemma2:latest', + model: 'gemma2:latest', + modified_at: '2024-07-22T20:33:28.123648Z', + size: 987654321, + digest: 'sha256:fedcba654321', + details: { + parent_model: '', + format: 'gguf', + family: 'gemma', + families: ['gemma'], + parameter_size: '2B', + quantization_level: 'Q4_0', + }, + }, + { + name: 'nomic-embed-text:latest', + model: 'nomic-embed-text:latest', + modified_at: '2024-07-22T20:33:28.123648Z', + size: 456789123, + digest: 'sha256:123456789abc', + details: { + parent_model: '', + format: 'gguf', + family: 'nomic', + families: ['nomic'], + parameter_size: '137M', + quantization_level: 'Q4_0', + }, + }, + ], +}; + +// Mock fetch to simulate the Ollama API response +global.fetch = async (input: RequestInfo | URL, options?: RequestInit) => { + const url = typeof input === 'string' ? input : input.toString(); + + if (url.includes('/api/tags')) { + return new Response(JSON.stringify(MOCK_MODELS_RESPONSE), { + headers: { 'Content-Type': 'application/json' }, + }); + } + + throw new Error(`Unknown API endpoint: ${url}`); +}; + +describe('ollama list', () => { + const options: OllamaPluginParams = { + serverAddress: 'http://localhost:3000', + }; + + let ai: Genkit; + beforeEach(() => { + ai = genkit({ + plugins: [ollama(options)], + }); + }); + + it('should return models with ollama/ prefix to maintain v1 compatibility', async () => { + const result = await ollama().list!(); + + // Should return 2 models (embedding models are filtered out) + assert.strictEqual(result.length, 2); + + // Check that model names have the ollama/ prefix (maintaining v1 compatibility) + const modelNames = result.map((m) => m.name); + assert.ok(modelNames.includes('ollama/llama3.2:latest')); + assert.ok(modelNames.includes('ollama/gemma2:latest')); + assert.ok(!modelNames.includes('ollama/nomic-embed-text:latest')); // embedding model filtered out + + // Check that each model has the correct structure + for (const model of result) { + assert.ok(model.name); + assert.ok(model.metadata); + assert.ok(model.metadata.model); + const modelInfo = model.metadata.model; + assert.strictEqual(modelInfo.supports?.multiturn, true); + assert.strictEqual(modelInfo.supports?.media, true); + assert.strictEqual(modelInfo.supports?.tools, true); + assert.strictEqual(modelInfo.supports?.toolChoice, true); + assert.strictEqual(modelInfo.supports?.systemRole, true); + assert.strictEqual(modelInfo.supports?.constrained, 'all'); + } + }); + + it('should list models through Genkit instance', async () => { + const result = await ai.registry.listResolvableActions(); + + // Should return 2 models (embedding models are filtered out) + const modelActions = Object.values(result).filter( + (action) => action.actionType === 'model' + ); + assert.strictEqual(modelActions.length, 2); + + // Check that model names have the ollama/ prefix + const modelNames = modelActions.map((m) => m.name); + assert.ok(modelNames.includes('ollama/llama3.2:latest')); + assert.ok(modelNames.includes('ollama/gemma2:latest')); + assert.ok(!modelNames.includes('ollama/nomic-embed-text:latest')); // embedding model filtered out + }); +}); diff --git a/js/plugins/ollama/tests/model_test.ts b/js/plugins/ollama/tests/model_test.ts index 3c81eb24b3..d2b99ac2ce 100644 --- a/js/plugins/ollama/tests/model_test.ts +++ b/js/plugins/ollama/tests/model_test.ts @@ -19,9 +19,11 @@ import { beforeEach, describe, it } from 'node:test'; import { ollama } from '../src/index.js'; import type { OllamaPluginParams } from '../src/types.js'; +const BASE_TIME = new Date('2024-07-22T20:33:28.123648Z').getTime(); + const MOCK_TOOL_CALL_RESPONSE = { model: 'llama3.2', - created_at: '2024-07-22T20:33:28.123648Z', + created_at: new Date(BASE_TIME).toISOString(), message: { role: 'assistant', content: '', @@ -43,7 +45,7 @@ const MOCK_TOOL_CALL_RESPONSE = { const MOCK_END_RESPONSE = { model: 'llama3.2', - created_at: '2024-07-22T20:33:28.123648Z', + created_at: new Date(BASE_TIME).toISOString(), message: { role: 'assistant', content: 'The weather is sunny', @@ -52,16 +54,116 @@ const MOCK_END_RESPONSE = { done: true, }; -const MAGIC_WORD = 'sunnnnnnny'; +const MOCK_NO_TOOLS_END_RESPONSE = { + model: 'llama3.2', + created_at: new Date(BASE_TIME).toISOString(), + message: { + role: 'assistant', + content: 'I have no way of knowing that', + }, + done_reason: 'stop', + done: true, +}; + +// MockModel class to simulate the tool calling flow more clearly +class MockModel { + private callCount = 0; + private hasTools = false; + + // for non-streaming requests + async chat(request: any): Promise { + this.callCount++; + + // First call: initial request with tools → return tool call + if (this.callCount === 1 && request.tools && request.tools.length > 0) { + this.hasTools = true; + return MOCK_TOOL_CALL_RESPONSE; + } + + // Second call: follow-up with tool results → return final answer + if ( + this.callCount === 2 && + this.hasTools && + request.messages?.some((m: any) => m.role === 'tool') + ) { + return MOCK_END_RESPONSE; + } + + // Basic request without tools → return end response + return MOCK_NO_TOOLS_END_RESPONSE; + } + + // Create a streaming response for testing using a ReadableStream + createStreamingResponse(): ReadableStream { + const words = ['this', 'is', 'a', 'streaming', 'response']; + + return new ReadableStream({ + start(controller) { + let wordIndex = 0; + + const sendNextChunk = () => { + if (wordIndex >= words.length) { + controller.close(); + return; + } + + // Stream individual words (not cumulative) + const currentWord = words[wordIndex]; + const isLastChunk = wordIndex === words.length - 1; + + // Increment timestamp for each chunk + const chunkTime = new Date(BASE_TIME + wordIndex * 100).toISOString(); + + const response = { + model: 'llama3.2', + created_at: chunkTime, + message: { + role: 'assistant', + content: currentWord + (isLastChunk ? '' : ' '), // Add space except for last word + }, + done_reason: isLastChunk ? 'stop' : undefined, + done: isLastChunk, + }; + + controller.enqueue( + new TextEncoder().encode(JSON.stringify(response) + '\n') + ); + + wordIndex++; + setTimeout(sendNextChunk, 10); // Small delay to simulate streaming + }; + + sendNextChunk(); + }, + }); + } + + reset(): void { + this.callCount = 0; + this.hasTools = false; + } +} + +// Create a mock model instance to simulate the tool calling flow +const mockModel = new MockModel(); -// Mock fetch to simulate API responses +// Mock fetch to simulate the multi-turn tool calling flow using MockModel global.fetch = async (input: RequestInfo | URL, options?: RequestInit) => { const url = typeof input === 'string' ? input : input.toString(); if (url.includes('/api/chat')) { - if (options?.body && JSON.stringify(options.body).includes(MAGIC_WORD)) { - return new Response(JSON.stringify(MOCK_END_RESPONSE)); + const body = JSON.parse((options?.body as string) || '{}'); + + // Check if this is a streaming request + if (body.stream) { + const stream = mockModel.createStreamingResponse(); + return new Response(stream, { + headers: { 'Content-Type': 'application/json' }, + }); } - return new Response(JSON.stringify(MOCK_TOOL_CALL_RESPONSE)); + + // Non-streaming request + const response = await mockModel.chat(body); + return new Response(JSON.stringify(response)); } throw new Error('Unknown API endpoint'); }; @@ -74,11 +176,20 @@ describe('ollama models', () => { let ai: Genkit; beforeEach(() => { + mockModel.reset(); // Reset mock state between tests ai = genkit({ plugins: [ollama(options)], }); }); + it('should successfully return basic response', async () => { + const result = await ai.generate({ + model: 'ollama/test-model', + prompt: 'Hello', + }); + assert.ok(result.text === 'I have no way of knowing that'); + }); + it('should successfully return tool call response', async () => { const get_current_weather = ai.defineTool( { @@ -87,7 +198,7 @@ describe('ollama models', () => { inputSchema: z.object({ format: z.string(), location: z.string() }), }, async () => { - return MAGIC_WORD; + return 'sunny'; } ); @@ -99,34 +210,50 @@ describe('ollama models', () => { assert.ok(result.text === 'The weather is sunny'); }); - it('should throw for primitive tools', async () => { - const get_current_weather = ai.defineTool( + it('should throw for tools with primitive (non-object) input schema.', async () => { + // This tool will throw an error because it has a primitive (non-object) input schema. + const toolWithNonObjectInput = ai.defineTool( { - name: 'get_current_weather', - description: 'gets weather', - inputSchema: z.object({ format: z.string(), location: z.string() }), - }, - async () => { - return MAGIC_WORD; - } - ); - const fooz = ai.defineTool( - { - name: 'fooz', - description: 'gets fooz', + name: 'toolWithNonObjectInput', + description: 'tool with non-object input schema', inputSchema: z.string(), }, async () => { - return 1; + return 'anything'; } ); - await assert.rejects(async () => { + try { await ai.generate({ model: 'ollama/test-model', prompt: 'Hello', - tools: [get_current_weather, fooz], + tools: [toolWithNonObjectInput], }); + } catch (error) { + assert.ok(error instanceof Error); + + assert.ok( + error.message.includes('Ollama only supports tools with object inputs') + ); + } + }); + + it('should successfully return streaming response', async () => { + const streamingResult = ai.generateStream({ + model: 'ollama/test-model', + prompt: 'Hello', }); + + let fullText = ''; + let chunkCount = 0; + for await (const chunk of streamingResult.stream) { + fullText += chunk.text; // Each chunk contains individual words + chunkCount++; + } + + // Should have received multiple chunks (one per word) + assert.ok(chunkCount > 1); + // Final text should be complete + assert.ok(fullText === 'this is a streaming response'); }); });