diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 0b4af7b557..2758a485fb 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -18,6 +18,7 @@ import { Action, defineAction, getStreamingCallback, + Middleware, StreamingCallback, } from '@genkit-ai/core'; import { toJsonSchema } from '@genkit-ai/core/schema'; @@ -234,38 +235,10 @@ export type ModelAction< __configSchema: CustomOptionsSchema; }; -export interface ModelMiddleware { - ( - req: GenerateRequest, - next: (req?: GenerateRequest) => Promise - ): Promise; -} - -/** - * - */ -export function modelWithMiddleware( - model: ModelAction, - middleware: ModelMiddleware[] -): ModelAction { - const wrapped = (async (req: GenerateRequest) => { - const dispatch = async (index: number, req: GenerateRequest) => { - if (index === middleware.length) { - // end of the chain, call the original model action - return await model(req); - } - - const currentMiddleware = middleware[index]; - return currentMiddleware(req, async (modifiedReq) => - dispatch(index + 1, modifiedReq || req) - ); - }; - - return await dispatch(0, req); - }) as ModelAction; - wrapped.__action = model.__action; - return wrapped; -} +export type ModelMiddleware = Middleware< + z.infer, + z.infer +>; /** * Defines a new model and adds it to the registry. @@ -291,6 +264,11 @@ export function defineModel< ) => Promise ): ModelAction { const label = options.label || `${options.name} GenAI model`; + const middleware = [ + ...(options.use || []), + validateSupport(options), + conformOutput(), + ]; const act = defineAction( { actionType: 'model', @@ -308,6 +286,7 @@ export function defineModel< supports: options.supports, }, }, + use: middleware, }, (input) => { const startTimeMs = performance.now(); @@ -331,16 +310,7 @@ export function defineModel< Object.assign(act, { __configSchema: options.configSchema || z.unknown(), }); - const middleware = [ - ...(options.use || []), - validateSupport(options), - conformOutput(), - ]; - const ma = modelWithMiddleware( - act as ModelAction, - middleware - ) as ModelAction; - return ma; + return act as ModelAction; } export interface ModelReference { diff --git a/js/core/src/action.ts b/js/core/src/action.ts index 09fd1466cd..a78c8a724f 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -72,8 +72,40 @@ type ActionParams< outputSchema?: O; outputJsonSchema?: JSONSchema7; metadata?: M; + use?: Middleware, z.infer>[]; }; +export interface Middleware { + (req: I, next: (req?: I) => Promise): Promise; +} + +export function actionWithMiddleware< + I extends z.ZodTypeAny, + O extends z.ZodTypeAny, + M extends Record = Record, +>( + action: Action, + middleware: Middleware, z.infer>[] +): Action { + const wrapped = (async (req: z.infer) => { + const dispatch = async (index: number, req: z.infer) => { + if (index === middleware.length) { + // end of the chain, call the original model action + return await action(req); + } + + const currentMiddleware = middleware[index]; + return currentMiddleware(req, async (modifiedReq) => + dispatch(index + 1, modifiedReq || req) + ); + }; + + return await dispatch(0, req); + }) as Action; + wrapped.__action = action.__action; + return wrapped; +} + /** * Creates an action with the provided config. */ @@ -140,6 +172,10 @@ export function action< outputJsonSchema: config.outputJsonSchema, metadata: config.metadata, } as ActionMetadata; + + if (config.use) { + return actionWithMiddleware(actionFn, config.use); + } return actionFn; } diff --git a/js/core/tests/action_test.ts b/js/core/tests/action_test.ts new file mode 100644 index 0000000000..3d257d02a5 --- /dev/null +++ b/js/core/tests/action_test.ts @@ -0,0 +1,47 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import assert from 'node:assert'; +import { beforeEach, describe, it } from 'node:test'; +import { z } from 'zod'; +import { action } from '../src/action.js'; +import { __hardResetRegistryForTesting } from '../src/registry.js'; + +describe('action', () => { + beforeEach(__hardResetRegistryForTesting); + + it('applies middleware', async () => { + const act = action( + { + name: 'foo', + inputSchema: z.string(), + outputSchema: z.number(), + use: [ + async (input, next) => (await next(input + 'middle1')) + 1, + async (input, next) => (await next(input + 'middle2')) + 2, + ], + }, + async (input) => { + return input.length; + } + ); + + assert.strictEqual( + await act('foo'), + 20 // "foomiddle1middle2".length + 1 + 2 + ); + }); +});