Skip to content

Commit 75a16af

Browse files
authored
feat(js/plugins/vertexai): strict json mode (#890)
1 parent 09014eb commit 75a16af

File tree

5 files changed

+105
-51
lines changed

5 files changed

+105
-51
lines changed

js/plugins/vertexai/package.json

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@
3838
"@anthropic-ai/sdk": "^0.24.3",
3939
"@anthropic-ai/vertex-sdk": "^0.4.0",
4040
"@google-cloud/aiplatform": "^3.23.0",
41-
"@google-cloud/vertexai": "^1.1.0",
42-
"google-auth-library": "^9.6.3",
41+
"@google-cloud/vertexai": "^1.7.0",
42+
"google-auth-library": "^9.14.1",
4343
"googleapis": "^140.0.1",
4444
"node-fetch": "^3.3.2",
4545
"openai": "^4.52.7",
@@ -51,8 +51,8 @@
5151
"@genkit-ai/flow": "workspace:*"
5252
},
5353
"optionalDependencies": {
54-
"firebase-admin": "^12.1.0",
55-
"@google-cloud/bigquery": "^7.8.0"
54+
"@google-cloud/bigquery": "^7.8.0",
55+
"firebase-admin": "^12.1.0"
5656
},
5757
"devDependencies": {
5858
"@types/node": "^20.11.16",

js/plugins/vertexai/src/gemini.ts

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,10 @@ function fromGeminiFunctionResponsePart(part: GeminiPart): Part {
399399
}
400400

401401
// Converts vertex part to genkit part
402-
function fromGeminiPart(part: GeminiPart): Part {
402+
function fromGeminiPart(part: GeminiPart, jsonMode: boolean): Part {
403+
if (jsonMode && part.text !== undefined) {
404+
return { data: JSON.parse(part.text) };
405+
}
403406
if (part.text !== undefined) return { text: part.text };
404407
if (part.functionCall) return fromGeminiFunctionCallPart(part);
405408
if (part.functionResponse) return fromGeminiFunctionResponsePart(part);
@@ -411,14 +414,15 @@ function fromGeminiPart(part: GeminiPart): Part {
411414
}
412415

413416
export function fromGeminiCandidate(
414-
candidate: GenerateContentCandidate
417+
candidate: GenerateContentCandidate,
418+
jsonMode: boolean
415419
): CandidateData {
416420
const parts = candidate.content.parts || [];
417421
const genkitCandidate: CandidateData = {
418422
index: candidate.index || 0, // reasonable default?
419423
message: {
420424
role: 'model',
421-
content: parts.map(fromGeminiPart),
425+
content: parts.map((p) => fromGeminiPart(p, jsonMode)),
422426
},
423427
finishReason: fromGeminiFinishReason(candidate.finishReason),
424428
finishMessage: candidate.finishMessage,
@@ -518,11 +522,18 @@ export function geminiModel(
518522
}
519523
}
520524

525+
const tools = request.tools?.length
526+
? [{ functionDeclarations: request.tools?.map(toGeminiTool) }]
527+
: [];
528+
529+
// Cannot use tools and function calling at the same time
530+
const jsonMode =
531+
(request.output?.format === 'json' || !!request.output?.schema) &&
532+
tools.length === 0;
533+
521534
const chatRequest: StartChatParams = {
522535
systemInstruction,
523-
tools: request.tools?.length
524-
? [{ functionDeclarations: request.tools?.map(toGeminiTool) }]
525-
: [],
536+
tools,
526537
history: messages
527538
.slice(0, -1)
528539
.map((message) => toGeminiMessage(message, model)),
@@ -532,6 +543,7 @@ export function geminiModel(
532543
maxOutputTokens: request.config?.maxOutputTokens,
533544
topK: request.config?.topK,
534545
topP: request.config?.topP,
546+
responseMimeType: jsonMode ? 'application/json' : undefined,
535547
stopSequences: request.config?.stopSequences,
536548
},
537549
safetySettings: request.config?.safetySettings,
@@ -566,7 +578,7 @@ export function geminiModel(
566578
.sendMessageStream(msg.parts);
567579
for await (const item of result.stream) {
568580
(item as GenerateContentResponse).candidates?.forEach((candidate) => {
569-
const c = fromGeminiCandidate(candidate);
581+
const c = fromGeminiCandidate(candidate, jsonMode);
570582
streamingCallback({
571583
index: c.index,
572584
content: c.message.content,
@@ -578,7 +590,9 @@ export function geminiModel(
578590
throw new Error('No valid candidates returned.');
579591
}
580592
return {
581-
candidates: response.candidates?.map(fromGeminiCandidate) || [],
593+
candidates:
594+
response.candidates?.map((c) => fromGeminiCandidate(c, jsonMode)) ||
595+
[],
582596
custom: response,
583597
};
584598
} else {
@@ -592,7 +606,9 @@ export function geminiModel(
592606
throw new Error('No valid candidates returned.');
593607
}
594608
const responseCandidates =
595-
result.response.candidates?.map(fromGeminiCandidate) || [];
609+
result.response.candidates?.map((c) =>
610+
fromGeminiCandidate(c, jsonMode)
611+
) || [];
596612
return {
597613
candidates: responseCandidates,
598614
custom: result.response,

js/pnpm-lock.yaml

Lines changed: 34 additions & 15 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

js/testapps/basic-gemini/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"@genkit-ai/dotprompt": "workspace:*",
1919
"@genkit-ai/flow": "workspace:*",
2020
"@genkit-ai/googleai": "workspace:*",
21+
"@genkit-ai/vertexai": "workspace:*",
2122
"express": "^4.20.0",
2223
"zod": "^3.22.4"
2324
},

js/testapps/basic-gemini/src/index.ts

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,24 @@ import * as z from 'zod';
2020
import { defineTool, generate } from '@genkit-ai/ai';
2121
import { configureGenkit } from '@genkit-ai/core';
2222
import { defineFlow, startFlowsServer } from '@genkit-ai/flow';
23-
import { googleAI } from '@genkit-ai/googleai';
23+
import {
24+
gemini15Flash as gemini15FlashGoogleAi,
25+
googleAI,
26+
} from '@genkit-ai/googleai';
2427

2528
// Import models from the Google AI plugin. The Google AI API provides access to
26-
// several generative models. Here, we import Gemini 1.5 Flash.
27-
import { gemini15Flash } from '@genkit-ai/googleai';
29+
30+
import {
31+
gemini15Flash as gemini15FlashVertexAi,
32+
vertexAI,
33+
} from '@genkit-ai/vertexai';
34+
35+
const provider = process.env.PROVIDER || 'vertexai';
36+
37+
const plugin = provider === 'vertexai' ? vertexAI : googleAI;
2838

2939
configureGenkit({
30-
plugins: [
31-
// Load the Google AI plugin. You can optionally specify your API key
32-
// by passing in a config object; if you don't, the Google AI plugin uses
33-
// the value from the GOOGLE_GENAI_API_KEY environment variable, which is
34-
// the recommended practice.
35-
googleAI(),
36-
],
40+
plugins: [plugin()],
3741
// Log debug output to tbe console.
3842
logLevel: 'debug',
3943
// Perform OpenTelemetry instrumentation and enable trace collection.
@@ -62,23 +66,37 @@ export const jokeFlow = defineFlow(
6266
},
6367
async () => {
6468
// Construct a request and send it to the model API.
69+
if (provider === 'vertexai') {
70+
const llmResponse = await generate({
71+
model: gemini15FlashVertexAi,
72+
config: {
73+
temperature: 2,
74+
},
75+
output: {
76+
schema: z.object({ jokeSubject: z.string() }),
77+
},
78+
tools: [jokeSubjectGenerator],
79+
prompt: `come up with a subject to joke about (using the function provided)`,
80+
});
6581

66-
const llmResponse = await generate({
67-
model: gemini15Flash,
68-
config: {
69-
temperature: 2,
70-
},
71-
output: {
72-
schema: z.object({ jokeSubject: z.string() }),
73-
},
74-
tools: [jokeSubjectGenerator],
75-
prompt: `come up with a subject to joke about (using the function provided)`,
76-
});
82+
return llmResponse.output();
83+
} else {
84+
const llmResponse = await generate({
85+
model: gemini15FlashGoogleAi,
86+
config: {
87+
temperature: 2,
88+
},
89+
output: {
90+
schema: z.object({ jokeSubject: z.string() }),
91+
},
92+
tools: [jokeSubjectGenerator],
93+
prompt: `come up with a subject to joke about (using the function provided)`,
94+
});
95+
return llmResponse.output();
96+
}
7797

7898
// Handle the response from the model API. In this sample, we just convert
7999
// it to a string, but more complicated flows might coerce the response into
80-
// structured output or chain the response into another LLM call, etc.
81-
return llmResponse.output();
82100
}
83101
);
84102

0 commit comments

Comments
 (0)