Skip to content

Commit

Permalink
fix: disable External Filters for Gemini (run-llama#994)
Browse files Browse the repository at this point in the history
Co-authored-by: Alex Yang <himself65@outlook.com>
  • Loading branch information
PeronGH and himself65 authored Jul 2, 2024
1 parent d4c1482 commit 7dce3d2
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 12 deletions.
5 changes: 5 additions & 0 deletions .changeset/fluffy-knives-glow.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"llamaindex": patch
---

fix: disable External Filters for Gemini
11 changes: 9 additions & 2 deletions packages/llamaindex/src/llm/gemini/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import {
type IGeminiSession,
} from "./types.js";
import {
DEFAULT_SAFETY_SETTINGS,
GeminiHelper,
getChatContext,
getPartsText,
Expand Down Expand Up @@ -87,7 +88,10 @@ export class GeminiSession implements IGeminiSession {
}

getGenerativeModel(metadata: GoogleModelParams): GoogleGenerativeModel {
return this.gemini.getGenerativeModel(metadata);
return this.gemini.getGenerativeModel({
safetySettings: DEFAULT_SAFETY_SETTINGS,
...metadata,
});
}

getResponseText(response: EnhancedGenerateContentResponse): string {
Expand Down Expand Up @@ -143,8 +147,9 @@ export class GeminiSessionStore {
}> = [];

private static getSessionId(options: GeminiSessionOptions): string {
if (options.backend === GEMINI_BACKENDS.GOOGLE)
if (options.backend === GEMINI_BACKENDS.GOOGLE) {
return options?.apiKey ?? "";
}
return "";
}
private static sessionMatched(
Expand Down Expand Up @@ -223,6 +228,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
),
},
],
safetySettings: DEFAULT_SAFETY_SETTINGS,
});
const { response } = await chat.sendMessage(context.message);
const topCandidate = response.candidates![0];
Expand Down Expand Up @@ -258,6 +264,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
),
},
],
safetySettings: DEFAULT_SAFETY_SETTINGS,
});
const result = await chat.sendMessageStream(context.message);
yield* this.session.getChatStream(result);
Expand Down
46 changes: 40 additions & 6 deletions packages/llamaindex/src/llm/gemini/utils.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import {
type FunctionCall,
type Content as GeminiMessageContent,
HarmBlockThreshold,
HarmCategory,
type SafetySetting,
} from "@google/generative-ai";

import { type GenerateContentResponse } from "@google-cloud/vertexai";
Expand Down Expand Up @@ -53,10 +56,13 @@ const getImageParts = (
const { mimeType, base64: data } = extractDataUrlComponents(
message.image_url.url,
);
if (!mimeType || !ACCEPTED_IMAGE_MIME_TYPES.includes(mimeType))
if (!mimeType || !ACCEPTED_IMAGE_MIME_TYPES.includes(mimeType)) {
throw new Error(
`Gemini only accepts the following mimeTypes: ${ACCEPTED_IMAGE_MIME_TYPES.join("\n")}`,
`Gemini only accepts the following mimeTypes: ${ACCEPTED_IMAGE_MIME_TYPES.join(
"\n",
)}`,
);
}
return {
inlineData: {
mimeType,
Expand All @@ -65,10 +71,13 @@ const getImageParts = (
};
}
const mimeType = getFileURLMimeType(message.image_url.url);
if (!mimeType || !ACCEPTED_IMAGE_MIME_TYPES.includes(mimeType))
if (!mimeType || !ACCEPTED_IMAGE_MIME_TYPES.includes(mimeType)) {
throw new Error(
`Gemini only accepts the following mimeTypes: ${ACCEPTED_IMAGE_MIME_TYPES.join("\n")}`,
`Gemini only accepts the following mimeTypes: ${ACCEPTED_IMAGE_MIME_TYPES.join(
"\n",
)}`,
);
}
return {
fileData: { mimeType, fileUri: message.image_url.url },
};
Expand Down Expand Up @@ -124,10 +133,11 @@ export const getChatContext = (
// 2. Parts that have empty text
const fnMap = params.messages.reduce(
(result, message) => {
if (message.options && "toolCall" in message.options)
if (message.options && "toolCall" in message.options) {
message.options.toolCall.forEach((call) => {
result[call.id] = call.name;
});
}

return result;
},
Expand Down Expand Up @@ -224,10 +234,11 @@ export class GeminiHelper {
if (options && "toolResult" in options) {
if (!fnMap) throw Error("fnMap must be set");
const name = fnMap[options.toolResult.id];
if (!name)
if (!name) {
throw Error(
`Could not find the name for fn call with id ${options.toolResult.id}`,
);
}

return [
{
Expand Down Expand Up @@ -299,3 +310,26 @@ export function getFunctionCalls(
return undefined;
}
}

/**
* Safety settings to disable external filters
* Documentation: https://ai.google.dev/gemini-api/docs/safety-settings
*/
export const DEFAULT_SAFETY_SETTINGS: SafetySetting[] = [
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: HarmBlockThreshold.BLOCK_NONE,
},
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold: HarmBlockThreshold.BLOCK_NONE,
},
{
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold: HarmBlockThreshold.BLOCK_NONE,
},
{
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold: HarmBlockThreshold.BLOCK_NONE,
},
];
16 changes: 12 additions & 4 deletions packages/llamaindex/src/llm/gemini/vertex.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import {
type GenerateContentResponse,
VertexAI,
GenerativeModel as VertexGenerativeModel,
GenerativeModelPreview as VertexGenerativeModelPreview,
type GenerateContentResponse,
type ModelParams as VertexModelParams,
type StreamGenerateContentResult as VertexStreamGenerateContentResult,
} from "@google-cloud/vertexai";
Expand All @@ -21,7 +21,7 @@ import type {
ToolCallLLMMessageOptions,
} from "../types.js";
import { streamConverter } from "../utils.js";
import { getFunctionCalls, getText } from "./utils.js";
import { DEFAULT_SAFETY_SETTINGS, getFunctionCalls, getText } from "./utils.js";

/* To use Google's Vertex AI backend, it doesn't use api key authentication.
*
Expand Down Expand Up @@ -59,8 +59,16 @@ export class GeminiVertexSession implements IGeminiSession {
getGenerativeModel(
metadata: VertexModelParams,
): VertexGenerativeModelPreview | VertexGenerativeModel {
if (this.preview) return this.vertex.preview.getGenerativeModel(metadata);
return this.vertex.getGenerativeModel(metadata);
if (this.preview) {
return this.vertex.preview.getGenerativeModel({
safetySettings: DEFAULT_SAFETY_SETTINGS,
...metadata,
});
}
return this.vertex.getGenerativeModel({
safetySettings: DEFAULT_SAFETY_SETTINGS,
...metadata,
});
}

getResponseText(response: GenerateContentResponse): string {
Expand Down

0 comments on commit 7dce3d2

Please sign in to comment.