Skip to content

Commit

Permalink
feat: support custom parts and code execution in googleai (#642)
Browse files Browse the repository at this point in the history
* feat: support custom parts and code execution in googleai

* refactor: update googleai code exec and testapp

* refactor: update testapp
  • Loading branch information
cabljac authored Jul 23, 2024
1 parent f954f6d commit 877fd71
Show file tree
Hide file tree
Showing 8 changed files with 260 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ js/testapps/firebase-functions-sample1/public/bundle.js
js/testapps/firebase-functions-sample1/public/config.js
ui-debug.log
firebase-debug.log

**/*.env
# RAG sample files
!js/testapps/rag/package.json
js/testapps/rag/*.json
Expand Down
11 changes: 8 additions & 3 deletions js/ai/src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ const EmptyPartSchema = z.object({
toolResponse: z.never().optional(),
data: z.unknown().optional(),
metadata: z.record(z.unknown()).optional(),
custom: z.record(z.unknown()).optional(),
});

export const TextPartSchema = EmptyPartSchema.extend({
Expand Down Expand Up @@ -94,13 +95,20 @@ export const DataPartSchema = EmptyPartSchema.extend({

export type DataPart = z.infer<typeof DataPartSchema>;

export const CustomPartSchema = EmptyPartSchema.extend({
custom: z.record(z.any()),
});
export type CustomPart = z.infer<typeof CustomPartSchema>;

export const PartSchema = z.union([
TextPartSchema,
MediaPartSchema,
ToolRequestPartSchema,
ToolResponsePartSchema,
DataPartSchema,
CustomPartSchema,
]);

export type Part = z.infer<typeof PartSchema>;

export const RoleSchema = z.enum(['system', 'user', 'model', 'tool']);
Expand Down Expand Up @@ -363,9 +371,6 @@ export function getBasicUsageStats(
input: MessageData[],
candidates: CandidateData[]
): GenerationUsage {
const responseCandidateParts = candidates.flatMap(
(candidate) => candidate.message.content
);
const inputCounts = getPartCounts(input.flatMap((md) => md.content));
const outputCounts = getPartCounts(
candidates.flatMap((c) => c.message.content)
Expand Down
2 changes: 1 addition & 1 deletion js/plugins/googleai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"dependencies": {
"@genkit-ai/ai": "workspace:*",
"@genkit-ai/core": "workspace:*",
"@google/generative-ai": "^0.14.1",
"@google/generative-ai": "^0.15.0",
"google-auth-library": "^9.6.3",
"node-fetch": "^3.3.2",
"zod": "^3.22.4"
Expand Down
68 changes: 65 additions & 3 deletions js/plugins/googleai/src/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import {
InlineDataPart,
RequestOptions,
StartChatParams,
Tool,
} from '@google/generative-ai';
import process from 'process';
import z from 'zod';
Expand All @@ -71,6 +72,7 @@ const SafetySettingsSchema = z.object({

const GeminiConfigSchema = GenerationCommonConfigSchema.extend({
safetySettings: z.array(SafetySettingsSchema).optional(),
codeExecution: z.union([z.boolean(), z.object({}).strict()]).optional(),
});

export const geminiPro = modelRef({
Expand Down Expand Up @@ -316,11 +318,53 @@ function fromFunctionResponse(part: FunctionResponsePart): ToolResponsePart {
};
}

function fromExecutableCode(part: GeminiPart): Part {
if (!part.executableCode) {
throw new Error('Invalid GeminiPart: missing executableCode');
}
return {
custom: {
executableCode: {
language: part.executableCode.language,
code: part.executableCode.code,
},
},
};
}

function fromCodeExecutionResult(part: GeminiPart): Part {
if (!part.codeExecutionResult) {
throw new Error('Invalid GeminiPart: missing codeExecutionResult');
}
return {
custom: {
codeExecutionResult: {
outcome: part.codeExecutionResult.outcome,
output: part.codeExecutionResult.output,
},
},
};
}

function toCustomPart(part: Part): GeminiPart {
if (!part.custom) {
throw new Error('Invalid GeminiPart: missing custom');
}
if (part.custom.codeExecutionResult) {
return { codeExecutionResult: part.custom.codeExecutionResult };
}
if (part.custom.executableCode) {
return { executableCode: part.custom.executableCode };
}
throw new Error('Unsupported Custom Part type');
}

function toGeminiPart(part: Part): GeminiPart {
if (part.text !== undefined) return { text: part.text };
if (part.media) return toInlineData(part);
if (part.toolRequest) return toFunctionCall(part);
if (part.toolResponse) return toFunctionResponse(part);
if (part.custom) return toCustomPart(part);
throw new Error('Unsupported Part type');
}

Expand All @@ -332,6 +376,8 @@ function fromGeminiPart(part: GeminiPart, jsonMode: boolean): Part {
if (part.inlineData) return fromInlineData(part);
if (part.functionCall) return fromFunctionCall(part);
if (part.functionResponse) return fromFunctionResponse(part);
if (part.executableCode) return fromExecutableCode(part);
if (part.codeExecutionResult) return fromCodeExecutionResult(part);
throw new Error('Unsupported GeminiPart type');
}

Expand Down Expand Up @@ -473,12 +519,28 @@ export function googleAIModel(
? 'application/json'
: undefined,
};

const tools: Tool[] = [];

if (request.tools?.length) {
tools.push({
functionDeclarations: request.tools.map(toGeminiTool),
});
}

if (request.config?.codeExecution) {
tools.push({
codeExecution:
request.config.codeExecution === true
? {}
: request.config.codeExecution,
});
}

const chatRequest = {
systemInstruction,
generationConfig,
tools: request.tools?.length
? [{ functionDeclarations: request.tools?.map(toGeminiTool) }]
: [],
tools,
history: messages
.slice(0, -1)
.map((message) => toGeminiMessage(message, model)),
Expand Down
57 changes: 52 additions & 5 deletions js/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 29 additions & 0 deletions js/testapps/google-ai-code-execution/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"name": "google-ai-code-execution",
"version": "1.0.0",
"description": "",
"main": "lib/index.js",
"scripts": {
"test": "echo \"Error: no test specified\" && exit 1",
"start": "node lib/index.js",
"build": "tsc",
"build:watch": "tsc --watch"
},
"keywords": [],
"author": "",
"license": "ISC",
"dependencies": {
"@genkit-ai/ai": "workspace:*",
"@genkit-ai/core": "workspace:^",
"@genkit-ai/dotprompt": "workspace:*",
"@genkit-ai/flow": "workspace:*",
"@genkit-ai/google-cloud": "workspace:*",
"@genkit-ai/googleai": "workspace:*",
"dotenv": "^16.4.5",
"express": "^4.19.2",
"zod": "3.22.4"
},
"devDependencies": {
"typescript": "^5.5.3"
}
}
Loading

0 comments on commit 877fd71

Please sign in to comment.