Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions packages/types/src/providers/xai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,71 @@ export const xaiModels = {
contextWindow: 262_144,
supportsImages: false,
supportsPromptCache: true,
supportsNativeTools: true,
inputPrice: 0.2,
outputPrice: 1.5,
cacheWritesPrice: 0.02,
cacheReadsPrice: 0.02,
description: "xAI's Grok Code Fast model with 256K context window",
},
"grok-4-1-fast-reasoning": {
maxTokens: 65_536,
contextWindow: 2_000_000,
supportsImages: true,
supportsPromptCache: true,
supportsNativeTools: true,
inputPrice: 0.2,
outputPrice: 0.5,
cacheWritesPrice: 0.05,
cacheReadsPrice: 0.05,
description:
"xAI's Grok 4.1 Fast model with 2M context window, optimized for high-performance agentic tool calling with reasoning",
},
"grok-4-1-fast-non-reasoning": {
maxTokens: 65_536,
contextWindow: 2_000_000,
supportsImages: true,
supportsPromptCache: true,
supportsNativeTools: true,
inputPrice: 0.2,
outputPrice: 0.5,
cacheWritesPrice: 0.05,
cacheReadsPrice: 0.05,
description:
"xAI's Grok 4.1 Fast model with 2M context window, optimized for high-performance agentic tool calling",
},
"grok-4-fast-reasoning": {
maxTokens: 65_536,
contextWindow: 2_000_000,
supportsImages: true,
supportsPromptCache: true,
supportsNativeTools: true,
inputPrice: 0.2,
outputPrice: 0.5,
cacheWritesPrice: 0.05,
cacheReadsPrice: 0.05,
description:
"xAI's Grok 4 Fast model with 2M context window, optimized for high-performance agentic tool calling with reasoning",
},
"grok-4-fast-non-reasoning": {
maxTokens: 65_536,
contextWindow: 2_000_000,
supportsImages: true,
supportsPromptCache: true,
supportsNativeTools: true,
inputPrice: 0.2,
outputPrice: 0.5,
cacheWritesPrice: 0.05,
cacheReadsPrice: 0.05,
description:
"xAI's Grok 4 Fast model with 2M context window, optimized for high-performance agentic tool calling",
},
"grok-4": {
maxTokens: 8192,
contextWindow: 256000,
supportsImages: true,
supportsPromptCache: true,
supportsNativeTools: true,
inputPrice: 3.0,
outputPrice: 15.0,
cacheWritesPrice: 0.75,
Expand All @@ -33,6 +87,7 @@ export const xaiModels = {
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: true,
supportsNativeTools: true,
inputPrice: 3.0,
outputPrice: 15.0,
cacheWritesPrice: 0.75,
Expand All @@ -44,6 +99,7 @@ export const xaiModels = {
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: true,
supportsNativeTools: true,
inputPrice: 5.0,
outputPrice: 25.0,
cacheWritesPrice: 1.25,
Expand All @@ -55,6 +111,7 @@ export const xaiModels = {
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: true,
supportsNativeTools: true,
inputPrice: 0.3,
outputPrice: 0.5,
cacheWritesPrice: 0.07,
Expand All @@ -67,6 +124,7 @@ export const xaiModels = {
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: true,
supportsNativeTools: true,
inputPrice: 0.6,
outputPrice: 4.0,
cacheWritesPrice: 0.15,
Expand All @@ -79,6 +137,7 @@ export const xaiModels = {
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: false,
supportsNativeTools: true,
inputPrice: 2.0,
outputPrice: 10.0,
description: "xAI's Grok-2 model (version 1212) with 128K context window",
Expand All @@ -88,6 +147,7 @@ export const xaiModels = {
contextWindow: 32768,
supportsImages: true,
supportsPromptCache: false,
supportsNativeTools: true,
inputPrice: 2.0,
outputPrice: 10.0,
description: "xAI's Grok-2 Vision model (version 1212) with image support and 32K context window",
Expand Down
216 changes: 216 additions & 0 deletions src/api/providers/__tests__/xai.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -280,4 +280,220 @@ describe("XAIHandler", () => {
}),
)
})

describe("Native Tool Calling", () => {
const testTools = [
{
type: "function" as const,
function: {
name: "test_tool",
description: "A test tool",
parameters: {
type: "object",
properties: {
arg1: { type: "string", description: "First argument" },
},
required: ["arg1"],
},
},
},
]

it("should include tools in request when model supports native tools and tools are provided", async () => {
const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })

mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
async next() {
return { done: true }
},
}),
}
})

const messageGenerator = handlerWithTools.createMessage("test prompt", [], {
taskId: "test-task-id",
tools: testTools,
toolProtocol: "native",
})
await messageGenerator.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
tools: expect.arrayContaining([
expect.objectContaining({
type: "function",
function: expect.objectContaining({
name: "test_tool",
}),
}),
]),
parallel_tool_calls: false,
}),
)
})

