diff --git a/packages/types/src/providers/xai.ts b/packages/types/src/providers/xai.ts index 3189e593da3..ac328ed05b5 100644 --- a/packages/types/src/providers/xai.ts +++ b/packages/types/src/providers/xai.ts @@ -11,17 +11,71 @@ export const xaiModels = { contextWindow: 262_144, supportsImages: false, supportsPromptCache: true, + supportsNativeTools: true, inputPrice: 0.2, outputPrice: 1.5, cacheWritesPrice: 0.02, cacheReadsPrice: 0.02, description: "xAI's Grok Code Fast model with 256K context window", }, + "grok-4-1-fast-reasoning": { + maxTokens: 65_536, + contextWindow: 2_000_000, + supportsImages: true, + supportsPromptCache: true, + supportsNativeTools: true, + inputPrice: 0.2, + outputPrice: 0.5, + cacheWritesPrice: 0.05, + cacheReadsPrice: 0.05, + description: + "xAI's Grok 4.1 Fast model with 2M context window, optimized for high-performance agentic tool calling with reasoning", + }, + "grok-4-1-fast-non-reasoning": { + maxTokens: 65_536, + contextWindow: 2_000_000, + supportsImages: true, + supportsPromptCache: true, + supportsNativeTools: true, + inputPrice: 0.2, + outputPrice: 0.5, + cacheWritesPrice: 0.05, + cacheReadsPrice: 0.05, + description: + "xAI's Grok 4.1 Fast model with 2M context window, optimized for high-performance agentic tool calling", + }, + "grok-4-fast-reasoning": { + maxTokens: 65_536, + contextWindow: 2_000_000, + supportsImages: true, + supportsPromptCache: true, + supportsNativeTools: true, + inputPrice: 0.2, + outputPrice: 0.5, + cacheWritesPrice: 0.05, + cacheReadsPrice: 0.05, + description: + "xAI's Grok 4 Fast model with 2M context window, optimized for high-performance agentic tool calling with reasoning", + }, + "grok-4-fast-non-reasoning": { + maxTokens: 65_536, + contextWindow: 2_000_000, + supportsImages: true, + supportsPromptCache: true, + supportsNativeTools: true, + inputPrice: 0.2, + outputPrice: 0.5, + cacheWritesPrice: 0.05, + cacheReadsPrice: 0.05, + description: + "xAI's Grok 4 Fast model with 2M context window, optimized for high-performance agentic tool calling", + }, "grok-4": { maxTokens: 8192, contextWindow: 256000, supportsImages: true, supportsPromptCache: true, + supportsNativeTools: true, inputPrice: 3.0, outputPrice: 15.0, cacheWritesPrice: 0.75, @@ -33,6 +87,7 @@ export const xaiModels = { contextWindow: 131072, supportsImages: false, supportsPromptCache: true, + supportsNativeTools: true, inputPrice: 3.0, outputPrice: 15.0, cacheWritesPrice: 0.75, @@ -44,6 +99,7 @@ export const xaiModels = { contextWindow: 131072, supportsImages: false, supportsPromptCache: true, + supportsNativeTools: true, inputPrice: 5.0, outputPrice: 25.0, cacheWritesPrice: 1.25, @@ -55,6 +111,7 @@ export const xaiModels = { contextWindow: 131072, supportsImages: false, supportsPromptCache: true, + supportsNativeTools: true, inputPrice: 0.3, outputPrice: 0.5, cacheWritesPrice: 0.07, @@ -67,6 +124,7 @@ export const xaiModels = { contextWindow: 131072, supportsImages: false, supportsPromptCache: true, + supportsNativeTools: true, inputPrice: 0.6, outputPrice: 4.0, cacheWritesPrice: 0.15, @@ -79,6 +137,7 @@ export const xaiModels = { contextWindow: 131072, supportsImages: false, supportsPromptCache: false, + supportsNativeTools: true, inputPrice: 2.0, outputPrice: 10.0, description: "xAI's Grok-2 model (version 1212) with 128K context window", @@ -88,6 +147,7 @@ export const xaiModels = { contextWindow: 32768, supportsImages: true, supportsPromptCache: false, + supportsNativeTools: true, inputPrice: 2.0, outputPrice: 10.0, description: "xAI's Grok-2 Vision model (version 1212) with image support and 32K context window", diff --git a/src/api/providers/__tests__/xai.spec.ts b/src/api/providers/__tests__/xai.spec.ts index 1d3d4a15093..52e6e33689d 100644 --- a/src/api/providers/__tests__/xai.spec.ts +++ b/src/api/providers/__tests__/xai.spec.ts @@ -280,4 +280,220 @@ describe("XAIHandler", () => { }), ) }) + + describe("Native Tool Calling", () => { + const testTools = [ + { + type: "function" as const, + function: { + name: "test_tool", + description: "A test tool", + parameters: { + type: "object", + properties: { + arg1: { type: "string", description: "First argument" }, + }, + required: ["arg1"], + }, + }, + }, + ] + + it("should include tools in request when model supports native tools and tools are provided", async () => { + const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" }) + + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + async next() { + return { done: true } + }, + }), + } + }) + + const messageGenerator = handlerWithTools.createMessage("test prompt", [], { + taskId: "test-task-id", + tools: testTools, + toolProtocol: "native", + }) + await messageGenerator.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + tools: expect.arrayContaining([ + expect.objectContaining({ + type: "function", + function: expect.objectContaining({ + name: "test_tool", + }), + }), + ]), + parallel_tool_calls: false, + }), + ) + }) + + it("should include tool_choice when provided", async () => { + const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" }) + + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + async next() { + return { done: true } + }, + }), + } + }) + + const messageGenerator = handlerWithTools.createMessage("test prompt", [], { + taskId: "test-task-id", + tools: testTools, + toolProtocol: "native", + tool_choice: "auto", + }) + await messageGenerator.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + tool_choice: "auto", + }), + ) + }) + + it("should not include tools when toolProtocol is xml", async () => { + const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" }) + + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + async next() { + return { done: true } + }, + }), + } + }) + + const messageGenerator = handlerWithTools.createMessage("test prompt", [], { + taskId: "test-task-id", + tools: testTools, + toolProtocol: "xml", + }) + await messageGenerator.next() + + const callArgs = mockCreate.mock.calls[mockCreate.mock.calls.length - 1][0] + expect(callArgs).not.toHaveProperty("tools") + expect(callArgs).not.toHaveProperty("tool_choice") + }) + + it("should yield tool_call_partial chunks during streaming", async () => { + const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" }) + + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + next: vi + .fn() + .mockResolvedValueOnce({ + done: false, + value: { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: "call_123", + function: { + name: "test_tool", + arguments: '{"arg1":', + }, + }, + ], + }, + }, + ], + }, + }) + .mockResolvedValueOnce({ + done: false, + value: { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + function: { + arguments: '"value"}', + }, + }, + ], + }, + }, + ], + }, + }) + .mockResolvedValueOnce({ done: true }), + }), + } + }) + + const stream = handlerWithTools.createMessage("test prompt", [], { + taskId: "test-task-id", + tools: testTools, + toolProtocol: "native", + }) + + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks).toContainEqual({ + type: "tool_call_partial", + index: 0, + id: "call_123", + name: "test_tool", + arguments: '{"arg1":', + }) + + expect(chunks).toContainEqual({ + type: "tool_call_partial", + index: 0, + id: undefined, + name: undefined, + arguments: '"value"}', + }) + }) + + it("should set parallel_tool_calls based on metadata", async () => { + const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" }) + + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + async next() { + return { done: true } + }, + }), + } + }) + + const messageGenerator = handlerWithTools.createMessage("test prompt", [], { + taskId: "test-task-id", + tools: testTools, + toolProtocol: "native", + parallelToolCalls: true, + }) + await messageGenerator.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + parallel_tool_calls: true, + }), + ) + }) + }) }) diff --git a/src/api/providers/xai.ts b/src/api/providers/xai.ts index 7eb6e9866dd..cf19dc6f4c6 100644 --- a/src/api/providers/xai.ts +++ b/src/api/providers/xai.ts @@ -52,6 +52,11 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler ): ApiStream { const { id: modelId, info: modelInfo, reasoning } = this.getModel() + // Check if model supports native tools and tools are provided with native protocol + const supportsNativeTools = modelInfo.supportsNativeTools ?? false + const useNativeTools = + supportsNativeTools && metadata?.tools && metadata.tools.length > 0 && metadata?.toolProtocol !== "xml" + // Use the OpenAI-compatible API. let stream try { @@ -63,6 +68,9 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler stream: true, stream_options: { include_usage: true }, ...(reasoning && reasoning), + ...(useNativeTools && { tools: this.convertToolsForOpenAI(metadata.tools) }), + ...(useNativeTools && metadata.tool_choice && { tool_choice: metadata.tool_choice }), + ...(useNativeTools && { parallel_tool_calls: metadata?.parallelToolCalls ?? false }), }) } catch (error) { throw handleOpenAIError(error, this.providerName) @@ -85,6 +93,19 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler } } + // Handle tool calls in stream - emit partial chunks for NativeToolCallParser + if (delta?.tool_calls) { + for (const toolCall of delta.tool_calls) { + yield { + type: "tool_call_partial", + index: toolCall.index, + id: toolCall.id, + name: toolCall.function?.name, + arguments: toolCall.function?.arguments, + } + } + } + if (chunk.usage) { // Extract detailed token information if available // First check for prompt_tokens_details structure (real API response)