Skip to content

Commit 280c60d

Browse files
authored
feat: allow specifying middleware on the generate function #450 (#858)
1 parent c35603a commit 280c60d

File tree

7 files changed

+289
-130
lines changed

7 files changed

+289
-130
lines changed

js/ai/src/generate.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import { z } from 'zod';
2727
import { DocumentData } from './document.js';
2828
import { extractJson } from './extract.js';
2929
import {
30-
generateAction,
30+
generateHelper,
3131
GenerateUtilParamSchema,
3232
inferRoleFromParts,
3333
} from './generateAction.js';
@@ -41,6 +41,7 @@ import {
4141
MessageData,
4242
ModelAction,
4343
ModelArgument,
44+
ModelMiddleware,
4445
ModelReference,
4546
Part,
4647
ToolDefinition,
@@ -490,6 +491,8 @@ export interface GenerateOptions<
490491
returnToolRequests?: boolean;
491492
/** When provided, models supporting streaming will call the provided callback with chunks as generation progresses. */
492493
streamingCallback?: StreamingCallback<GenerateResponseChunk>;
494+
/** Middleware to be used with this model call. */
495+
use?: ModelMiddleware[];
493496
}
494497

495498
async function resolveModel(options: GenerateOptions): Promise<ModelAction> {
@@ -612,7 +615,7 @@ export async function generate<
612615
resolvedOptions.streamingCallback,
613616
async () =>
614617
new GenerateResponse<O>(
615-
await generateAction(params),
618+
await generateHelper(params, resolvedOptions.use),
616619
await toGenerateRequest(resolvedOptions)
617620
)
618621
);

js/ai/src/generateAction.ts

