From 6de87f58e9494dbf7cb2e3fde3d0df50ae4728bf Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Wed, 27 Nov 2024 11:13:51 -0800 Subject: [PATCH] fix(js/ai): Fixes use of namespaced tools in model calls. (#1423) --- genkit-tools/common/src/types/model.ts | 4 ++ genkit-tools/genkit-schema.json | 5 +++ js/ai/src/generate.ts | 2 +- js/ai/src/generate/action.ts | 20 ++++++++-- js/ai/src/model.ts | 4 ++ js/ai/src/tool.ts | 21 ++++++++-- js/ai/tests/generate/generate_test.ts | 51 ++++++++++++++++++++++++- js/genkit/src/genkit.ts | 2 +- js/testapps/flow-simple-ai/src/index.ts | 50 ++++++++++++++++++++---- 9 files changed, 141 insertions(+), 18 deletions(-) diff --git a/genkit-tools/common/src/types/model.ts b/genkit-tools/common/src/types/model.ts index 547652c2f..5ce711c17 100644 --- a/genkit-tools/common/src/types/model.ts +++ b/genkit-tools/common/src/types/model.ts @@ -133,6 +133,10 @@ export const ToolDefinitionSchema = z.object({ .record(z.any()) .describe('Valid JSON Schema describing the output of the tool.') .optional(), + metadata: z + .record(z.any()) + .describe('additional metadata for this tool definition') + .optional(), }); export type ToolDefinition = z.infer; diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 437fe51b0..e1115db3f 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -839,6 +839,11 @@ "type": "object", "additionalProperties": {}, "description": "Valid JSON Schema describing the output of the tool." + }, + "metadata": { + "type": "object", + "additionalProperties": {}, + "description": "additional metadata for this tool definition" } }, "required": [ diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index cb49c6d08..e6d1bb2af 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -129,7 +129,7 @@ export async function toGenerateRequest( messages: injectInstructions(messages, instructions), config: options.config, docs: options.docs, - tools: tools?.map((tool) => toToolDefinition(tool)) || [], + tools: tools?.map(toToolDefinition) || [], output: { ...(resolvedFormat?.config || {}), schema: resolvedSchema, diff --git a/js/ai/src/generate/action.ts b/js/ai/src/generate/action.ts index 8a4d01753..48d6b6e9b 100644 --- a/js/ai/src/generate/action.ts +++ b/js/ai/src/generate/action.ts @@ -15,6 +15,7 @@ */ import { + GenkitError, getStreamingCallback, runWithStreamingCallback, z, @@ -116,6 +117,19 @@ async function generate( const tools = await resolveTools(registry, rawRequest.tools); const resolvedFormat = await resolveFormat(registry, rawRequest.output); + // Create a lookup of tool names with namespaces stripped to original names + const toolMap = tools.reduce>((acc, tool) => { + const name = tool.__action.name; + const shortName = name.substring(name.lastIndexOf('/') + 1); + if (acc[shortName]) { + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: `Cannot provide two tools with the same name: '${name}' and '${acc[shortName]}'`, + }); + } + acc[shortName] = tool; + return acc; + }, {}); const request = await actionToGenerateRequest( rawRequest, @@ -184,9 +198,7 @@ async function generate( 'Tool request expected but not provided in tool request part' ); } - const tool = tools?.find( - (tool) => tool.__action.name === part.toolRequest?.name - ); + const tool = toolMap[part.toolRequest?.name]; if (!tool) { throw Error(`Tool ${part.toolRequest?.name} not found`); } @@ -238,7 +250,7 @@ async function actionToGenerateRequest( messages: options.messages, config: options.config, docs: options.docs, - tools: resolvedTools?.map((tool) => toToolDefinition(tool)) || [], + tools: resolvedTools?.map(toToolDefinition) || [], output: { ...(resolvedFormat?.config || {}), schema: toJsonSchema({ diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 18609fd8d..c0eb0b9c8 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -166,6 +166,10 @@ export const ToolDefinitionSchema = z.object({ .record(z.any()) .describe('Valid JSON Schema describing the output of the tool.') .nullish(), + metadata: z + .record(z.any()) + .describe('additional metadata for this tool definition') + .optional(), }); export type ToolDefinition = z.infer; diff --git a/js/ai/src/tool.ts b/js/ai/src/tool.ts index bfeb37efd..a56f9ad59 100644 --- a/js/ai/src/tool.ts +++ b/js/ai/src/tool.ts @@ -109,7 +109,10 @@ export async function resolveTools< } else if (typeof (ref as ExecutablePrompt).asTool === 'function') { return await (ref as ExecutablePrompt).asTool(); } else if (ref.name) { - return await lookupToolByName(registry, ref.name); + return await lookupToolByName( + registry, + (ref as ToolDefinition).metadata?.originalName || ref.name + ); } throw new Error('Tools must be strings, tool definitions, or actions.'); }) @@ -136,8 +139,14 @@ export async function lookupToolByName( export function toToolDefinition( tool: Action ): ToolDefinition { - return { - name: tool.__action.name, + const originalName = tool.__action.name; + let name = originalName; + if (originalName.includes('/')) { + name = originalName.substring(originalName.lastIndexOf('/') + 1); + } + + const out: ToolDefinition = { + name, description: tool.__action.description || '', outputSchema: toJsonSchema({ schema: tool.__action.outputSchema ?? z.void(), @@ -148,6 +157,12 @@ export function toToolDefinition( jsonSchema: tool.__action.inputJsonSchema, })!, }; + + if (originalName !== name) { + out.metadata = { originalName }; + } + + return out; } /** diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index 9c98ae6cc..72c3a565a 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { z } from '@genkit-ai/core'; +import { PluginProvider, z } from '@genkit-ai/core'; import { Registry } from '@genkit-ai/core/registry'; import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; @@ -43,6 +43,23 @@ describe('toGenerateRequest', () => { } ); + const namespacedPlugin: PluginProvider = { + name: 'namespaced', + initializer: async () => {}, + }; + registry.registerPluginProvider('namespaced', namespacedPlugin); + + defineTool( + registry, + { + name: 'namespaced/add', + description: 'add two numbers together', + inputSchema: z.object({ a: z.number(), b: z.number() }), + outputSchema: z.number(), + }, + async ({ a, b }) => a + b + ); + const testCases = [ { should: 'translate a string prompt correctly', @@ -95,6 +112,38 @@ describe('toGenerateRequest', () => { output: {}, }, }, + { + should: 'strip namespaces from tools when passing to the model', + prompt: { + model: 'vertexai/gemini-1.0-pro', + tools: ['namespaced/add'], + prompt: 'Add 10 and 5.', + }, + expectedOutput: { + messages: [{ role: 'user', content: [{ text: 'Add 10 and 5.' }] }], + config: undefined, + docs: undefined, + tools: [ + { + description: 'add two numbers together', + inputSchema: { + $schema: 'http://json-schema.org/draft-07/schema#', + additionalProperties: true, + properties: { a: { type: 'number' }, b: { type: 'number' } }, + required: ['a', 'b'], + type: 'object', + }, + name: 'add', + outputSchema: { + $schema: 'http://json-schema.org/draft-07/schema#', + type: 'number', + }, + metadata: { originalName: 'namespaced/add' }, + }, + ], + output: {}, + }, + }, { should: 'translate a string prompt correctly with tools referenced by their action', diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index 593594c31..e091ca222 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -466,7 +466,7 @@ export class Genkit { if (!response.tools && options.tools) { response.tools = ( await resolveTools(this.registry, options.tools) - ).map(toToolDefinition); + ).map((t) => toToolDefinition(t)); } if (!response.output && options.output) { response.output = { diff --git a/js/testapps/flow-simple-ai/src/index.ts b/js/testapps/flow-simple-ai/src/index.ts index 57ad3b404..ecc27cfe7 100644 --- a/js/testapps/flow-simple-ai/src/index.ts +++ b/js/testapps/flow-simple-ai/src/index.ts @@ -28,6 +28,7 @@ import { initializeApp } from 'firebase-admin/app'; import { getFirestore } from 'firebase-admin/firestore'; import { MessageSchema, genkit, run, z } from 'genkit'; import { logger } from 'genkit/logging'; +import { PluginProvider } from 'genkit/plugin'; import { Allow, parse } from 'partial-json'; logger.setLogLevel('debug'); @@ -53,6 +54,32 @@ const ai = genkit({ plugins: [googleAI(), vertexAI()], }); +const math: PluginProvider = { + name: 'math', + initializer: async () => { + ai.defineTool( + { + name: 'math/add', + description: 'add two numbers', + inputSchema: z.object({ a: z.number(), b: z.number() }), + outputSchema: z.number(), + }, + async ({ a, b }) => a + b + ); + + ai.defineTool( + { + name: 'math/subtract', + description: 'subtract two numbers', + inputSchema: z.object({ a: z.number(), b: z.number() }), + outputSchema: z.number(), + }, + async ({ a, b }) => a - b + ); + }, +}; +ai.registry.registerPluginProvider('math', math); + const app = initializeApp(); export const jokeFlow = ai.defineFlow( @@ -538,11 +565,18 @@ export const arrayStreamTester = ai.defineStreamingFlow( } ); -// async function main() { -// const { stream, output } = arrayStreamTester(); -// for await (const chunk of stream) { -// console.log(chunk); -// } -// console.log(await output); -// } -// main(); +ai.defineFlow( + { + name: 'math', + inputSchema: z.string(), + outputSchema: z.string(), + }, + async (query) => { + const { text } = await ai.generate({ + model: gemini15Flash, + prompt: query, + tools: ['math/add', 'math/subtract'], + }); + return text; + } +);