diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 9f61eea3c7..2a791bc65d 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -16,6 +16,7 @@ import { Action, + actionWithMiddleware, config as genkitConfig, GenkitError, runWithStreamingCallback, @@ -36,6 +37,7 @@ import { MessageData, ModelAction, ModelArgument, + ModelMiddleware, ModelReference, Part, Role, @@ -494,6 +496,8 @@ export interface GenerateOptions< returnToolRequests?: boolean; /** When provided, models supporting streaming will call the provided callback with chunks as generation progresses. */ streamingCallback?: StreamingCallback; + /** Middlewera to be used with this model call. */ + use?: ModelMiddleware[]; } const isValidCandidate = ( @@ -619,7 +623,19 @@ export async function generate< ? (chunk: GenerateResponseChunkData) => resolvedOptions.streamingCallback!(new GenerateResponseChunk(chunk)) : undefined, - async () => new GenerateResponse>(await model(request), request) + async () => { + var modelAction = model; + if (resolvedOptions.use) { + modelAction = actionWithMiddleware( + modelAction, + resolvedOptions.use + ) as ModelAction; + } + return new GenerateResponse>( + await modelAction(request), + request + ); + } ); // throw NoValidCandidates if all candidates are blocked or diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 45cf8be059..a5cf9a0240 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -269,6 +269,7 @@ export function defineModel< configSchema?: CustomOptionsSchema; /** Descriptive name for this model e.g. 'Google AI - Gemini Pro'. */ label?: string; + /** Middlewera to be used with this model. */ use?: ModelMiddleware[]; }, runner: ( diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index f8717f9f86..a949fdd10f 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -14,20 +14,25 @@ * limitations under the License. */ +import { __hardResetRegistryForTesting } from '@genkit-ai/core/registry'; import assert from 'node:assert'; -import { describe, it } from 'node:test'; +import { beforeEach, describe, it } from 'node:test'; import { z } from 'zod'; import { Candidate, GenerateOptions, GenerateResponse, Message, + generate, toGenerateRequest, } from '../../src/generate.js'; import { CandidateData, GenerateRequest, MessageData, + ModelAction, + ModelMiddleware, + defineModel, } from '../../src/model.js'; import { defineTool } from '../../src/tool.js'; @@ -506,3 +511,95 @@ describe('toGenerateRequest', () => { }); } }); + +describe('generate', () => { + beforeEach(__hardResetRegistryForTesting); + + var echoModel: ModelAction; + + beforeEach(() => { + echoModel = defineModel( + { + name: 'echoModel', + }, + async (request) => { + return { + candidates: [ + { + index: 0, + finishReason: 'stop', + message: { + role: 'model', + content: [ + { + text: + 'Echo: ' + + request.messages + .map((m) => m.content.map((c) => c.text).join()) + .join(), + }, + ], + }, + }, + ], + }; + } + ); + }); + + it('applies middleware', async () => { + const wrapRequest: ModelMiddleware = async (req, next) => { + return next({ + ...req, + messages: [ + { + role: 'user', + content: [ + { + text: + '(' + + req.messages + .map((m) => m.content.map((c) => c.text).join()) + .join() + + ')', + }, + ], + }, + ], + }); + }; + const wrapResponse: ModelMiddleware = async (req, next) => { + const res = await next(req); + return { + candidates: [ + { + index: 0, + finishReason: 'stop', + message: { + role: 'model', + content: [ + { + text: + '[' + + res.candidates[0].message.content + .map((c) => c.text) + .join() + + ']', + }, + ], + }, + }, + ], + }; + }; + + const response = await generate({ + prompt: 'banana', + model: echoModel, + use: [wrapRequest, wrapResponse], + }); + + const want = '[Echo: (banana)]'; + assert.deepStrictEqual(response.text(), want); + }); +}); diff --git a/js/plugins/dotprompt/src/prompt.ts b/js/plugins/dotprompt/src/prompt.ts index 24bbc80aa0..0524f44f90 100644 --- a/js/plugins/dotprompt/src/prompt.ts +++ b/js/plugins/dotprompt/src/prompt.ts @@ -188,6 +188,7 @@ export class Dotprompt implements PromptMetadata { tools: (options.tools || []).concat(this.tools || []), streamingCallback: options.streamingCallback, returnToolRequests: options.returnToolRequests, + use: options.use, }; } diff --git a/js/plugins/dotprompt/tests/prompt_test.ts b/js/plugins/dotprompt/tests/prompt_test.ts index ce0f14464d..ea75468488 100644 --- a/js/plugins/dotprompt/tests/prompt_test.ts +++ b/js/plugins/dotprompt/tests/prompt_test.ts @@ -95,14 +95,17 @@ describe('Prompt', () => { const prompt = testPrompt(`Hello {{name}}, how are you?`); const streamingCallback = (c) => console.log(c); + const middleware = []; const rendered = await prompt.render({ input: { name: 'Michael' }, streamingCallback, returnToolRequests: true, + use: middleware, }); assert.strictEqual(rendered.streamingCallback, streamingCallback); assert.strictEqual(rendered.returnToolRequests, true); + assert.strictEqual(rendered.use, middleware); }); });