diff --git a/src/api/providers/__tests__/xai.spec.ts b/src/api/providers/__tests__/xai.spec.ts index 52e6e33689d..6bdf7fa044c 100644 --- a/src/api/providers/__tests__/xai.spec.ts +++ b/src/api/providers/__tests__/xai.spec.ts @@ -495,5 +495,87 @@ describe("XAIHandler", () => { }), ) }) + + it("should yield tool_call_end events when finish_reason is tool_calls", async () => { + // Import NativeToolCallParser to set up state + const { NativeToolCallParser } = await import("../../../core/assistant-message/NativeToolCallParser") + + // Clear any previous state + NativeToolCallParser.clearRawChunkState() + + const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" }) + + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + next: vi + .fn() + .mockResolvedValueOnce({ + done: false, + value: { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: "call_xai_test", + function: { + name: "test_tool", + arguments: '{"arg1":"value"}', + }, + }, + ], + }, + }, + ], + }, + }) + .mockResolvedValueOnce({ + done: false, + value: { + choices: [ + { + delta: {}, + finish_reason: "tool_calls", + }, + ], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, + }, + }) + .mockResolvedValueOnce({ done: true }), + }), + } + }) + + const stream = handlerWithTools.createMessage("test prompt", [], { + taskId: "test-task-id", + tools: testTools, + toolProtocol: "native", + }) + + const chunks = [] + for await (const chunk of stream) { + // Simulate what Task.ts does: when we receive tool_call_partial, + // process it through NativeToolCallParser to populate rawChunkTracker + if (chunk.type === "tool_call_partial") { + NativeToolCallParser.processRawChunk({ + index: chunk.index, + id: chunk.id, + name: chunk.name, + arguments: chunk.arguments, + }) + } + chunks.push(chunk) + } + + // Should have tool_call_partial and tool_call_end + const partialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial") + const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end") + + expect(partialChunks).toHaveLength(1) + expect(endChunks).toHaveLength(1) + expect(endChunks[0].id).toBe("call_xai_test") + }) }) }) diff --git a/src/api/providers/xai.ts b/src/api/providers/xai.ts index 50256c03a0d..36c1ab17dcb 100644 --- a/src/api/providers/xai.ts +++ b/src/api/providers/xai.ts @@ -3,6 +3,7 @@ import OpenAI from "openai" import { type XAIModelId, xaiDefaultModelId, xaiModels } from "@roo-code/types" +import { NativeToolCallParser } from "../../core/assistant-message/NativeToolCallParser" import type { ApiHandlerOptions } from "../../shared/api" import { ApiStream } from "../transform/stream" @@ -83,6 +84,7 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler for await (const chunk of stream) { const delta = chunk.choices[0]?.delta + const finishReason = chunk.choices[0]?.finish_reason if (delta?.content) { yield { @@ -111,6 +113,15 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler } } + // Process finish_reason to emit tool_call_end events + // This ensures tool calls are finalized even if the stream doesn't properly close + if (finishReason) { + const endEvents = NativeToolCallParser.processFinishReason(finishReason) + for (const event of endEvents) { + yield event + } + } + if (chunk.usage) { // Extract detailed token information if available // First check for prompt_tokens_details structure (real API response)