From 6d88e6a44a9fa8550e41f5437d4705f124d4cb5f Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Tue, 7 May 2024 10:37:45 -0400 Subject: [PATCH 1/2] Moved middleware up to actions --- js/ai/src/model.ts | 51 +++++++----------------------------- js/core/src/action.ts | 35 ++++++++++++++++++++++++- js/core/tests/action_test.ts | 44 +++++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+), 43 deletions(-) create mode 100644 js/core/tests/action_test.ts diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 0b4af7b557..5dd2636bfc 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -18,6 +18,7 @@ import { Action, defineAction, getStreamingCallback, + Middleware, StreamingCallback, } from '@genkit-ai/core'; import { toJsonSchema } from '@genkit-ai/core/schema'; @@ -234,38 +235,7 @@ export type ModelAction< __configSchema: CustomOptionsSchema; }; -export interface ModelMiddleware { - ( - req: GenerateRequest, - next: (req?: GenerateRequest) => Promise - ): Promise; -} - -/** - * - */ -export function modelWithMiddleware( - model: ModelAction, - middleware: ModelMiddleware[] -): ModelAction { - const wrapped = (async (req: GenerateRequest) => { - const dispatch = async (index: number, req: GenerateRequest) => { - if (index === middleware.length) { - // end of the chain, call the original model action - return await model(req); - } - - const currentMiddleware = middleware[index]; - return currentMiddleware(req, async (modifiedReq) => - dispatch(index + 1, modifiedReq || req) - ); - }; - - return await dispatch(0, req); - }) as ModelAction; - wrapped.__action = model.__action; - return wrapped; -} +export type ModelMiddleware = Middleware, z.infer>; /** * Defines a new model and adds it to the registry. @@ -291,6 +261,11 @@ export function defineModel< ) => Promise ): ModelAction { const label = options.label || `${options.name} GenAI model`; + const middleware = [ + ...(options.use || []), + validateSupport(options), + conformOutput(), + ]; const act = defineAction( { actionType: 'model', @@ -308,6 +283,7 @@ export function defineModel< supports: options.supports, }, }, + use: middleware, }, (input) => { const startTimeMs = performance.now(); @@ -331,16 +307,7 @@ export function defineModel< Object.assign(act, { __configSchema: options.configSchema || z.unknown(), }); - const middleware = [ - ...(options.use || []), - validateSupport(options), - conformOutput(), - ]; - const ma = modelWithMiddleware( - act as ModelAction, - middleware - ) as ModelAction; - return ma; + return act as ModelAction; } export interface ModelReference { diff --git a/js/core/src/action.ts b/js/core/src/action.ts index 09fd1466cd..c80f4bc2da 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -22,9 +22,9 @@ import { ActionType, lookupPlugin, registerAction } from './registry.js'; import { parseSchema } from './schema.js'; import * as telemetry from './telemetry.js'; import { - SPAN_TYPE_ATTR, runInNewSpan, setCustomMetadataAttributes, + SPAN_TYPE_ATTR, } from './tracing.js'; export { Status, StatusCodes, StatusSchema } from './statusTypes.js'; @@ -72,8 +72,37 @@ type ActionParams< outputSchema?: O; outputJsonSchema?: JSONSchema7; metadata?: M; + use?: Middleware, z.infer>[]; }; +export interface Middleware { + (req: I, next: (req?: I) => Promise): Promise; +} + +export function actionWithMiddleware< + I extends z.ZodTypeAny, + O extends z.ZodTypeAny, + M extends Record = Record, +>(action: Action, middleware: Middleware, z.infer>[]): Action { + const wrapped = (async (req: z.infer) => { + const dispatch = async (index: number, req: z.infer) => { + if (index === middleware.length) { + // end of the chain, call the original model action + return await action(req); + } + + const currentMiddleware = middleware[index]; + return currentMiddleware(req, async (modifiedReq) => + dispatch(index + 1, modifiedReq || req) + ); + }; + + return await dispatch(0, req); + }) as Action; + wrapped.__action = action.__action; + return wrapped; +} + /** * Creates an action with the provided config. */ @@ -140,6 +169,10 @@ export function action< outputJsonSchema: config.outputJsonSchema, metadata: config.metadata, } as ActionMetadata; + + if (config.use) { + return actionWithMiddleware(actionFn, config.use); + } return actionFn; } diff --git a/js/core/tests/action_test.ts b/js/core/tests/action_test.ts new file mode 100644 index 0000000000..7b354b257a --- /dev/null +++ b/js/core/tests/action_test.ts @@ -0,0 +1,44 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import assert from 'node:assert'; +import { beforeEach, describe, it } from 'node:test'; +import { __hardResetRegistryForTesting } from '../src/registry.js'; +import { action } from '../src/action.js'; +import { z } from 'zod'; + +describe('action', () => { + beforeEach(__hardResetRegistryForTesting); + + it('applies middleware', async () => { + const act = action({ + name: 'foo', + inputSchema: z.string(), + outputSchema: z.number(), + use: [ + async (input, next) => await next(input + 'middle1') + 1, + async (input, next) => await next(input + 'middle2') + 2, + ] + }, async (input) => { + return input.length + }); + + assert.strictEqual( + await act("foo"), + 20 // "foomiddle1middle2".length + 1 + 2 + ); + }); +}); From 996198cd0ccb6299873fb1adac228f991cbc9e4d Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Tue, 7 May 2024 10:41:27 -0400 Subject: [PATCH 2/2] format --- js/ai/src/model.ts | 5 ++++- js/core/src/action.ts | 7 +++++-- js/core/tests/action_test.ts | 31 +++++++++++++++++-------------- 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 5dd2636bfc..2758a485fb 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -235,7 +235,10 @@ export type ModelAction< __configSchema: CustomOptionsSchema; }; -export type ModelMiddleware = Middleware, z.infer>; +export type ModelMiddleware = Middleware< + z.infer, + z.infer +>; /** * Defines a new model and adds it to the registry. diff --git a/js/core/src/action.ts b/js/core/src/action.ts index c80f4bc2da..a78c8a724f 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -22,9 +22,9 @@ import { ActionType, lookupPlugin, registerAction } from './registry.js'; import { parseSchema } from './schema.js'; import * as telemetry from './telemetry.js'; import { + SPAN_TYPE_ATTR, runInNewSpan, setCustomMetadataAttributes, - SPAN_TYPE_ATTR, } from './tracing.js'; export { Status, StatusCodes, StatusSchema } from './statusTypes.js'; @@ -83,7 +83,10 @@ export function actionWithMiddleware< I extends z.ZodTypeAny, O extends z.ZodTypeAny, M extends Record = Record, ->(action: Action, middleware: Middleware, z.infer>[]): Action { +>( + action: Action, + middleware: Middleware, z.infer>[] +): Action { const wrapped = (async (req: z.infer) => { const dispatch = async (index: number, req: z.infer) => { if (index === middleware.length) { diff --git a/js/core/tests/action_test.ts b/js/core/tests/action_test.ts index 7b354b257a..3d257d02a5 100644 --- a/js/core/tests/action_test.ts +++ b/js/core/tests/action_test.ts @@ -16,28 +16,31 @@ import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; -import { __hardResetRegistryForTesting } from '../src/registry.js'; -import { action } from '../src/action.js'; import { z } from 'zod'; +import { action } from '../src/action.js'; +import { __hardResetRegistryForTesting } from '../src/registry.js'; describe('action', () => { beforeEach(__hardResetRegistryForTesting); it('applies middleware', async () => { - const act = action({ - name: 'foo', - inputSchema: z.string(), - outputSchema: z.number(), - use: [ - async (input, next) => await next(input + 'middle1') + 1, - async (input, next) => await next(input + 'middle2') + 2, - ] - }, async (input) => { - return input.length - }); + const act = action( + { + name: 'foo', + inputSchema: z.string(), + outputSchema: z.number(), + use: [ + async (input, next) => (await next(input + 'middle1')) + 1, + async (input, next) => (await next(input + 'middle2')) + 2, + ], + }, + async (input) => { + return input.length; + } + ); assert.strictEqual( - await act("foo"), + await act('foo'), 20 // "foomiddle1middle2".length + 1 + 2 ); });