From c47bdc21784661da6d9cdba4d99e595ac1458ace Mon Sep 17 00:00:00 2001 From: Jacob Cable Date: Wed, 11 Sep 2024 11:59:07 +0100 Subject: [PATCH] feat(js/plugins/vertexai): strict json mode --- js/plugins/vertexai/package.json | 8 ++-- js/plugins/vertexai/src/gemini.ts | 34 ++++++++++---- js/pnpm-lock.yaml | 49 +++++++++++++------- js/testapps/basic-gemini/package.json | 1 + js/testapps/basic-gemini/src/index.ts | 64 +++++++++++++++++---------- 5 files changed, 105 insertions(+), 51 deletions(-) diff --git a/js/plugins/vertexai/package.json b/js/plugins/vertexai/package.json index f6aac424ad..4cd8730c83 100644 --- a/js/plugins/vertexai/package.json +++ b/js/plugins/vertexai/package.json @@ -38,8 +38,8 @@ "@anthropic-ai/sdk": "^0.24.3", "@anthropic-ai/vertex-sdk": "^0.4.0", "@google-cloud/aiplatform": "^3.23.0", - "@google-cloud/vertexai": "^1.1.0", - "google-auth-library": "^9.6.3", + "@google-cloud/vertexai": "^1.7.0", + "google-auth-library": "^9.14.1", "googleapis": "^140.0.1", "node-fetch": "^3.3.2", "openai": "^4.52.7", @@ -51,8 +51,8 @@ "@genkit-ai/flow": "workspace:*" }, "optionalDependencies": { - "firebase-admin": "^12.1.0", - "@google-cloud/bigquery": "^7.8.0" + "@google-cloud/bigquery": "^7.8.0", + "firebase-admin": "^12.1.0" }, "devDependencies": { "@types/node": "^20.11.16", diff --git a/js/plugins/vertexai/src/gemini.ts b/js/plugins/vertexai/src/gemini.ts index 49d347bce8..7d0a8a2ea2 100644 --- a/js/plugins/vertexai/src/gemini.ts +++ b/js/plugins/vertexai/src/gemini.ts @@ -399,7 +399,10 @@ function fromGeminiFunctionResponsePart(part: GeminiPart): Part { } // Converts vertex part to genkit part -function fromGeminiPart(part: GeminiPart): Part { +function fromGeminiPart(part: GeminiPart, jsonMode: boolean): Part { + if (jsonMode && part.text !== undefined) { + return { data: JSON.parse(part.text) }; + } if (part.text !== undefined) return { text: part.text }; if (part.functionCall) return fromGeminiFunctionCallPart(part); if (part.functionResponse) return fromGeminiFunctionResponsePart(part); @@ -411,14 +414,15 @@ function fromGeminiPart(part: GeminiPart): Part { } export function fromGeminiCandidate( - candidate: GenerateContentCandidate + candidate: GenerateContentCandidate, + jsonMode: boolean ): CandidateData { const parts = candidate.content.parts || []; const genkitCandidate: CandidateData = { index: candidate.index || 0, // reasonable default? message: { role: 'model', - content: parts.map(fromGeminiPart), + content: parts.map((p) => fromGeminiPart(p, jsonMode)), }, finishReason: fromGeminiFinishReason(candidate.finishReason), finishMessage: candidate.finishMessage, @@ -518,11 +522,18 @@ export function geminiModel( } } + const tools = request.tools?.length + ? [{ functionDeclarations: request.tools?.map(toGeminiTool) }] + : []; + + // Cannot use tools and function calling at the same time + const jsonMode = + (request.output?.format === 'json' || !!request.output?.schema) && + tools.length === 0; + const chatRequest: StartChatParams = { systemInstruction, - tools: request.tools?.length - ? [{ functionDeclarations: request.tools?.map(toGeminiTool) }] - : [], + tools, history: messages .slice(0, -1) .map((message) => toGeminiMessage(message, model)), @@ -532,6 +543,7 @@ export function geminiModel( maxOutputTokens: request.config?.maxOutputTokens, topK: request.config?.topK, topP: request.config?.topP, + responseMimeType: jsonMode ? 'application/json' : undefined, stopSequences: request.config?.stopSequences, }, safetySettings: request.config?.safetySettings, @@ -566,7 +578,7 @@ export function geminiModel( .sendMessageStream(msg.parts); for await (const item of result.stream) { (item as GenerateContentResponse).candidates?.forEach((candidate) => { - const c = fromGeminiCandidate(candidate); + const c = fromGeminiCandidate(candidate, jsonMode); streamingCallback({ index: c.index, content: c.message.content, @@ -578,7 +590,9 @@ export function geminiModel( throw new Error('No valid candidates returned.'); } return { - candidates: response.candidates?.map(fromGeminiCandidate) || [], + candidates: + response.candidates?.map((c) => fromGeminiCandidate(c, jsonMode)) || + [], custom: response, }; } else { @@ -592,7 +606,9 @@ export function geminiModel( throw new Error('No valid candidates returned.'); } const responseCandidates = - result.response.candidates?.map(fromGeminiCandidate) || []; + result.response.candidates?.map((c) => + fromGeminiCandidate(c, jsonMode) + ) || []; return { candidates: responseCandidates, custom: result.response, diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index 27c4aa5a47..cc19472d3a 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -622,11 +622,11 @@ importers: specifier: ^3.23.0 version: 3.25.0(encoding@0.1.13) '@google-cloud/vertexai': - specifier: ^1.1.0 - version: 1.1.0(encoding@0.1.13) + specifier: ^1.7.0 + version: 1.7.0(encoding@0.1.13) google-auth-library: - specifier: ^9.6.3 - version: 9.7.0(encoding@0.1.13) + specifier: ^9.14.1 + version: 9.14.1(encoding@0.1.13) googleapis: specifier: ^140.0.1 version: 140.0.1(encoding@0.1.13) @@ -711,6 +711,9 @@ importers: '@genkit-ai/googleai': specifier: workspace:* version: link:../../plugins/googleai + '@genkit-ai/vertexai': + specifier: workspace:* + version: link:../../plugins/vertexai express: specifier: ^4.20.0 version: 4.20.0 @@ -1794,8 +1797,8 @@ packages: resolution: {integrity: sha512-sZW14pfxEQZSIbBPs6doFYtcbK31Bs3E4jH5Ly3jJnBkYfkMPX8sXG3ZQXCJa88MKtUNPlgBdMN2OJUzmFe5/g==} engines: {node: '>=14'} - '@google-cloud/vertexai@1.1.0': - resolution: {integrity: sha512-hfwfdlVpJ+kM6o2b5UFfPnweBcz8tgHAFRswnqUKYqLJsvKU0DDD0Z2/YKoHyAUoPJAv20qg6KlC3msNeUKUiw==} + '@google-cloud/vertexai@1.7.0': + resolution: {integrity: sha512-N4YcVzFQ+sPN9c3SeMhbpLfWVbeaLxPbICKsJ6yKthcr4G7tdu9pCs3HUw+Mip0M2xgiKZ8/WWvq6FXbPnlrUA==} engines: {node: '>=18.0.0'} '@google/generative-ai@0.15.0': @@ -3611,6 +3614,10 @@ packages: resolution: {integrity: sha512-epX3ww/mNnhl6tL45EQ/oixsY8JLEgUFoT4A5E/5iAR4esld9Kqv6IJGk7EmGuOgDvaarwF95hU2+v7Irql9lw==} engines: {node: '>=14'} + google-auth-library@9.14.1: + resolution: {integrity: sha512-Rj+PMjoNFGFTmtItH7gHfbHpGVSb3vmnGK3nwNBqxQF9NoBpttSZI/rc0WiM63ma2uGDQtYEkMHkK9U6937NiA==} + engines: {node: '>=14'} + google-auth-library@9.7.0: resolution: {integrity: sha512-I/AvzBiUXDzLOy4iIZ2W+Zq33W4lcukQv1nl7C8WUA6SQwyQwUwu3waNmWNAvzds//FG8SZ+DnKnW/2k6mQS8A==} engines: {node: '>=14'} @@ -5321,7 +5328,7 @@ snapshots: '@anthropic-ai/vertex-sdk@0.4.0(encoding@0.1.13)': dependencies: '@anthropic-ai/sdk': 0.24.3(encoding@0.1.13) - google-auth-library: 9.7.0(encoding@0.1.13) + google-auth-library: 9.14.1(encoding@0.1.13) transitivePeerDependencies: - encoding - supports-color @@ -5483,7 +5490,7 @@ snapshots: duplexify: 4.1.3 ent: 2.2.0 extend: 3.0.2 - google-auth-library: 9.11.0(encoding@0.1.13) + google-auth-library: 9.14.1(encoding@0.1.13) retry-request: 7.0.2(encoding@0.1.13) teeny-request: 9.0.0(encoding@0.1.13) transitivePeerDependencies: @@ -5562,7 +5569,7 @@ snapshots: '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) '@opentelemetry/resources': 1.25.1(@opentelemetry/api@1.9.0) '@opentelemetry/sdk-metrics': 1.25.1(@opentelemetry/api@1.9.0) - google-auth-library: 9.7.0(encoding@0.1.13) + google-auth-library: 9.11.0(encoding@0.1.13) googleapis: 137.1.0(encoding@0.1.13) transitivePeerDependencies: - encoding @@ -5630,7 +5637,7 @@ snapshots: ent: 2.2.0 fast-xml-parser: 4.3.6 gaxios: 6.3.0(encoding@0.1.13) - google-auth-library: 9.11.0(encoding@0.1.13) + google-auth-library: 9.14.1(encoding@0.1.13) mime: 3.0.0 p-limit: 3.1.0 retry-request: 7.0.2(encoding@0.1.13) @@ -5641,9 +5648,9 @@ snapshots: - supports-color optional: true - '@google-cloud/vertexai@1.1.0(encoding@0.1.13)': + '@google-cloud/vertexai@1.7.0(encoding@0.1.13)': dependencies: - google-auth-library: 9.7.0(encoding@0.1.13) + google-auth-library: 9.14.1(encoding@0.1.13) transitivePeerDependencies: - encoding - supports-color @@ -7693,6 +7700,18 @@ snapshots: - encoding - supports-color + google-auth-library@9.14.1(encoding@0.1.13): + dependencies: + base64-js: 1.5.1 + ecdsa-sig-formatter: 1.0.11 + gaxios: 6.3.0(encoding@0.1.13) + gcp-metadata: 6.1.0(encoding@0.1.13) + gtoken: 7.1.0(encoding@0.1.13) + jws: 4.0.0 + transitivePeerDependencies: + - encoding + - supports-color + google-auth-library@9.7.0(encoding@0.1.13): dependencies: base64-js: 1.5.1 @@ -7730,7 +7749,7 @@ snapshots: '@types/long': 4.0.2 abort-controller: 3.0.0 duplexify: 4.1.3 - google-auth-library: 9.11.0(encoding@0.1.13) + google-auth-library: 9.14.1(encoding@0.1.13) node-fetch: 2.7.0(encoding@0.1.13) object-hash: 3.0.0 proto3-json-serializer: 2.0.2 @@ -7759,7 +7778,7 @@ snapshots: dependencies: extend: 3.0.2 gaxios: 6.3.0(encoding@0.1.13) - google-auth-library: 9.11.0(encoding@0.1.13) + google-auth-library: 9.14.1(encoding@0.1.13) qs: 6.12.0 url-template: 2.0.8 uuid: 9.0.1 @@ -7777,7 +7796,7 @@ snapshots: googleapis@140.0.1(encoding@0.1.13): dependencies: - google-auth-library: 9.7.0(encoding@0.1.13) + google-auth-library: 9.14.1(encoding@0.1.13) googleapis-common: 7.2.0(encoding@0.1.13) transitivePeerDependencies: - encoding diff --git a/js/testapps/basic-gemini/package.json b/js/testapps/basic-gemini/package.json index ab532edb46..d9b55b6d22 100644 --- a/js/testapps/basic-gemini/package.json +++ b/js/testapps/basic-gemini/package.json @@ -18,6 +18,7 @@ "@genkit-ai/dotprompt": "workspace:*", "@genkit-ai/flow": "workspace:*", "@genkit-ai/googleai": "workspace:*", + "@genkit-ai/vertexai": "workspace:*", "express": "^4.20.0", "zod": "^3.22.4" }, diff --git a/js/testapps/basic-gemini/src/index.ts b/js/testapps/basic-gemini/src/index.ts index dab5d55e88..89ead8c20e 100644 --- a/js/testapps/basic-gemini/src/index.ts +++ b/js/testapps/basic-gemini/src/index.ts @@ -20,20 +20,24 @@ import * as z from 'zod'; import { defineTool, generate } from '@genkit-ai/ai'; import { configureGenkit } from '@genkit-ai/core'; import { defineFlow, startFlowsServer } from '@genkit-ai/flow'; -import { googleAI } from '@genkit-ai/googleai'; +import { + gemini15Flash as gemini15FlashGoogleAi, + googleAI, +} from '@genkit-ai/googleai'; // Import models from the Google AI plugin. The Google AI API provides access to -// several generative models. Here, we import Gemini 1.5 Flash. -import { gemini15Flash } from '@genkit-ai/googleai'; + +import { + gemini15Flash as gemini15FlashVertexAi, + vertexAI, +} from '@genkit-ai/vertexai'; + +const provider = process.env.PROVIDER || 'vertexai'; + +const plugin = provider === 'vertexai' ? vertexAI : googleAI; configureGenkit({ - plugins: [ - // Load the Google AI plugin. You can optionally specify your API key - // by passing in a config object; if you don't, the Google AI plugin uses - // the value from the GOOGLE_GENAI_API_KEY environment variable, which is - // the recommended practice. - googleAI(), - ], + plugins: [plugin()], // Log debug output to tbe console. logLevel: 'debug', // Perform OpenTelemetry instrumentation and enable trace collection. @@ -62,23 +66,37 @@ export const jokeFlow = defineFlow( }, async () => { // Construct a request and send it to the model API. + if (provider === 'vertexai') { + const llmResponse = await generate({ + model: gemini15FlashVertexAi, + config: { + temperature: 2, + }, + output: { + schema: z.object({ jokeSubject: z.string() }), + }, + tools: [jokeSubjectGenerator], + prompt: `come up with a subject to joke about (using the function provided)`, + }); - const llmResponse = await generate({ - model: gemini15Flash, - config: { - temperature: 2, - }, - output: { - schema: z.object({ jokeSubject: z.string() }), - }, - tools: [jokeSubjectGenerator], - prompt: `come up with a subject to joke about (using the function provided)`, - }); + return llmResponse.output(); + } else { + const llmResponse = await generate({ + model: gemini15FlashGoogleAi, + config: { + temperature: 2, + }, + output: { + schema: z.object({ jokeSubject: z.string() }), + }, + tools: [jokeSubjectGenerator], + prompt: `come up with a subject to joke about (using the function provided)`, + }); + return llmResponse.output(); + } // Handle the response from the model API. In this sample, we just convert // it to a string, but more complicated flows might coerce the response into - // structured output or chain the response into another LLM call, etc. - return llmResponse.output(); } );