Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions src/api/providers/__tests__/xai.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -495,5 +495,87 @@ describe("XAIHandler", () => {
}),
)
})

it("should yield tool_call_end events when finish_reason is tool_calls", async () => {
// Import NativeToolCallParser to set up state
const { NativeToolCallParser } = await import("../../../core/assistant-message/NativeToolCallParser")

// Clear any previous state
NativeToolCallParser.clearRawChunkState()

const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })

mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: vi
.fn()
.mockResolvedValueOnce({
done: false,
value: {
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: "call_xai_test",
function: {
name: "test_tool",
arguments: '{"arg1":"value"}',
},
},
],
},
},
],
},
})
.mockResolvedValueOnce({
done: false,
value: {
choices: [
{
delta: {},
finish_reason: "tool_calls",
},
],
usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 },
},
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

const stream = handlerWithTools.createMessage("test prompt", [], {
taskId: "test-task-id",
tools: testTools,
toolProtocol: "native",
})

const chunks = []
for await (const chunk of stream) {
// Simulate what Task.ts does: when we receive tool_call_partial,
// process it through NativeToolCallParser to populate rawChunkTracker
if (chunk.type === "tool_call_partial") {
NativeToolCallParser.processRawChunk({
index: chunk.index,
id: chunk.id,
name: chunk.name,
arguments: chunk.arguments,
})
}
chunks.push(chunk)
}

// Should have tool_call_partial and tool_call_end
const partialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial")
const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end")

expect(partialChunks).toHaveLength(1)
expect(endChunks).toHaveLength(1)
expect(endChunks[0].id).toBe("call_xai_test")
})
})
})
11 changes: 11 additions & 0 deletions src/api/providers/xai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import OpenAI from "openai"

import { type XAIModelId, xaiDefaultModelId, xaiModels } from "@roo-code/types"

import { NativeToolCallParser } from "../../core/assistant-message/NativeToolCallParser"
import type { ApiHandlerOptions } from "../../shared/api"

import { ApiStream } from "../transform/stream"
Expand Down Expand Up @@ -83,6 +84,7 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler

for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta
const finishReason = chunk.choices[0]?.finish_reason

if (delta?.content) {
yield {
Expand Down Expand Up @@ -111,6 +113,15 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler
}
}

// Process finish_reason to emit tool_call_end events
// This ensures tool calls are finalized even if the stream doesn't properly close
if (finishReason) {
const endEvents = NativeToolCallParser.processFinishReason(finishReason)
for (const event of endEvents) {
yield event
}
}

if (chunk.usage) {
// Extract detailed token information if available
// First check for prompt_tokens_details structure (real API response)
Expand Down
Loading