diff --git a/src/api/providers/__tests__/vercel-ai-gateway.spec.ts b/src/api/providers/__tests__/vercel-ai-gateway.spec.ts index f7c1d9d8851..63e2f93d0f5 100644 --- a/src/api/providers/__tests__/vercel-ai-gateway.spec.ts +++ b/src/api/providers/__tests__/vercel-ai-gateway.spec.ts @@ -274,6 +274,217 @@ describe("VercelAiGatewayHandler", () => { totalCost: 0.005, }) }) + + describe("native tool calling", () => { + const testTools = [ + { + type: "function" as const, + function: { + name: "test_tool", + description: "A test tool", + parameters: { + type: "object", + properties: { + arg1: { type: "string" }, + }, + required: ["arg1"], + }, + }, + }, + ] + + beforeEach(() => { + mockCreate.mockImplementation(async () => ({ + [Symbol.asyncIterator]: async function* () { + yield { + choices: [ + { + delta: {}, + index: 0, + }, + ], + } + }, + })) + }) + + it("should include tools when provided", async () => { + const handler = new VercelAiGatewayHandler(mockOptions) + + const messageGenerator = handler.createMessage("test prompt", [], { + taskId: "test-task-id", + tools: testTools, + toolProtocol: "native", + }) + await messageGenerator.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + tools: expect.arrayContaining([ + expect.objectContaining({ + type: "function", + function: expect.objectContaining({ + name: "test_tool", + }), + }), + ]), + }), + ) + }) + + it("should include tool_choice when provided", async () => { + const handler = new VercelAiGatewayHandler(mockOptions) + + const messageGenerator = handler.createMessage("test prompt", [], { + taskId: "test-task-id", + tools: testTools, + toolProtocol: "native", + tool_choice: "auto", + }) + await messageGenerator.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + tool_choice: "auto", + }), + ) + }) + + it("should set parallel_tool_calls when toolProtocol is native", async () => { + const handler = new VercelAiGatewayHandler(mockOptions) + + const messageGenerator = handler.createMessage("test prompt", [], { + taskId: "test-task-id", + tools: testTools, + toolProtocol: "native", + parallelToolCalls: true, + }) + await messageGenerator.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + parallel_tool_calls: true, + }), + ) + }) + + it("should default parallel_tool_calls to false", async () => { + const handler = new VercelAiGatewayHandler(mockOptions) + + const messageGenerator = handler.createMessage("test prompt", [], { + taskId: "test-task-id", + tools: testTools, + toolProtocol: "native", + }) + await messageGenerator.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + parallel_tool_calls: false, + }), + ) + }) + + it("should yield tool_call_partial chunks when streaming tool calls", async () => { + mockCreate.mockImplementation(async () => ({ + [Symbol.asyncIterator]: async function* () { + yield { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: "call_123", + function: { + name: "test_tool", + arguments: '{"arg1":', + }, + }, + ], + }, + index: 0, + }, + ], + } + yield { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + function: { + arguments: '"value"}', + }, + }, + ], + }, + index: 0, + }, + ], + } + yield { + choices: [ + { + delta: {}, + index: 0, + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + }, + } + }, + })) + + const handler = new VercelAiGatewayHandler(mockOptions) + + const stream = handler.createMessage("test prompt", [], { + taskId: "test-task-id", + tools: testTools, + toolProtocol: "native", + }) + + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial") + expect(toolCallChunks).toHaveLength(2) + expect(toolCallChunks[0]).toEqual({ + type: "tool_call_partial", + index: 0, + id: "call_123", + name: "test_tool", + arguments: '{"arg1":', + }) + expect(toolCallChunks[1]).toEqual({ + type: "tool_call_partial", + index: 0, + id: undefined, + name: undefined, + arguments: '"value"}', + }) + }) + + it("should include stream_options with include_usage", async () => { + const handler = new VercelAiGatewayHandler(mockOptions) + + const messageGenerator = handler.createMessage("test prompt", [], { + taskId: "test-task-id", + }) + await messageGenerator.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + stream_options: { include_usage: true }, + }), + ) + }) + }) }) describe("completePrompt", () => { diff --git a/src/api/providers/fetchers/__tests__/vercel-ai-gateway.spec.ts b/src/api/providers/fetchers/__tests__/vercel-ai-gateway.spec.ts index 30ad2f41d5b..ed4baa33b8b 100644 --- a/src/api/providers/fetchers/__tests__/vercel-ai-gateway.spec.ts +++ b/src/api/providers/fetchers/__tests__/vercel-ai-gateway.spec.ts @@ -176,6 +176,7 @@ describe("Vercel AI Gateway Fetchers", () => { maxTokens: 8000, contextWindow: 100000, supportsImages: false, + supportsNativeTools: true, supportsPromptCache: false, inputPrice: 2500000, outputPrice: 10000000, diff --git a/src/api/providers/fetchers/vercel-ai-gateway.ts b/src/api/providers/fetchers/vercel-ai-gateway.ts index 3b6852c28d5..646def5ec84 100644 --- a/src/api/providers/fetchers/vercel-ai-gateway.ts +++ b/src/api/providers/fetchers/vercel-ai-gateway.ts @@ -108,6 +108,7 @@ export const parseVercelAiGatewayModel = ({ id, model }: { id: string; model: Ve contextWindow: model.context_window, supportsImages, supportsPromptCache, + supportsNativeTools: true, inputPrice: parseApiPrice(model.pricing?.input), outputPrice: parseApiPrice(model.pricing?.output), cacheWritesPrice, diff --git a/src/api/providers/vercel-ai-gateway.ts b/src/api/providers/vercel-ai-gateway.ts index be77d35986b..96863ac1ea9 100644 --- a/src/api/providers/vercel-ai-gateway.ts +++ b/src/api/providers/vercel-ai-gateway.ts @@ -60,6 +60,12 @@ export class VercelAiGatewayHandler extends RouterProvider implements SingleComp : undefined, max_completion_tokens: info.maxTokens, stream: true, + stream_options: { include_usage: true }, + ...(metadata?.tools && { tools: this.convertToolsForOpenAI(metadata.tools) }), + ...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }), + ...(metadata?.toolProtocol === "native" && { + parallel_tool_calls: metadata.parallelToolCalls ?? false, + }), } const completion = await this.client.chat.completions.create(body) @@ -73,6 +79,19 @@ export class VercelAiGatewayHandler extends RouterProvider implements SingleComp } } + // Emit raw tool call chunks - NativeToolCallParser handles state management + if (delta?.tool_calls) { + for (const toolCall of delta.tool_calls) { + yield { + type: "tool_call_partial", + index: toolCall.index, + id: toolCall.id, + name: toolCall.function?.name, + arguments: toolCall.function?.arguments, + } + } + } + if (chunk.usage) { const usage = chunk.usage as VercelAiGatewayUsage yield {