From 8c5751cabcc4b76daaab68e3d9ff28a8d5718aaf Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Mon, 8 Jul 2024 12:06:26 -0700 Subject: [PATCH] Fix bugs related to combining tools and output conformance (#542) --- js/ai/src/generate.ts | 3 ++ js/ai/src/model/middleware.ts | 13 +++++- js/ai/tests/model/middleware_test.ts | 61 ++++++++++++++++++++++--- js/testapps/flow-simple-ai/src/index.ts | 5 +- 4 files changed, 72 insertions(+), 10 deletions(-) diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index f423df019..1bc28e2e7 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -635,6 +635,9 @@ export async function generate< if (resolvedOptions.output?.schema || resolvedOptions.output?.jsonSchema) { // find a candidate with valid output schema const candidateErrors = response.candidates.map((c) => { + // don't validate messages that have no text or data + if (c.text() === '' && c.data() === null) return null; + try { parseSchema(c.output(), { jsonSchema: resolvedOptions.output?.jsonSchema, diff --git a/js/ai/src/model/middleware.ts b/js/ai/src/model/middleware.ts index 122abe2ab..b34ade016 100644 --- a/js/ai/src/model/middleware.ts +++ b/js/ai/src/model/middleware.ts @@ -15,7 +15,7 @@ */ import { Document } from '../document.js'; -import { ModelInfo, ModelMiddleware, Part } from '../model.js'; +import { MessageData, ModelInfo, ModelMiddleware, Part } from '../model.js'; /** * Preprocess a GenerateRequest to download referenced http(s) media URLs and @@ -117,9 +117,18 @@ export function validateSupport(options: { }; } +function lastUserMessage(messages: MessageData[]) { + for (let i = messages.length - 1; i >= 0; i--) { + if (messages[i].role === 'user') { + return messages[i]; + } + } +} + export function conformOutput(): ModelMiddleware { return async (req, next) => { - const lastMessage = req.messages.at(-1)!; + const lastMessage = lastUserMessage(req.messages); + if (!lastMessage) return next(req); const outputPartIndex = lastMessage.content.findIndex( (p) => p.metadata?.purpose === 'output' ); diff --git a/js/ai/tests/model/middleware_test.ts b/js/ai/tests/model/middleware_test.ts index 14ba730bc..f5fb4918d 100644 --- a/js/ai/tests/model/middleware_test.ts +++ b/js/ai/tests/model/middleware_test.ts @@ -170,17 +170,21 @@ describe('conformOutput (default middleware)', () => { const response = await echoModel(req); const treq = response.candidates[0].message.content[0] .data as GenerateRequest; + + const lastUserMessage = treq.messages + .reverse() + .find((m) => m.role === 'user'); if ( - treq.messages - .at(-1)! - ?.content.filter((p) => p.metadata?.purpose === 'output').length > 1 + lastUserMessage && + lastUserMessage.content.filter((p) => p.metadata?.purpose === 'output') + .length > 1 ) { throw new Error('too many output parts'); } - return treq.messages - .at(-1) - ?.content.find((p) => p.metadata?.purpose === 'output')!; + return lastUserMessage?.content.find( + (p) => p.metadata?.purpose === 'output' + )!; } it('adds output instructions to the last message', async () => { @@ -199,6 +203,51 @@ describe('conformOutput (default middleware)', () => { ); }); + it('adds output to the last message with "user" role', async () => { + const part = await testRequest({ + messages: [ + { + content: [ + { + text: 'First message.', + }, + ], + role: 'user', + }, + { + content: [ + { + toolRequest: { + name: 'localRestaurant', + input: { + location: 'wtf', + }, + }, + }, + ], + role: 'model', + }, + { + content: [ + { + toolResponse: { + name: 'localRestaurant', + output: 'McDonalds', + }, + }, + ], + role: 'tool', + }, + ], + output: { format: 'json', schema }, + }); + + assert( + part?.text?.includes(JSON.stringify(schema)), + "schema wasn't found in output part" + ); + }); + it('does not add output instructions if already provided', async () => { const part = await testRequest({ messages: [ diff --git a/js/testapps/flow-simple-ai/src/index.ts b/js/testapps/flow-simple-ai/src/index.ts index 3cb1f17fd..bc6669bd3 100644 --- a/js/testapps/flow-simple-ai/src/index.ts +++ b/js/testapps/flow-simple-ai/src/index.ts @@ -215,15 +215,16 @@ export const jokeWithToolsFlow = defineFlow( modelName: z.enum([geminiPro.name, googleGeminiPro.name]), subject: z.string(), }), - outputSchema: z.string(), + outputSchema: z.object({ model: z.string(), joke: z.string() }), }, async (input) => { const llmResponse = await generate({ model: input.modelName, tools, + output: { schema: z.object({ joke: z.string() }) }, prompt: `Tell a joke about ${input.subject}.`, }); - return `From ${input.modelName}: ${llmResponse.text()}`; + return { ...llmResponse.output()!, model: input.modelName }; } );