it("should include tool_choice when provided", async () => {
const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })

mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
async next() {
return { done: true }
},
}),
}
})

const messageGenerator = handlerWithTools.createMessage("test prompt", [], {
taskId: "test-task-id",
tools: testTools,
toolProtocol: "native",
tool_choice: "auto",
})
await messageGenerator.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
tool_choice: "auto",
}),
)
})

it("should not include tools when toolProtocol is xml", async () => {
const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })

mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
async next() {
return { done: true }
},
}),
}
})

const messageGenerator = handlerWithTools.createMessage("test prompt", [], {
taskId: "test-task-id",
tools: testTools,
toolProtocol: "xml",
})
await messageGenerator.next()

const callArgs = mockCreate.mock.calls[mockCreate.mock.calls.length - 1][0]
expect(callArgs).not.toHaveProperty("tools")
expect(callArgs).not.toHaveProperty("tool_choice")
})

it("should yield tool_call_partial chunks during streaming", async () => {
const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })

mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: vi
.fn()
.mockResolvedValueOnce({
done: false,
value: {
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: "call_123",
function: {
name: "test_tool",
arguments: '{"arg1":',
},
},
],
},
},
],
},
})
.mockResolvedValueOnce({
done: false,
value: {
choices: [
{
delta: {
tool_calls: [
{
index: 0,
function: {
arguments: '"value"}',
},
},
],
},
},
],
},
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

const stream = handlerWithTools.createMessage("test prompt", [], {
taskId: "test-task-id",
tools: testTools,
toolProtocol: "native",
})

const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

expect(chunks).toContainEqual({
type: "tool_call_partial",
index: 0,
id: "call_123",
name: "test_tool",
arguments: '{"arg1":',
})

expect(chunks).toContainEqual({
type: "tool_call_partial",
index: 0,
id: undefined,
name: undefined,
arguments: '"value"}',
})
})

it("should set parallel_tool_calls based on metadata", async () => {
const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })

mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
async next() {
return { done: true }
},
}),
}
})

const messageGenerator = handlerWithTools.createMessage("test prompt", [], {
taskId: "test-task-id",
tools: testTools,
toolProtocol: "native",
parallelToolCalls: true,
})
await messageGenerator.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
parallel_tool_calls: true,
}),
)
})
})
})
21 changes: 21 additions & 0 deletions src/api/providers/xai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler
): ApiStream {
const { id: modelId, info: modelInfo, reasoning } = this.getModel()

// Check if model supports native tools and tools are provided with native protocol
const supportsNativeTools = modelInfo.supportsNativeTools ?? false
const useNativeTools =
supportsNativeTools && metadata?.tools && metadata.tools.length > 0 && metadata?.toolProtocol !== "xml"

// Use the OpenAI-compatible API.
let stream
try {
Expand All @@ -63,6 +68,9 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler
stream: true,
stream_options: { include_usage: true },
...(reasoning && reasoning),
...(useNativeTools && { tools: this.convertToolsForOpenAI(metadata.tools) }),
...(useNativeTools && metadata.tool_choice && { tool_choice: metadata.tool_choice }),
...(useNativeTools && { parallel_tool_calls: metadata?.parallelToolCalls ?? false }),
})
} catch (error) {
throw handleOpenAIError(error, this.providerName)
Expand All @@ -85,6 +93,19 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler
}
}

// Handle tool calls in stream - emit partial chunks for NativeToolCallParser
if (delta?.tool_calls) {
for (const toolCall of delta.tool_calls) {
yield {
type: "tool_call_partial",
index: toolCall.index,
id: toolCall.id,
name: toolCall.function?.name,
arguments: toolCall.function?.arguments,
}
}
}

if (chunk.usage) {
// Extract detailed token information if available
// First check for prompt_tokens_details structure (real API response)
Expand Down
Loading