diff --git a/packages/types/src/tool.ts b/packages/types/src/tool.ts index be3f49c40f2..76e03f8c803 100644 --- a/packages/types/src/tool.ts +++ b/packages/types/src/tool.ts @@ -37,6 +37,7 @@ export const toolNames = [ "update_todo_list", "run_slash_command", "generate_image", + "custom_tool", ] as const export const toolNamesSchema = z.enum(toolNames) diff --git a/src/core/assistant-message/__tests__/presentAssistantMessage-custom-tool.spec.ts b/src/core/assistant-message/__tests__/presentAssistantMessage-custom-tool.spec.ts new file mode 100644 index 00000000000..6ad8c58282c --- /dev/null +++ b/src/core/assistant-message/__tests__/presentAssistantMessage-custom-tool.spec.ts @@ -0,0 +1,349 @@ +// npx vitest src/core/assistant-message/__tests__/presentAssistantMessage-custom-tool.spec.ts + +import { describe, it, expect, beforeEach, vi } from "vitest" +import { presentAssistantMessage } from "../presentAssistantMessage" + +// Mock dependencies +vi.mock("../../task/Task") +vi.mock("../../tools/validateToolUse", () => ({ + validateToolUse: vi.fn(), +})) + +// Mock custom tool registry - must be done inline without external variable references +vi.mock("@roo-code/core", () => ({ + customToolRegistry: { + has: vi.fn(), + get: vi.fn(), + }, +})) + +vi.mock("@roo-code/telemetry", () => ({ + TelemetryService: { + instance: { + captureToolUsage: vi.fn(), + captureConsecutiveMistakeError: vi.fn(), + }, + }, +})) + +import { TelemetryService } from "@roo-code/telemetry" +import { customToolRegistry } from "@roo-code/core" + +describe("presentAssistantMessage - Custom Tool Recording", () => { + let mockTask: any + + beforeEach(() => { + // Reset all mocks + vi.clearAllMocks() + + // Create a mock Task with minimal properties needed for testing + mockTask = { + taskId: "test-task-id", + instanceId: "test-instance", + abort: false, + presentAssistantMessageLocked: false, + presentAssistantMessageHasPendingUpdates: false, + currentStreamingContentIndex: 0, + assistantMessageContent: [], + userMessageContent: [], + didCompleteReadingStream: false, + didRejectTool: false, + didAlreadyUseTool: false, + diffEnabled: false, + consecutiveMistakeCount: 0, + clineMessages: [], + api: { + getModel: () => ({ id: "test-model", info: {} }), + }, + browserSession: { + closeBrowser: vi.fn().mockResolvedValue(undefined), + }, + recordToolUsage: vi.fn(), + recordToolError: vi.fn(), + toolRepetitionDetector: { + check: vi.fn().mockReturnValue({ allowExecution: true }), + }, + providerRef: { + deref: () => ({ + getState: vi.fn().mockResolvedValue({ + mode: "code", + customModes: [], + experiments: { + customTools: true, // Enable by default + }, + }), + }), + }, + say: vi.fn().mockResolvedValue(undefined), + ask: vi.fn().mockResolvedValue({ response: "yesButtonClicked" }), + } + }) + + describe("Custom tool usage recording", () => { + it("should record custom tool usage as 'custom_tool' when experiment is enabled", async () => { + const toolCallId = "tool_call_custom_123" + mockTask.assistantMessageContent = [ + { + type: "tool_use", + id: toolCallId, + name: "my_custom_tool", + params: { value: "test" }, + partial: false, + }, + ] + + // Mock customToolRegistry to recognize this as a custom tool + vi.mocked(customToolRegistry.has).mockReturnValue(true) + vi.mocked(customToolRegistry.get).mockReturnValue({ + name: "my_custom_tool", + description: "A custom tool", + execute: vi.fn().mockResolvedValue("Custom tool result"), + }) + + await presentAssistantMessage(mockTask) + + // Should record as "custom_tool", not "my_custom_tool" + expect(mockTask.recordToolUsage).toHaveBeenCalledWith("custom_tool") + expect(TelemetryService.instance.captureToolUsage).toHaveBeenCalledWith( + mockTask.taskId, + "custom_tool", + "native", + ) + }) + + it("should record custom tool usage as 'custom_tool' in XML protocol", async () => { + mockTask.assistantMessageContent = [ + { + type: "tool_use", + // No ID = XML protocol + name: "my_custom_tool", + params: { value: "test" }, + partial: false, + }, + ] + + vi.mocked(customToolRegistry.has).mockReturnValue(true) + vi.mocked(customToolRegistry.get).mockReturnValue({ + name: "my_custom_tool", + description: "A custom tool", + execute: vi.fn().mockResolvedValue("Custom tool result"), + }) + + await presentAssistantMessage(mockTask) + + expect(mockTask.recordToolUsage).toHaveBeenCalledWith("custom_tool") + expect(TelemetryService.instance.captureToolUsage).toHaveBeenCalledWith( + mockTask.taskId, + "custom_tool", + "xml", + ) + }) + }) + + describe("Custom tool error recording", () => { + it("should record custom tool error as 'custom_tool'", async () => { + const toolCallId = "tool_call_custom_error_123" + mockTask.assistantMessageContent = [ + { + type: "tool_use", + id: toolCallId, + name: "failing_custom_tool", + params: {}, + partial: false, + }, + ] + + // Mock customToolRegistry with a tool that throws an error + vi.mocked(customToolRegistry.has).mockReturnValue(true) + vi.mocked(customToolRegistry.get).mockReturnValue({ + name: "failing_custom_tool", + description: "A failing custom tool", + execute: vi.fn().mockRejectedValue(new Error("Custom tool execution failed")), + }) + + await presentAssistantMessage(mockTask) + + // Should record error as "custom_tool", not "failing_custom_tool" + expect(mockTask.recordToolError).toHaveBeenCalledWith("custom_tool", "Custom tool execution failed") + expect(mockTask.consecutiveMistakeCount).toBe(1) + }) + }) + + describe("Regular tool recording", () => { + it("should record regular tool usage with actual tool name", async () => { + const toolCallId = "tool_call_read_file_123" + mockTask.assistantMessageContent = [ + { + type: "tool_use", + id: toolCallId, + name: "read_file", + params: { path: "test.txt" }, + partial: false, + }, + ] + + // read_file is not a custom tool + vi.mocked(customToolRegistry.has).mockReturnValue(false) + + await presentAssistantMessage(mockTask) + + // Should record as "read_file", not "custom_tool" + expect(mockTask.recordToolUsage).toHaveBeenCalledWith("read_file") + expect(TelemetryService.instance.captureToolUsage).toHaveBeenCalledWith( + mockTask.taskId, + "read_file", + "native", + ) + }) + + it("should record MCP tool usage as 'use_mcp_tool' (not custom_tool)", async () => { + const toolCallId = "tool_call_mcp_123" + mockTask.assistantMessageContent = [ + { + type: "tool_use", + id: toolCallId, + name: "use_mcp_tool", + params: { + server_name: "test-server", + tool_name: "test-tool", + arguments: "{}", + }, + partial: false, + }, + ] + + vi.mocked(customToolRegistry.has).mockReturnValue(false) + + // Mock MCP hub for use_mcp_tool + mockTask.providerRef = { + deref: () => ({ + getState: vi.fn().mockResolvedValue({ + mode: "code", + customModes: [], + experiments: { + customTools: true, + }, + }), + getMcpHub: () => ({ + findServerNameBySanitizedName: () => "test-server", + executeToolCall: vi.fn().mockResolvedValue({ content: [{ type: "text", text: "result" }] }), + }), + }), + } + + await presentAssistantMessage(mockTask) + + // Should record as "use_mcp_tool", not "custom_tool" + expect(mockTask.recordToolUsage).toHaveBeenCalledWith("use_mcp_tool") + expect(TelemetryService.instance.captureToolUsage).toHaveBeenCalledWith( + mockTask.taskId, + "use_mcp_tool", + "native", + ) + }) + }) + + describe("Custom tool experiment gate", () => { + it("should treat custom tool as unknown when experiment is disabled", async () => { + const toolCallId = "tool_call_disabled_123" + mockTask.assistantMessageContent = [ + { + type: "tool_use", + id: toolCallId, + name: "my_custom_tool", + params: {}, + partial: false, + }, + ] + + // Mock provider state with customTools experiment DISABLED + mockTask.providerRef = { + deref: () => ({ + getState: vi.fn().mockResolvedValue({ + mode: "code", + customModes: [], + experiments: { + customTools: false, // Disabled + }, + }), + }), + } + + // Even if registry recognizes it, experiment gate should prevent execution + vi.mocked(customToolRegistry.has).mockReturnValue(true) + vi.mocked(customToolRegistry.get).mockReturnValue({ + name: "my_custom_tool", + description: "A custom tool", + execute: vi.fn().mockResolvedValue("Should not execute"), + }) + + await presentAssistantMessage(mockTask) + + // Should be treated as unknown tool (not executed) + expect(mockTask.say).toHaveBeenCalledWith("error", "unknownToolError") + expect(mockTask.consecutiveMistakeCount).toBe(1) + + // Custom tool should NOT have been executed + const getMock = vi.mocked(customToolRegistry.get) + if (getMock.mock.results.length > 0) { + const customTool = getMock.mock.results[0].value + if (customTool) { + expect(customTool.execute).not.toHaveBeenCalled() + } + } + }) + + it("should not call customToolRegistry.has() when experiment is disabled", async () => { + mockTask.assistantMessageContent = [ + { + type: "tool_use", + id: "tool_call_123", + name: "some_tool", + params: {}, + partial: false, + }, + ] + + // Disable experiment + mockTask.providerRef = { + deref: () => ({ + getState: vi.fn().mockResolvedValue({ + mode: "code", + customModes: [], + experiments: { + customTools: false, + }, + }), + }), + } + + await presentAssistantMessage(mockTask) + + // When experiment is off, shouldn't even check the registry + // (Code checks stateExperiments?.customTools before calling has()) + expect(customToolRegistry.has).not.toHaveBeenCalled() + }) + }) + + describe("Partial blocks", () => { + it("should not record usage for partial custom tool blocks", async () => { + mockTask.assistantMessageContent = [ + { + type: "tool_use", + id: "tool_call_partial_123", + name: "my_custom_tool", + params: { value: "test" }, + partial: true, // Still streaming + }, + ] + + vi.mocked(customToolRegistry.has).mockReturnValue(true) + + await presentAssistantMessage(mockTask) + + // Should not record usage for partial blocks + expect(mockTask.recordToolUsage).not.toHaveBeenCalled() + expect(TelemetryService.instance.captureToolUsage).not.toHaveBeenCalled() + }) + }) +}) diff --git a/src/core/assistant-message/presentAssistantMessage.ts b/src/core/assistant-message/presentAssistantMessage.ts index f604e955e43..9b1a5fcffb1 100644 --- a/src/core/assistant-message/presentAssistantMessage.ts +++ b/src/core/assistant-message/presentAssistantMessage.ts @@ -695,8 +695,11 @@ export async function presentAssistantMessage(cline: Task) { } if (!block.partial) { - cline.recordToolUsage(block.name) - TelemetryService.instance.captureToolUsage(cline.taskId, block.name, toolProtocol) + // Check if this is a custom tool - if so, record as "custom_tool" (like MCP tools) + const isCustomTool = stateExperiments?.customTools && customToolRegistry.has(block.name) + const recordName = isCustomTool ? "custom_tool" : block.name + cline.recordToolUsage(recordName) + TelemetryService.instance.captureToolUsage(cline.taskId, recordName, toolProtocol) } // Validate tool use before execution - ONLY for complete (non-partial) blocks. @@ -1091,6 +1094,8 @@ export async function presentAssistantMessage(cline: Task) { cline.consecutiveMistakeCount = 0 } catch (executionError: any) { cline.consecutiveMistakeCount++ + // Record custom tool error with static name + cline.recordToolError("custom_tool", executionError.message) await handleError(`executing custom tool "${block.name}"`, executionError) } diff --git a/src/shared/tools.ts b/src/shared/tools.ts index f2b4ec3544e..f893a3d332e 100644 --- a/src/shared/tools.ts +++ b/src/shared/tools.ts @@ -266,6 +266,7 @@ export const TOOL_DISPLAY_NAMES: Record = { update_todo_list: "update todo list", run_slash_command: "run slash command", generate_image: "generate images", + custom_tool: "use custom tools", } as const // Define available tool groups.