Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions js/plugins/vertexai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
"@anthropic-ai/sdk": "^0.24.3",
"@anthropic-ai/vertex-sdk": "^0.4.0",
"@google-cloud/aiplatform": "^3.23.0",
"@google-cloud/vertexai": "^1.1.0",
"google-auth-library": "^9.6.3",
"@google-cloud/vertexai": "^1.7.0",
"google-auth-library": "^9.14.1",
"googleapis": "^140.0.1",
"node-fetch": "^3.3.2",
"openai": "^4.52.7",
Expand All @@ -51,8 +51,8 @@
"@genkit-ai/flow": "workspace:*"
},
"optionalDependencies": {
"firebase-admin": "^12.1.0",
"@google-cloud/bigquery": "^7.8.0"
"@google-cloud/bigquery": "^7.8.0",
"firebase-admin": "^12.1.0"
},
"devDependencies": {
"@types/node": "^20.11.16",
Expand Down
34 changes: 25 additions & 9 deletions js/plugins/vertexai/src/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,10 @@ function fromGeminiFunctionResponsePart(part: GeminiPart): Part {
}

// Converts vertex part to genkit part
function fromGeminiPart(part: GeminiPart): Part {
function fromGeminiPart(part: GeminiPart, jsonMode: boolean): Part {
if (jsonMode && part.text !== undefined) {
return { data: JSON.parse(part.text) };
}
if (part.text !== undefined) return { text: part.text };
if (part.functionCall) return fromGeminiFunctionCallPart(part);
if (part.functionResponse) return fromGeminiFunctionResponsePart(part);
Expand All @@ -411,14 +414,15 @@ function fromGeminiPart(part: GeminiPart): Part {
}

export function fromGeminiCandidate(
candidate: GenerateContentCandidate
candidate: GenerateContentCandidate,
jsonMode: boolean
): CandidateData {
const parts = candidate.content.parts || [];
const genkitCandidate: CandidateData = {
index: candidate.index || 0, // reasonable default?
message: {
role: 'model',
content: parts.map(fromGeminiPart),
content: parts.map((p) => fromGeminiPart(p, jsonMode)),
},
finishReason: fromGeminiFinishReason(candidate.finishReason),
finishMessage: candidate.finishMessage,
Expand Down Expand Up @@ -518,11 +522,18 @@ export function geminiModel(
}
}

const tools = request.tools?.length
? [{ functionDeclarations: request.tools?.map(toGeminiTool) }]
: [];

// Cannot use tools and function calling at the same time
const jsonMode =
(request.output?.format === 'json' || !!request.output?.schema) &&
tools.length === 0;

const chatRequest: StartChatParams = {
systemInstruction,
tools: request.tools?.length
? [{ functionDeclarations: request.tools?.map(toGeminiTool) }]
: [],
tools,
history: messages
.slice(0, -1)
.map((message) => toGeminiMessage(message, model)),
Expand All @@ -532,6 +543,7 @@ export function geminiModel(
maxOutputTokens: request.config?.maxOutputTokens,
topK: request.config?.topK,
topP: request.config?.topP,
responseMimeType: jsonMode ? 'application/json' : undefined,
stopSequences: request.config?.stopSequences,
},
safetySettings: request.config?.safetySettings,
Expand Down Expand Up @@ -566,7 +578,7 @@ export function geminiModel(
.sendMessageStream(msg.parts);
for await (const item of result.stream) {
(item as GenerateContentResponse).candidates?.forEach((candidate) => {
const c = fromGeminiCandidate(candidate);
const c = fromGeminiCandidate(candidate, jsonMode);
streamingCallback({
index: c.index,
content: c.message.content,
Expand All @@ -578,7 +590,9 @@ export function geminiModel(
throw new Error('No valid candidates returned.');
}
return {
candidates: response.candidates?.map(fromGeminiCandidate) || [],
candidates:
response.candidates?.map((c) => fromGeminiCandidate(c, jsonMode)) ||
[],
custom: response,
};
} else {
Expand All @@ -592,7 +606,9 @@ export function geminiModel(
throw new Error('No valid candidates returned.');
}
const responseCandidates =
result.response.candidates?.map(fromGeminiCandidate) || [];
result.response.candidates?.map((c) =>
fromGeminiCandidate(c, jsonMode)
) || [];
return {
candidates: responseCandidates,
custom: result.response,
Expand Down
49 changes: 34 additions & 15 deletions js/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions js/testapps/basic-gemini/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"@genkit-ai/dotprompt": "workspace:*",
"@genkit-ai/flow": "workspace:*",
"@genkit-ai/googleai": "workspace:*",
"@genkit-ai/vertexai": "workspace:*",
"express": "^4.20.0",
"zod": "^3.22.4"
},
Expand Down
64 changes: 41 additions & 23 deletions js/testapps/basic-gemini/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,24 @@ import * as z from 'zod';
import { defineTool, generate } from '@genkit-ai/ai';
import { configureGenkit } from '@genkit-ai/core';
import { defineFlow, startFlowsServer } from '@genkit-ai/flow';
import { googleAI } from '@genkit-ai/googleai';
import {
gemini15Flash as gemini15FlashGoogleAi,
googleAI,
} from '@genkit-ai/googleai';

// Import models from the Google AI plugin. The Google AI API provides access to
// several generative models. Here, we import Gemini 1.5 Flash.
import { gemini15Flash } from '@genkit-ai/googleai';

import {
gemini15Flash as gemini15FlashVertexAi,
vertexAI,
} from '@genkit-ai/vertexai';

const provider = process.env.PROVIDER || 'vertexai';

const plugin = provider === 'vertexai' ? vertexAI : googleAI;

configureGenkit({
plugins: [
// Load the Google AI plugin. You can optionally specify your API key
// by passing in a config object; if you don't, the Google AI plugin uses
// the value from the GOOGLE_GENAI_API_KEY environment variable, which is
// the recommended practice.
googleAI(),
],
plugins: [plugin()],
// Log debug output to tbe console.
logLevel: 'debug',
// Perform OpenTelemetry instrumentation and enable trace collection.
Expand Down Expand Up @@ -62,23 +66,37 @@ export const jokeFlow = defineFlow(
},
async () => {
// Construct a request and send it to the model API.
if (provider === 'vertexai') {
const llmResponse = await generate({
model: gemini15FlashVertexAi,
config: {
temperature: 2,
},
output: {
schema: z.object({ jokeSubject: z.string() }),
},
tools: [jokeSubjectGenerator],
prompt: `come up with a subject to joke about (using the function provided)`,
});

const llmResponse = await generate({
model: gemini15Flash,
config: {
temperature: 2,
},
output: {
schema: z.object({ jokeSubject: z.string() }),
},
tools: [jokeSubjectGenerator],
prompt: `come up with a subject to joke about (using the function provided)`,
});
return llmResponse.output();
} else {
const llmResponse = await generate({
model: gemini15FlashGoogleAi,
config: {
temperature: 2,
},
output: {
schema: z.object({ jokeSubject: z.string() }),
},
tools: [jokeSubjectGenerator],
prompt: `come up with a subject to joke about (using the function provided)`,
});
return llmResponse.output();
}

// Handle the response from the model API. In this sample, we just convert
// it to a string, but more complicated flows might coerce the response into
// structured output or chain the response into another LLM call, etc.
return llmResponse.output();
}
);

Expand Down
Loading