diff --git a/packages/types/src/model.ts b/packages/types/src/model.ts index 49e8e73edff..a21b763351b 100644 --- a/packages/types/src/model.ts +++ b/packages/types/src/model.ts @@ -111,6 +111,13 @@ export const modelInfoSchema = z.object({ supportsNativeTools: z.boolean().optional(), // Default tool protocol preferred by this model (if not specified, falls back to capability/provider defaults) defaultToolProtocol: z.enum(["xml", "native"]).optional(), + // Exclude specific native tools from being available (only applies to native protocol) + // These tools will be removed from the set of tools available to the model + excludedTools: z.array(z.string()).optional(), + // Include specific native tools (only applies to native protocol) + // These tools will be added if they belong to an allowed group in the current mode + // Cannot force-add tools from groups the mode doesn't allow + includedTools: z.array(z.string()).optional(), /** * Service tiers with pricing information. * Each tier can have a name (for OpenAI service tiers) and pricing overrides. diff --git a/src/core/prompts/tools/__tests__/filter-tools-for-mode.spec.ts b/src/core/prompts/tools/__tests__/filter-tools-for-mode.spec.ts index 0eb7d506e87..60aaf14b217 100644 --- a/src/core/prompts/tools/__tests__/filter-tools-for-mode.spec.ts +++ b/src/core/prompts/tools/__tests__/filter-tools-for-mode.spec.ts @@ -1,7 +1,8 @@ -import { describe, it, expect } from "vitest" +import { describe, it, expect, beforeEach, afterEach } from "vitest" import type OpenAI from "openai" -import type { ModeConfig } from "@roo-code/types" -import { filterNativeToolsForMode, filterMcpToolsForMode } from "../filter-tools-for-mode" +import type { ModeConfig, ModelInfo } from "@roo-code/types" +import { filterNativeToolsForMode, filterMcpToolsForMode, applyModelToolCustomization } from "../filter-tools-for-mode" +import * as toolsModule from "../../../../shared/tools" describe("filterNativeToolsForMode", () => { const mockNativeTools: OpenAI.Chat.ChatCompletionTool[] = [ @@ -467,4 +468,360 @@ describe("filterMcpToolsForMode", () => { // Should include MCP tools since default mode has mcp group expect(filtered.length).toBeGreaterThan(0) }) + + describe("applyModelToolCustomization", () => { + const codeMode: ModeConfig = { + slug: "code", + name: "Code", + roleDefinition: "Test", + groups: ["read", "edit", "browser", "command", "mcp"] as const, + } + + const architectMode: ModeConfig = { + slug: "architect", + name: "Architect", + roleDefinition: "Test", + groups: ["read", "browser", "mcp"] as const, + } + + it("should return original tools when modelInfo is undefined", () => { + const tools = new Set(["read_file", "write_to_file", "apply_diff"]) + const result = applyModelToolCustomization(tools, codeMode, undefined) + expect(result).toEqual(tools) + }) + + it("should exclude tools specified in excludedTools", () => { + const tools = new Set(["read_file", "write_to_file", "apply_diff"]) + const modelInfo: ModelInfo = { + contextWindow: 100000, + supportsPromptCache: false, + excludedTools: ["apply_diff"], + } + const result = applyModelToolCustomization(tools, codeMode, modelInfo) + expect(result.has("read_file")).toBe(true) + expect(result.has("write_to_file")).toBe(true) + expect(result.has("apply_diff")).toBe(false) + }) + + it("should exclude multiple tools", () => { + const tools = new Set(["read_file", "write_to_file", "apply_diff", "execute_command"]) + const modelInfo: ModelInfo = { + contextWindow: 100000, + supportsPromptCache: false, + excludedTools: ["apply_diff", "write_to_file"], + } + const result = applyModelToolCustomization(tools, codeMode, modelInfo) + expect(result.has("read_file")).toBe(true) + expect(result.has("execute_command")).toBe(true) + expect(result.has("write_to_file")).toBe(false) + expect(result.has("apply_diff")).toBe(false) + }) + + it("should include tools only if they belong to allowed groups", () => { + const tools = new Set(["read_file"]) + const modelInfo: ModelInfo = { + contextWindow: 100000, + supportsPromptCache: false, + includedTools: ["write_to_file", "apply_diff"], // Both in edit group + } + const result = applyModelToolCustomization(tools, codeMode, modelInfo) + expect(result.has("read_file")).toBe(true) + expect(result.has("write_to_file")).toBe(true) + expect(result.has("apply_diff")).toBe(true) + }) + + it("should NOT include tools from groups not allowed by mode", () => { + const tools = new Set(["read_file"]) + const modelInfo: ModelInfo = { + contextWindow: 100000, + supportsPromptCache: false, + includedTools: ["write_to_file", "apply_diff"], // Edit group tools + } + // Architect mode doesn't have edit group + const result = applyModelToolCustomization(tools, architectMode, modelInfo) + expect(result.has("read_file")).toBe(true) + expect(result.has("write_to_file")).toBe(false) // Not in allowed groups + expect(result.has("apply_diff")).toBe(false) // Not in allowed groups + }) + + it("should apply both exclude and include operations", () => { + const tools = new Set(["read_file", "write_to_file", "apply_diff"]) + const modelInfo: ModelInfo = { + contextWindow: 100000, + supportsPromptCache: false, + excludedTools: ["apply_diff"], + includedTools: ["insert_content"], // Another edit tool + } + const result = applyModelToolCustomization(tools, codeMode, modelInfo) + expect(result.has("read_file")).toBe(true) + expect(result.has("write_to_file")).toBe(true) + expect(result.has("apply_diff")).toBe(false) // Excluded + expect(result.has("insert_content")).toBe(true) // Included + }) + + it("should handle empty excludedTools and includedTools arrays", () => { + const tools = new Set(["read_file", "write_to_file"]) + const modelInfo: ModelInfo = { + contextWindow: 100000, + supportsPromptCache: false, + excludedTools: [], + includedTools: [], + } + const result = applyModelToolCustomization(tools, codeMode, modelInfo) + expect(result).toEqual(tools) + }) + + it("should ignore excluded tools that are not in the original set", () => { + const tools = new Set(["read_file", "write_to_file"]) + const modelInfo: ModelInfo = { + contextWindow: 100000, + supportsPromptCache: false, + excludedTools: ["apply_diff", "nonexistent_tool"], + } + const result = applyModelToolCustomization(tools, codeMode, modelInfo) + expect(result.has("read_file")).toBe(true) + expect(result.has("write_to_file")).toBe(true) + expect(result.size).toBe(2) + }) + + it("should NOT include customTools by default", () => { + const tools = new Set(["read_file", "write_to_file"]) + // Assume 'edit' group has a customTool defined in TOOL_GROUPS + const modelInfo: ModelInfo = { + contextWindow: 100000, + supportsPromptCache: false, + // No includedTools specified + } + const result = applyModelToolCustomization(tools, codeMode, modelInfo) + // customTools should not be in the result unless explicitly included + expect(result.has("read_file")).toBe(true) + expect(result.has("write_to_file")).toBe(true) + }) + + it("should NOT include tools that are not in any TOOL_GROUPS", () => { + const tools = new Set(["read_file"]) + const modelInfo: ModelInfo = { + contextWindow: 100000, + supportsPromptCache: false, + includedTools: ["my_custom_tool"], // Not in any tool group + } + const result = applyModelToolCustomization(tools, codeMode, modelInfo) + expect(result.has("read_file")).toBe(true) + expect(result.has("my_custom_tool")).toBe(false) + }) + + it("should NOT include undefined tools even with allowed groups", () => { + const tools = new Set(["read_file"]) + const modelInfo: ModelInfo = { + contextWindow: 100000, + supportsPromptCache: false, + includedTools: ["custom_edit_tool"], // Not in any tool group + } + // Even though architect mode has read group, undefined tools are not added + const result = applyModelToolCustomization(tools, architectMode, modelInfo) + expect(result.has("read_file")).toBe(true) + expect(result.has("custom_edit_tool")).toBe(false) + }) + + describe("with customTools defined in TOOL_GROUPS", () => { + const originalToolGroups = { ...toolsModule.TOOL_GROUPS } + + beforeEach(() => { + // Add a customTool to the edit group + ;(toolsModule.TOOL_GROUPS as any).edit = { + ...originalToolGroups.edit, + customTools: ["special_edit_tool"], + } + }) + + afterEach(() => { + // Restore original TOOL_GROUPS + ;(toolsModule.TOOL_GROUPS as any).edit = originalToolGroups.edit + }) + + it("should include customTools when explicitly specified in includedTools", () => { + const tools = new Set(["read_file", "write_to_file"]) + const modelInfo: ModelInfo = { + contextWindow: 100000, + supportsPromptCache: false, + includedTools: ["special_edit_tool"], // customTool from edit group + } + const result = applyModelToolCustomization(tools, codeMode, modelInfo) + expect(result.has("read_file")).toBe(true) + expect(result.has("write_to_file")).toBe(true) + expect(result.has("special_edit_tool")).toBe(true) // customTool should be included + }) + + it("should NOT include customTools when not specified in includedTools", () => { + const tools = new Set(["read_file", "write_to_file"]) + const modelInfo: ModelInfo = { + contextWindow: 100000, + supportsPromptCache: false, + // No includedTools specified + } + const result = applyModelToolCustomization(tools, codeMode, modelInfo) + expect(result.has("read_file")).toBe(true) + expect(result.has("write_to_file")).toBe(true) + expect(result.has("special_edit_tool")).toBe(false) // customTool should NOT be included by default + }) + + it("should NOT include customTools from groups not allowed by mode", () => { + const tools = new Set(["read_file"]) + const modelInfo: ModelInfo = { + contextWindow: 100000, + supportsPromptCache: false, + includedTools: ["special_edit_tool"], // customTool from edit group + } + // Architect mode doesn't have edit group + const result = applyModelToolCustomization(tools, architectMode, modelInfo) + expect(result.has("read_file")).toBe(true) + expect(result.has("special_edit_tool")).toBe(false) // customTool should NOT be included + }) + }) + }) + + describe("filterNativeToolsForMode with model customization", () => { + const mockNativeTools: OpenAI.Chat.ChatCompletionTool[] = [ + { + type: "function", + function: { + name: "read_file", + description: "Read files", + parameters: {}, + }, + }, + { + type: "function", + function: { + name: "write_to_file", + description: "Write files", + parameters: {}, + }, + }, + { + type: "function", + function: { + name: "apply_diff", + description: "Apply diff", + parameters: {}, + }, + }, + { + type: "function", + function: { + name: "insert_content", + description: "Insert content", + parameters: {}, + }, + }, + { + type: "function", + function: { + name: "execute_command", + description: "Execute command", + parameters: {}, + }, + }, + ] + + it("should exclude tools when model specifies excludedTools", () => { + const codeMode: ModeConfig = { + slug: "code", + name: "Code", + roleDefinition: "Test", + groups: ["read", "edit", "browser", "command", "mcp"] as const, + } + + const modelInfo: ModelInfo = { + contextWindow: 100000, + supportsPromptCache: false, + excludedTools: ["apply_diff"], + } + + const filtered = filterNativeToolsForMode(mockNativeTools, "code", [codeMode], {}, undefined, { + modelInfo, + }) + + const toolNames = filtered.map((t) => ("function" in t ? t.function.name : "")) + + expect(toolNames).toContain("read_file") + expect(toolNames).toContain("write_to_file") + expect(toolNames).toContain("insert_content") + expect(toolNames).not.toContain("apply_diff") // Excluded by model + }) + + it("should include tools when model specifies includedTools from allowed groups", () => { + const modeWithOnlyRead: ModeConfig = { + slug: "limited", + name: "Limited", + roleDefinition: "Test", + groups: ["read", "edit"] as const, + } + + const modelInfo: ModelInfo = { + contextWindow: 100000, + supportsPromptCache: false, + includedTools: ["insert_content"], // Edit group tool + } + + const filtered = filterNativeToolsForMode(mockNativeTools, "limited", [modeWithOnlyRead], {}, undefined, { + modelInfo, + }) + + const toolNames = filtered.map((t) => ("function" in t ? t.function.name : "")) + + expect(toolNames).toContain("insert_content") // Included by model + }) + + it("should NOT include tools from groups not allowed by mode", () => { + const architectMode: ModeConfig = { + slug: "architect", + name: "Architect", + roleDefinition: "Test", + groups: ["read", "browser"] as const, // No edit group + } + + const modelInfo: ModelInfo = { + contextWindow: 100000, + supportsPromptCache: false, + includedTools: ["write_to_file", "apply_diff"], // Edit group tools + } + + const filtered = filterNativeToolsForMode(mockNativeTools, "architect", [architectMode], {}, undefined, { + modelInfo, + }) + + const toolNames = filtered.map((t) => ("function" in t ? t.function.name : "")) + + expect(toolNames).toContain("read_file") + expect(toolNames).not.toContain("write_to_file") // Not in mode's allowed groups + expect(toolNames).not.toContain("apply_diff") // Not in mode's allowed groups + }) + + it("should combine excludedTools and includedTools", () => { + const codeMode: ModeConfig = { + slug: "code", + name: "Code", + roleDefinition: "Test", + groups: ["read", "edit", "browser", "command", "mcp"] as const, + } + + const modelInfo: ModelInfo = { + contextWindow: 100000, + supportsPromptCache: false, + excludedTools: ["apply_diff"], + includedTools: ["insert_content"], + } + + const filtered = filterNativeToolsForMode(mockNativeTools, "code", [codeMode], {}, undefined, { + modelInfo, + }) + + const toolNames = filtered.map((t) => ("function" in t ? t.function.name : "")) + + expect(toolNames).toContain("write_to_file") + expect(toolNames).toContain("insert_content") // Included + expect(toolNames).not.toContain("apply_diff") // Excluded + }) + }) }) diff --git a/src/core/prompts/tools/filter-tools-for-mode.ts b/src/core/prompts/tools/filter-tools-for-mode.ts index c386240ddd1..3eec44c643f 100644 --- a/src/core/prompts/tools/filter-tools-for-mode.ts +++ b/src/core/prompts/tools/filter-tools-for-mode.ts @@ -1,5 +1,5 @@ import type OpenAI from "openai" -import type { ModeConfig, ToolName, ToolGroup } from "@roo-code/types" +import type { ModeConfig, ToolName, ToolGroup, ModelInfo } from "@roo-code/types" import { getModeBySlug, getToolsForMode, isToolAllowedForMode } from "../../../shared/modes" import { TOOL_GROUPS, ALWAYS_AVAILABLE_TOOLS } from "../../../shared/tools" import { defaultModeSlug } from "../../../shared/modes" @@ -7,7 +7,72 @@ import type { CodeIndexManager } from "../../../services/code-index/manager" import type { McpHub } from "../../../services/mcp/McpHub" /** - * Filters native tools based on mode restrictions. + * Apply model-specific tool customization to a set of allowed tools. + * + * This function filters tools based on model configuration: + * 1. Removes tools specified in modelInfo.excludedTools + * 2. Adds tools from modelInfo.includedTools (only if they belong to allowed groups) + * + * @param allowedTools - Set of tools already allowed by mode configuration + * @param modeConfig - Current mode configuration to check tool groups + * @param modelInfo - Model configuration with tool customization + * @returns Modified set of tools after applying model customization + */ +export function applyModelToolCustomization( + allowedTools: Set, + modeConfig: ModeConfig, + modelInfo?: ModelInfo, +): Set { + if (!modelInfo) { + return allowedTools + } + + const result = new Set(allowedTools) + + // Apply excluded tools (remove from allowed set) + if (modelInfo.excludedTools && modelInfo.excludedTools.length > 0) { + modelInfo.excludedTools.forEach((tool) => { + result.delete(tool) + }) + } + + // Apply included tools (add to allowed set, but only if they belong to an allowed group) + if (modelInfo.includedTools && modelInfo.includedTools.length > 0) { + // Build a map of tool -> group for all tools in TOOL_GROUPS (including customTools) + const toolToGroup = new Map() + for (const [groupName, groupConfig] of Object.entries(TOOL_GROUPS)) { + // Add regular tools + groupConfig.tools.forEach((tool) => { + toolToGroup.set(tool, groupName as ToolGroup) + }) + // Add customTools (opt-in only tools) + if (groupConfig.customTools) { + groupConfig.customTools.forEach((tool) => { + toolToGroup.set(tool, groupName as ToolGroup) + }) + } + } + + // Get the list of allowed groups for this mode + const allowedGroups = new Set( + modeConfig.groups.map((groupEntry) => (Array.isArray(groupEntry) ? groupEntry[0] : groupEntry)), + ) + + // Add included tools only if they belong to an allowed group + // This includes both regular tools and customTools + modelInfo.includedTools.forEach((tool) => { + const toolGroup = toolToGroup.get(tool) + if (toolGroup && allowedGroups.has(toolGroup)) { + result.add(tool) + } + }) + } + + return result +} + +/** + * Filters native tools based on mode restrictions and model customization. * This ensures native tools are filtered the same way XML tools are filtered in the system prompt. * * @param nativeTools - Array of all available native tools @@ -15,7 +80,7 @@ import type { McpHub } from "../../../services/mcp/McpHub" * @param customModes - Custom mode configurations * @param experiments - Experiment flags * @param codeIndexManager - Code index manager for codebase_search feature check - * @param settings - Additional settings for tool filtering + * @param settings - Additional settings for tool filtering (includes modelInfo for model-specific customization) * @param mcpHub - MCP hub for checking available resources * @returns Filtered array of tools allowed for the mode */ @@ -43,7 +108,7 @@ export function filterNativeToolsForMode( const allToolsForMode = getToolsForMode(modeConfig.groups) // Filter to only tools that pass permission checks - const allowedToolNames = new Set( + let allowedToolNames = new Set( allToolsForMode.filter((tool) => isToolAllowedForMode( tool as ToolName, @@ -56,6 +121,10 @@ export function filterNativeToolsForMode( ), ) + // Apply model-specific tool customization + const modelInfo = settings?.modelInfo as ModelInfo | undefined + allowedToolNames = applyModelToolCustomization(allowedToolNames, modeConfig, modelInfo) + // Conditionally exclude codebase_search if feature is disabled or not configured if ( !codeIndexManager || diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index 04062797b50..74e55702fd1 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -3479,6 +3479,7 @@ export class Task extends EventEmitter implements TaskLike { apiConfiguration, maxReadFileLine: state?.maxReadFileLine ?? -1, browserToolEnabled: state?.browserToolEnabled ?? true, + modelInfo, }) } diff --git a/src/core/task/build-tools.ts b/src/core/task/build-tools.ts index a9f02005f09..4586e4b546b 100644 --- a/src/core/task/build-tools.ts +++ b/src/core/task/build-tools.ts @@ -1,5 +1,5 @@ import type OpenAI from "openai" -import type { ProviderSettings, ModeConfig } from "@roo-code/types" +import type { ProviderSettings, ModeConfig, ModelInfo } from "@roo-code/types" import type { ClineProvider } from "../webview/ClineProvider" import { getNativeTools, getMcpServerTools } from "../prompts/tools/native-tools" import { filterNativeToolsForMode, filterMcpToolsForMode } from "../prompts/tools/filter-tools-for-mode" @@ -13,6 +13,7 @@ interface BuildToolsOptions { apiConfiguration: ProviderSettings | undefined maxReadFileLine: number browserToolEnabled: boolean + modelInfo?: ModelInfo } /** @@ -23,8 +24,17 @@ interface BuildToolsOptions { * @returns Array of filtered native and MCP tools */ export async function buildNativeToolsArray(options: BuildToolsOptions): Promise { - const { provider, cwd, mode, customModes, experiments, apiConfiguration, maxReadFileLine, browserToolEnabled } = - options + const { + provider, + cwd, + mode, + customModes, + experiments, + apiConfiguration, + maxReadFileLine, + browserToolEnabled, + modelInfo, + } = options const mcpHub = provider.getMcpHub() @@ -36,6 +46,7 @@ export async function buildNativeToolsArray(options: BuildToolsOptions): Promise const filterSettings = { todoListEnabled: apiConfiguration?.todoListEnabled ?? true, browserToolEnabled: browserToolEnabled ?? true, + modelInfo, } // Determine if partial reads are enabled based on maxReadFileLine setting diff --git a/src/shared/modes.ts b/src/shared/modes.ts index 484e79fd445..bbde43fd0ad 100644 --- a/src/shared/modes.ts +++ b/src/shared/modes.ts @@ -47,7 +47,7 @@ export function doesFileMatchRegex(filePath: string, pattern: string): boolean { export function getToolsForMode(groups: readonly GroupEntry[]): string[] { const tools = new Set() - // Add tools from each group + // Add tools from each group (excluding customTools which are opt-in only) groups.forEach((group) => { const groupName = getGroupName(group) const groupConfig = TOOL_GROUPS[groupName] diff --git a/src/shared/tools.ts b/src/shared/tools.ts index 6aea5fd3114..7246931e63a 100644 --- a/src/shared/tools.ts +++ b/src/shared/tools.ts @@ -237,6 +237,7 @@ export interface GenerateImageToolUse extends ToolUse<"generate_image"> { export type ToolGroupConfig = { tools: readonly string[] alwaysAvailable?: boolean // Whether this group is always available and shouldn't show in prompts view + customTools?: readonly string[] // Opt-in only tools - only available when explicitly included via model's includedTools } export const TOOL_DISPLAY_NAMES: Record = {