diff --git a/js/ai/src/formats/json.ts b/js/ai/src/formats/json.ts index cc4beecd0a..5d7465546a 100644 --- a/js/ai/src/formats/json.ts +++ b/js/ai/src/formats/json.ts @@ -23,8 +23,19 @@ export const jsonFormatter: Formatter = { format: 'json', contentType: 'application/json', constrained: true, + defaultInstructions: false, }, - handler: () => { + handler: (schema) => { + let instructions: string | undefined; + + if (schema) { + instructions = `Output should be in JSON format and conform to the following schema: + +\`\`\` +${JSON.stringify(schema)} +\`\`\` +`; + } return { parseChunk: (chunk) => { return extractJson(chunk.accumulatedText); @@ -33,6 +44,8 @@ export const jsonFormatter: Formatter = { parseMessage: (message) => { return extractJson(message.text); }, + + instructions, }; }, }; diff --git a/js/ai/src/formats/types.ts b/js/ai/src/formats/types.ts index 7f0c9fbd5a..3cd17e3d3a 100644 --- a/js/ai/src/formats/types.ts +++ b/js/ai/src/formats/types.ts @@ -23,7 +23,9 @@ export type OutputContentTypes = 'application/json' | 'text/plain'; export interface Formatter { name: string; - config: ModelRequest['output']; + config: ModelRequest['output'] & { + defaultInstructions?: false; + }; handler: (schema?: JSONSchema) => { parseMessage(message: Message): O; parseChunk?: (chunk: GenerateResponseChunk) => CO; diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index d0b3ebef14..8d9a1766f6 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -33,7 +33,10 @@ import { resolveFormat, resolveInstructions, } from './formats/index.js'; -import { generateHelper } from './generate/action.js'; +import { + generateHelper, + shouldInjectFormatInstructions, +} from './generate/action.js'; import { GenerateResponseChunk } from './generate/chunk.js'; import { GenerateResponse } from './generate/response.js'; import { Message } from './message.js'; @@ -211,14 +214,19 @@ export async function toGenerateRequest( ); const out = { - messages: injectInstructions(messages, instructions), + messages: shouldInjectFormatInstructions( + resolvedFormat?.config, + options.output + ) + ? injectInstructions(messages, instructions) + : messages, config: options.config, docs: options.docs, tools: tools?.map(toToolDefinition) || [], output: { ...(resolvedFormat?.config || {}), - schema: resolvedSchema, ...options.output, + schema: resolvedSchema, }, } as GenerateRequest; if (!out?.output?.schema) delete out?.output?.schema; @@ -343,16 +351,11 @@ export async function generate< resolvedOptions.output.format = 'json'; } const resolvedFormat = await resolveFormat(registry, resolvedOptions.output); - const instructions = resolveInstructions( - resolvedFormat, - resolvedSchema, - resolvedOptions?.output?.instructions - ); const params: GenerateActionOptions = { model: resolvedModel.modelAction.__action.name, docs: resolvedOptions.docs, - messages: injectInstructions(messages, instructions), + messages: messages, tools, toolChoice: resolvedOptions.toolChoice, config: { diff --git a/js/ai/src/generate/action.ts b/js/ai/src/generate/action.ts index 74f69ad072..2dc71bf8ad 100644 --- a/js/ai/src/generate/action.ts +++ b/js/ai/src/generate/action.ts @@ -41,6 +41,7 @@ import { import { GenerateActionOptions, GenerateActionOptionsSchema, + GenerateActionOutputConfig, GenerateRequest, GenerateRequestSchema, GenerateResponseChunkData, @@ -172,7 +173,14 @@ function applyFormat( ); if (resolvedFormat) { - outRequest.messages = injectInstructions(outRequest.messages, instructions); + if ( + shouldInjectFormatInstructions(resolvedFormat.config, rawRequest?.output) + ) { + outRequest.messages = injectInstructions( + outRequest.messages, + instructions + ); + } outRequest.output = { // use output config from the format ...resolvedFormat.config, @@ -184,6 +192,16 @@ function applyFormat( return outRequest; } +export function shouldInjectFormatInstructions( + formatConfig?: Formatter['config'], + rawRequestConfig?: z.infer +) { + return ( + formatConfig?.defaultInstructions !== false || + rawRequestConfig?.instructions + ); +} + function applyTransferPreamble( rawRequest: GenerateActionOptions, transferPreamble?: GenerateActionOptions diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 3f7e2ca576..368d48a24e 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -699,6 +699,14 @@ export async function resolveModel( return out; } +export const GenerateActionOutputConfig = z.object({ + format: z.string().optional(), + contentType: z.string().optional(), + instructions: z.union([z.boolean(), z.string()]).optional(), + jsonSchema: z.any().optional(), + constrained: z.boolean().optional(), +}); + export const GenerateActionOptionsSchema = z.object({ /** A model name (e.g. `vertexai/gemini-1.0-pro`). */ model: z.string(), @@ -713,15 +721,7 @@ export const GenerateActionOptionsSchema = z.object({ /** Configuration for the generation request. */ config: z.any().optional(), /** Configuration for the desired output of the request. Defaults to the model's default output if unspecified. */ - output: z - .object({ - format: z.string().optional(), - contentType: z.string().optional(), - instructions: z.union([z.boolean(), z.string()]).optional(), - jsonSchema: z.any().optional(), - constrained: z.boolean().optional(), - }) - .optional(), + output: GenerateActionOutputConfig.optional(), /** Options for resuming an interrupted generation. */ resume: z .object({ diff --git a/js/ai/tests/model/middleware_test.ts b/js/ai/tests/model/middleware_test.ts index 49cd624fd3..cdd70d1caa 100644 --- a/js/ai/tests/model/middleware_test.ts +++ b/js/ai/tests/model/middleware_test.ts @@ -396,7 +396,7 @@ describe('augmentWithContext', () => { }); }); -describe('simulateConstrainedGeneration', () => { +describe.only('simulateConstrainedGeneration', () => { let registry: Registry; beforeEach(() => { @@ -555,4 +555,70 @@ describe('simulateConstrainedGeneration', () => { tools: [], }); }); + + it('uses format instructions when instructions is explicitly set to true', async () => { + let pm = defineProgrammableModel(registry, { + supports: { constrained: 'all' }, + }); + pm.handleResponse = async (req, sc) => { + return { + message: { + role: 'model', + content: [{ text: '```\n{"foo": "bar"}\n```' }], + }, + }; + }; + + const { output } = await generate(registry, { + model: 'programmableModel', + prompt: 'generate json', + output: { + instructions: true, + constrained: false, + schema: z.object({ + foo: z.string(), + }), + }, + }); + assert.deepEqual(output, { foo: 'bar' }); + assert.deepStrictEqual(pm.lastRequest, { + config: {}, + messages: [ + { + role: 'user', + content: [ + { text: 'generate json' }, + { + metadata: { + purpose: 'output', + }, + text: + 'Output should be in JSON format and conform to the following schema:\n' + + '\n' + + '```\n' + + '{"type":"object","properties":{"foo":{"type":"string"}},"required":["foo"],"additionalProperties":true,"$schema":"http://json-schema.org/draft-07/schema#"}\n' + + '```\n', + }, + ], + }, + ], + output: { + constrained: false, + contentType: 'application/json', + format: 'json', + schema: { + $schema: 'http://json-schema.org/draft-07/schema#', + additionalProperties: true, + properties: { + foo: { + type: 'string', + }, + }, + required: ['foo'], + type: 'object', + }, + }, + tools: [], + }); + }); }); diff --git a/js/core/src/action.ts b/js/core/src/action.ts index 4fee1628cd..4b9441f19b 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { JSONSchema7 } from 'json-schema'; +import { type JSONSchema7 } from 'json-schema'; import * as z from 'zod'; import { lazy } from './async.js'; import { ActionContext, getContext, runWithContext } from './context.js'; @@ -26,7 +26,7 @@ import { setCustomMetadataAttributes, } from './tracing.js'; -export { Status, StatusCodes, StatusSchema } from './statusTypes.js'; +export { StatusCodes, StatusSchema, type Status } from './statusTypes.js'; export { JSONSchema7 }; /** diff --git a/js/core/src/error.ts b/js/core/src/error.ts index bda27d595f..305220a48e 100644 --- a/js/core/src/error.ts +++ b/js/core/src/error.ts @@ -15,7 +15,7 @@ */ import { Registry } from './registry.js'; -import { httpStatusCode, StatusName } from './statusTypes.js'; +import { httpStatusCode, type StatusName } from './statusTypes.js'; export { StatusName }; diff --git a/js/genkit/src/registry.ts b/js/genkit/src/registry.ts index 8c45c10d46..367d697b5d 100644 --- a/js/genkit/src/registry.ts +++ b/js/genkit/src/registry.ts @@ -15,8 +15,8 @@ */ export { - ActionType, - AsyncProvider, Registry, - Schema, + type ActionType, + type AsyncProvider, + type Schema, } from '@genkit-ai/core/registry'; diff --git a/js/genkit/tests/formats_test.ts b/js/genkit/tests/formats_test.ts index 774bbd47e4..b23ef49678 100644 --- a/js/genkit/tests/formats_test.ts +++ b/js/genkit/tests/formats_test.ts @@ -93,7 +93,7 @@ describe('formats', () => { }); it('lets you define and use a custom output format with simulated constrained generation', async () => { - defineEchoModel(ai, { supports: { constrained: false } }); + defineEchoModel(ai, { supports: { constrained: 'none' } }); const { output } = await ai.generate({ model: 'echoModel', diff --git a/js/genkit/tests/prompts_test.ts b/js/genkit/tests/prompts_test.ts index 14d547c814..3f4ea276e2 100644 --- a/js/genkit/tests/prompts_test.ts +++ b/js/genkit/tests/prompts_test.ts @@ -173,7 +173,7 @@ describe('definePrompt', () => { }); }); -describe.only('definePrompt', () => { +describe('definePrompt', () => { describe('default model', () => { let ai: GenkitBeta; @@ -310,7 +310,7 @@ describe.only('definePrompt', () => { }); }); - describe.only('default model ref', () => { + describe('default model ref', () => { let ai: GenkitBeta; beforeEach(() => { diff --git a/js/testapps/flow-simple-ai/src/index.ts b/js/testapps/flow-simple-ai/src/index.ts index 79da26e118..fea7b141f5 100644 --- a/js/testapps/flow-simple-ai/src/index.ts +++ b/js/testapps/flow-simple-ai/src/index.ts @@ -654,3 +654,85 @@ ai.defineFlow('blockingMiddleware', async () => { }); return text; }); + +ai.defineFlow('formatJson', async (input, { sendChunk }) => { + const { output, text } = await ai.generate({ + prompt: `generate an RPG game character of type ${input || 'archer'}`, + output: { + constrained: false, + instructions: true, + schema: z + .object({ + name: z.string(), + weapon: z.string(), + }) + .strict(), + }, + onChunk: (c) => sendChunk(c.output), + }); + return { output, text }; +}); + +ai.defineFlow('formatJsonManualSchema', async (input, { sendChunk }) => { + const { output, text } = await ai.generate({ + prompt: `generate one RPG game character of type ${input || 'archer'} and generated JSON must match this interface + + \`\`\`typescript + interface Character { + name: string; + weapon: string; + } + \`\`\` + `, + output: { + constrained: true, + instructions: false, + schema: z + .object({ + name: z.string(), + weapon: z.string(), + }) + .strict(), + }, + onChunk: (c) => sendChunk(c.output), + }); + return { output, text }; +}); + +ai.defineFlow('testArray', async (input, { sendChunk }) => { + const { output } = await ai.generate({ + prompt: `10 different weapons for ${input}`, + output: { + format: 'array', + schema: z.array(z.string()), + }, + onChunk: (c) => sendChunk(c.output), + }); + return output; +}); + +ai.defineFlow('formatEnum', async (input, { sendChunk }) => { + const { output } = await ai.generate({ + prompt: `classify the denger level of sky diving`, + output: { + format: 'enum', + schema: z.enum(['safe', 'dangerous', 'medium']), + }, + onChunk: (c) => sendChunk(c.output), + }); + return output; +}); + +ai.defineFlow('formatJsonl', async (input, { sendChunk }) => { + const { output } = await ai.generate({ + prompt: `generate 5 randon persons`, + output: { + format: 'jsonl', + schema: z.array( + z.object({ name: z.string(), surname: z.string() }).strict() + ), + }, + onChunk: (c) => sendChunk(c.output), + }); + return output; +}); diff --git a/js/testapps/format-tester/src/index.ts b/js/testapps/format-tester/src/index.ts index 84dcfb5905..59716c1629 100644 --- a/js/testapps/format-tester/src/index.ts +++ b/js/testapps/format-tester/src/index.ts @@ -154,6 +154,7 @@ if (!models.length) { 'vertexai/gemini-1.5-flash', 'googleai/gemini-1.5-pro', 'googleai/gemini-1.5-flash', + 'googleai/gemini-2.0-flash', ]; }