Skip to content

Commit

Permalink
🐛 fix: fix tools calling in some edge cases (lobehub#3872)
Browse files Browse the repository at this point in the history
* fix tools calling edge case

* fix tools calling edge case
  • Loading branch information
arvinxx authored Sep 10, 2024
1 parent 319fd75 commit 2ed759d
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 107 deletions.
300 changes: 201 additions & 99 deletions src/libs/agent-runtime/utils/streams/openai.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,60 +78,6 @@ describe('OpenAIStream', () => {
expect(onCompletionMock).toHaveBeenCalledTimes(1);
});

it('should handle tool calls', async () => {
const mockOpenAIStream = new ReadableStream({
start(controller) {
controller.enqueue({
choices: [
{
delta: {
tool_calls: [
{
function: { name: 'tool1', arguments: '{}' },
id: 'call_1',
index: 0,
type: 'function',
},
{
function: { name: 'tool2', arguments: '{}' },
id: 'call_2',
index: 1,
},
],
},
index: 0,
},
],
id: '2',
});

controller.close();
},
});

const onToolCallMock = vi.fn();

const protocolStream = OpenAIStream(mockOpenAIStream, {
onToolCall: onToolCallMock,
});

const decoder = new TextDecoder();
const chunks = [];

// @ts-ignore
for await (const chunk of protocolStream) {
chunks.push(decoder.decode(chunk, { stream: true }));
}

expect(chunks).toEqual([
'id: 2\n',
'event: tool_calls\n',
`data: [{"function":{"name":"tool1","arguments":"{}"},"id":"call_1","index":0,"type":"function"},{"function":{"name":"tool2","arguments":"{}"},"id":"call_2","index":1,"type":"function"}]\n\n`,
]);

expect(onToolCallMock).toHaveBeenCalledTimes(1);
});

it('should handle empty stream', async () => {
const mockStream = new ReadableStream({
start(controller) {
Expand Down Expand Up @@ -216,51 +162,6 @@ describe('OpenAIStream', () => {
]);
});

it('should handle tool calls without index and type', async () => {
const mockOpenAIStream = new ReadableStream({
start(controller) {
controller.enqueue({
choices: [
{
delta: {
tool_calls: [
{
function: { name: 'tool1', arguments: '{}' },
id: 'call_1',
},
{
function: { name: 'tool2', arguments: '{}' },
id: 'call_2',
},
],
},
index: 0,
},
],
id: '5',
});

controller.close();
},
});

const protocolStream = OpenAIStream(mockOpenAIStream);

const decoder = new TextDecoder();
const chunks = [];

// @ts-ignore
for await (const chunk of protocolStream) {
chunks.push(decoder.decode(chunk, { stream: true }));
}

expect(chunks).toEqual([
'id: 5\n',
'event: tool_calls\n',
`data: [{"function":{"name":"tool1","arguments":"{}"},"id":"call_1","index":0,"type":"function"},{"function":{"name":"tool2","arguments":"{}"},"id":"call_2","index":1,"type":"function"}]\n\n`,
]);
});

it('should handle error when there is not correct error', async () => {
const mockOpenAIStream = new ReadableStream({
start(controller) {
Expand Down Expand Up @@ -302,4 +203,205 @@ describe('OpenAIStream', () => {
].map((i) => `${i}\n`),
);
});

describe('Tools Calling', () => {
it('should handle OpenAI official tool calls', async () => {
const mockOpenAIStream = new ReadableStream({
start(controller) {
controller.enqueue({
choices: [
{
delta: {
tool_calls: [
{
function: { name: 'tool1', arguments: '{}' },
id: 'call_1',
index: 0,
type: 'function',
},
{
function: { name: 'tool2', arguments: '{}' },
id: 'call_2',
index: 1,
},
],
},
index: 0,
},
],
id: '2',
});

controller.close();
},
});

const onToolCallMock = vi.fn();

const protocolStream = OpenAIStream(mockOpenAIStream, {
onToolCall: onToolCallMock,
});

const decoder = new TextDecoder();
const chunks = [];

// @ts-ignore
for await (const chunk of protocolStream) {
chunks.push(decoder.decode(chunk, { stream: true }));
}

expect(chunks).toEqual([
'id: 2\n',
'event: tool_calls\n',
`data: [{"function":{"name":"tool1","arguments":"{}"},"id":"call_1","index":0,"type":"function"},{"function":{"name":"tool2","arguments":"{}"},"id":"call_2","index":1,"type":"function"}]\n\n`,
]);

expect(onToolCallMock).toHaveBeenCalledTimes(1);
});

it('should handle tool calls without index and type like mistral and minimax', async () => {
const mockOpenAIStream = new ReadableStream({
start(controller) {
controller.enqueue({
choices: [
{
delta: {
tool_calls: [
{
function: { name: 'tool1', arguments: '{}' },
id: 'call_1',
},
{
function: { name: 'tool2', arguments: '{}' },
id: 'call_2',
},
],
},
index: 0,
},
],
id: '5',
});

controller.close();
},
});

const protocolStream = OpenAIStream(mockOpenAIStream);

const decoder = new TextDecoder();
const chunks = [];

// @ts-ignore
for await (const chunk of protocolStream) {
chunks.push(decoder.decode(chunk, { stream: true }));
}

expect(chunks).toEqual([
'id: 5\n',
'event: tool_calls\n',
`data: [{"function":{"name":"tool1","arguments":"{}"},"id":"call_1","index":0,"type":"function"},{"function":{"name":"tool2","arguments":"{}"},"id":"call_2","index":1,"type":"function"}]\n\n`,
]);
});

it('should handle LiteLLM tools Calling', async () => {
const streamData = [
{
id: '1',
choices: [{ index: 0, delta: { content: '为了获取杭州的天气情况', role: 'assistant' } }],
},
{
id: '1',
choices: [{ index: 0, delta: { content: '让我为您查询一下。' } }],
},
{
id: '1',
choices: [
{
index: 0,
delta: {
content: '',
tool_calls: [
{
id: 'toolu_01VQtK4W9kqxGGLHgsPPxiBj',
function: { arguments: '', name: 'realtime-weather____fetchCurrentWeather' },
type: 'function',
index: 0,
},
],
},
},
],
},
{
id: '1',
choices: [
{
index: 0,
delta: {
content: '',
tool_calls: [
{
function: { arguments: '{"city": "\u676d\u5dde"}' },
type: 'function',
index: 0,
},
],
},
},
],
},
{
id: '1',
choices: [{ finish_reason: 'tool_calls', index: 0, delta: {} }],
},
];

const mockOpenAIStream = new ReadableStream({
start(controller) {
streamData.forEach((data) => {
controller.enqueue(data);
});

controller.close();
},
});

const onToolCallMock = vi.fn();

const protocolStream = OpenAIStream(mockOpenAIStream, {
onToolCall: onToolCallMock,
});

const decoder = new TextDecoder();
const chunks = [];

// @ts-ignore
for await (const chunk of protocolStream) {
chunks.push(decoder.decode(chunk, { stream: true }));
}

expect(chunks).toEqual(
[
'id: 1',
'event: text',
`data: "为了获取杭州的天气情况"\n`,
'id: 1',
'event: text',
`data: "让我为您查询一下。"\n`,
'id: 1',
'event: tool_calls',
`data: [{"function":{"arguments":"","name":"realtime-weather____fetchCurrentWeather"},"id":"toolu_01VQtK4W9kqxGGLHgsPPxiBj","index":0,"type":"function"}]\n`,
'id: 1',
'event: tool_calls',
`data: [{"function":{"arguments":"{\\"city\\": \\"杭州\\"}"},"id":"toolu_01VQtK4W9kqxGGLHgsPPxiBj","index":0,"type":"function"}]\n`,
'id: 1',
'event: stop',
`data: "tool_calls"\n`,
].map((i) => `${i}\n`),
);

expect(onToolCallMock).toHaveBeenCalledTimes(2);
});
});
});
26 changes: 18 additions & 8 deletions src/libs/agent-runtime/utils/streams/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@ import { ChatStreamCallbacks } from '../../types';
import {
StreamProtocolChunk,
StreamProtocolToolCallChunk,
StreamStack,
StreamToolCallChunkData,
createCallbacksTransformer,
createSSEProtocolTransformer,
generateToolCallId,
} from './protocol';

export const transformOpenAIStream = (chunk: OpenAI.ChatCompletionChunk): StreamProtocolChunk => {
export const transformOpenAIStream = (
chunk: OpenAI.ChatCompletionChunk,
stack?: StreamStack,
): StreamProtocolChunk => {
// maybe need another structure to add support for multiple choices

try {
Expand All @@ -23,16 +27,20 @@ export const transformOpenAIStream = (chunk: OpenAI.ChatCompletionChunk): Stream
return { data: chunk, id: chunk.id, type: 'data' };
}

if (typeof item.delta?.content === 'string') {
if (typeof item.delta?.content === 'string' && !item.finish_reason && !item.delta?.tool_calls) {
return { data: item.delta.content, id: chunk.id, type: 'text' };
}

if (item.delta?.tool_calls) {
return {
data: item.delta.tool_calls.map(
(value, index): StreamToolCallChunkData => ({
data: item.delta.tool_calls.map((value, index): StreamToolCallChunkData => {
if (stack && !stack.tool) {
stack.tool = { id: value.id!, index: value.index, name: value.function!.name! };
}

return {
function: value.function,
id: value.id || generateToolCallId(index, value.function?.name),
id: value.id || stack?.tool?.id || generateToolCallId(index, value.function?.name),

// mistral's tool calling don't have index and function field, it's data like:
// [{"id":"xbhnmTtY7","function":{"name":"lobe-image-designer____text2image____builtin","arguments":"{\"prompts\": [\"A photo of a small, fluffy dog with a playful expression and wagging tail.\", \"A watercolor painting of a small, energetic dog with a glossy coat and bright eyes.\", \"A vector illustration of a small, adorable dog with a short snout and perky ears.\", \"A drawing of a small, scruffy dog with a mischievous grin and a wagging tail.\"], \"quality\": \"standard\", \"seeds\": [123456, 654321, 111222, 333444], \"size\": \"1024x1024\", \"style\": \"vivid\"}"}}]
Expand All @@ -43,8 +51,8 @@ export const transformOpenAIStream = (chunk: OpenAI.ChatCompletionChunk): Stream
// so we need to add these default values
index: typeof value.index !== 'undefined' ? value.index : index,
type: value.type || 'function',
}),
),
};
}),
id: chunk.id,
type: 'tool_calls',
} as StreamProtocolToolCallChunk;
Expand Down Expand Up @@ -97,10 +105,12 @@ export const OpenAIStream = (
stream: Stream<OpenAI.ChatCompletionChunk> | ReadableStream,
callbacks?: ChatStreamCallbacks,
) => {
const streamStack: StreamStack = { id: '' };

const readableStream =
stream instanceof ReadableStream ? stream : readableFromAsyncIterable(chatStreamable(stream));

return readableStream
.pipeThrough(createSSEProtocolTransformer(transformOpenAIStream))
.pipeThrough(createSSEProtocolTransformer(transformOpenAIStream, streamStack))
.pipeThrough(createCallbacksTransformer(callbacks));
};

0 comments on commit 2ed759d

Please sign in to comment.