diff --git a/src/api/providers/__tests__/base-openai-compatible-provider.spec.ts b/src/api/providers/__tests__/base-openai-compatible-provider.spec.ts index 57ee649d6e..7d0d2548fc 100644 --- a/src/api/providers/__tests__/base-openai-compatible-provider.spec.ts +++ b/src/api/providers/__tests__/base-openai-compatible-provider.spec.ts @@ -383,4 +383,166 @@ describe("BaseOpenAiCompatibleProvider", () => { expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 100, outputTokens: 50 }) }) }) + + describe("Tool call handling", () => { + it("should yield tool_call_end events when finish_reason is tool_calls", async () => { + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + next: vi + .fn() + .mockResolvedValueOnce({ + done: false, + value: { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: "call_123", + function: { name: "test_tool", arguments: '{"arg":' }, + }, + ], + }, + }, + ], + }, + }) + .mockResolvedValueOnce({ + done: false, + value: { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + function: { arguments: '"value"}' }, + }, + ], + }, + }, + ], + }, + }) + .mockResolvedValueOnce({ + done: false, + value: { + choices: [ + { + delta: {}, + finish_reason: "tool_calls", + }, + ], + }, + }) + .mockResolvedValueOnce({ done: true }), + }), + } + }) + + const stream = handler.createMessage("system prompt", []) + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Should have tool_call_partial and tool_call_end + const partialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial") + const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end") + + expect(partialChunks).toHaveLength(2) + expect(endChunks).toHaveLength(1) + expect(endChunks[0]).toEqual({ type: "tool_call_end", id: "call_123" }) + }) + + it("should yield multiple tool_call_end events for parallel tool calls", async () => { + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + next: vi + .fn() + .mockResolvedValueOnce({ + done: false, + value: { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: "call_001", + function: { name: "tool_a", arguments: "{}" }, + }, + { + index: 1, + id: "call_002", + function: { name: "tool_b", arguments: "{}" }, + }, + ], + }, + }, + ], + }, + }) + .mockResolvedValueOnce({ + done: false, + value: { + choices: [ + { + delta: {}, + finish_reason: "tool_calls", + }, + ], + }, + }) + .mockResolvedValueOnce({ done: true }), + }), + } + }) + + const stream = handler.createMessage("system prompt", []) + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end") + expect(endChunks).toHaveLength(2) + expect(endChunks.map((c: any) => c.id).sort()).toEqual(["call_001", "call_002"]) + }) + + it("should not yield tool_call_end when finish_reason is not tool_calls", async () => { + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + next: vi + .fn() + .mockResolvedValueOnce({ + done: false, + value: { + choices: [ + { + delta: { content: "Some text response" }, + finish_reason: "stop", + }, + ], + }, + }) + .mockResolvedValueOnce({ done: true }), + }), + } + }) + + const stream = handler.createMessage("system prompt", []) + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end") + expect(endChunks).toHaveLength(0) + }) + }) }) diff --git a/src/api/providers/base-openai-compatible-provider.ts b/src/api/providers/base-openai-compatible-provider.ts index 92b9558c45..5aee7267b3 100644 --- a/src/api/providers/base-openai-compatible-provider.ts +++ b/src/api/providers/base-openai-compatible-provider.ts @@ -129,6 +129,7 @@ export abstract class BaseOpenAiCompatibleProvider ) let lastUsage: OpenAI.CompletionUsage | undefined + const activeToolCallIds = new Set() for await (const chunk of stream) { // Check for provider-specific error responses (e.g., MiniMax base_resp) @@ -140,6 +141,7 @@ export abstract class BaseOpenAiCompatibleProvider } const delta = chunk.choices?.[0]?.delta + const finishReason = chunk.choices?.[0]?.finish_reason if (delta?.content) { for (const processedChunk of matcher.update(delta.content)) { @@ -162,6 +164,9 @@ export abstract class BaseOpenAiCompatibleProvider // Emit raw tool call chunks - NativeToolCallParser handles state management if (delta?.tool_calls) { for (const toolCall of delta.tool_calls) { + if (toolCall.id) { + activeToolCallIds.add(toolCall.id) + } yield { type: "tool_call_partial", index: toolCall.index, @@ -172,6 +177,15 @@ export abstract class BaseOpenAiCompatibleProvider } } + // Emit tool_call_end events when finish_reason is "tool_calls" + // This ensures tool calls are finalized even if the stream doesn't properly close + if (finishReason === "tool_calls" && activeToolCallIds.size > 0) { + for (const id of activeToolCallIds) { + yield { type: "tool_call_end", id } + } + activeToolCallIds.clear() + } + if (chunk.usage) { lastUsage = chunk.usage }