Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions src/api/providers/__tests__/base-openai-compatible-provider.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -383,4 +383,166 @@ describe("BaseOpenAiCompatibleProvider", () => {
expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 100, outputTokens: 50 })
})
})

describe("Tool call handling", () => {
it("should yield tool_call_end events when finish_reason is tool_calls", async () => {
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: vi
.fn()
.mockResolvedValueOnce({
done: false,
value: {
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: "call_123",
function: { name: "test_tool", arguments: '{"arg":' },
},
],
},
},
],
},
})
.mockResolvedValueOnce({
done: false,
value: {
choices: [
{
delta: {
tool_calls: [
{
index: 0,
function: { arguments: '"value"}' },
},
],
},
},
],
},
})
.mockResolvedValueOnce({
done: false,
value: {
choices: [
{
delta: {},
finish_reason: "tool_calls",
},
],
},
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

const stream = handler.createMessage("system prompt", [])
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

// Should have tool_call_partial and tool_call_end
const partialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial")
const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end")

expect(partialChunks).toHaveLength(2)
expect(endChunks).toHaveLength(1)
expect(endChunks[0]).toEqual({ type: "tool_call_end", id: "call_123" })
})

it("should yield multiple tool_call_end events for parallel tool calls", async () => {
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: vi
.fn()
.mockResolvedValueOnce({
done: false,
value: {
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: "call_001",
function: { name: "tool_a", arguments: "{}" },
},
{
index: 1,
id: "call_002",
function: { name: "tool_b", arguments: "{}" },
},
],
},
},
],
},
})
.mockResolvedValueOnce({
done: false,
value: {
choices: [
{
delta: {},
finish_reason: "tool_calls",
},
],
},
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

const stream = handler.createMessage("system prompt", [])
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end")
expect(endChunks).toHaveLength(2)
expect(endChunks.map((c: any) => c.id).sort()).toEqual(["call_001", "call_002"])
})

it("should not yield tool_call_end when finish_reason is not tool_calls", async () => {
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: vi
.fn()
.mockResolvedValueOnce({
done: false,
value: {
choices: [
{
delta: { content: "Some text response" },
finish_reason: "stop",
},
],
},
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

const stream = handler.createMessage("system prompt", [])
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end")
expect(endChunks).toHaveLength(0)
})
})
})
14 changes: 14 additions & 0 deletions src/api/providers/base-openai-compatible-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
)

let lastUsage: OpenAI.CompletionUsage | undefined
const activeToolCallIds = new Set<string>()

for await (const chunk of stream) {
// Check for provider-specific error responses (e.g., MiniMax base_resp)
Expand All @@ -140,6 +141,7 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
}

const delta = chunk.choices?.[0]?.delta
const finishReason = chunk.choices?.[0]?.finish_reason

if (delta?.content) {
for (const processedChunk of matcher.update(delta.content)) {
Expand All @@ -162,6 +164,9 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
// Emit raw tool call chunks - NativeToolCallParser handles state management
if (delta?.tool_calls) {
for (const toolCall of delta.tool_calls) {
if (toolCall.id) {
activeToolCallIds.add(toolCall.id)
}
yield {
type: "tool_call_partial",
index: toolCall.index,
Expand All @@ -172,6 +177,15 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
}
}

// Emit tool_call_end events when finish_reason is "tool_calls"
// This ensures tool calls are finalized even if the stream doesn't properly close
if (finishReason === "tool_calls" && activeToolCallIds.size > 0) {
for (const id of activeToolCallIds) {
yield { type: "tool_call_end", id }
}
activeToolCallIds.clear()
}

if (chunk.usage) {
lastUsage = chunk.usage
}
Expand Down
Loading