From 447105a6dc9035cfe27ca5ee9ce7115ce48aa50b Mon Sep 17 00:00:00 2001 From: Parham Saidi Date: Wed, 15 May 2024 11:33:55 +0200 Subject: [PATCH] fix: Gemini text chat - prevent sending broken messageContent and history (#822) --- .changeset/sharp-knives-ring.md | 5 +++ packages/core/src/Prompt.ts | 1 + packages/core/src/llm/gemini.ts | 74 +++++++++++++-------------------- packages/core/src/llm/index.ts | 2 +- 4 files changed, 37 insertions(+), 45 deletions(-) create mode 100644 .changeset/sharp-knives-ring.md diff --git a/.changeset/sharp-knives-ring.md b/.changeset/sharp-knives-ring.md new file mode 100644 index 0000000000..2d8502fdaa --- /dev/null +++ b/.changeset/sharp-knives-ring.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +Improve Gemini message and context preparation diff --git a/packages/core/src/Prompt.ts b/packages/core/src/Prompt.ts index 3d0cdffefd..ef344bd17f 100644 --- a/packages/core/src/Prompt.ts +++ b/packages/core/src/Prompt.ts @@ -371,6 +371,7 @@ export function messagesToHistoryStr(messages: ChatMessage[]) { } export const defaultContextSystemPrompt = ({ context = "" }) => { + if (!context) return ""; return `Context information is below. --------------------- ${context} diff --git a/packages/core/src/llm/gemini.ts b/packages/core/src/llm/gemini.ts index 9239d44d17..9b98fcd243 100644 --- a/packages/core/src/llm/gemini.ts +++ b/packages/core/src/llm/gemini.ts @@ -139,7 +139,7 @@ class GeminiHelper { > = { user: "user", system: "user", - assistant: "user", + assistant: "model", memory: "user", }; @@ -152,38 +152,26 @@ class GeminiHelper { }; public static mergeNeighboringSameRoleMessages( - messages: ChatMessage[], - ): ChatMessage[] { - // Gemini does not support multiple messages of the same role in a row, so we merge them - const mergedMessages: ChatMessage[] = []; - let i: number = 0; - - while (i < messages.length) { - const currentMessage: ChatMessage = messages[i]; - // Initialize merged content with current message content - const mergedContent: MessageContent[] = [currentMessage.content]; - - // Check if the next message exists and has the same role - while ( - i + 1 < messages.length && - this.ROLES_TO_GEMINI[messages[i + 1].role] === - this.ROLES_TO_GEMINI[currentMessage.role] - ) { - i++; - const nextMessage: ChatMessage = messages[i]; - mergedContent.push(nextMessage.content); - } - - // Create a new ChatMessage object with merged content - const mergedMessage: ChatMessage = { - role: currentMessage.role, - content: mergedContent.join("\n"), - }; - mergedMessages.push(mergedMessage); - i++; - } - - return mergedMessages; + messages: GeminiMessageContent[], + ): GeminiMessageContent[] { + return messages.reduce( + ( + result: GeminiMessageContent[], + current: GeminiMessageContent, + index: number, + ) => { + if (index > 0 && messages[index - 1].role === current.role) { + result[result.length - 1].parts = [ + ...result[result.length - 1].parts, + ...current.parts, + ]; + } else { + result.push(current); + } + return result; + }, + [], + ); } public static messageContentToGeminiParts(content: MessageContent): Part[] { @@ -214,8 +202,8 @@ class GeminiHelper { message: ChatMessage, ): GeminiMessageContent { return { - role: this.ROLES_TO_GEMINI[message.role], - parts: this.messageContentToGeminiParts(message.content), + role: GeminiHelper.ROLES_TO_GEMINI[message.role], + parts: GeminiHelper.messageContentToGeminiParts(message.content), }; } } @@ -260,22 +248,20 @@ export class Gemini extends ToolCallLLM { chat: ChatSession; messageContent: Part[]; } { - const { messages } = params; - const mergedMessages = - GeminiHelper.mergeNeighboringSameRoleMessages(messages); - const history = mergedMessages.slice(0, -1); - const nextMessage = mergedMessages[mergedMessages.length - 1]; - const messageContent = GeminiHelper.chatMessageToGemini(nextMessage).parts; + const messages = GeminiHelper.mergeNeighboringSameRoleMessages( + params.messages.map(GeminiHelper.chatMessageToGemini), + ); + + const history = messages.slice(0, -1); const client = this.session.gemini.getGenerativeModel(this.metadata); const chat = client.startChat({ - history: history.map(GeminiHelper.chatMessageToGemini), + history, }); - return { chat, - messageContent, + messageContent: messages[messages.length - 1].parts, }; } diff --git a/packages/core/src/llm/index.ts b/packages/core/src/llm/index.ts index 478ccb715b..3430bd7860 100644 --- a/packages/core/src/llm/index.ts +++ b/packages/core/src/llm/index.ts @@ -5,7 +5,7 @@ export { Anthropic, } from "./anthropic.js"; export { FireworksLLM } from "./fireworks.js"; -export { GEMINI_MODEL, Gemini } from "./gemini.js"; +export { GEMINI_MODEL, Gemini, GeminiSession } from "./gemini.js"; export { Groq } from "./groq.js"; export { HuggingFaceInferenceAPI } from "./huggingface.js"; export {