Skip to content

Commit

Permalink
🐛 fix: Gemini cannot input images when server database is enabled (lo…
Browse files Browse the repository at this point in the history
…behub#3370)

* Update index.ts

* Update index.ts

* Update index.ts

* Update index.ts

* ♻️ refactor: refactor the google implement

* ✅ test: fix tests

* ✅ test: fix tests

---------

Co-authored-by: Arvin Xu <arvinx@foxmail.com>
  • Loading branch information
sxjeru and arvinxx authored Sep 9, 2024
1 parent 4dc3172 commit eb552d2
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 43 deletions.
70 changes: 46 additions & 24 deletions src/libs/agent-runtime/google/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import OpenAI from 'openai';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

import { OpenAIChatMessage } from '@/libs/agent-runtime';
import * as imageToBase64Module from '@/utils/imageToBase64';

import * as debugStreamModule from '../utils/debugStream';
import { LobeGoogleAI } from './index';
Expand Down Expand Up @@ -303,36 +304,57 @@ describe('LobeGoogleAI', () => {

describe('private method', () => {
describe('convertContentToGooglePart', () => {
it('should throw TypeError when image URL does not contain base64 data', () => {
// 提供一个不包含base64数据的图像URL
const invalidImageUrl = 'http://example.com/image.png';
it('should handle URL type images', async () => {
const imageUrl = 'http://example.com/image.png';
const mockBase64 = 'mockBase64Data';

expect(() =>
// Mock the imageUrlToBase64 function
vi.spyOn(imageToBase64Module, 'imageUrlToBase64').mockResolvedValueOnce(mockBase64);

const result = await instance['convertContentToGooglePart']({
type: 'image_url',
image_url: { url: imageUrl },
});

expect(result).toEqual({
inlineData: {
data: mockBase64,
mimeType: 'image/png',
},
});

expect(imageToBase64Module.imageUrlToBase64).toHaveBeenCalledWith(imageUrl);
});

it('should throw TypeError for unsupported image URL types', async () => {
const unsupportedImageUrl = 'unsupported://example.com/image.png';

await expect(
instance['convertContentToGooglePart']({
type: 'image_url',
image_url: { url: invalidImageUrl },
image_url: { url: unsupportedImageUrl },
}),
).toThrow(TypeError);
).rejects.toThrow(TypeError);
});
});

describe('buildGoogleMessages', () => {
it('get default result with gemini-pro', () => {
it('get default result with gemini-pro', async () => {
const messages: OpenAIChatMessage[] = [{ content: 'Hello', role: 'user' }];

const contents = instance['buildGoogleMessages'](messages, 'gemini-pro');
const contents = await instance['buildGoogleMessages'](messages, 'gemini-pro');

expect(contents).toHaveLength(1);
expect(contents).toEqual([{ parts: [{ text: 'Hello' }], role: 'user' }]);
});

it('messages should end with user if using gemini-pro', () => {
it('messages should end with user if using gemini-pro', async () => {
const messages: OpenAIChatMessage[] = [
{ content: 'Hello', role: 'user' },
{ content: 'Hi', role: 'assistant' },
];

const contents = instance['buildGoogleMessages'](messages, 'gemini-pro');
const contents = await instance['buildGoogleMessages'](messages, 'gemini-pro');

expect(contents).toHaveLength(3);
expect(contents).toEqual([
Expand All @@ -342,13 +364,13 @@ describe('LobeGoogleAI', () => {
]);
});

it('should include system role if there is a system role prompt', () => {
it('should include system role if there is a system role prompt', async () => {
const messages: OpenAIChatMessage[] = [
{ content: 'you are ChatGPT', role: 'system' },
{ content: 'Who are you', role: 'user' },
];

const contents = instance['buildGoogleMessages'](messages, 'gemini-pro');
const contents = await instance['buildGoogleMessages'](messages, 'gemini-pro');

expect(contents).toHaveLength(3);
expect(contents).toEqual([
Expand All @@ -358,13 +380,13 @@ describe('LobeGoogleAI', () => {
]);
});

it('should not modify the length if model is gemini-1.5-pro', () => {
it('should not modify the length if model is gemini-1.5-pro', async () => {
const messages: OpenAIChatMessage[] = [
{ content: 'Hello', role: 'user' },
{ content: 'Hi', role: 'assistant' },
];

const contents = instance['buildGoogleMessages'](messages, 'gemini-1.5-pro-latest');
const contents = await instance['buildGoogleMessages'](messages, 'gemini-1.5-pro-latest');

expect(contents).toHaveLength(2);
expect(contents).toEqual([
Expand All @@ -373,7 +395,7 @@ describe('LobeGoogleAI', () => {
]);
});

it('should use specified model when images are included in messages', () => {
it('should use specified model when images are included in messages', async () => {
const messages: OpenAIChatMessage[] = [
{
content: [
Expand All @@ -386,7 +408,7 @@ describe('LobeGoogleAI', () => {
const model = 'gemini-1.5-flash-latest';

// 调用 buildGoogleMessages 方法
const contents = instance['buildGoogleMessages'](messages, model);
const contents = await instance['buildGoogleMessages'](messages, model);

expect(contents).toHaveLength(1);
expect(contents).toEqual([
Expand Down Expand Up @@ -501,35 +523,35 @@ describe('LobeGoogleAI', () => {
});

describe('convertOAIMessagesToGoogleMessage', () => {
it('should correctly convert assistant message', () => {
it('should correctly convert assistant message', async () => {
const message: OpenAIChatMessage = {
role: 'assistant',
content: 'Hello',
};

const converted = instance['convertOAIMessagesToGoogleMessage'](message);
const converted = await instance['convertOAIMessagesToGoogleMessage'](message);

expect(converted).toEqual({
role: 'model',
parts: [{ text: 'Hello' }],
});
});

it('should correctly convert user message', () => {
it('should correctly convert user message', async () => {
const message: OpenAIChatMessage = {
role: 'user',
content: 'Hi',
};

const converted = instance['convertOAIMessagesToGoogleMessage'](message);
const converted = await instance['convertOAIMessagesToGoogleMessage'](message);

expect(converted).toEqual({
role: 'user',
parts: [{ text: 'Hi' }],
});
});

it('should correctly convert message with inline base64 image parts', () => {
it('should correctly convert message with inline base64 image parts', async () => {
const message: OpenAIChatMessage = {
role: 'user',
content: [
Expand All @@ -538,7 +560,7 @@ describe('LobeGoogleAI', () => {
],
};

const converted = instance['convertOAIMessagesToGoogleMessage'](message);
const converted = await instance['convertOAIMessagesToGoogleMessage'](message);

expect(converted).toEqual({
role: 'user',
Expand All @@ -548,7 +570,7 @@ describe('LobeGoogleAI', () => {
],
});
});
it.skip('should correctly convert message with image url parts', () => {
it.skip('should correctly convert message with image url parts', async () => {
const message: OpenAIChatMessage = {
role: 'user',
content: [
Expand All @@ -557,7 +579,7 @@ describe('LobeGoogleAI', () => {
],
};

const converted = instance['convertOAIMessagesToGoogleMessage'](message);
const converted = await instance['convertOAIMessagesToGoogleMessage'](message);

expect(converted).toEqual({
role: 'user',
Expand Down
49 changes: 30 additions & 19 deletions src/libs/agent-runtime/google/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import {
import { JSONSchema7 } from 'json-schema';
import { transform } from 'lodash-es';

import { imageUrlToBase64 } from '@/utils/imageToBase64';

import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType, ILobeAgentRuntimeErrorType } from '../error';
import {
Expand Down Expand Up @@ -52,7 +54,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {
try {
const model = payload.model;

const contents = this.buildGoogleMessages(payload.messages, model);
const contents = await this.buildGoogleMessages(payload.messages, model);

const geminiStreamResult = await this.client
.getGenerativeModel(
Expand Down Expand Up @@ -109,7 +111,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {
}
}

private convertContentToGooglePart = (content: UserMessageContentPart): Part => {
private convertContentToGooglePart = async (content: UserMessageContentPart): Promise<Part> => {
switch (content.type) {
case 'text': {
return { text: content.text };
Expand All @@ -130,51 +132,60 @@ export class LobeGoogleAI implements LobeRuntimeAI {
};
}

// if (type === 'url') {
// return {
// fileData: {
// fileUri: content.image_url.url,
// mimeType: mimeType || 'image/png',
// },
// };
// }
if (type === 'url') {
const base64Image = await imageUrlToBase64(content.image_url.url);

return {
inlineData: {
data: base64Image,
mimeType: mimeType || 'image/png',
},
};
}

throw new TypeError(`currently we don't support image url: ${content.image_url.url}`);
}
}
};

private convertOAIMessagesToGoogleMessage = (message: OpenAIChatMessage): Content => {
private convertOAIMessagesToGoogleMessage = async (
message: OpenAIChatMessage,
): Promise<Content> => {
const content = message.content as string | UserMessageContentPart[];

return {
parts:
typeof content === 'string'
? [{ text: content }]
: content.map((c) => this.convertContentToGooglePart(c)),
: await Promise.all(content.map(async (c) => await this.convertContentToGooglePart(c))),
role: message.role === 'assistant' ? 'model' : 'user',
};
};

// convert messages from the Vercel AI SDK Format to the format
// that is expected by the Google GenAI SDK
private buildGoogleMessages = (messages: OpenAIChatMessage[], model: string): Content[] => {
private buildGoogleMessages = async (
messages: OpenAIChatMessage[],
model: string,
): Promise<Content[]> => {
// if the model is gemini-1.5-pro-latest, we don't need any special handling
if (model === 'gemini-1.5-pro-latest') {
return messages
const pools = messages
.filter((message) => message.role !== 'function')
.map((msg) => this.convertOAIMessagesToGoogleMessage(msg));
.map(async (msg) => await this.convertOAIMessagesToGoogleMessage(msg));

return Promise.all(pools);
}

const contents: Content[] = [];
let lastRole = 'model';

messages.forEach((message) => {
for (const message of messages) {
// current to filter function message
if (message.role === 'function') {
return;
continue;
}
const googleMessage = this.convertOAIMessagesToGoogleMessage(message);
const googleMessage = await this.convertOAIMessagesToGoogleMessage(message);

// if the last message is a model message and the current message is a model message
// then we need to add a user message to separate them
Expand All @@ -187,7 +198,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {

// update the last role
lastRole = googleMessage.role;
});
}

// if the last message is a user message, then we need to add a model message to separate them
if (lastRole === 'model') {
Expand Down
16 changes: 16 additions & 0 deletions src/utils/imageToBase64.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,19 @@ export const imageToBase64 = ({

return canvas.toDataURL(type);
};

export const imageUrlToBase64 = async (imageUrl: string): Promise<string> => {
try {
const res = await fetch(imageUrl);
const arrayBuffer = await res.arrayBuffer();

return typeof btoa === 'function'
? btoa(
new Uint8Array(arrayBuffer).reduce((data, byte) => data + String.fromCharCode(byte), ''),
)
: Buffer.from(arrayBuffer).toString('base64');
} catch (error) {
console.error('Error converting image to base64:', error);
throw error;
}
};

0 comments on commit eb552d2

Please sign in to comment.