diff --git a/packages/types/src/providers/xai.ts b/packages/types/src/providers/xai.ts index 37e0f2d12e0..2954888d733 100644 --- a/packages/types/src/providers/xai.ts +++ b/packages/types/src/providers/xai.ts @@ -30,6 +30,8 @@ export const xaiModels = { cacheReadsPrice: 0.05, description: "xAI's Grok 4.1 Fast model with 2M context window, optimized for high-performance agentic tool calling with reasoning", + supportsReasoningEffort: ["low", "high"], + reasoningEffort: "low", includedTools: ["search_replace"], excludedTools: ["apply_diff"], }, @@ -58,6 +60,8 @@ export const xaiModels = { cacheReadsPrice: 0.05, description: "xAI's Grok 4 Fast model with 2M context window, optimized for high-performance agentic tool calling with reasoning", + supportsReasoningEffort: ["low", "high"], + reasoningEffort: "low", includedTools: ["search_replace"], excludedTools: ["apply_diff"], }, diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 4cf3a3627e0..417e69a07ad 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -761,6 +761,9 @@ importers: '@ai-sdk/mistral': specifier: ^3.0.0 version: 3.0.18(zod@3.25.76) + '@ai-sdk/xai': + specifier: ^3.0.46 + version: 3.0.46(zod@3.25.76) '@anthropic-ai/bedrock-sdk': specifier: ^0.10.2 version: 0.10.4 @@ -1462,6 +1465,12 @@ packages: peerDependencies: zod: 3.25.76 + '@ai-sdk/openai-compatible@2.0.26': + resolution: {integrity: sha512-l6jdFjI1C2eDAEm7oo+dnRn0oG1EkcyqfbEZ7ozT0TnYrah6amX2JkftYMP1GRzNtAeCB3WNN8XspXdmi6ZNlQ==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + '@ai-sdk/provider-utils@3.0.20': resolution: {integrity: sha512-iXHVe0apM2zUEzauqJwqmpC37A5rihrStAih5Ks+JE32iTe4LZ58y17UGBjpQQTCRw9YxMeo2UFLxLpBluyvLQ==} engines: {node: '>=18'} @@ -1512,6 +1521,12 @@ packages: resolution: {integrity: sha512-VkPLrutM6VdA924/mG8OS+5frbVTcu6e046D2bgDo00tehBANR1QBJ/mPcZ9tXMFOsVcm6SQArOregxePzTFPw==} engines: {node: '>=18'} + '@ai-sdk/xai@3.0.46': + resolution: {integrity: sha512-26qM/jYcFhF5krTM7bQT1CiZcdz22EQmA+r5me1hKYFM/yM20sSUMHnAcUzvzuuG9oQVKF0tziU2IcC0HX5huQ==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + '@alcalzone/ansi-tokenize@0.2.3': resolution: {integrity: sha512-jsElTJ0sQ4wHRz+C45tfect76BwbTbgkgKByOzpCN9xG61N5V6u/glvg1CsNJhq2xJIFpKHSwG3D2wPPuEYOrQ==} engines: {node: '>=18'} @@ -6530,10 +6545,6 @@ packages: resolution: {integrity: sha512-mQw+2fkQbALzQ7V0MY0IqdnXNOeTtP4r0lN9z7AAawCXgqea7bDii20AYrIBrFd/Hx0M2Ocz6S111CaFkUcb0Q==} engines: {node: '>=0.8.x'} - eventsource-parser@3.0.2: - resolution: {integrity: sha512-6RxOBZ/cYgd8usLwsEl+EC09Au/9BcmCKYF2/xbml6DNczf7nv0MQb+7BA2F+li6//I+28VNlQR37XfQtcAJuA==} - engines: {node: '>=18.0.0'} - eventsource-parser@3.0.6: resolution: {integrity: sha512-Vo1ab+QXPzZ4tCa8SwIHJFaSzy4R6SHf7BY79rFBDf0idraZWAkYrDjDj8uWaSm3S2TK+hJ7/t1CEmZ7jXw+pg==} engines: {node: '>=18.0.0'} @@ -11146,6 +11157,12 @@ snapshots: '@ai-sdk/provider-utils': 4.0.11(zod@3.25.76) zod: 3.25.76 + '@ai-sdk/openai-compatible@2.0.26(zod@3.25.76)': + dependencies: + '@ai-sdk/provider': 3.0.7 + '@ai-sdk/provider-utils': 4.0.13(zod@3.25.76) + zod: 3.25.76 + '@ai-sdk/provider-utils@3.0.20(zod@3.25.76)': dependencies: '@ai-sdk/provider': 2.0.1 @@ -11202,6 +11219,13 @@ snapshots: dependencies: json-schema: 0.4.0 + '@ai-sdk/xai@3.0.46(zod@3.25.76)': + dependencies: + '@ai-sdk/openai-compatible': 2.0.26(zod@3.25.76) + '@ai-sdk/provider': 3.0.7 + '@ai-sdk/provider-utils': 4.0.13(zod@3.25.76) + zod: 3.25.76 + '@alcalzone/ansi-tokenize@0.2.3': dependencies: ansi-styles: 6.2.3 @@ -17027,13 +17051,11 @@ snapshots: events@3.3.0: {} - eventsource-parser@3.0.2: {} - eventsource-parser@3.0.6: {} eventsource@3.0.7: dependencies: - eventsource-parser: 3.0.2 + eventsource-parser: 3.0.6 exceljs@4.4.0: dependencies: diff --git a/src/api/providers/__tests__/xai.spec.ts b/src/api/providers/__tests__/xai.spec.ts index c622c9d4fcf..27e0a25f5cc 100644 --- a/src/api/providers/__tests__/xai.spec.ts +++ b/src/api/providers/__tests__/xai.spec.ts @@ -1,587 +1,731 @@ -// npx vitest api/providers/__tests__/xai.spec.ts - -// Mock TelemetryService - must come before other imports -const mockCaptureException = vitest.hoisted(() => vitest.fn()) -vitest.mock("@roo-code/telemetry", () => ({ - TelemetryService: { - instance: { - captureException: mockCaptureException, - }, - }, -})) - -const mockCreate = vitest.fn() +// npx vitest run api/providers/__tests__/xai.spec.ts -vitest.mock("openai", () => { - const mockConstructor = vitest.fn() +// Use vi.hoisted to define mock functions that can be referenced in hoisted vi.mock() calls +const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({ + mockStreamText: vi.fn(), + mockGenerateText: vi.fn(), +})) +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal() return { - __esModule: true, - default: mockConstructor.mockImplementation(() => ({ chat: { completions: { create: mockCreate } } })), + ...actual, + streamText: mockStreamText, + generateText: mockGenerateText, } }) -import OpenAI from "openai" +vi.mock("@ai-sdk/xai", () => ({ + createXai: vi.fn(() => { + // Return a function that returns a mock language model + return vi.fn(() => ({ + modelId: "grok-code-fast-1", + provider: "xai", + })) + }), +})) + import type { Anthropic } from "@anthropic-ai/sdk" -import { xaiDefaultModelId, xaiModels } from "@roo-code/types" +import { xaiDefaultModelId, xaiModels, type XAIModelId } from "@roo-code/types" + +import type { ApiHandlerOptions } from "../../../shared/api" import { XAIHandler } from "../xai" describe("XAIHandler", () => { let handler: XAIHandler + let mockOptions: ApiHandlerOptions beforeEach(() => { - // Reset all mocks + mockOptions = { + xaiApiKey: "test-xai-api-key", + apiModelId: "grok-code-fast-1", + } + handler = new XAIHandler(mockOptions) vi.clearAllMocks() - mockCreate.mockClear() - mockCaptureException.mockClear() - - // Create handler with mock - handler = new XAIHandler({}) }) - it("should use the correct X.AI base URL", () => { - expect(OpenAI).toHaveBeenCalledWith( - expect.objectContaining({ - baseURL: "https://api.x.ai/v1", - }), - ) - }) + describe("constructor", () => { + it("should initialize with provided options", () => { + expect(handler).toBeInstanceOf(XAIHandler) + expect(handler.getModel().id).toBe(mockOptions.apiModelId) + }) - it("should use the provided API key", () => { - // Clear mocks before this specific test - vi.clearAllMocks() + it("should use default model ID if not provided", () => { + const handlerWithoutModel = new XAIHandler({ + ...mockOptions, + apiModelId: undefined, + }) + expect(handlerWithoutModel.getModel().id).toBe(xaiDefaultModelId) + }) + }) - // Create a handler with our API key - const xaiApiKey = "test-api-key" - new XAIHandler({ xaiApiKey }) + describe("getModel", () => { + it("should return default model when no model is specified", () => { + const handlerWithoutModel = new XAIHandler({ + xaiApiKey: "test-xai-api-key", + }) + const model = handlerWithoutModel.getModel() + expect(model.id).toBe(xaiDefaultModelId) + expect(model.info).toEqual(xaiModels[xaiDefaultModelId]) + }) - // Verify the OpenAI constructor was called with our API key - expect(OpenAI).toHaveBeenCalledWith( - expect.objectContaining({ - apiKey: xaiApiKey, - }), - ) - }) + it("should return specified model when valid model is provided", () => { + const testModelId: XAIModelId = "grok-3" + const handlerWithModel = new XAIHandler({ + apiModelId: testModelId, + xaiApiKey: "test-xai-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual(xaiModels[testModelId]) + }) - it("should return default model when no model is specified", () => { - const model = handler.getModel() - expect(model.id).toBe(xaiDefaultModelId) - expect(model.info).toEqual(xaiModels[xaiDefaultModelId]) - }) + it("should return grok-3-mini model with correct configuration", () => { + const testModelId: XAIModelId = "grok-3-mini" + const handlerWithModel = new XAIHandler({ + apiModelId: testModelId, + xaiApiKey: "test-xai-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 8192, + contextWindow: 131072, + supportsImages: true, + supportsPromptCache: true, + inputPrice: 0.3, + outputPrice: 0.5, + }), + ) + }) - test("should return specified model when valid model is provided", () => { - const testModelId = "grok-3" - const handlerWithModel = new XAIHandler({ apiModelId: testModelId }) - const model = handlerWithModel.getModel() + it("should return grok-4-0709 model with correct configuration", () => { + const testModelId: XAIModelId = "grok-4-0709" + const handlerWithModel = new XAIHandler({ + apiModelId: testModelId, + xaiApiKey: "test-xai-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 8192, + contextWindow: 256_000, + supportsImages: true, + supportsPromptCache: true, + inputPrice: 3.0, + outputPrice: 15.0, + }), + ) + }) - expect(model.id).toBe(testModelId) - expect(model.info).toEqual(xaiModels[testModelId]) - }) + it("should fall back to default model for invalid model ID", () => { + const handlerWithInvalidModel = new XAIHandler({ + ...mockOptions, + apiModelId: "invalid-model", + }) + const model = handlerWithInvalidModel.getModel() + expect(model.id).toBe(xaiDefaultModelId) + expect(model.info).toBe(xaiModels[xaiDefaultModelId]) + }) - it("should include reasoning_effort parameter for mini models", async () => { - const miniModelHandler = new XAIHandler({ - apiModelId: "grok-3-mini", - reasoningEffort: "high", + it("should include model parameters from getModelParams", () => { + const model = handler.getModel() + expect(model).toHaveProperty("temperature") + expect(model).toHaveProperty("maxTokens") }) + }) - // Setup mock for streaming response - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } + describe("createMessage", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "text" as const, + text: "Hello!", }, - }), + ], + }, + ] + + it("should handle streaming responses", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response from xAI" } } - }) - // Start generating a message - const messageGenerator = miniModelHandler.createMessage("test prompt", []) - await messageGenerator.next() // Start the generator + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) - // Check that reasoning_effort was included - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - reasoning_effort: "high", - }), - ) - }) + const mockProviderMetadata = Promise.resolve({}) - it("should not include reasoning_effort parameter for non-mini models", async () => { - const regularModelHandler = new XAIHandler({ - apiModelId: "grok-3", - reasoningEffort: "high", - }) + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) - // Setup mock for streaming response - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) } + + expect(chunks.length).toBeGreaterThan(0) + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(1) + expect(textChunks[0].text).toBe("Test response from xAI") }) - // Start generating a message - const messageGenerator = regularModelHandler.createMessage("test prompt", []) - await messageGenerator.next() // Start the generator + it("should include usage information", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } - // Check call args for reasoning_effort - const calls = mockCreate.mock.calls - const lastCall = calls[calls.length - 1][0] - expect(lastCall).not.toHaveProperty("reasoning_effort") - }) + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 20, + }) - it("completePrompt method should return text from OpenAI API", async () => { - const expectedResponse = "This is a test response" - mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] }) + const mockProviderMetadata = Promise.resolve({}) - const result = await handler.completePrompt("test prompt") - expect(result).toBe(expectedResponse) - }) + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) - it("should handle errors in completePrompt", async () => { - const errorMessage = "API error" - mockCreate.mockRejectedValueOnce(new Error(errorMessage)) + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } - await expect(handler.completePrompt("test prompt")).rejects.toThrow(`xAI completion error: ${errorMessage}`) - }) + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].inputTokens).toBe(10) + expect(usageChunks[0].outputTokens).toBe(20) + }) - it("createMessage should yield text content from stream", async () => { - const testContent = "This is test content" - - // Setup mock for streaming response - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [{ delta: { content: testContent } }], - }, - }) - .mockResolvedValueOnce({ done: true }), - }), + it("should handle cached tokens in usage data from providerMetadata", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } } - }) - // Create and consume the stream - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() + const mockUsage = Promise.resolve({ + inputTokens: 100, + outputTokens: 50, + }) - // Verify the content - expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ - type: "text", - text: testContent, - }) - }) + // xAI provides cache metrics via providerMetadata for supported models + const mockProviderMetadata = Promise.resolve({ + xai: { + cachedPromptTokens: 30, + }, + }) - it("createMessage should yield reasoning content from stream", async () => { - const testReasoning = "Test reasoning content" - - // Setup mock for streaming response - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [{ delta: { reasoning_content: testReasoning } }], - }, - }) - .mockResolvedValueOnce({ done: true }), - }), + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].inputTokens).toBe(100) + expect(usageChunks[0].outputTokens).toBe(50) + expect(usageChunks[0].cacheReadTokens).toBe(30) }) - // Create and consume the stream - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() + it("should handle usage with details.cachedInputTokens when providerMetadata is not available", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + const mockUsage = Promise.resolve({ + inputTokens: 100, + outputTokens: 50, + details: { + cachedInputTokens: 25, + }, + }) + + const mockProviderMetadata = Promise.resolve({}) - // Verify the reasoning content - expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ - type: "reasoning", - text: testReasoning, + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].cacheReadTokens).toBe(25) + expect(usageChunks[0].cacheWriteTokens).toBeUndefined() }) - }) - it("createMessage should yield usage data from stream", async () => { - // Setup mock for streaming response that includes usage data - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [{ delta: {} }], // Needs to have choices array to avoid error - usage: { - prompt_tokens: 10, - completion_tokens: 20, - cache_read_input_tokens: 5, - cache_creation_input_tokens: 15, - }, - }, - }) - .mockResolvedValueOnce({ done: true }), + it("should pass correct temperature (0 default) to streamText", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), + }) + + const handlerWithDefaultTemp = new XAIHandler({ + xaiApiKey: "test-key", + apiModelId: "grok-code-fast-1", + }) + + const stream = handlerWithDefaultTemp.createMessage(systemPrompt, messages) + for await (const _ of stream) { + // consume stream + } + + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 0, }), + ) + }) + + it("should use user-specified temperature over default", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), + }) + + const handlerWithCustomTemp = new XAIHandler({ + xaiApiKey: "test-key", + apiModelId: "grok-3", + modelTemperature: 0.7, + }) + + const stream = handlerWithCustomTemp.createMessage(systemPrompt, messages) + for await (const _ of stream) { + // consume stream } + + // User-specified temperature should take precedence over everything + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 0.7, + }), + ) }) - // Create and consume the stream - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() - - // Verify the usage data - expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ - type: "usage", - inputTokens: 10, - outputTokens: 20, - cacheReadTokens: 5, - cacheWriteTokens: 15, + it("should handle stream with multiple chunks", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Hello" } + yield { type: "text-delta", text: " world" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 5, outputTokens: 10 }), + providerMetadata: Promise.resolve({}), + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const textChunks = chunks.filter((c) => c.type === "text") + expect(textChunks[0]).toEqual({ type: "text", text: "Hello" }) + expect(textChunks[1]).toEqual({ type: "text", text: " world" }) + + const usageChunks = chunks.filter((c) => c.type === "usage") + expect(usageChunks[0]).toMatchObject({ type: "usage", inputTokens: 5, outputTokens: 10 }) }) - }) - it("createMessage should pass correct parameters to OpenAI client", async () => { - // Setup a handler with specific model - const modelId = "grok-3" - const modelInfo = xaiModels[modelId] - const handlerWithModel = new XAIHandler({ apiModelId: modelId }) - - // Setup mock for streaming response - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), + it("should handle reasoning content from stream", async () => { + async function* mockFullStream() { + yield { type: "reasoning-delta", text: "Let me think about this..." } + yield { type: "text-delta", text: "Here is my answer" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 20 }), + providerMetadata: Promise.resolve({}), + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) } + + const reasoningChunks = chunks.filter((c) => c.type === "reasoning") + expect(reasoningChunks).toHaveLength(1) + expect(reasoningChunks[0].text).toBe("Let me think about this...") + + const textChunks = chunks.filter((c) => c.type === "text") + expect(textChunks).toHaveLength(1) + expect(textChunks[0].text).toBe("Here is my answer") }) - // System prompt and messages - const systemPrompt = "Test system prompt" - const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message" }] - - // Start generating a message - const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages) - await messageGenerator.next() // Start the generator - - // Check that all parameters were passed correctly - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - model: modelId, - max_tokens: modelInfo.maxTokens, - temperature: 0, - messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]), - stream: true, - stream_options: { include_usage: true }, - }), - ) - }) + it("should handle errors during streaming", async () => { + const mockError = new Error("API error") + ;(mockError as any).name = "AI_APICallError" + ;(mockError as any).status = 500 - describe("Native Tool Calling", () => { - const testTools = [ - { - type: "function" as const, - function: { - name: "test_tool", - description: "A test tool", - parameters: { - type: "object", - properties: { - arg1: { type: "string", description: "First argument" }, - }, - required: ["arg1"], - }, - }, - }, - ] + async function* mockFullStream(): AsyncGenerator { + // This yield is unreachable but needed to satisfy the require-yield lint rule + yield undefined as never + throw mockError + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), + }) - it("should include tools in request when model supports native tools and tools are provided (native is default)", async () => { - const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" }) + const stream = handler.createMessage(systemPrompt, messages) - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), + await expect(async () => { + for await (const _ of stream) { + // consume stream } - }) + }).rejects.toThrow("xAI") + }) + }) - const messageGenerator = handlerWithTools.createMessage("test prompt", [], { - taskId: "test-task-id", - tools: testTools, + describe("completePrompt", () => { + it("should complete a prompt using generateText", async () => { + mockGenerateText.mockResolvedValue({ + text: "Test completion from xAI", }) - await messageGenerator.next() - expect(mockCreate).toHaveBeenCalledWith( + const result = await handler.completePrompt("Test prompt") + + expect(result).toBe("Test completion from xAI") + expect(mockGenerateText).toHaveBeenCalledWith( expect.objectContaining({ - tools: expect.arrayContaining([ - expect.objectContaining({ - type: "function", - function: expect.objectContaining({ - name: "test_tool", - }), - }), - ]), - parallel_tool_calls: true, + prompt: "Test prompt", }), ) }) - it("should include tool_choice when provided", async () => { - const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" }) - - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), - } + it("should use default temperature in completePrompt", async () => { + mockGenerateText.mockResolvedValue({ + text: "Test completion", }) - const messageGenerator = handlerWithTools.createMessage("test prompt", [], { - taskId: "test-task-id", - tools: testTools, - tool_choice: "auto", - }) - await messageGenerator.next() + await handler.completePrompt("Test prompt") - expect(mockCreate).toHaveBeenCalledWith( + expect(mockGenerateText).toHaveBeenCalledWith( expect.objectContaining({ - tool_choice: "auto", + temperature: 0, }), ) }) - it("should always include tools and tool_choice (tools are guaranteed to be present after ALWAYS_AVAILABLE_TOOLS)", async () => { - const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" }) + it("should handle errors in completePrompt", async () => { + const mockError = new Error("API error") + ;(mockError as any).name = "AI_APICallError" + mockGenerateText.mockRejectedValue(mockError) - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), + await expect(handler.completePrompt("Test prompt")).rejects.toThrow("xAI") + }) + }) + + describe("processUsageMetrics", () => { + it("should correctly process usage metrics including cache information from providerMetadata", () => { + class TestXAIHandler extends XAIHandler { + public testProcessUsageMetrics(usage: any, providerMetadata?: any) { + return this.processUsageMetrics(usage, providerMetadata) } - }) + } - const messageGenerator = handlerWithTools.createMessage("test prompt", [], { - taskId: "test-task-id", - }) - await messageGenerator.next() + const testHandler = new TestXAIHandler(mockOptions) + + const usage = { + inputTokens: 100, + outputTokens: 50, + } + + const providerMetadata = { + xai: { + cachedPromptTokens: 20, + }, + } + + const result = testHandler.testProcessUsageMetrics(usage, providerMetadata) + + expect(result.type).toBe("usage") + expect(result.inputTokens).toBe(100) + expect(result.outputTokens).toBe(50) + expect(result.cacheReadTokens).toBe(20) + // xAI doesn't report cache write tokens separately + expect(result.cacheWriteTokens).toBeUndefined() + }) + + it("should handle missing cache metrics gracefully", () => { + class TestXAIHandler extends XAIHandler { + public testProcessUsageMetrics(usage: any, providerMetadata?: any) { + return this.processUsageMetrics(usage, providerMetadata) + } + } + + const testHandler = new TestXAIHandler(mockOptions) + + const usage = { + inputTokens: 100, + outputTokens: 50, + } + + const result = testHandler.testProcessUsageMetrics(usage) + + expect(result.type).toBe("usage") + expect(result.inputTokens).toBe(100) + expect(result.outputTokens).toBe(50) + expect(result.cacheWriteTokens).toBeUndefined() + expect(result.cacheReadTokens).toBeUndefined() + }) + + it("should include reasoning tokens when provided", () => { + class TestXAIHandler extends XAIHandler { + public testProcessUsageMetrics(usage: any, providerMetadata?: any) { + return this.processUsageMetrics(usage, providerMetadata) + } + } - // Tools are now always present (minimum 6 from ALWAYS_AVAILABLE_TOOLS) - const callArgs = mockCreate.mock.calls[mockCreate.mock.calls.length - 1][0] - expect(callArgs).toHaveProperty("tools") - expect(callArgs).toHaveProperty("tool_choice") - expect(callArgs).toHaveProperty("parallel_tool_calls", true) + const testHandler = new TestXAIHandler(mockOptions) + + const usage = { + inputTokens: 100, + outputTokens: 50, + details: { + reasoningTokens: 30, + }, + } + + const result = testHandler.testProcessUsageMetrics(usage) + + expect(result.reasoningTokens).toBe(30) }) + }) - it("should yield tool_call_partial chunks during streaming", async () => { - const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" }) - - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_123", - function: { - name: "test_tool", - arguments: '{"arg1":', - }, - }, - ], - }, - }, - ], - }, - }) - .mockResolvedValueOnce({ - done: false, - value: { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - function: { - arguments: '"value"}', - }, - }, - ], - }, - }, - ], - }, - }) - .mockResolvedValueOnce({ done: true }), - }), + describe("tool handling", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text" as const, text: "Hello!" }], + }, + ] + + it("should handle tool calls in streaming", async () => { + async function* mockFullStream() { + yield { + type: "tool-input-start", + id: "tool-call-1", + toolName: "read_file", + } + yield { + type: "tool-input-delta", + id: "tool-call-1", + delta: '{"path":"test.ts"}', + } + yield { + type: "tool-input-end", + id: "tool-call-1", } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, }) - const stream = handlerWithTools.createMessage("test prompt", [], { - taskId: "test-task-id", - tools: testTools, + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages, { + taskId: "test-task", + tools: [ + { + type: "function", + function: { + name: "read_file", + description: "Read a file", + parameters: { + type: "object", + properties: { path: { type: "string" } }, + required: ["path"], + }, + }, + }, + ], }) - const chunks = [] + const chunks: any[] = [] for await (const chunk of stream) { chunks.push(chunk) } - expect(chunks).toContainEqual({ - type: "tool_call_partial", - index: 0, - id: "call_123", - name: "test_tool", - arguments: '{"arg1":', + const toolCallStartChunks = chunks.filter((c) => c.type === "tool_call_start") + const toolCallDeltaChunks = chunks.filter((c) => c.type === "tool_call_delta") + const toolCallEndChunks = chunks.filter((c) => c.type === "tool_call_end") + + expect(toolCallStartChunks.length).toBe(1) + expect(toolCallStartChunks[0].id).toBe("tool-call-1") + expect(toolCallStartChunks[0].name).toBe("read_file") + + expect(toolCallDeltaChunks.length).toBe(1) + expect(toolCallDeltaChunks[0].delta).toBe('{"path":"test.ts"}') + + expect(toolCallEndChunks.length).toBe(1) + expect(toolCallEndChunks[0].id).toBe("tool-call-1") + }) + + it("should ignore tool-call events to prevent duplicate tools in UI", async () => { + async function* mockFullStream() { + yield { + type: "tool-call", + toolCallId: "tool-call-1", + toolName: "read_file", + input: { path: "test.ts" }, + } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, }) - expect(chunks).toContainEqual({ - type: "tool_call_partial", - index: 0, - id: undefined, - name: undefined, - arguments: '"value"}', + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // tool-call events should be ignored (only tool-input-start/delta/end are processed) + const toolCallChunks = chunks.filter( + (c) => c.type === "tool_call_start" || c.type === "tool_call_delta" || c.type === "tool_call_end", + ) + expect(toolCallChunks.length).toBe(0) }) - it("should set parallel_tool_calls based on metadata", async () => { - const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" }) + it("should pass tools to streamText when provided", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test" } + } - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), - } + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), }) - const messageGenerator = handlerWithTools.createMessage("test prompt", [], { - taskId: "test-task-id", + const testTools = [ + { + type: "function" as const, + function: { + name: "test_tool", + description: "A test tool", + parameters: { + type: "object", + properties: { + arg1: { type: "string", description: "First argument" }, + }, + required: ["arg1"], + }, + }, + }, + ] + + const stream = handler.createMessage(systemPrompt, messages, { + taskId: "test-task", tools: testTools, - parallelToolCalls: true, + tool_choice: "auto", }) - await messageGenerator.next() - expect(mockCreate).toHaveBeenCalledWith( + for await (const _ of stream) { + // consume stream + } + + expect(mockStreamText).toHaveBeenCalledWith( expect.objectContaining({ - parallel_tool_calls: true, + tools: expect.any(Object), + toolChoice: "auto", }), ) }) + }) - it("should yield tool_call_end events when finish_reason is tool_calls", async () => { - // Import NativeToolCallParser to set up state - const { NativeToolCallParser } = await import("../../../core/assistant-message/NativeToolCallParser") - - // Clear any previous state - NativeToolCallParser.clearRawChunkState() - - const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" }) - - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_xai_test", - function: { - name: "test_tool", - arguments: '{"arg1":"value"}', - }, - }, - ], - }, - }, - ], - }, - }) - .mockResolvedValueOnce({ - done: false, - value: { - choices: [ - { - delta: {}, - finish_reason: "tool_calls", - }, - ], - usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, - }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } + describe("reasoning effort (mini models)", () => { + it("should include reasoning effort for grok-3-mini model", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), }) - const stream = handlerWithTools.createMessage("test prompt", [], { - taskId: "test-task-id", - tools: testTools, + const miniModelHandler = new XAIHandler({ + xaiApiKey: "test-key", + apiModelId: "grok-3-mini", + reasoningEffort: "high", }) - const chunks = [] - for await (const chunk of stream) { - // Simulate what Task.ts does: when we receive tool_call_partial, - // process it through NativeToolCallParser to populate rawChunkTracker - if (chunk.type === "tool_call_partial") { - NativeToolCallParser.processRawChunk({ - index: chunk.index, - id: chunk.id, - name: chunk.name, - arguments: chunk.arguments, - }) - } - chunks.push(chunk) + const stream = miniModelHandler.createMessage("test prompt", []) + for await (const _ of stream) { + // consume stream } - // Should have tool_call_partial and tool_call_end - const partialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial") - const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end") - - expect(partialChunks).toHaveLength(1) - expect(endChunks).toHaveLength(1) - expect(endChunks[0].id).toBe("call_xai_test") + // Check that provider options are passed for reasoning + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + providerOptions: expect.any(Object), + }), + ) }) }) }) diff --git a/src/api/providers/xai.ts b/src/api/providers/xai.ts index 8df9cc66eca..238dbeaf2de 100644 --- a/src/api/providers/xai.ts +++ b/src/api/providers/xai.ts @@ -1,166 +1,190 @@ import { Anthropic } from "@anthropic-ai/sdk" -import OpenAI from "openai" +import { createXai } from "@ai-sdk/xai" +import { streamText, generateText, ToolSet } from "ai" -import { type XAIModelId, xaiDefaultModelId, xaiModels, ApiProviderError } from "@roo-code/types" -import { TelemetryService } from "@roo-code/telemetry" +import { type XAIModelId, xaiDefaultModelId, xaiModels, type ModelInfo } from "@roo-code/types" -import { NativeToolCallParser } from "../../core/assistant-message/NativeToolCallParser" import type { ApiHandlerOptions } from "../../shared/api" -import { ApiStream } from "../transform/stream" -import { convertToOpenAiMessages } from "../transform/openai-format" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, + handleAiSdkError, +} from "../transform/ai-sdk" +import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { getModelParams } from "../transform/model-params" import { DEFAULT_HEADERS } from "./constants" import { BaseProvider } from "./base-provider" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" -import { handleOpenAIError } from "./utils/openai-error-handler" const XAI_DEFAULT_TEMPERATURE = 0 +/** + * xAI provider using the dedicated @ai-sdk/xai package. + * Provides native support for Grok models including reasoning models. + */ export class XAIHandler extends BaseProvider implements SingleCompletionHandler { protected options: ApiHandlerOptions - private client: OpenAI - private readonly providerName = "xAI" + protected provider: ReturnType constructor(options: ApiHandlerOptions) { super() this.options = options - const apiKey = this.options.xaiApiKey ?? "not-provided" - - this.client = new OpenAI({ + // Create the xAI provider using AI SDK + this.provider = createXai({ baseURL: "https://api.x.ai/v1", - apiKey: apiKey, - defaultHeaders: DEFAULT_HEADERS, + apiKey: options.xaiApiKey ?? "not-provided", + headers: DEFAULT_HEADERS, }) } - override getModel() { + override getModel(): { + id: XAIModelId + info: ModelInfo + maxTokens?: number + temperature?: number + reasoning?: any + } { const id = this.options.apiModelId && this.options.apiModelId in xaiModels ? (this.options.apiModelId as XAIModelId) : xaiDefaultModelId const info = xaiModels[id] - const params = getModelParams({ format: "openai", modelId: id, model: info, settings: this.options }) + const params = getModelParams({ + format: "openai", + modelId: id, + model: info, + settings: this.options, + defaultTemperature: XAI_DEFAULT_TEMPERATURE, + }) return { id, info, ...params } } + /** + * Get the language model for the configured model ID. + */ + protected getLanguageModel() { + const { id } = this.getModel() + return this.provider(id) + } + + /** + * Process usage metrics from the AI SDK response. + */ + protected processUsageMetrics( + usage: { + inputTokens?: number + outputTokens?: number + details?: { + cachedInputTokens?: number + reasoningTokens?: number + } + }, + providerMetadata?: { + xai?: { + cachedPromptTokens?: number + } + }, + ): ApiStreamUsageChunk { + // Extract cache metrics from xAI's providerMetadata if available + // xAI supports prompt caching through prompt_tokens_details.cached_tokens + const cacheReadTokens = providerMetadata?.xai?.cachedPromptTokens ?? usage.details?.cachedInputTokens + + return { + type: "usage", + inputTokens: usage.inputTokens || 0, + outputTokens: usage.outputTokens || 0, + cacheReadTokens, + cacheWriteTokens: undefined, // xAI doesn't report cache write tokens separately + reasoningTokens: usage.details?.reasoningTokens, + } + } + + /** + * Get the max tokens parameter to include in the request. + */ + protected getMaxOutputTokens(): number | undefined { + const { info } = this.getModel() + return this.options.modelMaxTokens || info.maxTokens || undefined + } + + /** + * Create a message stream using the AI SDK. + */ override async *createMessage( systemPrompt: string, messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { - const { id: modelId, info: modelInfo, reasoning } = this.getModel() - - // Use the OpenAI-compatible API. - const requestOptions = { - model: modelId, - max_tokens: modelInfo.maxTokens, - temperature: this.options.modelTemperature ?? XAI_DEFAULT_TEMPERATURE, - messages: [ - { role: "system", content: systemPrompt }, - ...convertToOpenAiMessages(messages), - ] as OpenAI.Chat.ChatCompletionMessageParam[], - stream: true as const, - stream_options: { include_usage: true }, - ...(reasoning && reasoning), - tools: this.convertToolsForOpenAI(metadata?.tools), - tool_choice: metadata?.tool_choice, - parallel_tool_calls: metadata?.parallelToolCalls ?? true, + const { temperature, reasoning } = this.getModel() + const languageModel = this.getLanguageModel() + + // Convert messages to AI SDK format + const aiSdkMessages = convertToAiSdkMessages(messages) + + // Convert tools to OpenAI format first, then to AI SDK format + const openAiTools = this.convertToolsForOpenAI(metadata?.tools) + const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined + + // Build the request options + const requestOptions: Parameters[0] = { + model: languageModel, + system: systemPrompt, + messages: aiSdkMessages, + temperature: this.options.modelTemperature ?? temperature ?? XAI_DEFAULT_TEMPERATURE, + maxOutputTokens: this.getMaxOutputTokens(), + tools: aiSdkTools, + toolChoice: mapToolChoice(metadata?.tool_choice), + ...(reasoning && { providerOptions: { xai: reasoning } }), } - let stream - try { - stream = await this.client.chat.completions.create(requestOptions) - } catch (error) { - const errorMessage = error instanceof Error ? error.message : String(error) - const apiError = new ApiProviderError(errorMessage, this.providerName, modelId, "createMessage") - TelemetryService.instance.captureException(apiError) - throw handleOpenAIError(error, this.providerName) - } + // Use streamText for streaming responses + const result = streamText(requestOptions) - for await (const chunk of stream) { - const delta = chunk.choices[0]?.delta - const finishReason = chunk.choices[0]?.finish_reason - - if (delta?.content) { - yield { - type: "text", - text: delta.content, - } - } - - if (delta && "reasoning_content" in delta && delta.reasoning_content) { - yield { - type: "reasoning", - text: delta.reasoning_content as string, - } - } - - // Handle tool calls in stream - emit partial chunks for NativeToolCallParser - if (delta?.tool_calls) { - for (const toolCall of delta.tool_calls) { - yield { - type: "tool_call_partial", - index: toolCall.index, - id: toolCall.id, - name: toolCall.function?.name, - arguments: toolCall.function?.arguments, - } - } - } - - // Process finish_reason to emit tool_call_end events - // This ensures tool calls are finalized even if the stream doesn't properly close - if (finishReason) { - const endEvents = NativeToolCallParser.processFinishReason(finishReason) - for (const event of endEvents) { - yield event + try { + // Process the full stream to get all events including reasoning + for await (const part of result.fullStream) { + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk } } - if (chunk.usage) { - // Extract detailed token information if available - // First check for prompt_tokens_details structure (real API response) - const promptDetails = "prompt_tokens_details" in chunk.usage ? chunk.usage.prompt_tokens_details : null - const cachedTokens = promptDetails && "cached_tokens" in promptDetails ? promptDetails.cached_tokens : 0 - - // Fall back to direct fields in usage (used in test mocks) - const readTokens = - cachedTokens || - ("cache_read_input_tokens" in chunk.usage ? (chunk.usage as any).cache_read_input_tokens : 0) - const writeTokens = - "cache_creation_input_tokens" in chunk.usage ? (chunk.usage as any).cache_creation_input_tokens : 0 - - yield { - type: "usage", - inputTokens: chunk.usage.prompt_tokens || 0, - outputTokens: chunk.usage.completion_tokens || 0, - cacheReadTokens: readTokens, - cacheWriteTokens: writeTokens, - } + // Yield usage metrics at the end, including cache metrics from providerMetadata + const usage = await result.usage + const providerMetadata = await result.providerMetadata + if (usage) { + yield this.processUsageMetrics(usage, providerMetadata as any) } + } catch (error) { + // Handle AI SDK errors (AI_RetryError, AI_APICallError, etc.) + throw handleAiSdkError(error, "xAI") } } + /** + * Complete a prompt using the AI SDK generateText. + */ async completePrompt(prompt: string): Promise { - const { id: modelId, reasoning } = this.getModel() + const { temperature, reasoning } = this.getModel() + const languageModel = this.getLanguageModel() try { - const response = await this.client.chat.completions.create({ - model: modelId, - messages: [{ role: "user", content: prompt }], - ...(reasoning && reasoning), + const { text } = await generateText({ + model: languageModel, + prompt, + maxOutputTokens: this.getMaxOutputTokens(), + temperature: this.options.modelTemperature ?? temperature ?? XAI_DEFAULT_TEMPERATURE, + ...(reasoning && { providerOptions: { xai: reasoning } }), }) - return response.choices[0]?.message.content || "" + return text } catch (error) { - const errorMessage = error instanceof Error ? error.message : String(error) - const apiError = new ApiProviderError(errorMessage, this.providerName, modelId, "completePrompt") - TelemetryService.instance.captureException(apiError) - throw handleOpenAIError(error, this.providerName) + throw handleAiSdkError(error, "xAI") } } } diff --git a/src/package.json b/src/package.json index 6292fd15940..04402de28af 100644 --- a/src/package.json +++ b/src/package.json @@ -455,6 +455,7 @@ "@ai-sdk/fireworks": "^2.0.26", "@ai-sdk/groq": "^3.0.19", "@ai-sdk/mistral": "^3.0.0", + "@ai-sdk/xai": "^3.0.46", "sambanova-ai-provider": "^1.2.2", "@anthropic-ai/bedrock-sdk": "^0.10.2", "@anthropic-ai/sdk": "^0.37.0",