Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/api/providers/__tests__/openai.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,10 @@ describe("OpenAiHandler", () => {
name: undefined,
arguments: '"value"}',
})

// Verify tool_call_end event is emitted when finish_reason is "tool_calls"
const toolCallEndChunks = chunks.filter((chunk) => chunk.type === "tool_call_end")
expect(toolCallEndChunks).toHaveLength(1)
})

it("should yield tool calls even when finish_reason is not set (fallback behavior)", async () => {
Expand Down Expand Up @@ -855,6 +859,10 @@ describe("OpenAiHandler", () => {
name: undefined,
arguments: "{}",
})

// Verify tool_call_end event is emitted when finish_reason is "tool_calls"
const toolCallEndChunks = chunks.filter((chunk) => chunk.type === "tool_call_end")
expect(toolCallEndChunks).toHaveLength(1)
})

it("should yield tool calls for O3 model even when finish_reason is not set (fallback behavior)", async () => {
Expand Down
70 changes: 47 additions & 23 deletions src/api/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,11 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
)

let lastUsage
const activeToolCallIds = new Set<string>()

for await (const chunk of stream) {
const delta = chunk.choices?.[0]?.delta ?? {}
const finishReason = chunk.choices?.[0]?.finish_reason

if (delta.content) {
for (const chunk of matcher.update(delta.content)) {
Expand All @@ -211,17 +213,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
}
}

if (delta.tool_calls) {
for (const toolCall of delta.tool_calls) {
yield {
type: "tool_call_partial",
index: toolCall.index,
id: toolCall.id,
name: toolCall.function?.name,
arguments: toolCall.function?.arguments,
}
}
}
yield* this.processToolCalls(delta, finishReason, activeToolCallIds)

if (chunk.usage) {
lastUsage = chunk.usage
Expand Down Expand Up @@ -443,8 +435,11 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
}

private async *handleStreamResponse(stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>): ApiStream {
const activeToolCallIds = new Set<string>()

for await (const chunk of stream) {
const delta = chunk.choices?.[0]?.delta
const finishReason = chunk.choices?.[0]?.finish_reason

if (delta) {
if (delta.content) {
Expand All @@ -454,18 +449,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
}
}

// Emit raw tool call chunks - NativeToolCallParser handles state management
if (delta.tool_calls) {
for (const toolCall of delta.tool_calls) {
yield {
type: "tool_call_partial",
index: toolCall.index,
id: toolCall.id,
name: toolCall.function?.name,
arguments: toolCall.function?.arguments,
}
}
}
yield* this.processToolCalls(delta, finishReason, activeToolCallIds)
}

if (chunk.usage) {
Expand All @@ -478,6 +462,46 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
}
}

/**
* Helper generator to process tool calls from a stream chunk.
* Tracks active tool call IDs and yields tool_call_partial and tool_call_end events.
* @param delta - The delta object from the stream chunk
* @param finishReason - The finish_reason from the stream chunk
* @param activeToolCallIds - Set to track active tool call IDs (mutated in place)
*/
private *processToolCalls(
delta: OpenAI.Chat.Completions.ChatCompletionChunk.Choice.Delta | undefined,
finishReason: string | null | undefined,
activeToolCallIds: Set<string>,
): Generator<
| { type: "tool_call_partial"; index: number; id?: string; name?: string; arguments?: string }
| { type: "tool_call_end"; id: string }
> {
if (delta?.tool_calls) {
for (const toolCall of delta.tool_calls) {
if (toolCall.id) {
activeToolCallIds.add(toolCall.id)
}
yield {
type: "tool_call_partial",
index: toolCall.index,
id: toolCall.id,
name: toolCall.function?.name,
arguments: toolCall.function?.arguments,
}
}
}

// Emit tool_call_end events when finish_reason is "tool_calls"
// This ensures tool calls are finalized even if the stream doesn't properly close
if (finishReason === "tool_calls" && activeToolCallIds.size > 0) {
for (const id of activeToolCallIds) {
yield { type: "tool_call_end", id }
}
activeToolCallIds.clear()
}
}

protected _getUrlHost(baseUrl?: string): string {
try {
return new URL(baseUrl ?? "").host
Expand Down
Loading