Lines changed: 171 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import {
1818
Action,
1919
defineAction,
2020
getStreamingCallback,
21+
Middleware,
2122
runWithStreamingCallback,
2223
} from '@genkit-ai/core';
2324
import { lookupAction } from '@genkit-ai/core/registry';
@@ -26,6 +27,7 @@ import {
2627
toJsonSchema,
2728
validateSchema,
2829
} from '@genkit-ai/core/schema';
30+
import { runInNewSpan, SPAN_TYPE_ATTR } from '@genkit-ai/core/tracing';
2931
import { z } from 'zod';
3032
import { DocumentDataSchema } from './document.js';
3133
import {
@@ -37,7 +39,9 @@ import {
3739
import {
3840
CandidateData,
3941
GenerateRequest,
42+
GenerateRequestSchema,
4043
GenerateResponseChunkData,
44+
GenerateResponseData,
4145
GenerateResponseSchema,
4246
MessageData,
4347
MessageSchema,
@@ -85,141 +89,193 @@ export const generateAction = defineAction(
8589
inputSchema: GenerateUtilParamSchema,
8690
outputSchema: GenerateResponseSchema,
8791
},
88-
async (input) => {
89-
const model = (await lookupAction(`/model/${input.model}`)) as ModelAction;
90-
if (!model) {
91-
throw new Error(`Model ${input.model} not found`);
92+
async (input) => generate(input)
93+
);
94+
95+
/**
96+
* Encapsulates all generate logic. This is similar to `generateAction` except not an action and can take middleware.
97+
*/
98+
export async function generateHelper(
99+
input: z.infer<typeof GenerateUtilParamSchema>,
100+
middleware?: Middleware[]
101+
): Promise<GenerateResponseData> {
102+
// do tracing
103+
return await runInNewSpan(
104+
{
105+
metadata: {
106+
name: 'generate',
107+
},
108+
labels: {
109+
[SPAN_TYPE_ATTR]: 'helper',
110+
},
111+
},
112+
async (metadata) => {
113+
metadata.name = 'generate';
114+
metadata.input = input;
115+
const output = await generate(input, middleware);
116+
metadata.output = JSON.stringify(output);
117+
return output;
92118
}
119+
);
120+
}
93121

94-
let tools: ToolAction[] | undefined;
95-
if (input.tools?.length) {
96-
if (!model.__action.metadata?.model.supports?.tools) {
97-
throw new Error(
98-
`Model ${input.model} does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.`
99-
);
100-
}
101-
tools = await Promise.all(
102-
input.tools.map(async (toolRef) => {
103-
if (typeof toolRef === 'string') {
104-
const tool = (await lookupAction(toolRef)) as ToolAction;
105-
if (!tool) {
106-
throw new Error(`Tool ${toolRef} not found`);
107-
}
108-
return tool;
109-
}
110-
throw '';
111-
})
122+
async function generate(
123+
input: z.infer<typeof GenerateUtilParamSchema>,
124+
middleware?: Middleware[]
125+
): Promise<GenerateResponseData> {
126+
const model = (await lookupAction(`/model/${input.model}`)) as ModelAction;
127+
if (!model) {
128+
throw new Error(`Model ${input.model} not found`);
129+
}
130+
131+
let tools: ToolAction[] | undefined;
132+
if (input.tools?.length) {
133+
if (!model.__action.metadata?.model.supports?.tools) {
134+
throw new Error(
135+
`Model ${input.model} does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.`
112136
);
113137
}
138+
tools = await Promise.all(
139+
input.tools.map(async (toolRef) => {
140+
if (typeof toolRef === 'string') {
141+
const tool = (await lookupAction(toolRef)) as ToolAction;
142+
if (!tool) {
143+
throw new Error(`Tool ${toolRef} not found`);
144+
}
145+
return tool;
146+
}
147+
throw '';
148+
})
149+
);
150+
}
114151

115-
const request = await actionToGenerateRequest(input, tools);
152+
const request = await actionToGenerateRequest(input, tools);
116153

117-
const accumulatedChunks: GenerateResponseChunkData[] = [];
154+
const accumulatedChunks: GenerateResponseChunkData[] = [];
118155

119-
const streamingCallback = getStreamingCallback();
120-
const response = await runWithStreamingCallback(
121-
streamingCallback
122-
? (chunk: GenerateResponseChunkData) => {
123-
// Store accumulated chunk data
124-
accumulatedChunks.push(chunk);
125-
if (streamingCallback) {
126-
streamingCallback!(
127-
new GenerateResponseChunk(chunk, accumulatedChunks)
128-
);
129-
}
156+
const streamingCallback = getStreamingCallback();
157+
const response = await runWithStreamingCallback(
158+
streamingCallback
159+
? (chunk: GenerateResponseChunkData) => {
160+
// Store accumulated chunk data
161+
accumulatedChunks.push(chunk);
162+
if (streamingCallback) {
163+
streamingCallback!(
164+
new GenerateResponseChunk(chunk, accumulatedChunks)
165+
);
130166
}
131-
: undefined,
132-
async () => new GenerateResponse(await model(request))
133-
);
167+
}
168+
: undefined,
169+
async () => {
170+
const dispatch = async (
171+
index: number,
172+
req: z.infer<typeof GenerateRequestSchema>
173+
) => {
174+
if (!middleware || index === middleware.length) {
175+
// end of the chain, call the original model action
176+
return await model(req);
177+
}
134178

135-
// throw NoValidCandidates if all candidates are blocked or
136-
if (
137-
!response.candidates.some((c) =>
138-
['stop', 'length'].includes(c.finishReason)
139-
)
140-
) {
141-
throw new NoValidCandidatesError({
142-
message: `All candidates returned finishReason issues: ${JSON.stringify(response.candidates.map((c) => c.finishReason))}`,
143-
response,
144-
});
179+
const currentMiddleware = middleware[index];
180+
return currentMiddleware(req, async (modifiedReq) =>
181+
dispatch(index + 1, modifiedReq || req)
182+
);
183+
};
184+
185+
return new GenerateResponse(await dispatch(0, request));
145186
}
187+
);
146188

147-
if (input.output?.jsonSchema && !response.toolRequests()?.length) {
148-
// find a candidate with valid output schema
149-
const candidateErrors = response.candidates.map((c) => {
150-
// don't validate messages that have no text or data
151-
if (c.text() === '' && c.data() === null) return null;
189+
// throw NoValidCandidates if all candidates are blocked or
190+
if (
191+
!response.candidates.some((c) =>
192+
['stop', 'length'].includes(c.finishReason)
193+
)
194+
) {
195+
throw new NoValidCandidatesError({
196+
message: `All candidates returned finishReason issues: ${JSON.stringify(response.candidates.map((c) => c.finishReason))}`,
197+
response,
198+
});
199+
}
152200

153-
try {
154-
parseSchema(c.output(), {
155-
jsonSchema: input.output?.jsonSchema,
156-
});
157-
return null;
158-
} catch (e) {
159-
return e as Error;
160-
}
161-
});
162-
// if all candidates have a non-null error...
163-
if (candidateErrors.every((c) => !!c)) {
164-
throw new NoValidCandidatesError({
165-
message: `Generation resulted in no candidates matching provided output schema.${candidateErrors.map((e, i) => `\n\nCandidate[${i}] ${e!.toString()}`)}`,
166-
response,
167-
detail: {
168-
candidateErrors: candidateErrors,
169-
},
201+
if (input.output?.jsonSchema && !response.toolRequests()?.length) {
202+
// find a candidate with valid output schema
203+
const candidateErrors = response.candidates.map((c) => {
204+
// don't validate messages that have no text or data
205+
if (c.text() === '' && c.data() === null) return null;
206+
207+
try {
208+
parseSchema(c.output(), {
209+
jsonSchema: input.output?.jsonSchema,
170210
});
211+
return null;
212+
} catch (e) {
213+
return e as Error;
171214
}
215+
});
216+
// if all candidates have a non-null error...
217+
if (candidateErrors.every((c) => !!c)) {
218+
throw new NoValidCandidatesError({
219+
message: `Generation resulted in no candidates matching provided output schema.${candidateErrors.map((e, i) => `\n\nCandidate[${i}] ${e!.toString()}`)}`,
220+
response,
221+
detail: {
222+
candidateErrors: candidateErrors,
223+
},
224+
});
172225
}
226+
}
173227

174-
// Pick the first valid candidate.
175-
let selected: Candidate<any> | undefined;
176-
for (const candidate of response.candidates) {
177-
if (isValidCandidate(candidate, tools || [])) {
178-
selected = candidate;
179-
break;
180-
}
228+
// Pick the first valid candidate.
229+
let selected: Candidate<any> | undefined;
230+
for (const candidate of response.candidates) {
231+
if (isValidCandidate(candidate, tools || [])) {
232+
selected = candidate;
233+
break;
181234
}
235+
}
182236

183-
if (!selected) {
184-
throw new Error('No valid candidates found');
185-
}
237+
if (!selected) {
238+
throw new NoValidCandidatesError({
239+
message: 'No valid candidates found',
240+
response,
241+
});
242+
}
186243

187-
const toolCalls = selected.message.content.filter(
188-
(part) => !!part.toolRequest
189-
);
190-
if (input.returnToolRequests || toolCalls.length === 0) {
191-
return response.toJSON();
192-
}
193-
const toolResponses: ToolResponsePart[] = await Promise.all(
194-
toolCalls.map(async (part) => {
195-
if (!part.toolRequest) {
196-
throw Error(
197-
'Tool request expected but not provided in tool request part'
198-
);
199-
}
200-
const tool = tools?.find(
201-
(tool) => tool.__action.name === part.toolRequest?.name
202-
);
203-
if (!tool) {
204-
throw Error('Tool not found');
205-
}
206-
return {
207-
toolResponse: {
208-
name: part.toolRequest.name,
209-
ref: part.toolRequest.ref,
210-
output: await tool(part.toolRequest?.input),
211-
},
212-
};
213-
})
214-
);
215-
const nextRequest = {
216-
...input,
217-
history: [...request.messages, selected.message],
218-
prompt: toolResponses,
219-
};
220-
return await generateAction(nextRequest);
244+
const toolCalls = selected.message.content.filter(
245+
(part) => !!part.toolRequest
246+
);
247+
if (input.returnToolRequests || toolCalls.length === 0) {
248+
return response.toJSON();
221249
}
222-
);
250+
const toolResponses: ToolResponsePart[] = await Promise.all(
251+
toolCalls.map(async (part) => {
252+
if (!part.toolRequest) {
253+
throw Error(
254+
'Tool request expected but not provided in tool request part'
255+
);
256+
}
257+
const tool = tools?.find(
258+
(tool) => tool.__action.name === part.toolRequest?.name
259+
);
260+
if (!tool) {
261+
throw Error('Tool not found');
262+
}
263+
return {
264+
toolResponse: {
265+
name: part.toolRequest.name,
266+
ref: part.toolRequest.ref,
267+
output: await tool(part.toolRequest?.input),
268+
},
269+
};
270+
})
271+
);
272+
const nextRequest = {
273+
...input,
274+
history: [...request.messages, selected.message],
275+
prompt: toolResponses,
276+
};
277+
return await generateHelper(nextRequest, middleware);
278+
}
223279

224280
async function actionToGenerateRequest(
225281
options: z.infer<typeof GenerateUtilParamSchema>,

js/ai/src/model.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ export function defineModel<
289289
configSchema?: CustomOptionsSchema;
290290
/** Descriptive name for this model e.g. 'Google AI - Gemini Pro'. */
291291
label?: string;
292+
/** Middleware to be used with this model. */
292293
use?: ModelMiddleware[];
293294
},
294295
runner: (

0 commit comments

Comments
 (0)