From 2cbab2501aeb307074695b247b7eff58649af72b Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Mon, 26 Aug 2024 17:03:02 -0700 Subject: [PATCH 1/5] Makes inputSchema optional for tools. --- js/ai/src/model.ts | 3 +- js/testapps/flow-simple-ai/src/index.ts | 41 +++++++++++++++++++++++-- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 70bc0936d..142cd3da7 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -153,7 +153,8 @@ export const ToolDefinitionSchema = z.object({ description: z.string(), inputSchema: z .record(z.any()) - .describe('Valid JSON Schema representing the input of the tool.'), + .describe('Valid JSON Schema representing the input of the tool.') + .optional(), outputSchema: z .record(z.any()) .describe('Valid JSON Schema describing the output of the tool.') diff --git a/js/testapps/flow-simple-ai/src/index.ts b/js/testapps/flow-simple-ai/src/index.ts index 934fda755..7ac354945 100644 --- a/js/testapps/flow-simple-ai/src/index.ts +++ b/js/testapps/flow-simple-ai/src/index.ts @@ -14,8 +14,7 @@ * limitations under the License. */ -import { generate, generateStream, retrieve } from '@genkit-ai/ai'; -import { defineTool } from '@genkit-ai/ai/tool'; +import { defineTool, generate, generateStream, retrieve } from '@genkit-ai/ai'; import { configureGenkit } from '@genkit-ai/core'; import { dotprompt, prompt } from '@genkit-ai/dotprompt'; import { defineFirestoreRetriever, firebase } from '@genkit-ai/firebase'; @@ -23,8 +22,8 @@ import { defineFlow, run } from '@genkit-ai/flow'; import { googleCloud } from '@genkit-ai/google-cloud'; import { gemini15Flash, - googleAI, geminiPro as googleGeminiPro, + googleAI, } from '@genkit-ai/googleai'; import { gemini15ProPreview, @@ -429,6 +428,7 @@ export const invalidOutput = defineFlow( } ); +import { MessageSchema } from '@genkit-ai/ai/model'; import { GoogleAIFileManager } from '@google/generative-ai/server'; const fileManager = new GoogleAIFileManager( process.env.GOOGLE_GENAI_API_KEY || process.env.GOOGLE_API_KEY! @@ -465,3 +465,38 @@ export const fileApi = defineFlow( return result.text(); } ); + +export const testTools = [ + // test a tool with no input / output schema + defineTool( + { name: 'getColor', description: 'gets a random color' }, + async () => { + const colors = [ + 'red', + 'orange', + 'yellow', + 'blue', + 'green', + 'indigo', + 'violet', + ]; + return [Math.floor(Math.random() * colors.length)]; + } + ), +]; + +export const toolTester = defineFlow( + { + name: 'toolTester', + inputSchema: z.string(), + outputSchema: z.array(MessageSchema), + }, + async (query) => { + const result = await generate({ + model: gemini15Flash, + prompt: query, + tools: testTools, + }); + return result.toHistory(); + } +); From 8ed0774c6c0e14119009cc5ac21647d3ce79844d Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Mon, 26 Aug 2024 17:21:47 -0700 Subject: [PATCH 2/5] fixes --- js/ai/src/generate.ts | 14 +++++++++----- js/ai/src/model.ts | 4 ++-- js/testapps/flow-simple-ai/src/index.ts | 2 +- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index e97ac93cc..8d412942e 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -16,10 +16,10 @@ import { Action, - GenkitError, - StreamingCallback, config as genkitConfig, + GenkitError, runWithStreamingCallback, + StreamingCallback, } from '@genkit-ai/core'; import { lookupAction } from '@genkit-ai/core/registry'; import { toJsonSchema, validateSchema } from '@genkit-ai/core/schema'; @@ -27,8 +27,8 @@ import { z } from 'zod'; import { DocumentData } from './document.js'; import { extractJson } from './extract.js'; import { - GenerateUtilParamSchema, generateAction, + GenerateUtilParamSchema, inferRoleFromParts, } from './generateAction.js'; import { @@ -47,7 +47,7 @@ import { ToolRequestPart, ToolResponsePart, } from './model.js'; -import { ToolArgument, resolveTools, toToolDefinition } from './tool.js'; +import { resolveTools, ToolArgument, toToolDefinition } from './tool.js'; /** * Message represents a single role's contribution to a generation. Each message @@ -610,7 +610,11 @@ export async function generate< return await runWithStreamingCallback( resolvedOptions.streamingCallback, - async () => new GenerateResponse(await generateAction(params)) + async () => + new GenerateResponse( + await generateAction(params), + await toGenerateRequest(resolvedOptions) + ) ); } diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 142cd3da7..1b84f8ba2 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -154,11 +154,11 @@ export const ToolDefinitionSchema = z.object({ inputSchema: z .record(z.any()) .describe('Valid JSON Schema representing the input of the tool.') - .optional(), + .nullish(), outputSchema: z .record(z.any()) .describe('Valid JSON Schema describing the output of the tool.') - .optional(), + .nullish(), }); export type ToolDefinition = z.infer; diff --git a/js/testapps/flow-simple-ai/src/index.ts b/js/testapps/flow-simple-ai/src/index.ts index 7ac354945..e36cdf2b6 100644 --- a/js/testapps/flow-simple-ai/src/index.ts +++ b/js/testapps/flow-simple-ai/src/index.ts @@ -480,7 +480,7 @@ export const testTools = [ 'indigo', 'violet', ]; - return [Math.floor(Math.random() * colors.length)]; + return colors[Math.floor(Math.random() * colors.length)]; } ), ]; From 72f33b99e649abd549d527ab792e7aafb8a260c5 Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Mon, 26 Aug 2024 17:22:06 -0700 Subject: [PATCH 3/5] format --- js/testapps/flow-simple-ai/src/index.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/testapps/flow-simple-ai/src/index.ts b/js/testapps/flow-simple-ai/src/index.ts index e36cdf2b6..0b64d41d4 100644 --- a/js/testapps/flow-simple-ai/src/index.ts +++ b/js/testapps/flow-simple-ai/src/index.ts @@ -22,8 +22,8 @@ import { defineFlow, run } from '@genkit-ai/flow'; import { googleCloud } from '@genkit-ai/google-cloud'; import { gemini15Flash, - geminiPro as googleGeminiPro, googleAI, + geminiPro as googleGeminiPro, } from '@genkit-ai/googleai'; import { gemini15ProPreview, From ecb3223ebdb4aef0b1f58268b253c2b45bb9ca62 Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Mon, 26 Aug 2024 17:31:51 -0700 Subject: [PATCH 4/5] fix --- js/plugins/vertexai/src/openai_compatibility.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/plugins/vertexai/src/openai_compatibility.ts b/js/plugins/vertexai/src/openai_compatibility.ts index 8e1d27ca6..6d89565b6 100644 --- a/js/plugins/vertexai/src/openai_compatibility.ts +++ b/js/plugins/vertexai/src/openai_compatibility.ts @@ -75,7 +75,7 @@ function toOpenAiTool(tool: ToolDefinition): ChatCompletionTool { type: 'function', function: { name: tool.name, - parameters: tool.inputSchema, + parameters: tool.inputSchema || undefined, }, }; } From ba6b65761ea90c79d9bc71a99811cca7a36670a4 Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Tue, 27 Aug 2024 09:50:19 -0700 Subject: [PATCH 5/5] Adds toHistory() test for generate() --- js/ai/tests/generate/generate_test.ts | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index 2388c401b..a20d856f2 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -17,7 +17,7 @@ import assert from 'node:assert'; import { describe, it } from 'node:test'; import { z } from 'zod'; -import { GenerateResponseChunk } from '../../src/generate'; +import { generate, GenerateResponseChunk } from '../../src/generate'; import { Candidate, GenerateOptions, @@ -25,7 +25,7 @@ import { Message, toGenerateRequest, } from '../../src/generate.js'; -import { GenerateResponseChunkData } from '../../src/model'; +import { defineModel, GenerateResponseChunkData } from '../../src/model'; import { CandidateData, GenerateRequest, @@ -581,3 +581,26 @@ describe('GenerateResponseChunk', () => { } }); }); + +const echo = defineModel( + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + candidates: [ + { index: 0, message: input.messages[0], finishReason: 'stop' }, + ], + }) +); + +describe('generate', () => { + it('should preserve the request in the returned response, enabling toHistory()', async () => { + const response = await generate({ + model: echo, + prompt: 'Testing toHistory', + }); + + assert.deepEqual( + response.toHistory().map((m) => m.content[0].text), + ['Testing toHistory', 'Testing toHistory'] + ); + }); +});