Skip to content

Commit d369c61

Browse files
committed
Add native tool calling support
1 parent 2831c74 commit d369c61

File tree

3 files changed

+249
-0
lines changed

3 files changed

+249
-0
lines changed

packages/types/src/providers/xai.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ export const xaiModels = {
1111
contextWindow: 262_144,
1212
supportsImages: false,
1313
supportsPromptCache: true,
14+
supportsNativeTools: true,
1415
inputPrice: 0.2,
1516
outputPrice: 1.5,
1617
cacheWritesPrice: 0.02,
@@ -22,6 +23,7 @@ export const xaiModels = {
2223
contextWindow: 2_000_000,
2324
supportsImages: true,
2425
supportsPromptCache: true,
26+
supportsNativeTools: true,
2527
inputPrice: 0.2,
2628
outputPrice: 0.5,
2729
cacheWritesPrice: 0.05,
@@ -34,6 +36,7 @@ export const xaiModels = {
3436
contextWindow: 2_000_000,
3537
supportsImages: true,
3638
supportsPromptCache: true,
39+
supportsNativeTools: true,
3740
inputPrice: 0.2,
3841
outputPrice: 0.5,
3942
cacheWritesPrice: 0.05,
@@ -46,6 +49,7 @@ export const xaiModels = {
4649
contextWindow: 2_000_000,
4750
supportsImages: true,
4851
supportsPromptCache: true,
52+
supportsNativeTools: true,
4953
inputPrice: 0.2,
5054
outputPrice: 0.5,
5155
cacheWritesPrice: 0.05,
@@ -58,6 +62,7 @@ export const xaiModels = {
5862
contextWindow: 2_000_000,
5963
supportsImages: true,
6064
supportsPromptCache: true,
65+
supportsNativeTools: true,
6166
inputPrice: 0.2,
6267
outputPrice: 0.5,
6368
cacheWritesPrice: 0.05,
@@ -70,6 +75,7 @@ export const xaiModels = {
7075
contextWindow: 256000,
7176
supportsImages: true,
7277
supportsPromptCache: true,
78+
supportsNativeTools: true,
7379
inputPrice: 3.0,
7480
outputPrice: 15.0,
7581
cacheWritesPrice: 0.75,
@@ -81,6 +87,7 @@ export const xaiModels = {
8187
contextWindow: 131072,
8288
supportsImages: false,
8389
supportsPromptCache: true,
90+
supportsNativeTools: true,
8491
inputPrice: 3.0,
8592
outputPrice: 15.0,
8693
cacheWritesPrice: 0.75,
@@ -92,6 +99,7 @@ export const xaiModels = {
9299
contextWindow: 131072,
93100
supportsImages: false,
94101
supportsPromptCache: true,
102+
supportsNativeTools: true,
95103
inputPrice: 5.0,
96104
outputPrice: 25.0,
97105
cacheWritesPrice: 1.25,
@@ -103,6 +111,7 @@ export const xaiModels = {
103111
contextWindow: 131072,
104112
supportsImages: false,
105113
supportsPromptCache: true,
114+
supportsNativeTools: true,
106115
inputPrice: 0.3,
107116
outputPrice: 0.5,
108117
cacheWritesPrice: 0.07,
@@ -115,6 +124,7 @@ export const xaiModels = {
115124
contextWindow: 131072,
116125
supportsImages: false,
117126
supportsPromptCache: true,
127+
supportsNativeTools: true,
118128
inputPrice: 0.6,
119129
outputPrice: 4.0,
120130
cacheWritesPrice: 0.15,
@@ -127,6 +137,7 @@ export const xaiModels = {
127137
contextWindow: 131072,
128138
supportsImages: false,
129139
supportsPromptCache: false,
140+
supportsNativeTools: true,
130141
inputPrice: 2.0,
131142
outputPrice: 10.0,
132143
description: "xAI's Grok-2 model (version 1212) with 128K context window",
@@ -136,6 +147,7 @@ export const xaiModels = {
136147
contextWindow: 32768,
137148
supportsImages: true,
138149
supportsPromptCache: false,
150+
supportsNativeTools: true,
139151
inputPrice: 2.0,
140152
outputPrice: 10.0,
141153
description: "xAI's Grok-2 Vision model (version 1212) with image support and 32K context window",

src/api/providers/__tests__/xai.spec.ts

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,4 +280,220 @@ describe("XAIHandler", () => {
280280
}),
281281
)
282282
})
283+
284+
describe("Native Tool Calling", () => {
285+
const testTools = [
286+
{
287+
type: "function" as const,
288+
function: {
289+
name: "test_tool",
290+
description: "A test tool",
291+
parameters: {
292+
type: "object",
293+
properties: {
294+
arg1: { type: "string", description: "First argument" },
295+
},
296+
required: ["arg1"],
297+
},
298+
},
299+
},
300+
]
301+
302+
it("should include tools in request when model supports native tools and tools are provided", async () => {
303+
const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
304+
305+
mockCreate.mockImplementationOnce(() => {
306+
return {
307+
[Symbol.asyncIterator]: () => ({
308+
async next() {
309+
return { done: true }
310+
},
311+
}),
312+
}
313+
})
314+
315+
const messageGenerator = handlerWithTools.createMessage("test prompt", [], {
316+
taskId: "test-task-id",
317+
tools: testTools,
318+
toolProtocol: "native",
319+
})
320+
await messageGenerator.next()
321+
322+
expect(mockCreate).toHaveBeenCalledWith(
323+
expect.objectContaining({
324+
tools: expect.arrayContaining([
325+
expect.objectContaining({
326+
type: "function",
327+
function: expect.objectContaining({
328+
name: "test_tool",
329+
}),
330+
}),
331+
]),
332+
parallel_tool_calls: false,
333+
}),
334+
)
335+
})
336+
337+
it("should include tool_choice when provided", async () => {
338+
const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
339+
340+
mockCreate.mockImplementationOnce(() => {
341+
return {
342+
[Symbol.asyncIterator]: () => ({
343+
async next() {
344+
return { done: true }
345+
},
346+
}),
347+
}
348+
})
349+
350+
const messageGenerator = handlerWithTools.createMessage("test prompt", [], {
351+
taskId: "test-task-id",
352+
tools: testTools,
353+
toolProtocol: "native",
354+
tool_choice: "auto",
355+
})
356+
await messageGenerator.next()
357+
358+
expect(mockCreate).toHaveBeenCalledWith(
359+
expect.objectContaining({
360+
tool_choice: "auto",
361+
}),
362+
)
363+
})
364+
365+
it("should not include tools when toolProtocol is xml", async () => {
366+
const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
367+
368+
mockCreate.mockImplementationOnce(() => {
369+
return {
370+
[Symbol.asyncIterator]: () => ({
371+
async next() {
372+
return { done: true }
373+
},
374+
}),
375+
}
376+
})
377+
378+
const messageGenerator = handlerWithTools.createMessage("test prompt", [], {
379+
taskId: "test-task-id",
380+
tools: testTools,
381+
toolProtocol: "xml",
382+
})
383+
await messageGenerator.next()
384+
385+
const callArgs = mockCreate.mock.calls[mockCreate.mock.calls.length - 1][0]
386+
expect(callArgs).not.toHaveProperty("tools")
387+
expect(callArgs).not.toHaveProperty("tool_choice")
388+
})
389+
390+
it("should yield tool_call_partial chunks during streaming", async () => {
391+
const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
392+
393+
mockCreate.mockImplementationOnce(() => {
394+
return {
395+
[Symbol.asyncIterator]: () => ({
396+
next: vi
397+
.fn()
398+
.mockResolvedValueOnce({
399+
done: false,
400+
value: {
401+
choices: [
402+
{
403+
delta: {
404+
tool_calls: [
405+
{
406+
index: 0,
407+
id: "call_123",
408+
function: {
409+
name: "test_tool",
410+
arguments: '{"arg1":',
411+
},
412+
},
413+
],
414+
},
415+
},
416+
],
417+
},
418+
})
419+
.mockResolvedValueOnce({
420+
done: false,
421+
value: {
422+
choices: [
423+
{
424+
delta: {
425+
tool_calls: [
426+
{
427+
index: 0,
428+
function: {
429+
arguments: '"value"}',
430+
},
431+
},
432+
],
433+
},
434+
},
435+
],
436+
},
437+
})
438+
.mockResolvedValueOnce({ done: true }),
439+
}),
440+
}
441+
})
442+
443+
const stream = handlerWithTools.createMessage("test prompt", [], {
444+
taskId: "test-task-id",
445+
tools: testTools,
446+
toolProtocol: "native",
447+
})
448+
449+
const chunks = []
450+
for await (const chunk of stream) {
451+
chunks.push(chunk)
452+
}
453+
454+
expect(chunks).toContainEqual({
455+
type: "tool_call_partial",
456+
index: 0,
457+
id: "call_123",
458+
name: "test_tool",
459+
arguments: '{"arg1":',
460+
})
461+
462+
expect(chunks).toContainEqual({
463+
type: "tool_call_partial",
464+
index: 0,
465+
id: undefined,
466+
name: undefined,
467+
arguments: '"value"}',
468+
})
469+
})
470+
471+
it("should set parallel_tool_calls based on metadata", async () => {
472+
const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
473+
474+
mockCreate.mockImplementationOnce(() => {
475+
return {
476+
[Symbol.asyncIterator]: () => ({
477+
async next() {
478+
return { done: true }
479+
},
480+
}),
481+
}
482+
})
483+
484+
const messageGenerator = handlerWithTools.createMessage("test prompt", [], {
485+
taskId: "test-task-id",
486+
tools: testTools,
487+
toolProtocol: "native",
488+
parallelToolCalls: true,
489+
})
490+
await messageGenerator.next()
491+
492+
expect(mockCreate).toHaveBeenCalledWith(
493+
expect.objectContaining({
494+
parallel_tool_calls: true,
495+
}),
496+
)
497+
})
498+
})
283499
})

src/api/providers/xai.ts

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler
5252
): ApiStream {
5353
const { id: modelId, info: modelInfo, reasoning } = this.getModel()
5454

55+
// Check if model supports native tools and tools are provided with native protocol
56+
const supportsNativeTools = modelInfo.supportsNativeTools ?? false
57+
const useNativeTools =
58+
supportsNativeTools && metadata?.tools && metadata.tools.length > 0 && metadata?.toolProtocol !== "xml"
59+
5560
// Use the OpenAI-compatible API.
5661
let stream
5762
try {
@@ -63,6 +68,9 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler
6368
stream: true,
6469
stream_options: { include_usage: true },
6570
...(reasoning && reasoning),
71+
...(useNativeTools && { tools: this.convertToolsForOpenAI(metadata.tools) }),
72+
...(useNativeTools && metadata.tool_choice && { tool_choice: metadata.tool_choice }),
73+
...(useNativeTools && { parallel_tool_calls: metadata?.parallelToolCalls ?? false }),
6674
})
6775
} catch (error) {
6876
throw handleOpenAIError(error, this.providerName)
@@ -85,6 +93,19 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler
8593
}
8694
}
8795

96+
// Handle tool calls in stream - emit partial chunks for NativeToolCallParser
97+
if (delta?.tool_calls) {
98+
for (const toolCall of delta.tool_calls) {
99+
yield {
100+
type: "tool_call_partial",
101+
index: toolCall.index,
102+
id: toolCall.id,
103+
name: toolCall.function?.name,
104+
arguments: toolCall.function?.arguments,
105+
}
106+
}
107+
}
108+
88109
if (chunk.usage) {
89110
// Extract detailed token information if available
90111
// First check for prompt_tokens_details structure (real API response)

0 commit comments

Comments
 (0)