@@ -18,6 +18,7 @@ import {
18
18
Action ,
19
19
defineAction ,
20
20
getStreamingCallback ,
21
+ Middleware ,
21
22
runWithStreamingCallback ,
22
23
} from '@genkit-ai/core' ;
23
24
import { lookupAction } from '@genkit-ai/core/registry' ;
@@ -26,6 +27,7 @@ import {
26
27
toJsonSchema ,
27
28
validateSchema ,
28
29
} from '@genkit-ai/core/schema' ;
30
+ import { runInNewSpan , SPAN_TYPE_ATTR } from '@genkit-ai/core/tracing' ;
29
31
import { z } from 'zod' ;
30
32
import { DocumentDataSchema } from './document.js' ;
31
33
import {
@@ -37,7 +39,9 @@ import {
37
39
import {
38
40
CandidateData ,
39
41
GenerateRequest ,
42
+ GenerateRequestSchema ,
40
43
GenerateResponseChunkData ,
44
+ GenerateResponseData ,
41
45
GenerateResponseSchema ,
42
46
MessageData ,
43
47
MessageSchema ,
@@ -85,141 +89,193 @@ export const generateAction = defineAction(
85
89
inputSchema : GenerateUtilParamSchema ,
86
90
outputSchema : GenerateResponseSchema ,
87
91
} ,
88
- async ( input ) => {
89
- const model = ( await lookupAction ( `/model/${ input . model } ` ) ) as ModelAction ;
90
- if ( ! model ) {
91
- throw new Error ( `Model ${ input . model } not found` ) ;
92
+ async ( input ) => generate ( input )
93
+ ) ;
94
+
95
+ /**
96
+ * Encapsulates all generate logic. This is similar to `generateAction` except not an action and can take middleware.
97
+ */
98
+ export async function generateHelper (
99
+ input : z . infer < typeof GenerateUtilParamSchema > ,
100
+ middleware ?: Middleware [ ]
101
+ ) : Promise < GenerateResponseData > {
102
+ // do tracing
103
+ return await runInNewSpan (
104
+ {
105
+ metadata : {
106
+ name : 'generate' ,
107
+ } ,
108
+ labels : {
109
+ [ SPAN_TYPE_ATTR ] : 'helper' ,
110
+ } ,
111
+ } ,
112
+ async ( metadata ) => {
113
+ metadata . name = 'generate' ;
114
+ metadata . input = input ;
115
+ const output = await generate ( input , middleware ) ;
116
+ metadata . output = JSON . stringify ( output ) ;
117
+ return output ;
92
118
}
119
+ ) ;
120
+ }
93
121
94
- let tools : ToolAction [ ] | undefined ;
95
- if ( input . tools ?. length ) {
96
- if ( ! model . __action . metadata ?. model . supports ?. tools ) {
97
- throw new Error (
98
- `Model ${ input . model } does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.`
99
- ) ;
100
- }
101
- tools = await Promise . all (
102
- input . tools . map ( async ( toolRef ) => {
103
- if ( typeof toolRef === 'string' ) {
104
- const tool = ( await lookupAction ( toolRef ) ) as ToolAction ;
105
- if ( ! tool ) {
106
- throw new Error ( `Tool ${ toolRef } not found` ) ;
107
- }
108
- return tool ;
109
- }
110
- throw '' ;
111
- } )
122
+ async function generate (
123
+ input : z . infer < typeof GenerateUtilParamSchema > ,
124
+ middleware ?: Middleware [ ]
125
+ ) : Promise < GenerateResponseData > {
126
+ const model = ( await lookupAction ( `/model/${ input . model } ` ) ) as ModelAction ;
127
+ if ( ! model ) {
128
+ throw new Error ( `Model ${ input . model } not found` ) ;
129
+ }
130
+
131
+ let tools : ToolAction [ ] | undefined ;
132
+ if ( input . tools ?. length ) {
133
+ if ( ! model . __action . metadata ?. model . supports ?. tools ) {
134
+ throw new Error (
135
+ `Model ${ input . model } does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.`
112
136
) ;
113
137
}
138
+ tools = await Promise . all (
139
+ input . tools . map ( async ( toolRef ) => {
140
+ if ( typeof toolRef === 'string' ) {
141
+ const tool = ( await lookupAction ( toolRef ) ) as ToolAction ;
142
+ if ( ! tool ) {
143
+ throw new Error ( `Tool ${ toolRef } not found` ) ;
144
+ }
145
+ return tool ;
146
+ }
147
+ throw '' ;
148
+ } )
149
+ ) ;
150
+ }
114
151
115
- const request = await actionToGenerateRequest ( input , tools ) ;
152
+ const request = await actionToGenerateRequest ( input , tools ) ;
116
153
117
- const accumulatedChunks : GenerateResponseChunkData [ ] = [ ] ;
154
+ const accumulatedChunks : GenerateResponseChunkData [ ] = [ ] ;
118
155
119
- const streamingCallback = getStreamingCallback ( ) ;
120
- const response = await runWithStreamingCallback (
121
- streamingCallback
122
- ? ( chunk : GenerateResponseChunkData ) => {
123
- // Store accumulated chunk data
124
- accumulatedChunks . push ( chunk ) ;
125
- if ( streamingCallback ) {
126
- streamingCallback ! (
127
- new GenerateResponseChunk ( chunk , accumulatedChunks )
128
- ) ;
129
- }
156
+ const streamingCallback = getStreamingCallback ( ) ;
157
+ const response = await runWithStreamingCallback (
158
+ streamingCallback
159
+ ? ( chunk : GenerateResponseChunkData ) => {
160
+ // Store accumulated chunk data
161
+ accumulatedChunks . push ( chunk ) ;
162
+ if ( streamingCallback ) {
163
+ streamingCallback ! (
164
+ new GenerateResponseChunk ( chunk , accumulatedChunks )
165
+ ) ;
130
166
}
131
- : undefined ,
132
- async ( ) => new GenerateResponse ( await model ( request ) )
133
- ) ;
167
+ }
168
+ : undefined ,
169
+ async ( ) => {
170
+ const dispatch = async (
171
+ index : number ,
172
+ req : z . infer < typeof GenerateRequestSchema >
173
+ ) => {
174
+ if ( ! middleware || index === middleware . length ) {
175
+ // end of the chain, call the original model action
176
+ return await model ( req ) ;
177
+ }
134
178
135
- // throw NoValidCandidates if all candidates are blocked or
136
- if (
137
- ! response . candidates . some ( ( c ) =>
138
- [ 'stop' , 'length' ] . includes ( c . finishReason )
139
- )
140
- ) {
141
- throw new NoValidCandidatesError ( {
142
- message : `All candidates returned finishReason issues: ${ JSON . stringify ( response . candidates . map ( ( c ) => c . finishReason ) ) } ` ,
143
- response,
144
- } ) ;
179
+ const currentMiddleware = middleware [ index ] ;
180
+ return currentMiddleware ( req , async ( modifiedReq ) =>
181
+ dispatch ( index + 1 , modifiedReq || req )
182
+ ) ;
183
+ } ;
184
+
185
+ return new GenerateResponse ( await dispatch ( 0 , request ) ) ;
145
186
}
187
+ ) ;
146
188
147
- if ( input . output ?. jsonSchema && ! response . toolRequests ( ) ?. length ) {
148
- // find a candidate with valid output schema
149
- const candidateErrors = response . candidates . map ( ( c ) => {
150
- // don't validate messages that have no text or data
151
- if ( c . text ( ) === '' && c . data ( ) === null ) return null ;
189
+ // throw NoValidCandidates if all candidates are blocked or
190
+ if (
191
+ ! response . candidates . some ( ( c ) =>
192
+ [ 'stop' , 'length' ] . includes ( c . finishReason )
193
+ )
194
+ ) {
195
+ throw new NoValidCandidatesError ( {
196
+ message : `All candidates returned finishReason issues: ${ JSON . stringify ( response . candidates . map ( ( c ) => c . finishReason ) ) } ` ,
197
+ response,
198
+ } ) ;
199
+ }
152
200
153
- try {
154
- parseSchema ( c . output ( ) , {
155
- jsonSchema : input . output ?. jsonSchema ,
156
- } ) ;
157
- return null ;
158
- } catch ( e ) {
159
- return e as Error ;
160
- }
161
- } ) ;
162
- // if all candidates have a non-null error...
163
- if ( candidateErrors . every ( ( c ) => ! ! c ) ) {
164
- throw new NoValidCandidatesError ( {
165
- message : `Generation resulted in no candidates matching provided output schema.${ candidateErrors . map ( ( e , i ) => `\n\nCandidate[${ i } ] ${ e ! . toString ( ) } ` ) } ` ,
166
- response,
167
- detail : {
168
- candidateErrors : candidateErrors ,
169
- } ,
201
+ if ( input . output ?. jsonSchema && ! response . toolRequests ( ) ?. length ) {
202
+ // find a candidate with valid output schema
203
+ const candidateErrors = response . candidates . map ( ( c ) => {
204
+ // don't validate messages that have no text or data
205
+ if ( c . text ( ) === '' && c . data ( ) === null ) return null ;
206
+
207
+ try {
208
+ parseSchema ( c . output ( ) , {
209
+ jsonSchema : input . output ?. jsonSchema ,
170
210
} ) ;
211
+ return null ;
212
+ } catch ( e ) {
213
+ return e as Error ;
171
214
}
215
+ } ) ;
216
+ // if all candidates have a non-null error...
217
+ if ( candidateErrors . every ( ( c ) => ! ! c ) ) {
218
+ throw new NoValidCandidatesError ( {
219
+ message : `Generation resulted in no candidates matching provided output schema.${ candidateErrors . map ( ( e , i ) => `\n\nCandidate[${ i } ] ${ e ! . toString ( ) } ` ) } ` ,
220
+ response,
221
+ detail : {
222
+ candidateErrors : candidateErrors ,
223
+ } ,
224
+ } ) ;
172
225
}
226
+ }
173
227
174
- // Pick the first valid candidate.
175
- let selected : Candidate < any > | undefined ;
176
- for ( const candidate of response . candidates ) {
177
- if ( isValidCandidate ( candidate , tools || [ ] ) ) {
178
- selected = candidate ;
179
- break ;
180
- }
228
+ // Pick the first valid candidate.
229
+ let selected : Candidate < any > | undefined ;
230
+ for ( const candidate of response . candidates ) {
231
+ if ( isValidCandidate ( candidate , tools || [ ] ) ) {
232
+ selected = candidate ;
233
+ break ;
181
234
}
235
+ }
182
236
183
- if ( ! selected ) {
184
- throw new Error ( 'No valid candidates found' ) ;
185
- }
237
+ if ( ! selected ) {
238
+ throw new NoValidCandidatesError ( {
239
+ message : 'No valid candidates found' ,
240
+ response,
241
+ } ) ;
242
+ }
186
243
187
- const toolCalls = selected . message . content . filter (
188
- ( part ) => ! ! part . toolRequest
189
- ) ;
190
- if ( input . returnToolRequests || toolCalls . length === 0 ) {
191
- return response . toJSON ( ) ;
192
- }
193
- const toolResponses : ToolResponsePart [ ] = await Promise . all (
194
- toolCalls . map ( async ( part ) => {
195
- if ( ! part . toolRequest ) {
196
- throw Error (
197
- 'Tool request expected but not provided in tool request part'
198
- ) ;
199
- }
200
- const tool = tools ?. find (
201
- ( tool ) => tool . __action . name === part . toolRequest ?. name
202
- ) ;
203
- if ( ! tool ) {
204
- throw Error ( 'Tool not found' ) ;
205
- }
206
- return {
207
- toolResponse : {
208
- name : part . toolRequest . name ,
209
- ref : part . toolRequest . ref ,
210
- output : await tool ( part . toolRequest ?. input ) ,
211
- } ,
212
- } ;
213
- } )
214
- ) ;
215
- const nextRequest = {
216
- ...input ,
217
- history : [ ...request . messages , selected . message ] ,
218
- prompt : toolResponses ,
219
- } ;
220
- return await generateAction ( nextRequest ) ;
244
+ const toolCalls = selected . message . content . filter (
245
+ ( part ) => ! ! part . toolRequest
246
+ ) ;
247
+ if ( input . returnToolRequests || toolCalls . length === 0 ) {
248
+ return response . toJSON ( ) ;
221
249
}
222
- ) ;
250
+ const toolResponses : ToolResponsePart [ ] = await Promise . all (
251
+ toolCalls . map ( async ( part ) => {
252
+ if ( ! part . toolRequest ) {
253
+ throw Error (
254
+ 'Tool request expected but not provided in tool request part'
255
+ ) ;
256
+ }
257
+ const tool = tools ?. find (
258
+ ( tool ) => tool . __action . name === part . toolRequest ?. name
259
+ ) ;
260
+ if ( ! tool ) {
261
+ throw Error ( 'Tool not found' ) ;
262
+ }
263
+ return {
264
+ toolResponse : {
265
+ name : part . toolRequest . name ,
266
+ ref : part . toolRequest . ref ,
267
+ output : await tool ( part . toolRequest ?. input ) ,
268
+ } ,
269
+ } ;
270
+ } )
271
+ ) ;
272
+ const nextRequest = {
273
+ ...input ,
274
+ history : [ ...request . messages , selected . message ] ,
275
+ prompt : toolResponses ,
276
+ } ;
277
+ return await generateHelper ( nextRequest , middleware ) ;
278
+ }
223
279
224
280
async function actionToGenerateRequest (
225
281
options : z . infer < typeof GenerateUtilParamSchema > ,
0 commit comments