Skip to content

Commit

Permalink
Fix bugs related to combining tools and output conformance (#542)
Browse files Browse the repository at this point in the history
  • Loading branch information
mbleigh authored Jul 8, 2024
1 parent af977cc commit 8c5751c
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 10 deletions.
3 changes: 3 additions & 0 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,9 @@ export async function generate<
if (resolvedOptions.output?.schema || resolvedOptions.output?.jsonSchema) {
// find a candidate with valid output schema
const candidateErrors = response.candidates.map((c) => {
// don't validate messages that have no text or data
if (c.text() === '' && c.data() === null) return null;

try {
parseSchema(c.output(), {
jsonSchema: resolvedOptions.output?.jsonSchema,
Expand Down
13 changes: 11 additions & 2 deletions js/ai/src/model/middleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/

import { Document } from '../document.js';
import { ModelInfo, ModelMiddleware, Part } from '../model.js';
import { MessageData, ModelInfo, ModelMiddleware, Part } from '../model.js';

/**
* Preprocess a GenerateRequest to download referenced http(s) media URLs and
Expand Down Expand Up @@ -117,9 +117,18 @@ export function validateSupport(options: {
};
}

function lastUserMessage(messages: MessageData[]) {
for (let i = messages.length - 1; i >= 0; i--) {
if (messages[i].role === 'user') {
return messages[i];
}
}
}

export function conformOutput(): ModelMiddleware {
return async (req, next) => {
const lastMessage = req.messages.at(-1)!;
const lastMessage = lastUserMessage(req.messages);
if (!lastMessage) return next(req);
const outputPartIndex = lastMessage.content.findIndex(
(p) => p.metadata?.purpose === 'output'
);
Expand Down
61 changes: 55 additions & 6 deletions js/ai/tests/model/middleware_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,17 +170,21 @@ describe('conformOutput (default middleware)', () => {
const response = await echoModel(req);
const treq = response.candidates[0].message.content[0]
.data as GenerateRequest;

const lastUserMessage = treq.messages
.reverse()
.find((m) => m.role === 'user');
if (
treq.messages
.at(-1)!
?.content.filter((p) => p.metadata?.purpose === 'output').length > 1
lastUserMessage &&
lastUserMessage.content.filter((p) => p.metadata?.purpose === 'output')
.length > 1
) {
throw new Error('too many output parts');
}

return treq.messages
.at(-1)
?.content.find((p) => p.metadata?.purpose === 'output')!;
return lastUserMessage?.content.find(
(p) => p.metadata?.purpose === 'output'
)!;
}

it('adds output instructions to the last message', async () => {
Expand All @@ -199,6 +203,51 @@ describe('conformOutput (default middleware)', () => {
);
});

it('adds output to the last message with "user" role', async () => {
const part = await testRequest({
messages: [
{
content: [
{
text: 'First message.',
},
],
role: 'user',
},
{
content: [
{
toolRequest: {
name: 'localRestaurant',
input: {
location: 'wtf',
},
},
},
],
role: 'model',
},
{
content: [
{
toolResponse: {
name: 'localRestaurant',
output: 'McDonalds',
},
},
],
role: 'tool',
},
],
output: { format: 'json', schema },
});

assert(
part?.text?.includes(JSON.stringify(schema)),
"schema wasn't found in output part"
);
});

it('does not add output instructions if already provided', async () => {
const part = await testRequest({
messages: [
Expand Down
5 changes: 3 additions & 2 deletions js/testapps/flow-simple-ai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -215,15 +215,16 @@ export const jokeWithToolsFlow = defineFlow(
modelName: z.enum([geminiPro.name, googleGeminiPro.name]),
subject: z.string(),
}),
outputSchema: z.string(),
outputSchema: z.object({ model: z.string(), joke: z.string() }),
},
async (input) => {
const llmResponse = await generate({
model: input.modelName,
tools,
output: { schema: z.object({ joke: z.string() }) },
prompt: `Tell a joke about ${input.subject}.`,
});
return `From ${input.modelName}: ${llmResponse.text()}`;
return { ...llmResponse.output()!, model: input.modelName };
}
);

Expand Down

0 comments on commit 8c5751c

Please sign in to comment.