diff --git a/packages/types/src/providers/vertex.ts b/packages/types/src/providers/vertex.ts index e7a75c06a92..916d72afe06 100644 --- a/packages/types/src/providers/vertex.ts +++ b/packages/types/src/providers/vertex.ts @@ -278,6 +278,8 @@ export const vertexModels = { contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + supportsNativeTools: true, + defaultToolProtocol: "native", inputPrice: 3.0, outputPrice: 15.0, cacheWritesPrice: 3.75, @@ -289,6 +291,8 @@ export const vertexModels = { contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + supportsNativeTools: true, + defaultToolProtocol: "native", inputPrice: 3.0, outputPrice: 15.0, cacheWritesPrice: 3.75, @@ -300,6 +304,8 @@ export const vertexModels = { contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + supportsNativeTools: true, + defaultToolProtocol: "native", inputPrice: 1.0, outputPrice: 5.0, cacheWritesPrice: 1.25, @@ -311,6 +317,8 @@ export const vertexModels = { contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + supportsNativeTools: true, + defaultToolProtocol: "native", inputPrice: 5.0, outputPrice: 25.0, cacheWritesPrice: 6.25, @@ -322,6 +330,8 @@ export const vertexModels = { contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + supportsNativeTools: true, + defaultToolProtocol: "native", inputPrice: 15.0, outputPrice: 75.0, cacheWritesPrice: 18.75, @@ -333,6 +343,8 @@ export const vertexModels = { contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + supportsNativeTools: true, + defaultToolProtocol: "native", inputPrice: 15.0, outputPrice: 75.0, cacheWritesPrice: 18.75, @@ -343,6 +355,8 @@ export const vertexModels = { contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + supportsNativeTools: true, + defaultToolProtocol: "native", inputPrice: 3.0, outputPrice: 15.0, cacheWritesPrice: 3.75, @@ -355,6 +369,8 @@ export const vertexModels = { contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + supportsNativeTools: true, + defaultToolProtocol: "native", inputPrice: 3.0, outputPrice: 15.0, cacheWritesPrice: 3.75, @@ -365,6 +381,8 @@ export const vertexModels = { contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + supportsNativeTools: true, + defaultToolProtocol: "native", inputPrice: 3.0, outputPrice: 15.0, cacheWritesPrice: 3.75, @@ -375,6 +393,8 @@ export const vertexModels = { contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + supportsNativeTools: true, + defaultToolProtocol: "native", inputPrice: 3.0, outputPrice: 15.0, cacheWritesPrice: 3.75, @@ -385,6 +405,8 @@ export const vertexModels = { contextWindow: 200_000, supportsImages: false, supportsPromptCache: true, + supportsNativeTools: true, + defaultToolProtocol: "native", inputPrice: 1.0, outputPrice: 5.0, cacheWritesPrice: 1.25, @@ -395,6 +417,8 @@ export const vertexModels = { contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + supportsNativeTools: true, + defaultToolProtocol: "native", inputPrice: 15.0, outputPrice: 75.0, cacheWritesPrice: 18.75, @@ -405,6 +429,8 @@ export const vertexModels = { contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + supportsNativeTools: true, + defaultToolProtocol: "native", inputPrice: 0.25, outputPrice: 1.25, cacheWritesPrice: 0.3, diff --git a/src/api/providers/__tests__/anthropic-vertex.spec.ts b/src/api/providers/__tests__/anthropic-vertex.spec.ts index 02eef5c748a..a5e97ed6f3e 100644 --- a/src/api/providers/__tests__/anthropic-vertex.spec.ts +++ b/src/api/providers/__tests__/anthropic-vertex.spec.ts @@ -949,4 +949,299 @@ describe("VertexHandler", () => { ) }) }) + + describe("native tool calling", () => { + const systemPrompt = "You are a helpful assistant" + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text" as const, text: "What's the weather in London?" }], + }, + ] + + const mockTools = [ + { + type: "function" as const, + function: { + name: "get_weather", + description: "Get the current weather", + parameters: { + type: "object", + properties: { + location: { type: "string" }, + }, + required: ["location"], + }, + }, + }, + ] + + it("should include tools in request when native protocol is used", async () => { + handler = new AnthropicVertexHandler({ + apiModelId: "claude-3-5-sonnet-v2@20241022", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const mockStream = [ + { + type: "message_start", + message: { + usage: { + input_tokens: 10, + output_tokens: 0, + }, + }, + }, + ] + + const asyncIterator = { + async *[Symbol.asyncIterator]() { + for (const chunk of mockStream) { + yield chunk + } + }, + } + + const mockCreate = vitest.fn().mockResolvedValue(asyncIterator) + ;(handler["client"].messages as any).create = mockCreate + + const stream = handler.createMessage(systemPrompt, messages, { + taskId: "test-task", + tools: mockTools, + }) + + // Consume the stream to trigger the API call + for await (const _chunk of stream) { + // Just consume + } + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + tools: expect.arrayContaining([ + expect.objectContaining({ + name: "get_weather", + description: "Get the current weather", + input_schema: expect.objectContaining({ + type: "object", + properties: expect.objectContaining({ + location: { type: "string" }, + }), + }), + }), + ]), + tool_choice: { type: "auto", disable_parallel_tool_use: true }, + }), + ) + }) + + it("should not include tools when toolProtocol is xml", async () => { + handler = new AnthropicVertexHandler({ + apiModelId: "claude-3-5-sonnet-v2@20241022", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + toolProtocol: "xml", + }) + + const mockStream = [ + { + type: "message_start", + message: { + usage: { + input_tokens: 10, + output_tokens: 0, + }, + }, + }, + ] + + const asyncIterator = { + async *[Symbol.asyncIterator]() { + for (const chunk of mockStream) { + yield chunk + } + }, + } + + const mockCreate = vitest.fn().mockResolvedValue(asyncIterator) + ;(handler["client"].messages as any).create = mockCreate + + const stream = handler.createMessage(systemPrompt, messages, { + taskId: "test-task", + tools: mockTools, + }) + + // Consume the stream to trigger the API call + for await (const _chunk of stream) { + // Just consume + } + + expect(mockCreate).toHaveBeenCalledWith( + expect.not.objectContaining({ + tools: expect.anything(), + }), + ) + }) + + it("should handle tool_use blocks in stream and emit tool_call_partial", async () => { + handler = new AnthropicVertexHandler({ + apiModelId: "claude-3-5-sonnet-v2@20241022", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const mockStream = [ + { + type: "message_start", + message: { + usage: { + input_tokens: 100, + output_tokens: 50, + }, + }, + }, + { + type: "content_block_start", + index: 0, + content_block: { + type: "tool_use", + id: "toolu_123", + name: "get_weather", + }, + }, + ] + + const asyncIterator = { + async *[Symbol.asyncIterator]() { + for (const chunk of mockStream) { + yield chunk + } + }, + } + + const mockCreate = vitest.fn().mockResolvedValue(asyncIterator) + ;(handler["client"].messages as any).create = mockCreate + + const stream = handler.createMessage(systemPrompt, messages, { + taskId: "test-task", + tools: mockTools, + }) + + const chunks: ApiStreamChunk[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Find the tool_call_partial chunk + const toolCallChunk = chunks.find((chunk) => chunk.type === "tool_call_partial") + expect(toolCallChunk).toBeDefined() + expect(toolCallChunk).toEqual({ + type: "tool_call_partial", + index: 0, + id: "toolu_123", + name: "get_weather", + arguments: undefined, + }) + }) + + it("should handle input_json_delta in stream and emit tool_call_partial arguments", async () => { + handler = new AnthropicVertexHandler({ + apiModelId: "claude-3-5-sonnet-v2@20241022", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const mockStream = [ + { + type: "message_start", + message: { + usage: { + input_tokens: 100, + output_tokens: 50, + }, + }, + }, + { + type: "content_block_start", + index: 0, + content_block: { + type: "tool_use", + id: "toolu_123", + name: "get_weather", + }, + }, + { + type: "content_block_delta", + index: 0, + delta: { + type: "input_json_delta", + partial_json: '{"location":', + }, + }, + { + type: "content_block_delta", + index: 0, + delta: { + type: "input_json_delta", + partial_json: '"London"}', + }, + }, + { + type: "content_block_stop", + index: 0, + }, + ] + + const asyncIterator = { + async *[Symbol.asyncIterator]() { + for (const chunk of mockStream) { + yield chunk + } + }, + } + + const mockCreate = vitest.fn().mockResolvedValue(asyncIterator) + ;(handler["client"].messages as any).create = mockCreate + + const stream = handler.createMessage(systemPrompt, messages, { + taskId: "test-task", + tools: mockTools, + }) + + const chunks: ApiStreamChunk[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Find the tool_call_partial chunks + const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial") + expect(toolCallChunks).toHaveLength(3) + + // First chunk has id and name + expect(toolCallChunks[0]).toEqual({ + type: "tool_call_partial", + index: 0, + id: "toolu_123", + name: "get_weather", + arguments: undefined, + }) + + // Subsequent chunks have arguments + expect(toolCallChunks[1]).toEqual({ + type: "tool_call_partial", + index: 0, + id: undefined, + name: undefined, + arguments: '{"location":', + }) + + expect(toolCallChunks[2]).toEqual({ + type: "tool_call_partial", + index: 0, + id: undefined, + name: undefined, + arguments: '"London"}', + }) + }) + }) }) diff --git a/src/api/providers/anthropic-vertex.ts b/src/api/providers/anthropic-vertex.ts index f526da8fc02..6d2d93f7f55 100644 --- a/src/api/providers/anthropic-vertex.ts +++ b/src/api/providers/anthropic-vertex.ts @@ -8,6 +8,7 @@ import { vertexDefaultModelId, vertexModels, ANTHROPIC_DEFAULT_MAX_TOKENS, + TOOL_PROTOCOL, } from "@roo-code/types" import { ApiHandlerOptions } from "../../shared/api" @@ -17,6 +18,11 @@ import { ApiStream } from "../transform/stream" import { addCacheBreakpoints } from "../transform/caching/vertex" import { getModelParams } from "../transform/model-params" import { filterNonAnthropicBlocks } from "../transform/anthropic-filter" +import { resolveToolProtocol } from "../../utils/resolveToolProtocol" +import { + convertOpenAIToolsToAnthropic, + convertOpenAIToolChoiceToAnthropic, +} from "../../core/prompts/tools/native-tools/converters" import { BaseProvider } from "./base-provider" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" @@ -63,17 +69,30 @@ export class AnthropicVertexHandler extends BaseProvider implements SingleComple messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { - let { - id, - info: { supportsPromptCache }, - temperature, - maxTokens, - reasoning: thinking, - } = this.getModel() + let { id, info, temperature, maxTokens, reasoning: thinking } = this.getModel() + + const { supportsPromptCache } = info // Filter out non-Anthropic blocks (reasoning, thoughtSignature, etc.) before sending to the API const sanitizedMessages = filterNonAnthropicBlocks(messages) + // Enable native tools using resolveToolProtocol (which checks model's defaultToolProtocol) + // This matches the approach used in AnthropicHandler + // Also exclude tools when tool_choice is "none" since that means "don't use tools" + const toolProtocol = resolveToolProtocol(this.options, info, metadata?.toolProtocol) + const shouldIncludeNativeTools = + metadata?.tools && + metadata.tools.length > 0 && + toolProtocol === TOOL_PROTOCOL.NATIVE && + metadata?.tool_choice !== "none" + + const nativeToolParams = shouldIncludeNativeTools + ? { + tools: convertOpenAIToolsToAnthropic(metadata.tools!), + tool_choice: convertOpenAIToolChoiceToAnthropic(metadata.tool_choice, metadata.parallelToolCalls), + } + : {} + /** * Vertex API has specific limitations for prompt caching: * 1. Maximum of 4 blocks can have cache_control @@ -98,6 +117,7 @@ export class AnthropicVertexHandler extends BaseProvider implements SingleComple : systemPrompt, messages: supportsPromptCache ? addCacheBreakpoints(sanitizedMessages) : sanitizedMessages, stream: true, + ...nativeToolParams, } const stream = await this.client.messages.create(params) @@ -144,6 +164,17 @@ export class AnthropicVertexHandler extends BaseProvider implements SingleComple yield { type: "reasoning", text: (chunk.content_block as any).thinking } break } + case "tool_use": { + // Emit initial tool call partial with id and name + yield { + type: "tool_call_partial", + index: chunk.index, + id: chunk.content_block!.id, + name: chunk.content_block!.name, + arguments: undefined, + } + break + } } break @@ -158,12 +189,24 @@ export class AnthropicVertexHandler extends BaseProvider implements SingleComple yield { type: "reasoning", text: (chunk.delta as any).thinking } break } + case "input_json_delta": { + // Emit tool call partial chunks as arguments stream in + yield { + type: "tool_call_partial", + index: chunk.index, + id: undefined, + name: undefined, + arguments: (chunk.delta as any).partial_json, + } + break + } } break } case "content_block_stop": { // Block complete - no action needed for now. + // NativeToolCallParser handles tool call completion // Note: Signature for multi-turn thinking would require using stream.finalMessage() // after iteration completes, which requires restructuring the streaming approach. break diff --git a/src/api/providers/anthropic.ts b/src/api/providers/anthropic.ts index c17b4313ac1..e04301b678c 100644 --- a/src/api/providers/anthropic.ts +++ b/src/api/providers/anthropic.ts @@ -24,7 +24,10 @@ import { resolveToolProtocol } from "../../utils/resolveToolProtocol" import { BaseProvider } from "./base-provider" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" import { calculateApiCostAnthropic } from "../../shared/cost" -import { convertOpenAIToolsToAnthropic } from "../../core/prompts/tools/native-tools/converters" +import { + convertOpenAIToolsToAnthropic, + convertOpenAIToolChoiceToAnthropic, +} from "../../core/prompts/tools/native-tools/converters" export class AnthropicHandler extends BaseProvider implements SingleCompletionHandler { private options: ApiHandlerOptions @@ -85,7 +88,7 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa const nativeToolParams = shouldIncludeNativeTools ? { tools: convertOpenAIToolsToAnthropic(metadata.tools!), - tool_choice: this.convertOpenAIToolChoice(metadata.tool_choice, metadata.parallelToolCalls), + tool_choice: convertOpenAIToolChoiceToAnthropic(metadata.tool_choice, metadata.parallelToolCalls), } : {} @@ -377,49 +380,6 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa } } - /** - * Converts OpenAI tool_choice to Anthropic ToolChoice format - * @param toolChoice - OpenAI tool_choice parameter - * @param parallelToolCalls - When true, allows parallel tool calls. When false (default), disables parallel tool calls. - */ - private convertOpenAIToolChoice( - toolChoice: OpenAI.Chat.ChatCompletionCreateParams["tool_choice"], - parallelToolCalls?: boolean, - ): Anthropic.Messages.MessageCreateParams["tool_choice"] | undefined { - // Anthropic allows parallel tool calls by default. When parallelToolCalls is false or undefined, - // we disable parallel tool use to ensure one tool call at a time. - const disableParallelToolUse = !parallelToolCalls - - if (!toolChoice) { - // Default to auto with parallel tool use control - return { type: "auto", disable_parallel_tool_use: disableParallelToolUse } - } - - if (typeof toolChoice === "string") { - switch (toolChoice) { - case "none": - return undefined // Anthropic doesn't have "none", just omit tools - case "auto": - return { type: "auto", disable_parallel_tool_use: disableParallelToolUse } - case "required": - return { type: "any", disable_parallel_tool_use: disableParallelToolUse } - default: - return { type: "auto", disable_parallel_tool_use: disableParallelToolUse } - } - } - - // Handle object form { type: "function", function: { name: string } } - if (typeof toolChoice === "object" && "function" in toolChoice) { - return { - type: "tool", - name: toolChoice.function.name, - disable_parallel_tool_use: disableParallelToolUse, - } - } - - return { type: "auto", disable_parallel_tool_use: disableParallelToolUse } - } - async completePrompt(prompt: string) { let { id: model, temperature } = this.getModel() diff --git a/src/core/prompts/tools/native-tools/__tests__/converters.spec.ts b/src/core/prompts/tools/native-tools/__tests__/converters.spec.ts index 6665497a6f3..4c1f606754d 100644 --- a/src/core/prompts/tools/native-tools/__tests__/converters.spec.ts +++ b/src/core/prompts/tools/native-tools/__tests__/converters.spec.ts @@ -1,7 +1,11 @@ import { describe, it, expect } from "vitest" import type OpenAI from "openai" import type Anthropic from "@anthropic-ai/sdk" -import { convertOpenAIToolToAnthropic, convertOpenAIToolsToAnthropic } from "../converters" +import { + convertOpenAIToolToAnthropic, + convertOpenAIToolsToAnthropic, + convertOpenAIToolChoiceToAnthropic, +} from "../converters" describe("converters", () => { describe("convertOpenAIToolToAnthropic", () => { @@ -141,4 +145,68 @@ describe("converters", () => { expect(results).toEqual([]) }) }) + + describe("convertOpenAIToolChoiceToAnthropic", () => { + it("should return auto with disabled parallel tool use when toolChoice is undefined", () => { + const result = convertOpenAIToolChoiceToAnthropic(undefined) + expect(result).toEqual({ type: "auto", disable_parallel_tool_use: true }) + }) + + it("should return auto with enabled parallel tool use when parallelToolCalls is true", () => { + const result = convertOpenAIToolChoiceToAnthropic(undefined, true) + expect(result).toEqual({ type: "auto", disable_parallel_tool_use: false }) + }) + + it("should return undefined for 'none' tool choice", () => { + const result = convertOpenAIToolChoiceToAnthropic("none") + expect(result).toBeUndefined() + }) + + it("should return auto for 'auto' tool choice", () => { + const result = convertOpenAIToolChoiceToAnthropic("auto") + expect(result).toEqual({ type: "auto", disable_parallel_tool_use: true }) + }) + + it("should return any for 'required' tool choice", () => { + const result = convertOpenAIToolChoiceToAnthropic("required") + expect(result).toEqual({ type: "any", disable_parallel_tool_use: true }) + }) + + it("should return auto for unknown string tool choice", () => { + const result = convertOpenAIToolChoiceToAnthropic("unknown" as any) + expect(result).toEqual({ type: "auto", disable_parallel_tool_use: true }) + }) + + it("should convert function object form to tool type", () => { + const result = convertOpenAIToolChoiceToAnthropic({ + type: "function", + function: { name: "get_weather" }, + }) + expect(result).toEqual({ + type: "tool", + name: "get_weather", + disable_parallel_tool_use: true, + }) + }) + + it("should handle function object form with parallel tool calls enabled", () => { + const result = convertOpenAIToolChoiceToAnthropic( + { + type: "function", + function: { name: "read_file" }, + }, + true, + ) + expect(result).toEqual({ + type: "tool", + name: "read_file", + disable_parallel_tool_use: false, + }) + }) + + it("should return auto for object without function property", () => { + const result = convertOpenAIToolChoiceToAnthropic({ type: "something" } as any) + expect(result).toEqual({ type: "auto", disable_parallel_tool_use: true }) + }) + }) }) diff --git a/src/core/prompts/tools/native-tools/converters.ts b/src/core/prompts/tools/native-tools/converters.ts index b2040a0afc9..e124a642ff1 100644 --- a/src/core/prompts/tools/native-tools/converters.ts +++ b/src/core/prompts/tools/native-tools/converters.ts @@ -47,3 +47,63 @@ export function convertOpenAIToolToAnthropic(tool: OpenAI.Chat.ChatCompletionToo export function convertOpenAIToolsToAnthropic(tools: OpenAI.Chat.ChatCompletionTool[]): Anthropic.Tool[] { return tools.map(convertOpenAIToolToAnthropic) } + +/** + * Converts OpenAI tool_choice to Anthropic ToolChoice format. + * + * Maps OpenAI's tool_choice parameter to Anthropic's equivalent format: + * - "none" → undefined (Anthropic doesn't have "none", just omit tools) + * - "auto" → { type: "auto" } + * - "required" → { type: "any" } + * - { type: "function", function: { name } } → { type: "tool", name } + * + * @param toolChoice - OpenAI tool_choice parameter + * @param parallelToolCalls - When true, allows parallel tool calls. When false (default), disables parallel tool calls. + * @returns Anthropic ToolChoice or undefined if tools should be omitted + * + * @example + * ```typescript + * convertOpenAIToolChoiceToAnthropic("auto", false) + * // Returns: { type: "auto", disable_parallel_tool_use: true } + * + * convertOpenAIToolChoiceToAnthropic({ type: "function", function: { name: "get_weather" } }) + * // Returns: { type: "tool", name: "get_weather", disable_parallel_tool_use: true } + * ``` + */ +export function convertOpenAIToolChoiceToAnthropic( + toolChoice: OpenAI.Chat.ChatCompletionCreateParams["tool_choice"], + parallelToolCalls?: boolean, +): Anthropic.Messages.MessageCreateParams["tool_choice"] | undefined { + // Anthropic allows parallel tool calls by default. When parallelToolCalls is false or undefined, + // we disable parallel tool use to ensure one tool call at a time. + const disableParallelToolUse = !parallelToolCalls + + if (!toolChoice) { + // Default to auto with parallel tool use control + return { type: "auto", disable_parallel_tool_use: disableParallelToolUse } + } + + if (typeof toolChoice === "string") { + switch (toolChoice) { + case "none": + return undefined // Anthropic doesn't have "none", just omit tools + case "auto": + return { type: "auto", disable_parallel_tool_use: disableParallelToolUse } + case "required": + return { type: "any", disable_parallel_tool_use: disableParallelToolUse } + default: + return { type: "auto", disable_parallel_tool_use: disableParallelToolUse } + } + } + + // Handle object form { type: "function", function: { name: string } } + if (typeof toolChoice === "object" && "function" in toolChoice) { + return { + type: "tool", + name: toolChoice.function.name, + disable_parallel_tool_use: disableParallelToolUse, + } + } + + return { type: "auto", disable_parallel_tool_use: disableParallelToolUse } +}