From eb552d2a694efc499bab24be9e8dba2531b1e9cb Mon Sep 17 00:00:00 2001 From: sxjeru Date: Tue, 10 Sep 2024 02:42:44 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix:=20Gemini=20cannot=20input?= =?UTF-8?q?=20images=20when=20server=20database=20is=20enabled=20(#3370)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update index.ts * Update index.ts * Update index.ts * Update index.ts * ♻️ refactor: refactor the google implement * ✅ test: fix tests * ✅ test: fix tests --------- Co-authored-by: Arvin Xu --- src/libs/agent-runtime/google/index.test.ts | 70 ++++++++++++++------- src/libs/agent-runtime/google/index.ts | 49 +++++++++------ src/utils/imageToBase64.ts | 16 +++++ 3 files changed, 92 insertions(+), 43 deletions(-) diff --git a/src/libs/agent-runtime/google/index.test.ts b/src/libs/agent-runtime/google/index.test.ts index 9a161e6d31c7..17298e085c3f 100644 --- a/src/libs/agent-runtime/google/index.test.ts +++ b/src/libs/agent-runtime/google/index.test.ts @@ -5,6 +5,7 @@ import OpenAI from 'openai'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import { OpenAIChatMessage } from '@/libs/agent-runtime'; +import * as imageToBase64Module from '@/utils/imageToBase64'; import * as debugStreamModule from '../utils/debugStream'; import { LobeGoogleAI } from './index'; @@ -303,36 +304,57 @@ describe('LobeGoogleAI', () => { describe('private method', () => { describe('convertContentToGooglePart', () => { - it('should throw TypeError when image URL does not contain base64 data', () => { - // 提供一个不包含base64数据的图像URL - const invalidImageUrl = 'http://example.com/image.png'; + it('should handle URL type images', async () => { + const imageUrl = 'http://example.com/image.png'; + const mockBase64 = 'mockBase64Data'; - expect(() => + // Mock the imageUrlToBase64 function + vi.spyOn(imageToBase64Module, 'imageUrlToBase64').mockResolvedValueOnce(mockBase64); + + const result = await instance['convertContentToGooglePart']({ + type: 'image_url', + image_url: { url: imageUrl }, + }); + + expect(result).toEqual({ + inlineData: { + data: mockBase64, + mimeType: 'image/png', + }, + }); + + expect(imageToBase64Module.imageUrlToBase64).toHaveBeenCalledWith(imageUrl); + }); + + it('should throw TypeError for unsupported image URL types', async () => { + const unsupportedImageUrl = 'unsupported://example.com/image.png'; + + await expect( instance['convertContentToGooglePart']({ type: 'image_url', - image_url: { url: invalidImageUrl }, + image_url: { url: unsupportedImageUrl }, }), - ).toThrow(TypeError); + ).rejects.toThrow(TypeError); }); }); describe('buildGoogleMessages', () => { - it('get default result with gemini-pro', () => { + it('get default result with gemini-pro', async () => { const messages: OpenAIChatMessage[] = [{ content: 'Hello', role: 'user' }]; - const contents = instance['buildGoogleMessages'](messages, 'gemini-pro'); + const contents = await instance['buildGoogleMessages'](messages, 'gemini-pro'); expect(contents).toHaveLength(1); expect(contents).toEqual([{ parts: [{ text: 'Hello' }], role: 'user' }]); }); - it('messages should end with user if using gemini-pro', () => { + it('messages should end with user if using gemini-pro', async () => { const messages: OpenAIChatMessage[] = [ { content: 'Hello', role: 'user' }, { content: 'Hi', role: 'assistant' }, ]; - const contents = instance['buildGoogleMessages'](messages, 'gemini-pro'); + const contents = await instance['buildGoogleMessages'](messages, 'gemini-pro'); expect(contents).toHaveLength(3); expect(contents).toEqual([ @@ -342,13 +364,13 @@ describe('LobeGoogleAI', () => { ]); }); - it('should include system role if there is a system role prompt', () => { + it('should include system role if there is a system role prompt', async () => { const messages: OpenAIChatMessage[] = [ { content: 'you are ChatGPT', role: 'system' }, { content: 'Who are you', role: 'user' }, ]; - const contents = instance['buildGoogleMessages'](messages, 'gemini-pro'); + const contents = await instance['buildGoogleMessages'](messages, 'gemini-pro'); expect(contents).toHaveLength(3); expect(contents).toEqual([ @@ -358,13 +380,13 @@ describe('LobeGoogleAI', () => { ]); }); - it('should not modify the length if model is gemini-1.5-pro', () => { + it('should not modify the length if model is gemini-1.5-pro', async () => { const messages: OpenAIChatMessage[] = [ { content: 'Hello', role: 'user' }, { content: 'Hi', role: 'assistant' }, ]; - const contents = instance['buildGoogleMessages'](messages, 'gemini-1.5-pro-latest'); + const contents = await instance['buildGoogleMessages'](messages, 'gemini-1.5-pro-latest'); expect(contents).toHaveLength(2); expect(contents).toEqual([ @@ -373,7 +395,7 @@ describe('LobeGoogleAI', () => { ]); }); - it('should use specified model when images are included in messages', () => { + it('should use specified model when images are included in messages', async () => { const messages: OpenAIChatMessage[] = [ { content: [ @@ -386,7 +408,7 @@ describe('LobeGoogleAI', () => { const model = 'gemini-1.5-flash-latest'; // 调用 buildGoogleMessages 方法 - const contents = instance['buildGoogleMessages'](messages, model); + const contents = await instance['buildGoogleMessages'](messages, model); expect(contents).toHaveLength(1); expect(contents).toEqual([ @@ -501,13 +523,13 @@ describe('LobeGoogleAI', () => { }); describe('convertOAIMessagesToGoogleMessage', () => { - it('should correctly convert assistant message', () => { + it('should correctly convert assistant message', async () => { const message: OpenAIChatMessage = { role: 'assistant', content: 'Hello', }; - const converted = instance['convertOAIMessagesToGoogleMessage'](message); + const converted = await instance['convertOAIMessagesToGoogleMessage'](message); expect(converted).toEqual({ role: 'model', @@ -515,13 +537,13 @@ describe('LobeGoogleAI', () => { }); }); - it('should correctly convert user message', () => { + it('should correctly convert user message', async () => { const message: OpenAIChatMessage = { role: 'user', content: 'Hi', }; - const converted = instance['convertOAIMessagesToGoogleMessage'](message); + const converted = await instance['convertOAIMessagesToGoogleMessage'](message); expect(converted).toEqual({ role: 'user', @@ -529,7 +551,7 @@ describe('LobeGoogleAI', () => { }); }); - it('should correctly convert message with inline base64 image parts', () => { + it('should correctly convert message with inline base64 image parts', async () => { const message: OpenAIChatMessage = { role: 'user', content: [ @@ -538,7 +560,7 @@ describe('LobeGoogleAI', () => { ], }; - const converted = instance['convertOAIMessagesToGoogleMessage'](message); + const converted = await instance['convertOAIMessagesToGoogleMessage'](message); expect(converted).toEqual({ role: 'user', @@ -548,7 +570,7 @@ describe('LobeGoogleAI', () => { ], }); }); - it.skip('should correctly convert message with image url parts', () => { + it.skip('should correctly convert message with image url parts', async () => { const message: OpenAIChatMessage = { role: 'user', content: [ @@ -557,7 +579,7 @@ describe('LobeGoogleAI', () => { ], }; - const converted = instance['convertOAIMessagesToGoogleMessage'](message); + const converted = await instance['convertOAIMessagesToGoogleMessage'](message); expect(converted).toEqual({ role: 'user', diff --git a/src/libs/agent-runtime/google/index.ts b/src/libs/agent-runtime/google/index.ts index 1222ad764d14..2f26139f97be 100644 --- a/src/libs/agent-runtime/google/index.ts +++ b/src/libs/agent-runtime/google/index.ts @@ -10,6 +10,8 @@ import { import { JSONSchema7 } from 'json-schema'; import { transform } from 'lodash-es'; +import { imageUrlToBase64 } from '@/utils/imageToBase64'; + import { LobeRuntimeAI } from '../BaseAI'; import { AgentRuntimeErrorType, ILobeAgentRuntimeErrorType } from '../error'; import { @@ -52,7 +54,7 @@ export class LobeGoogleAI implements LobeRuntimeAI { try { const model = payload.model; - const contents = this.buildGoogleMessages(payload.messages, model); + const contents = await this.buildGoogleMessages(payload.messages, model); const geminiStreamResult = await this.client .getGenerativeModel( @@ -109,7 +111,7 @@ export class LobeGoogleAI implements LobeRuntimeAI { } } - private convertContentToGooglePart = (content: UserMessageContentPart): Part => { + private convertContentToGooglePart = async (content: UserMessageContentPart): Promise => { switch (content.type) { case 'text': { return { text: content.text }; @@ -130,51 +132,60 @@ export class LobeGoogleAI implements LobeRuntimeAI { }; } - // if (type === 'url') { - // return { - // fileData: { - // fileUri: content.image_url.url, - // mimeType: mimeType || 'image/png', - // }, - // }; - // } + if (type === 'url') { + const base64Image = await imageUrlToBase64(content.image_url.url); + + return { + inlineData: { + data: base64Image, + mimeType: mimeType || 'image/png', + }, + }; + } throw new TypeError(`currently we don't support image url: ${content.image_url.url}`); } } }; - private convertOAIMessagesToGoogleMessage = (message: OpenAIChatMessage): Content => { + private convertOAIMessagesToGoogleMessage = async ( + message: OpenAIChatMessage, + ): Promise => { const content = message.content as string | UserMessageContentPart[]; return { parts: typeof content === 'string' ? [{ text: content }] - : content.map((c) => this.convertContentToGooglePart(c)), + : await Promise.all(content.map(async (c) => await this.convertContentToGooglePart(c))), role: message.role === 'assistant' ? 'model' : 'user', }; }; // convert messages from the Vercel AI SDK Format to the format // that is expected by the Google GenAI SDK - private buildGoogleMessages = (messages: OpenAIChatMessage[], model: string): Content[] => { + private buildGoogleMessages = async ( + messages: OpenAIChatMessage[], + model: string, + ): Promise => { // if the model is gemini-1.5-pro-latest, we don't need any special handling if (model === 'gemini-1.5-pro-latest') { - return messages + const pools = messages .filter((message) => message.role !== 'function') - .map((msg) => this.convertOAIMessagesToGoogleMessage(msg)); + .map(async (msg) => await this.convertOAIMessagesToGoogleMessage(msg)); + + return Promise.all(pools); } const contents: Content[] = []; let lastRole = 'model'; - messages.forEach((message) => { + for (const message of messages) { // current to filter function message if (message.role === 'function') { - return; + continue; } - const googleMessage = this.convertOAIMessagesToGoogleMessage(message); + const googleMessage = await this.convertOAIMessagesToGoogleMessage(message); // if the last message is a model message and the current message is a model message // then we need to add a user message to separate them @@ -187,7 +198,7 @@ export class LobeGoogleAI implements LobeRuntimeAI { // update the last role lastRole = googleMessage.role; - }); + } // if the last message is a user message, then we need to add a model message to separate them if (lastRole === 'model') { diff --git a/src/utils/imageToBase64.ts b/src/utils/imageToBase64.ts index 3ab277c9ed09..63a4302215d5 100644 --- a/src/utils/imageToBase64.ts +++ b/src/utils/imageToBase64.ts @@ -35,3 +35,19 @@ export const imageToBase64 = ({ return canvas.toDataURL(type); }; + +export const imageUrlToBase64 = async (imageUrl: string): Promise => { + try { + const res = await fetch(imageUrl); + const arrayBuffer = await res.arrayBuffer(); + + return typeof btoa === 'function' + ? btoa( + new Uint8Array(arrayBuffer).reduce((data, byte) => data + String.fromCharCode(byte), ''), + ) + : Buffer.from(arrayBuffer).toString('base64'); + } catch (error) { + console.error('Error converting image to base64:', error); + throw error; + } +};