From 6a833de6596c9df0a84d26dbf3176720b582883e Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Sat, 22 Jun 2024 11:34:12 -0400 Subject: [PATCH 1/2] feat: allow specifying middleware on the `generate` function --- js/ai/src/generate.ts | 12 ++- js/ai/src/model.ts | 1 + js/ai/tests/generate/generate_test.ts | 99 ++++++++++++++++++++++- js/plugins/dotprompt/src/prompt.ts | 1 + js/plugins/dotprompt/tests/prompt_test.ts | 3 + 5 files changed, 114 insertions(+), 2 deletions(-) diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 9f61eea3c7..342b55baca 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -16,6 +16,7 @@ import { Action, + actionWithMiddleware, config as genkitConfig, GenkitError, runWithStreamingCallback, @@ -36,6 +37,7 @@ import { MessageData, ModelAction, ModelArgument, + ModelMiddleware, ModelReference, Part, Role, @@ -494,6 +496,8 @@ export interface GenerateOptions< returnToolRequests?: boolean; /** When provided, models supporting streaming will call the provided callback with chunks as generation progresses. */ streamingCallback?: StreamingCallback; + /** Middlewera to be used with this model call. */ + use?: ModelMiddleware[]; } const isValidCandidate = ( @@ -619,7 +623,13 @@ export async function generate< ? (chunk: GenerateResponseChunkData) => resolvedOptions.streamingCallback!(new GenerateResponseChunk(chunk)) : undefined, - async () => new GenerateResponse>(await model(request), request) + async () => { + var modelAction = model; + if (resolvedOptions.use) { + modelAction = actionWithMiddleware(modelAction, resolvedOptions.use) as ModelAction; + } + return new GenerateResponse>(await modelAction(request), request) + } ); // throw NoValidCandidates if all candidates are blocked or diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 45cf8be059..a5cf9a0240 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -269,6 +269,7 @@ export function defineModel< configSchema?: CustomOptionsSchema; /** Descriptive name for this model e.g. 'Google AI - Gemini Pro'. */ label?: string; + /** Middlewera to be used with this model. */ use?: ModelMiddleware[]; }, runner: ( diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index f8717f9f86..1f4a8fb995 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -14,11 +14,13 @@ * limitations under the License. */ +import { __hardResetRegistryForTesting } from '@genkit-ai/core/registry'; import assert from 'node:assert'; -import { describe, it } from 'node:test'; +import { beforeEach, describe, it } from 'node:test'; import { z } from 'zod'; import { Candidate, + generate, GenerateOptions, GenerateResponse, Message, @@ -26,8 +28,11 @@ import { } from '../../src/generate.js'; import { CandidateData, + defineModel, GenerateRequest, MessageData, + ModelAction, + ModelMiddleware, } from '../../src/model.js'; import { defineTool } from '../../src/tool.js'; @@ -506,3 +511,95 @@ describe('toGenerateRequest', () => { }); } }); + +describe('generate', () => { + beforeEach(__hardResetRegistryForTesting); + + var echoModel: ModelAction; + + beforeEach(() => { + echoModel = defineModel( + { + name: 'echoModel', + }, + async (request) => { + return { + candidates: [ + { + index: 0, + finishReason: 'stop', + message: { + role: 'model', + content: [ + { + text: + 'Echo: ' + + request.messages + .map((m) => m.content.map((c) => c.text).join()) + .join(), + }, + ], + }, + }, + ], + }; + } + ); + }); + + it('applies middleware', async () => { + const wrapRequest: ModelMiddleware = async (req, next) => { + return next({ + ...req, + messages: [ + { + role: 'user', + content: [ + { + text: + '(' + + req.messages + .map((m) => m.content.map((c) => c.text).join()) + .join() + + ')', + }, + ], + }, + ], + }); + }; + const wrapResponse: ModelMiddleware = async (req, next) => { + const res = await next(req); + return { + candidates: [ + { + index: 0, + finishReason: 'stop', + message: { + role: 'model', + content: [ + { + text: + '[' + + res.candidates[0].message.content + .map((c) => c.text) + .join() + + ']', + }, + ], + }, + }, + ], + }; + }; + + const response = await generate({ + prompt: 'banana', + model: echoModel, + use: [wrapRequest, wrapResponse], + }); + + const want = '[Echo: (banana)]'; + assert.deepStrictEqual(response.text(), want); + }); +}); diff --git a/js/plugins/dotprompt/src/prompt.ts b/js/plugins/dotprompt/src/prompt.ts index 24bbc80aa0..0524f44f90 100644 --- a/js/plugins/dotprompt/src/prompt.ts +++ b/js/plugins/dotprompt/src/prompt.ts @@ -188,6 +188,7 @@ export class Dotprompt implements PromptMetadata { tools: (options.tools || []).concat(this.tools || []), streamingCallback: options.streamingCallback, returnToolRequests: options.returnToolRequests, + use: options.use, }; } diff --git a/js/plugins/dotprompt/tests/prompt_test.ts b/js/plugins/dotprompt/tests/prompt_test.ts index ce0f14464d..1533bfa14f 100644 --- a/js/plugins/dotprompt/tests/prompt_test.ts +++ b/js/plugins/dotprompt/tests/prompt_test.ts @@ -95,14 +95,17 @@ describe('Prompt', () => { const prompt = testPrompt(`Hello {{name}}, how are you?`); const streamingCallback = (c) => console.log(c); + const middleware = [] const rendered = await prompt.render({ input: { name: 'Michael' }, streamingCallback, returnToolRequests: true, + use: middleware, }); assert.strictEqual(rendered.streamingCallback, streamingCallback); assert.strictEqual(rendered.returnToolRequests, true); + assert.strictEqual(rendered.use, middleware); }); }); From 24ebe17028a1696e1965106ca3d9e3fc624b9a33 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Sat, 22 Jun 2024 12:34:19 -0400 Subject: [PATCH 2/2] format --- js/ai/src/generate.ts | 10 ++++++++-- js/ai/tests/generate/generate_test.ts | 4 ++-- js/plugins/dotprompt/tests/prompt_test.ts | 2 +- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 342b55baca..2a791bc65d 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -626,9 +626,15 @@ export async function generate< async () => { var modelAction = model; if (resolvedOptions.use) { - modelAction = actionWithMiddleware(modelAction, resolvedOptions.use) as ModelAction; + modelAction = actionWithMiddleware( + modelAction, + resolvedOptions.use + ) as ModelAction; } - return new GenerateResponse>(await modelAction(request), request) + return new GenerateResponse>( + await modelAction(request), + request + ); } ); diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index 1f4a8fb995..a949fdd10f 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -20,19 +20,19 @@ import { beforeEach, describe, it } from 'node:test'; import { z } from 'zod'; import { Candidate, - generate, GenerateOptions, GenerateResponse, Message, + generate, toGenerateRequest, } from '../../src/generate.js'; import { CandidateData, - defineModel, GenerateRequest, MessageData, ModelAction, ModelMiddleware, + defineModel, } from '../../src/model.js'; import { defineTool } from '../../src/tool.js'; diff --git a/js/plugins/dotprompt/tests/prompt_test.ts b/js/plugins/dotprompt/tests/prompt_test.ts index 1533bfa14f..ea75468488 100644 --- a/js/plugins/dotprompt/tests/prompt_test.ts +++ b/js/plugins/dotprompt/tests/prompt_test.ts @@ -95,7 +95,7 @@ describe('Prompt', () => { const prompt = testPrompt(`Hello {{name}}, how are you?`); const streamingCallback = (c) => console.log(c); - const middleware = [] + const middleware = []; const rendered = await prompt.render({ input: { name: 'Michael' },