diff --git a/src/api/providers/__tests__/openai.spec.ts b/src/api/providers/__tests__/openai.spec.ts index 31fdaa2389..4469efd4d1 100644 --- a/src/api/providers/__tests__/openai.spec.ts +++ b/src/api/providers/__tests__/openai.spec.ts @@ -295,6 +295,10 @@ describe("OpenAiHandler", () => { name: undefined, arguments: '"value"}', }) + + // Verify tool_call_end event is emitted when finish_reason is "tool_calls" + const toolCallEndChunks = chunks.filter((chunk) => chunk.type === "tool_call_end") + expect(toolCallEndChunks).toHaveLength(1) }) it("should yield tool calls even when finish_reason is not set (fallback behavior)", async () => { @@ -855,6 +859,10 @@ describe("OpenAiHandler", () => { name: undefined, arguments: "{}", }) + + // Verify tool_call_end event is emitted when finish_reason is "tool_calls" + const toolCallEndChunks = chunks.filter((chunk) => chunk.type === "tool_call_end") + expect(toolCallEndChunks).toHaveLength(1) }) it("should yield tool calls for O3 model even when finish_reason is not set (fallback behavior)", async () => { diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index b198fe11d3..d6f50d0269 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -194,9 +194,11 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl ) let lastUsage + const activeToolCallIds = new Set() for await (const chunk of stream) { const delta = chunk.choices?.[0]?.delta ?? {} + const finishReason = chunk.choices?.[0]?.finish_reason if (delta.content) { for (const chunk of matcher.update(delta.content)) { @@ -211,17 +213,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl } } - if (delta.tool_calls) { - for (const toolCall of delta.tool_calls) { - yield { - type: "tool_call_partial", - index: toolCall.index, - id: toolCall.id, - name: toolCall.function?.name, - arguments: toolCall.function?.arguments, - } - } - } + yield* this.processToolCalls(delta, finishReason, activeToolCallIds) if (chunk.usage) { lastUsage = chunk.usage @@ -443,8 +435,11 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl } private async *handleStreamResponse(stream: AsyncIterable): ApiStream { + const activeToolCallIds = new Set() + for await (const chunk of stream) { const delta = chunk.choices?.[0]?.delta + const finishReason = chunk.choices?.[0]?.finish_reason if (delta) { if (delta.content) { @@ -454,18 +449,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl } } - // Emit raw tool call chunks - NativeToolCallParser handles state management - if (delta.tool_calls) { - for (const toolCall of delta.tool_calls) { - yield { - type: "tool_call_partial", - index: toolCall.index, - id: toolCall.id, - name: toolCall.function?.name, - arguments: toolCall.function?.arguments, - } - } - } + yield* this.processToolCalls(delta, finishReason, activeToolCallIds) } if (chunk.usage) { @@ -478,6 +462,46 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl } } + /** + * Helper generator to process tool calls from a stream chunk. + * Tracks active tool call IDs and yields tool_call_partial and tool_call_end events. + * @param delta - The delta object from the stream chunk + * @param finishReason - The finish_reason from the stream chunk + * @param activeToolCallIds - Set to track active tool call IDs (mutated in place) + */ + private *processToolCalls( + delta: OpenAI.Chat.Completions.ChatCompletionChunk.Choice.Delta | undefined, + finishReason: string | null | undefined, + activeToolCallIds: Set, + ): Generator< + | { type: "tool_call_partial"; index: number; id?: string; name?: string; arguments?: string } + | { type: "tool_call_end"; id: string } + > { + if (delta?.tool_calls) { + for (const toolCall of delta.tool_calls) { + if (toolCall.id) { + activeToolCallIds.add(toolCall.id) + } + yield { + type: "tool_call_partial", + index: toolCall.index, + id: toolCall.id, + name: toolCall.function?.name, + arguments: toolCall.function?.arguments, + } + } + } + + // Emit tool_call_end events when finish_reason is "tool_calls" + // This ensures tool calls are finalized even if the stream doesn't properly close + if (finishReason === "tool_calls" && activeToolCallIds.size > 0) { + for (const id of activeToolCallIds) { + yield { type: "tool_call_end", id } + } + activeToolCallIds.clear() + } + } + protected _getUrlHost(baseUrl?: string): string { try { return new URL(baseUrl ?? "").host