Skip to content

Commit

Permalink
fix(js/ai): Fixes use of namespaced tools in model calls. (#1423)
Browse files Browse the repository at this point in the history
  • Loading branch information
mbleigh authored Nov 27, 2024
1 parent 1d5259c commit c36cad7
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 18 deletions.
4 changes: 4 additions & 0 deletions genkit-tools/common/src/types/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ export const ToolDefinitionSchema = z.object({
.record(z.any())
.describe('Valid JSON Schema describing the output of the tool.')
.optional(),
metadata: z
.record(z.any())
.describe('additional metadata for this tool definition')
.optional(),
});
export type ToolDefinition = z.infer<typeof ToolDefinitionSchema>;

Expand Down
5 changes: 5 additions & 0 deletions genkit-tools/genkit-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,11 @@
"type": "object",
"additionalProperties": {},
"description": "Valid JSON Schema describing the output of the tool."
},
"metadata": {
"type": "object",
"additionalProperties": {},
"description": "additional metadata for this tool definition"
}
},
"required": [
Expand Down
2 changes: 1 addition & 1 deletion js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ export async function toGenerateRequest(
messages: injectInstructions(messages, instructions),
config: options.config,
docs: options.docs,
tools: tools?.map((tool) => toToolDefinition(tool)) || [],
tools: tools?.map(toToolDefinition) || [],
output: {
...(resolvedFormat?.config || {}),
schema: resolvedSchema,
Expand Down
20 changes: 16 additions & 4 deletions js/ai/src/generate/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

import {
GenkitError,
getStreamingCallback,
runWithStreamingCallback,
z,
Expand Down Expand Up @@ -116,6 +117,19 @@ async function generate(
const tools = await resolveTools(registry, rawRequest.tools);

const resolvedFormat = await resolveFormat(registry, rawRequest.output);
// Create a lookup of tool names with namespaces stripped to original names
const toolMap = tools.reduce<Record<string, ToolAction>>((acc, tool) => {
const name = tool.__action.name;
const shortName = name.substring(name.lastIndexOf('/') + 1);
if (acc[shortName]) {
throw new GenkitError({
status: 'INVALID_ARGUMENT',
message: `Cannot provide two tools with the same name: '${name}' and '${acc[shortName]}'`,
});
}
acc[shortName] = tool;
return acc;
}, {});

const request = await actionToGenerateRequest(
rawRequest,
Expand Down Expand Up @@ -184,9 +198,7 @@ async function generate(
'Tool request expected but not provided in tool request part'
);
}
const tool = tools?.find(
(tool) => tool.__action.name === part.toolRequest?.name
);
const tool = toolMap[part.toolRequest?.name];
if (!tool) {
throw Error(`Tool ${part.toolRequest?.name} not found`);
}
Expand Down Expand Up @@ -238,7 +250,7 @@ async function actionToGenerateRequest(
messages: options.messages,
config: options.config,
docs: options.docs,
tools: resolvedTools?.map((tool) => toToolDefinition(tool)) || [],
tools: resolvedTools?.map(toToolDefinition) || [],
output: {
...(resolvedFormat?.config || {}),
schema: toJsonSchema({
Expand Down
4 changes: 4 additions & 0 deletions js/ai/src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ export const ToolDefinitionSchema = z.object({
.record(z.any())
.describe('Valid JSON Schema describing the output of the tool.')
.nullish(),
metadata: z
.record(z.any())
.describe('additional metadata for this tool definition')
.optional(),
});
export type ToolDefinition = z.infer<typeof ToolDefinitionSchema>;

Expand Down
21 changes: 18 additions & 3 deletions js/ai/src/tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,10 @@ export async function resolveTools<
} else if (typeof (ref as ExecutablePrompt).asTool === 'function') {
return await (ref as ExecutablePrompt).asTool();
} else if (ref.name) {
return await lookupToolByName(registry, ref.name);
return await lookupToolByName(
registry,
(ref as ToolDefinition).metadata?.originalName || ref.name
);
}
throw new Error('Tools must be strings, tool definitions, or actions.');
})
Expand All @@ -136,8 +139,14 @@ export async function lookupToolByName(
export function toToolDefinition(
tool: Action<z.ZodTypeAny, z.ZodTypeAny>
): ToolDefinition {
return {
name: tool.__action.name,
const originalName = tool.__action.name;
let name = originalName;
if (originalName.includes('/')) {
name = originalName.substring(originalName.lastIndexOf('/') + 1);
}

const out: ToolDefinition = {
name,
description: tool.__action.description || '',
outputSchema: toJsonSchema({
schema: tool.__action.outputSchema ?? z.void(),
Expand All @@ -148,6 +157,12 @@ export function toToolDefinition(
jsonSchema: tool.__action.inputJsonSchema,
})!,
};

if (originalName !== name) {
out.metadata = { originalName };
}

return out;
}

/**
Expand Down
51 changes: 50 additions & 1 deletion js/ai/tests/generate/generate_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

import { z } from '@genkit-ai/core';
import { PluginProvider, z } from '@genkit-ai/core';
import { Registry } from '@genkit-ai/core/registry';
import assert from 'node:assert';
import { beforeEach, describe, it } from 'node:test';
Expand Down Expand Up @@ -43,6 +43,23 @@ describe('toGenerateRequest', () => {
}
);

const namespacedPlugin: PluginProvider = {
name: 'namespaced',
initializer: async () => {},
};
registry.registerPluginProvider('namespaced', namespacedPlugin);

defineTool(
registry,
{
name: 'namespaced/add',
description: 'add two numbers together',
inputSchema: z.object({ a: z.number(), b: z.number() }),
outputSchema: z.number(),
},
async ({ a, b }) => a + b
);

const testCases = [
{
should: 'translate a string prompt correctly',
Expand Down Expand Up @@ -95,6 +112,38 @@ describe('toGenerateRequest', () => {
output: {},
},
},
{
should: 'strip namespaces from tools when passing to the model',
prompt: {
model: 'vertexai/gemini-1.0-pro',
tools: ['namespaced/add'],
prompt: 'Add 10 and 5.',
},
expectedOutput: {
messages: [{ role: 'user', content: [{ text: 'Add 10 and 5.' }] }],
config: undefined,
docs: undefined,
tools: [
{
description: 'add two numbers together',
inputSchema: {
$schema: 'http://json-schema.org/draft-07/schema#',
additionalProperties: true,
properties: { a: { type: 'number' }, b: { type: 'number' } },
required: ['a', 'b'],
type: 'object',
},
name: 'add',
outputSchema: {
$schema: 'http://json-schema.org/draft-07/schema#',
type: 'number',
},
metadata: { originalName: 'namespaced/add' },
},
],
output: {},
},
},
{
should:
'translate a string prompt correctly with tools referenced by their action',
Expand Down
2 changes: 1 addition & 1 deletion js/genkit/src/genkit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ export class Genkit {
if (!response.tools && options.tools) {
response.tools = (
await resolveTools(this.registry, options.tools)
).map(toToolDefinition);
).map((t) => toToolDefinition(t));
}
if (!response.output && options.output) {
response.output = {
Expand Down
50 changes: 42 additions & 8 deletions js/testapps/flow-simple-ai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import { initializeApp } from 'firebase-admin/app';
import { getFirestore } from 'firebase-admin/firestore';
import { MessageSchema, genkit, run, z } from 'genkit';
import { logger } from 'genkit/logging';
import { PluginProvider } from 'genkit/plugin';
import { Allow, parse } from 'partial-json';

logger.setLogLevel('debug');
Expand All @@ -53,6 +54,32 @@ const ai = genkit({
plugins: [googleAI(), vertexAI()],
});

const math: PluginProvider = {
name: 'math',
initializer: async () => {
ai.defineTool(
{
name: 'math/add',
description: 'add two numbers',
inputSchema: z.object({ a: z.number(), b: z.number() }),
outputSchema: z.number(),
},
async ({ a, b }) => a + b
);

ai.defineTool(
{
name: 'math/subtract',
description: 'subtract two numbers',
inputSchema: z.object({ a: z.number(), b: z.number() }),
outputSchema: z.number(),
},
async ({ a, b }) => a - b
);
},
};
ai.registry.registerPluginProvider('math', math);

const app = initializeApp();

export const jokeFlow = ai.defineFlow(
Expand Down Expand Up @@ -538,11 +565,18 @@ export const arrayStreamTester = ai.defineStreamingFlow(
}
);

// async function main() {
// const { stream, output } = arrayStreamTester();
// for await (const chunk of stream) {
// console.log(chunk);
// }
// console.log(await output);
// }
// main();
ai.defineFlow(
{
name: 'math',
inputSchema: z.string(),
outputSchema: z.string(),
},
async (query) => {
const { text } = await ai.generate({
model: gemini15Flash,
prompt: query,
tools: ['math/add', 'math/subtract'],
});
return text;
}
);

0 comments on commit c36cad7

Please sign in to comment.