diff --git a/packages/types/src/providers/gemini.ts b/packages/types/src/providers/gemini.ts index 3f69dfdb59c..17aa16db272 100644 --- a/packages/types/src/providers/gemini.ts +++ b/packages/types/src/providers/gemini.ts @@ -16,6 +16,7 @@ export const geminiModels = { supportsReasoningEffort: ["low", "high"], reasoningEffort: "low", includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], supportsTemperature: true, defaultTemperature: 1, inputPrice: 4.0, @@ -43,6 +44,7 @@ export const geminiModels = { supportsReasoningEffort: ["minimal", "low", "medium", "high"], reasoningEffort: "medium", includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], supportsTemperature: true, defaultTemperature: 1, inputPrice: 0.3, @@ -59,6 +61,7 @@ export const geminiModels = { defaultToolProtocol: "native", supportsPromptCache: true, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 2.5, // This is the pricing for prompts above 200k tokens. outputPrice: 15, cacheReadsPrice: 0.625, @@ -89,6 +92,7 @@ export const geminiModels = { defaultToolProtocol: "native", supportsPromptCache: true, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 2.5, // This is the pricing for prompts above 200k tokens. outputPrice: 15, cacheReadsPrice: 0.625, @@ -118,6 +122,7 @@ export const geminiModels = { defaultToolProtocol: "native", supportsPromptCache: true, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 2.5, // This is the pricing for prompts above 200k tokens. outputPrice: 15, cacheReadsPrice: 0.625, @@ -145,6 +150,7 @@ export const geminiModels = { defaultToolProtocol: "native", supportsPromptCache: true, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 2.5, // This is the pricing for prompts above 200k tokens. outputPrice: 15, cacheReadsPrice: 0.625, @@ -176,6 +182,7 @@ export const geminiModels = { defaultToolProtocol: "native", supportsPromptCache: true, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 0.3, outputPrice: 2.5, cacheReadsPrice: 0.075, @@ -191,6 +198,7 @@ export const geminiModels = { defaultToolProtocol: "native", supportsPromptCache: true, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 0.3, outputPrice: 2.5, cacheReadsPrice: 0.075, @@ -206,6 +214,7 @@ export const geminiModels = { defaultToolProtocol: "native", supportsPromptCache: true, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 0.3, outputPrice: 2.5, cacheReadsPrice: 0.075, @@ -223,6 +232,7 @@ export const geminiModels = { defaultToolProtocol: "native", supportsPromptCache: true, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 0.1, outputPrice: 0.4, cacheReadsPrice: 0.025, @@ -238,6 +248,7 @@ export const geminiModels = { defaultToolProtocol: "native", supportsPromptCache: true, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 0.1, outputPrice: 0.4, cacheReadsPrice: 0.025, diff --git a/packages/types/src/providers/vertex.ts b/packages/types/src/providers/vertex.ts index db010b6c682..e7a75c06a92 100644 --- a/packages/types/src/providers/vertex.ts +++ b/packages/types/src/providers/vertex.ts @@ -16,6 +16,7 @@ export const vertexModels = { supportsReasoningEffort: ["low", "high"], reasoningEffort: "low", includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], supportsTemperature: true, defaultTemperature: 1, inputPrice: 4.0, @@ -43,6 +44,7 @@ export const vertexModels = { supportsReasoningEffort: ["minimal", "low", "medium", "high"], reasoningEffort: "medium", includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], supportsTemperature: true, defaultTemperature: 1, inputPrice: 0.3, @@ -58,6 +60,7 @@ export const vertexModels = { defaultToolProtocol: "native", supportsPromptCache: true, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 0.15, outputPrice: 3.5, maxThinkingTokens: 24_576, @@ -72,6 +75,7 @@ export const vertexModels = { defaultToolProtocol: "native", supportsPromptCache: true, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 0.15, outputPrice: 0.6, }, @@ -83,6 +87,7 @@ export const vertexModels = { defaultToolProtocol: "native", supportsPromptCache: true, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 0.3, outputPrice: 2.5, cacheReadsPrice: 0.075, @@ -98,6 +103,7 @@ export const vertexModels = { defaultToolProtocol: "native", supportsPromptCache: false, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 0.15, outputPrice: 3.5, maxThinkingTokens: 24_576, @@ -112,6 +118,7 @@ export const vertexModels = { defaultToolProtocol: "native", supportsPromptCache: false, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 0.15, outputPrice: 0.6, }, @@ -123,6 +130,7 @@ export const vertexModels = { defaultToolProtocol: "native", supportsPromptCache: true, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 2.5, outputPrice: 15, }, @@ -134,6 +142,7 @@ export const vertexModels = { defaultToolProtocol: "native", supportsPromptCache: true, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 2.5, outputPrice: 15, }, @@ -145,6 +154,7 @@ export const vertexModels = { defaultToolProtocol: "native", supportsPromptCache: true, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 2.5, outputPrice: 15, maxThinkingTokens: 32_768, @@ -158,6 +168,7 @@ export const vertexModels = { defaultToolProtocol: "native", supportsPromptCache: true, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 2.5, outputPrice: 15, maxThinkingTokens: 32_768, @@ -186,6 +197,7 @@ export const vertexModels = { defaultToolProtocol: "native", supportsPromptCache: false, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 0, outputPrice: 0, }, @@ -197,6 +209,7 @@ export const vertexModels = { defaultToolProtocol: "native", supportsPromptCache: false, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 0, outputPrice: 0, }, @@ -208,6 +221,7 @@ export const vertexModels = { defaultToolProtocol: "native", supportsPromptCache: true, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 0.15, outputPrice: 0.6, }, @@ -219,6 +233,7 @@ export const vertexModels = { defaultToolProtocol: "native", supportsPromptCache: false, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 0.075, outputPrice: 0.3, }, @@ -230,6 +245,7 @@ export const vertexModels = { defaultToolProtocol: "native", supportsPromptCache: false, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 0, outputPrice: 0, }, @@ -241,6 +257,7 @@ export const vertexModels = { defaultToolProtocol: "native", supportsPromptCache: true, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 0.075, outputPrice: 0.3, }, @@ -252,6 +269,7 @@ export const vertexModels = { defaultToolProtocol: "native", supportsPromptCache: false, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 1.25, outputPrice: 5, }, @@ -400,6 +418,7 @@ export const vertexModels = { defaultToolProtocol: "native", supportsPromptCache: true, includedTools: ["write_file", "edit_file"], + excludedTools: ["apply_diff"], inputPrice: 0.1, outputPrice: 0.4, cacheReadsPrice: 0.025, diff --git a/packages/types/src/tool.ts b/packages/types/src/tool.ts index 0a9afb3162b..658ebf4450b 100644 --- a/packages/types/src/tool.ts +++ b/packages/types/src/tool.ts @@ -21,6 +21,7 @@ export const toolNames = [ "apply_diff", "search_and_replace", "search_replace", + "edit_file", "apply_patch", "search_files", "list_files", diff --git a/src/api/providers/utils/router-tool-preferences.ts b/src/api/providers/utils/router-tool-preferences.ts index 40f8518e3c7..bb5ece3b96b 100644 --- a/src/api/providers/utils/router-tool-preferences.ts +++ b/src/api/providers/utils/router-tool-preferences.ts @@ -32,6 +32,7 @@ export function applyRouterToolPreferences(modelId: string, info: ModelInfo): Mo if (modelId.includes("gemini")) { result = { ...result, + excludedTools: [...new Set([...(result.excludedTools || []), "apply_diff"])], includedTools: [...new Set([...(result.includedTools || []), "write_file", "edit_file"])], } } diff --git a/src/core/assistant-message/NativeToolCallParser.ts b/src/core/assistant-message/NativeToolCallParser.ts index 250afdc3890..89bc1a9c340 100644 --- a/src/core/assistant-message/NativeToolCallParser.ts +++ b/src/core/assistant-message/NativeToolCallParser.ts @@ -525,6 +525,21 @@ export class NativeToolCallParser { } break + case "edit_file": + if ( + partialArgs.file_path !== undefined || + partialArgs.old_string !== undefined || + partialArgs.new_string !== undefined + ) { + nativeArgs = { + file_path: partialArgs.file_path, + old_string: partialArgs.old_string, + new_string: partialArgs.new_string, + expected_replacements: partialArgs.expected_replacements, + } + } + break + default: break } @@ -562,7 +577,7 @@ export class NativeToolCallParser { return this.parseDynamicMcpTool(toolCall) } - // Resolve tool alias to canonical name (e.g., "edit_file" -> "apply_diff", "temp_edit_file" -> "search_and_replace") + // Resolve tool alias to canonical name const resolvedName = resolveToolAlias(toolCall.name as string) as TName // Validate tool name (after alias resolution) @@ -785,6 +800,21 @@ export class NativeToolCallParser { } break + case "edit_file": + if ( + args.file_path !== undefined && + args.old_string !== undefined && + args.new_string !== undefined + ) { + nativeArgs = { + file_path: args.file_path, + old_string: args.old_string, + new_string: args.new_string, + expected_replacements: args.expected_replacements, + } as NativeArgsFor + } + break + default: break } diff --git a/src/core/assistant-message/presentAssistantMessage.ts b/src/core/assistant-message/presentAssistantMessage.ts index 2e8b791b349..52886ed76f6 100644 --- a/src/core/assistant-message/presentAssistantMessage.ts +++ b/src/core/assistant-message/presentAssistantMessage.ts @@ -23,6 +23,7 @@ import { writeToFileTool } from "../tools/WriteToFileTool" import { applyDiffTool } from "../tools/MultiApplyDiffTool" import { searchAndReplaceTool } from "../tools/SearchAndReplaceTool" import { searchReplaceTool } from "../tools/SearchReplaceTool" +import { editFileTool } from "../tools/EditFileTool" import { applyPatchTool } from "../tools/ApplyPatchTool" import { searchFilesTool } from "../tools/SearchFilesTool" import { browserActionTool } from "../tools/BrowserActionTool" @@ -403,6 +404,8 @@ export async function presentAssistantMessage(cline: Task) { return `[${block.name} for '${block.params.path}']` case "search_replace": return `[${block.name} for '${block.params.file_path}']` + case "edit_file": + return `[${block.name} for '${block.params.file_path}']` case "apply_patch": return `[${block.name}]` case "list_files": @@ -872,6 +875,16 @@ export async function presentAssistantMessage(cline: Task) { toolProtocol, }) break + case "edit_file": + await checkpointSaveAndMark(cline) + await editFileTool.handle(cline, block as ToolUse<"edit_file">, { + askApproval, + handleError, + pushToolResult, + removeClosingTag, + toolProtocol, + }) + break case "apply_patch": await checkpointSaveAndMark(cline) await applyPatchTool.handle(cline, block as ToolUse<"apply_patch">, { diff --git a/src/core/prompts/tools/__tests__/filter-tools-for-mode.spec.ts b/src/core/prompts/tools/__tests__/filter-tools-for-mode.spec.ts index d189b999150..50db6984f22 100644 --- a/src/core/prompts/tools/__tests__/filter-tools-for-mode.spec.ts +++ b/src/core/prompts/tools/__tests__/filter-tools-for-mode.spec.ts @@ -722,6 +722,14 @@ describe("filterMcpToolsForMode", () => { parameters: {}, }, }, + { + type: "function", + function: { + name: "edit_file", + description: "Edit file", + parameters: {}, + }, + }, ] it("should exclude tools when model specifies excludedTools", () => { @@ -823,7 +831,7 @@ describe("filterMcpToolsForMode", () => { expect(toolNames).not.toContain("apply_diff") // Excluded }) - it("should rename tools to alias names when model includes aliases", () => { + it("should honor included aliases while respecting exclusions", () => { const codeMode: ModeConfig = { slug: "code", name: "Code", @@ -834,6 +842,7 @@ describe("filterMcpToolsForMode", () => { const modelInfo: ModelInfo = { contextWindow: 100000, supportsPromptCache: false, + excludedTools: ["apply_diff"], includedTools: ["edit_file", "write_file"], } diff --git a/src/core/prompts/tools/native-tools/edit_file.ts b/src/core/prompts/tools/native-tools/edit_file.ts new file mode 100644 index 00000000000..ed6a59f3e1a --- /dev/null +++ b/src/core/prompts/tools/native-tools/edit_file.ts @@ -0,0 +1,70 @@ +import type OpenAI from "openai" + +const EDIT_FILE_DESCRIPTION = `Use this tool to replace text in an existing file, or create a new file. + +This tool performs literal string replacement with support for multiple occurrences. + +USAGE PATTERNS: + +1. MODIFY EXISTING FILE (default): + - Provide file_path, old_string (text to find), and new_string (replacement) + - By default, expects exactly 1 occurrence of old_string + - Use expected_replacements to replace multiple occurrences + +2. CREATE NEW FILE: + - Set old_string to empty string "" + - new_string becomes the entire file content + - File must not already exist + +CRITICAL REQUIREMENTS: + +1. EXACT MATCHING: The old_string must match the file contents EXACTLY, including: + - All whitespace (spaces, tabs, newlines) + - All indentation + - All punctuation and special characters + +2. CONTEXT FOR UNIQUENESS: For single replacements (default), include at least 3 lines of context BEFORE and AFTER the target text to ensure uniqueness. + +3. MULTIPLE REPLACEMENTS: If you need to replace multiple identical occurrences: + - Set expected_replacements to the exact count you expect to replace + - ALL occurrences will be replaced + +4. NO ESCAPING: Provide the literal text - do not escape special characters.` + +const edit_file = { + type: "function", + function: { + name: "edit_file", + description: EDIT_FILE_DESCRIPTION, + parameters: { + type: "object", + properties: { + file_path: { + type: "string", + description: + "The path to the file to modify or create. You can use either a relative path in the workspace or an absolute path. If an absolute path is provided, it will be preserved as is.", + }, + old_string: { + type: "string", + description: + "The exact literal text to replace (must match the file contents exactly, including all whitespace and indentation). For single replacements (default), include at least 3 lines of context BEFORE and AFTER the target text. Use empty string to create a new file.", + }, + new_string: { + type: "string", + description: + "The exact literal text to replace old_string with. When creating a new file (old_string is empty), this becomes the file content.", + }, + expected_replacements: { + type: "number", + description: + "Number of replacements expected. Defaults to 1 if not specified. Use when you want to replace multiple occurrences of the same text.", + minimum: 1, + }, + }, + required: ["file_path", "old_string", "new_string"], + additionalProperties: false, + }, + }, +} satisfies OpenAI.Chat.ChatCompletionTool + +export default edit_file diff --git a/src/core/prompts/tools/native-tools/index.ts b/src/core/prompts/tools/native-tools/index.ts index 760d987b47b..79302a39f31 100644 --- a/src/core/prompts/tools/native-tools/index.ts +++ b/src/core/prompts/tools/native-tools/index.ts @@ -15,6 +15,7 @@ import { createReadFileTool } from "./read_file" import runSlashCommand from "./run_slash_command" import searchAndReplace from "./search_and_replace" import searchReplace from "./search_replace" +import edit_file from "./edit_file" import searchFiles from "./search_files" import switchMode from "./switch_mode" import updateTodoList from "./update_todo_list" @@ -47,6 +48,7 @@ export function getNativeTools(partialReadsEnabled: boolean = true): OpenAI.Chat runSlashCommand, searchAndReplace, searchReplace, + edit_file, searchFiles, switchMode, updateTodoList, diff --git a/src/core/tools/EditFileTool.ts b/src/core/tools/EditFileTool.ts new file mode 100644 index 00000000000..8d04fe23016 --- /dev/null +++ b/src/core/tools/EditFileTool.ts @@ -0,0 +1,373 @@ +import fs from "fs/promises" +import path from "path" + +import { getReadablePath } from "../../utils/path" +import { isPathOutsideWorkspace } from "../../utils/pathUtils" +import { Task } from "../task/Task" +import { formatResponse } from "../prompts/responses" +import { ClineSayTool } from "../../shared/ExtensionMessage" +import { RecordSource } from "../context-tracking/FileContextTrackerTypes" +import { fileExistsAtPath } from "../../utils/fs" +import { DEFAULT_WRITE_DELAY_MS } from "@roo-code/types" +import { EXPERIMENT_IDS, experiments } from "../../shared/experiments" +import { sanitizeUnifiedDiff, computeDiffStats } from "../diff/stats" +import { BaseTool, ToolCallbacks } from "./BaseTool" +import type { ToolUse } from "../../shared/tools" + +interface EditFileParams { + file_path: string + old_string: string + new_string: string + expected_replacements?: number +} + +/** + * Count occurrences of a substring in a string. + * @param str The string to search in + * @param substr The substring to count + * @returns Number of non-overlapping occurrences + */ +function countOccurrences(str: string, substr: string): number { + if (substr === "") return 0 + let count = 0 + let pos = str.indexOf(substr) + while (pos !== -1) { + count++ + pos = str.indexOf(substr, pos + substr.length) + } + return count +} + +/** + * Safely replace all occurrences of a literal string, handling $ escape sequences. + * Standard String.replaceAll treats $ specially in the replacement string. + * This function ensures literal replacement. + * + * @param str The original string + * @param oldString The string to replace + * @param newString The replacement string + * @returns The string with all occurrences replaced + */ +function safeLiteralReplace(str: string, oldString: string, newString: string): string { + if (oldString === "" || !str.includes(oldString)) { + return str + } + + // If newString doesn't contain $, we can use replaceAll directly + if (!newString.includes("$")) { + return str.replaceAll(oldString, newString) + } + + // Escape $ to prevent ECMAScript GetSubstitution issues + // $$ becomes a single $ in the output, so we double-escape + const escapedNewString = newString.replaceAll("$", "$$$$") + return str.replaceAll(oldString, escapedNewString) +} + +/** + * Apply a replacement operation. + * + * @param currentContent The current file content (null if file doesn't exist) + * @param oldString The string to replace + * @param newString The replacement string + * @param isNewFile Whether this is creating a new file + * @returns The resulting content + */ +function applyReplacement( + currentContent: string | null, + oldString: string, + newString: string, + isNewFile: boolean, +): string { + if (isNewFile) { + return newString + } + // If oldString is empty and it's not a new file, do not modify the content + if (oldString === "" || currentContent === null) { + return currentContent ?? "" + } + + return safeLiteralReplace(currentContent, oldString, newString) +} + +export class EditFileTool extends BaseTool<"edit_file"> { + readonly name = "edit_file" as const + + parseLegacy(params: Partial>): EditFileParams { + return { + file_path: params.file_path || "", + old_string: params.old_string || "", + new_string: params.new_string || "", + expected_replacements: params.expected_replacements + ? parseInt(params.expected_replacements, 10) + : undefined, + } + } + + async execute(params: EditFileParams, task: Task, callbacks: ToolCallbacks): Promise { + const { file_path, old_string, new_string, expected_replacements = 1 } = params + const { askApproval, handleError, pushToolResult, toolProtocol } = callbacks + + try { + // Validate required parameters + if (!file_path) { + task.consecutiveMistakeCount++ + task.recordToolError("edit_file") + pushToolResult(await task.sayAndCreateMissingParamError("edit_file", "file_path")) + return + } + + // Determine relative path - file_path can be absolute or relative + let relPath: string + if (path.isAbsolute(file_path)) { + relPath = path.relative(task.cwd, file_path) + } else { + relPath = file_path + } + + const accessAllowed = task.rooIgnoreController?.validateAccess(relPath) + + if (!accessAllowed) { + await task.say("rooignore_error", relPath) + pushToolResult(formatResponse.rooIgnoreError(relPath, toolProtocol)) + return + } + + // Check if file is write-protected + const isWriteProtected = task.rooProtectedController?.isWriteProtected(relPath) || false + + const absolutePath = path.resolve(task.cwd, relPath) + const fileExists = await fileExistsAtPath(absolutePath) + + let currentContent: string | null = null + let isNewFile = false + + // Read file or determine if creating new + if (fileExists) { + try { + currentContent = await fs.readFile(absolutePath, "utf8") + // Normalize line endings to LF + currentContent = currentContent.replace(/\r\n/g, "\n") + } catch (error) { + task.consecutiveMistakeCount++ + task.recordToolError("edit_file") + const errorMessage = `Failed to read file '${relPath}'. Please verify file permissions and try again.` + await task.say("error", errorMessage) + pushToolResult(formatResponse.toolError(errorMessage, toolProtocol)) + return + } + + // Check if trying to create a file that already exists + if (old_string === "") { + task.consecutiveMistakeCount++ + task.recordToolError("edit_file") + const errorMessage = `File '${relPath}' already exists. Cannot create a new file with empty old_string when file exists.` + await task.say("error", errorMessage) + pushToolResult(formatResponse.toolError(errorMessage, toolProtocol)) + return + } + } else { + // File doesn't exist + if (old_string === "") { + // Creating a new file + isNewFile = true + } else { + // Trying to replace in non-existent file + task.consecutiveMistakeCount++ + task.recordToolError("edit_file") + const errorMessage = `File not found: ${relPath}. Cannot perform replacement on a non-existent file. Use an empty old_string to create a new file.` + await task.say("error", errorMessage) + pushToolResult(formatResponse.toolError(errorMessage, toolProtocol)) + return + } + } + + // Validate replacement operation + if (!isNewFile && currentContent !== null) { + // Check occurrence count + const occurrences = countOccurrences(currentContent, old_string) + + if (occurrences === 0) { + task.consecutiveMistakeCount++ + task.recordToolError("edit_file", "no_match") + pushToolResult( + formatResponse.toolError( + `No match found for the specified 'old_string'. Please ensure it matches the file contents exactly, including all whitespace and indentation.`, + toolProtocol, + ), + ) + return + } + + if (occurrences !== expected_replacements) { + task.consecutiveMistakeCount++ + task.recordToolError("edit_file", "occurrence_mismatch") + pushToolResult( + formatResponse.toolError( + `Expected ${expected_replacements} occurrence(s) but found ${occurrences}. Please adjust your old_string to match exactly ${expected_replacements} occurrence(s), or set expected_replacements to ${occurrences}.`, + toolProtocol, + ), + ) + return + } + + // Validate that old_string and new_string are different + if (old_string === new_string) { + task.consecutiveMistakeCount++ + task.recordToolError("edit_file") + pushToolResult( + formatResponse.toolError( + "No changes to apply. The old_string and new_string are identical.", + toolProtocol, + ), + ) + return + } + } + + // Apply the replacement + const newContent = applyReplacement(currentContent, old_string, new_string, isNewFile) + + // Check if any changes were made + if (!isNewFile && newContent === currentContent) { + pushToolResult(`No changes needed for '${relPath}'`) + return + } + + task.consecutiveMistakeCount = 0 + + // Initialize diff view + task.diffViewProvider.editType = isNewFile ? "create" : "modify" + task.diffViewProvider.originalContent = currentContent || "" + + // Generate and validate diff + const diff = formatResponse.createPrettyPatch(relPath, currentContent || "", newContent) + if (!diff && !isNewFile) { + pushToolResult(`No changes needed for '${relPath}'`) + await task.diffViewProvider.reset() + return + } + + // Check if preventFocusDisruption experiment is enabled + const provider = task.providerRef.deref() + const state = await provider?.getState() + const diagnosticsEnabled = state?.diagnosticsEnabled ?? true + const writeDelayMs = state?.writeDelayMs ?? DEFAULT_WRITE_DELAY_MS + const isPreventFocusDisruptionEnabled = experiments.isEnabled( + state?.experiments ?? {}, + EXPERIMENT_IDS.PREVENT_FOCUS_DISRUPTION, + ) + + const sanitizedDiff = sanitizeUnifiedDiff(diff || "") + const diffStats = computeDiffStats(sanitizedDiff) || undefined + const isOutsideWorkspace = isPathOutsideWorkspace(absolutePath) + + const sharedMessageProps: ClineSayTool = { + tool: isNewFile ? "newFileCreated" : "appliedDiff", + path: getReadablePath(task.cwd, relPath), + diff: sanitizedDiff, + isOutsideWorkspace, + } + + const completeMessage = JSON.stringify({ + ...sharedMessageProps, + content: sanitizedDiff, + isProtected: isWriteProtected, + diffStats, + } satisfies ClineSayTool) + + // Show diff view if focus disruption prevention is disabled + if (!isPreventFocusDisruptionEnabled) { + await task.diffViewProvider.open(relPath) + await task.diffViewProvider.update(newContent, true) + task.diffViewProvider.scrollToFirstDiff() + } + + const didApprove = await askApproval("tool", completeMessage, undefined, isWriteProtected) + + if (!didApprove) { + // Revert changes if diff view was shown + if (!isPreventFocusDisruptionEnabled) { + await task.diffViewProvider.revertChanges() + } + pushToolResult("Changes were rejected by the user.") + await task.diffViewProvider.reset() + return + } + + // Save the changes + if (isPreventFocusDisruptionEnabled) { + // Direct file write without diff view or opening the file + await task.diffViewProvider.saveDirectly( + relPath, + newContent, + isNewFile, + diagnosticsEnabled, + writeDelayMs, + ) + } else { + // Call saveChanges to update the DiffViewProvider properties + await task.diffViewProvider.saveChanges(diagnosticsEnabled, writeDelayMs) + } + + // Track file edit operation + if (relPath) { + await task.fileContextTracker.trackFileContext(relPath, "roo_edited" as RecordSource) + } + + task.didEditFile = true + + // Get the formatted response message + const replacementInfo = + !isNewFile && expected_replacements > 1 ? ` (${expected_replacements} replacements)` : "" + const message = await task.diffViewProvider.pushToolWriteResult(task, task.cwd, isNewFile) + + pushToolResult(message + replacementInfo) + + // Record successful tool usage and cleanup + task.recordToolUsage("edit_file") + await task.diffViewProvider.reset() + + // Process any queued messages after file edit completes + task.processQueuedMessages() + } catch (error) { + await handleError("edit_file", error as Error) + await task.diffViewProvider.reset() + } + } + + override async handlePartial(task: Task, block: ToolUse<"edit_file">): Promise { + const filePath: string | undefined = block.params.file_path + const oldString: string | undefined = block.params.old_string + + let operationPreview: string | undefined + if (oldString !== undefined) { + if (oldString === "") { + operationPreview = "creating new file" + } else { + const preview = oldString.length > 50 ? oldString.substring(0, 50) + "..." : oldString + operationPreview = `replacing: "${preview}"` + } + } + + // Determine relative path for display + let relPath = filePath || "" + if (filePath && path.isAbsolute(filePath)) { + relPath = path.relative(task.cwd, filePath) + } + + const absolutePath = relPath ? path.resolve(task.cwd, relPath) : "" + const isOutsideWorkspace = absolutePath ? isPathOutsideWorkspace(absolutePath) : false + + const sharedMessageProps: ClineSayTool = { + tool: "appliedDiff", + path: getReadablePath(task.cwd, relPath), + diff: operationPreview, + isOutsideWorkspace, + } + + await task.ask("tool", JSON.stringify(sharedMessageProps), block.partial).catch(() => {}) + } +} + +export const editFileTool = new EditFileTool() diff --git a/src/core/tools/__tests__/editFileTool.spec.ts b/src/core/tools/__tests__/editFileTool.spec.ts new file mode 100644 index 00000000000..ab632252dff --- /dev/null +++ b/src/core/tools/__tests__/editFileTool.spec.ts @@ -0,0 +1,455 @@ +import * as path from "path" +import fs from "fs/promises" + +import type { MockedFunction } from "vitest" + +import { fileExistsAtPath } from "../../../utils/fs" +import { isPathOutsideWorkspace } from "../../../utils/pathUtils" +import { getReadablePath } from "../../../utils/path" +import { ToolUse, ToolResponse } from "../../../shared/tools" +import { editFileTool } from "../EditFileTool" + +vi.mock("fs/promises", () => ({ + default: { + readFile: vi.fn().mockResolvedValue(""), + }, +})) + +vi.mock("path", async () => { + const originalPath = await vi.importActual("path") + return { + ...originalPath, + resolve: vi.fn().mockImplementation((...args) => { + const separator = process.platform === "win32" ? "\\" : "/" + return args.join(separator) + }), + isAbsolute: vi.fn().mockReturnValue(false), + relative: vi.fn().mockImplementation((from, to) => to), + } +}) + +vi.mock("delay", () => ({ + default: vi.fn(), +})) + +vi.mock("../../../utils/fs", () => ({ + fileExistsAtPath: vi.fn().mockResolvedValue(true), +})) + +vi.mock("../../prompts/responses", () => ({ + formatResponse: { + toolError: vi.fn((msg) => `Error: ${msg}`), + rooIgnoreError: vi.fn((path) => `Access denied: ${path}`), + createPrettyPatch: vi.fn(() => "mock-diff"), + }, +})) + +vi.mock("../../../utils/pathUtils", () => ({ + isPathOutsideWorkspace: vi.fn().mockReturnValue(false), +})) + +vi.mock("../../../utils/path", () => ({ + getReadablePath: vi.fn().mockReturnValue("test/path.txt"), +})) + +vi.mock("../../diff/stats", () => ({ + sanitizeUnifiedDiff: vi.fn((diff) => diff), + computeDiffStats: vi.fn(() => ({ additions: 1, deletions: 1 })), +})) + +vi.mock("vscode", () => ({ + window: { + showWarningMessage: vi.fn().mockResolvedValue(undefined), + }, + env: { + openExternal: vi.fn(), + }, + Uri: { + parse: vi.fn(), + }, +})) + +describe("editFileTool", () => { + // Test data + const testFilePath = "test/file.txt" + const absoluteFilePath = process.platform === "win32" ? "C:\\test\\file.txt" : "/test/file.txt" + const testFileContent = "Line 1\nLine 2\nLine 3\nLine 4" + const testOldString = "Line 2" + const testNewString = "Modified Line 2" + + // Mocked functions + const mockedFileExistsAtPath = fileExistsAtPath as MockedFunction + const mockedFsReadFile = fs.readFile as unknown as MockedFunction< + (path: string, encoding: string) => Promise + > + const mockedIsPathOutsideWorkspace = isPathOutsideWorkspace as MockedFunction + const mockedGetReadablePath = getReadablePath as MockedFunction + const mockedPathResolve = path.resolve as MockedFunction + const mockedPathIsAbsolute = path.isAbsolute as MockedFunction + + const mockTask: any = {} + let mockAskApproval: ReturnType + let mockHandleError: ReturnType + let mockPushToolResult: ReturnType + let mockRemoveClosingTag: ReturnType + let toolResult: ToolResponse | undefined + + beforeEach(() => { + vi.clearAllMocks() + + mockedPathResolve.mockReturnValue(absoluteFilePath) + mockedPathIsAbsolute.mockReturnValue(false) + mockedFileExistsAtPath.mockResolvedValue(true) + mockedFsReadFile.mockResolvedValue(testFileContent) + mockedIsPathOutsideWorkspace.mockReturnValue(false) + mockedGetReadablePath.mockReturnValue("test/path.txt") + + mockTask.cwd = "/" + mockTask.consecutiveMistakeCount = 0 + mockTask.didEditFile = false + mockTask.providerRef = { + deref: vi.fn().mockReturnValue({ + getState: vi.fn().mockResolvedValue({ + diagnosticsEnabled: true, + writeDelayMs: 1000, + experiments: {}, + }), + }), + } + mockTask.rooIgnoreController = { + validateAccess: vi.fn().mockReturnValue(true), + } + mockTask.rooProtectedController = { + isWriteProtected: vi.fn().mockReturnValue(false), + } + mockTask.diffViewProvider = { + editType: undefined, + isEditing: false, + originalContent: "", + open: vi.fn().mockResolvedValue(undefined), + update: vi.fn().mockResolvedValue(undefined), + reset: vi.fn().mockResolvedValue(undefined), + revertChanges: vi.fn().mockResolvedValue(undefined), + saveChanges: vi.fn().mockResolvedValue({ + newProblemsMessage: "", + userEdits: null, + finalContent: "final content", + }), + saveDirectly: vi.fn().mockResolvedValue(undefined), + scrollToFirstDiff: vi.fn(), + pushToolWriteResult: vi.fn().mockResolvedValue("Tool result message"), + } + mockTask.fileContextTracker = { + trackFileContext: vi.fn().mockResolvedValue(undefined), + } + mockTask.say = vi.fn().mockResolvedValue(undefined) + mockTask.ask = vi.fn().mockResolvedValue(undefined) + mockTask.recordToolError = vi.fn() + mockTask.recordToolUsage = vi.fn() + mockTask.processQueuedMessages = vi.fn() + mockTask.sayAndCreateMissingParamError = vi.fn().mockResolvedValue("Missing param error") + + mockAskApproval = vi.fn().mockResolvedValue(true) + mockHandleError = vi.fn().mockResolvedValue(undefined) + mockRemoveClosingTag = vi.fn((tag, content) => content) + + toolResult = undefined + }) + + /** + * Helper function to execute the edit_file tool with different parameters + */ + async function executeEditFileTool( + params: Partial = {}, + options: { + fileExists?: boolean + fileContent?: string + isPartial?: boolean + accessAllowed?: boolean + } = {}, + ): Promise { + const fileExists = options.fileExists ?? true + const fileContent = options.fileContent ?? testFileContent + const isPartial = options.isPartial ?? false + const accessAllowed = options.accessAllowed ?? true + + mockedFileExistsAtPath.mockResolvedValue(fileExists) + mockedFsReadFile.mockResolvedValue(fileContent) + mockTask.rooIgnoreController.validateAccess.mockReturnValue(accessAllowed) + + const toolUse: ToolUse = { + type: "tool_use", + name: "edit_file", + params: { + file_path: testFilePath, + old_string: testOldString, + new_string: testNewString, + ...params, + }, + partial: isPartial, + } + + mockPushToolResult = vi.fn((result: ToolResponse) => { + toolResult = result + }) + + await editFileTool.handle(mockTask, toolUse as ToolUse<"edit_file">, { + askApproval: mockAskApproval, + handleError: mockHandleError, + pushToolResult: mockPushToolResult, + removeClosingTag: mockRemoveClosingTag, + toolProtocol: "native", + }) + + return toolResult + } + + describe("parameter validation", () => { + it("returns error when file_path is missing", async () => { + const result = await executeEditFileTool({ file_path: undefined }) + + expect(result).toBe("Missing param error") + expect(mockTask.consecutiveMistakeCount).toBe(1) + expect(mockTask.recordToolError).toHaveBeenCalledWith("edit_file") + }) + + it("treats undefined new_string as empty string (deletion)", async () => { + await executeEditFileTool( + { old_string: "Line 2", new_string: undefined }, + { fileContent: "Line 1\nLine 2\nLine 3" }, + ) + + expect(mockAskApproval).toHaveBeenCalled() + }) + + it("allows empty new_string for deletion", async () => { + await executeEditFileTool( + { old_string: "Line 2", new_string: "" }, + { fileContent: "Line 1\nLine 2\nLine 3" }, + ) + + expect(mockAskApproval).toHaveBeenCalled() + }) + + it("returns error when old_string equals new_string", async () => { + const result = await executeEditFileTool({ + old_string: "same", + new_string: "same", + }) + + expect(result).toContain("Error:") + expect(mockTask.consecutiveMistakeCount).toBe(1) + }) + }) + + describe("file access", () => { + it("returns error when file does not exist and old_string is not empty", async () => { + const result = await executeEditFileTool({}, { fileExists: false }) + + expect(result).toContain("Error:") + expect(result).toContain("File not found") + expect(mockTask.consecutiveMistakeCount).toBe(1) + }) + + it("returns error when access is denied", async () => { + const result = await executeEditFileTool({}, { accessAllowed: false }) + + expect(result).toContain("Access denied") + }) + }) + + describe("edit_file logic", () => { + it("returns error when no match is found", async () => { + const result = await executeEditFileTool( + { old_string: "NonExistent" }, + { fileContent: "Line 1\nLine 2\nLine 3" }, + ) + + expect(result).toContain("Error:") + expect(result).toContain("No match found") + expect(mockTask.consecutiveMistakeCount).toBe(1) + expect(mockTask.recordToolError).toHaveBeenCalledWith("edit_file", "no_match") + }) + + it("returns error when occurrence count does not match expected_replacements", async () => { + const result = await executeEditFileTool( + { old_string: "Line", expected_replacements: "1" }, + { fileContent: "Line 1\nLine 2\nLine 3" }, + ) + + expect(result).toContain("Error:") + expect(result).toContain("Expected 1 occurrence(s) but found 3") + expect(mockTask.consecutiveMistakeCount).toBe(1) + expect(mockTask.recordToolError).toHaveBeenCalledWith("edit_file", "occurrence_mismatch") + }) + + it("succeeds when occurrence count matches expected_replacements", async () => { + await executeEditFileTool( + { old_string: "Line", new_string: "Row", expected_replacements: "4" }, + { fileContent: "Line 1\nLine 2\nLine 3\nLine 4" }, + ) + + expect(mockTask.consecutiveMistakeCount).toBe(0) + expect(mockTask.diffViewProvider.editType).toBe("modify") + expect(mockAskApproval).toHaveBeenCalled() + }) + + it("successfully replaces single unique match", async () => { + await executeEditFileTool( + { + old_string: "Line 2", + new_string: "Modified Line 2", + }, + { fileContent: "Line 1\nLine 2\nLine 3" }, + ) + + expect(mockTask.consecutiveMistakeCount).toBe(0) + expect(mockTask.diffViewProvider.editType).toBe("modify") + expect(mockAskApproval).toHaveBeenCalled() + }) + + it("defaults expected_replacements to 1", async () => { + const result = await executeEditFileTool( + { old_string: "Line" }, + { fileContent: "Line 1\nLine 2\nLine 3\nLine 4" }, + ) + + expect(result).toContain("Error:") + expect(result).toContain("Expected 1 occurrence(s) but found 4") + }) + }) + + describe("file creation", () => { + it("creates new file when old_string is empty and file does not exist", async () => { + await executeEditFileTool({ old_string: "", new_string: "New file content" }, { fileExists: false }) + + expect(mockTask.consecutiveMistakeCount).toBe(0) + expect(mockTask.diffViewProvider.editType).toBe("create") + expect(mockAskApproval).toHaveBeenCalled() + }) + + it("returns error when trying to create file that already exists", async () => { + const result = await executeEditFileTool( + { old_string: "", new_string: "Content" }, + { fileExists: true, fileContent: "Existing content" }, + ) + + expect(result).toContain("Error:") + expect(result).toContain("already exists") + expect(mockTask.consecutiveMistakeCount).toBe(1) + }) + }) + + describe("approval workflow", () => { + it("saves changes when user approves", async () => { + mockAskApproval.mockResolvedValue(true) + + await executeEditFileTool() + + expect(mockTask.diffViewProvider.saveChanges).toHaveBeenCalled() + expect(mockTask.didEditFile).toBe(true) + expect(mockTask.recordToolUsage).toHaveBeenCalledWith("edit_file") + }) + + it("reverts changes when user rejects", async () => { + mockAskApproval.mockResolvedValue(false) + + const result = await executeEditFileTool() + + expect(mockTask.diffViewProvider.revertChanges).toHaveBeenCalled() + expect(mockTask.diffViewProvider.saveChanges).not.toHaveBeenCalled() + expect(result).toContain("rejected") + }) + }) + + describe("partial block handling", () => { + it("handles partial block without errors", async () => { + await executeEditFileTool({}, { isPartial: true }) + + expect(mockTask.ask).toHaveBeenCalled() + }) + + it("shows creating new file preview when old_string is empty", async () => { + await executeEditFileTool({ old_string: "" }, { isPartial: true }) + + expect(mockTask.ask).toHaveBeenCalled() + }) + }) + + describe("error handling", () => { + it("handles file read errors gracefully", async () => { + mockedFsReadFile.mockRejectedValueOnce(new Error("Read failed")) + + const toolUse: ToolUse = { + type: "tool_use", + name: "edit_file", + params: { + file_path: testFilePath, + old_string: testOldString, + new_string: testNewString, + }, + partial: false, + } + + let capturedResult: ToolResponse | undefined + const localPushToolResult = vi.fn((result: ToolResponse) => { + capturedResult = result + }) + + await editFileTool.handle(mockTask, toolUse as ToolUse<"edit_file">, { + askApproval: mockAskApproval, + handleError: mockHandleError, + pushToolResult: localPushToolResult, + removeClosingTag: mockRemoveClosingTag, + toolProtocol: "native", + }) + + expect(capturedResult).toContain("Error:") + expect(capturedResult).toContain("Failed to read file") + expect(mockTask.consecutiveMistakeCount).toBe(1) + }) + + it("handles general errors and resets diff view", async () => { + mockTask.diffViewProvider.open.mockRejectedValueOnce(new Error("General error")) + + await executeEditFileTool() + + expect(mockHandleError).toHaveBeenCalledWith("edit_file", expect.any(Error)) + expect(mockTask.diffViewProvider.reset).toHaveBeenCalled() + }) + }) + + describe("file tracking", () => { + it("tracks file context after successful edit", async () => { + await executeEditFileTool() + + expect(mockTask.fileContextTracker.trackFileContext).toHaveBeenCalledWith(testFilePath, "roo_edited") + }) + }) + + describe("CRLF normalization", () => { + it("normalizes CRLF to LF when reading file", async () => { + const contentWithCRLF = "Line 1\r\nLine 2\r\nLine 3" + + await executeEditFileTool( + { old_string: "Line 2", new_string: "Modified Line 2" }, + { fileContent: contentWithCRLF }, + ) + + expect(mockTask.consecutiveMistakeCount).toBe(0) + expect(mockAskApproval).toHaveBeenCalled() + }) + }) + + describe("dollar sign handling", () => { + it("handles $ in new_string correctly", async () => { + await executeEditFileTool( + { old_string: "Line 2", new_string: "Cost: $100" }, + { fileContent: "Line 1\nLine 2\nLine 3" }, + ) + + expect(mockTask.consecutiveMistakeCount).toBe(0) + expect(mockAskApproval).toHaveBeenCalled() + }) + }) +}) diff --git a/src/shared/tools.ts b/src/shared/tools.ts index de7a65bfb79..f2b4ec3544e 100644 --- a/src/shared/tools.ts +++ b/src/shared/tools.ts @@ -72,9 +72,10 @@ export const toolParamNames = [ "files", // Native protocol parameter for read_file "operations", // search_and_replace parameter for multiple operations "patch", // apply_patch parameter - "file_path", // search_replace parameter - "old_string", // search_replace parameter - "new_string", // search_replace parameter + "file_path", // search_replace and edit_file parameter + "old_string", // search_replace and edit_file parameter + "new_string", // search_replace and edit_file parameter + "expected_replacements", // edit_file parameter for multiple occurrences ] as const export type ToolParamName = (typeof toolParamNames)[number] @@ -93,6 +94,7 @@ export type NativeToolArgs = { apply_diff: { path: string; diff: string } search_and_replace: { path: string; operations: Array<{ search: string; replace: string }> } search_replace: { file_path: string; old_string: string; new_string: string } + edit_file: { file_path: string; old_string: string; new_string: string; expected_replacements?: number } apply_patch: { patch: string } ask_followup_question: { question: string @@ -249,6 +251,7 @@ export const TOOL_DISPLAY_NAMES: Record = { apply_diff: "apply changes", search_and_replace: "apply changes using search and replace", search_replace: "apply single search and replace", + edit_file: "edit files using search and replace", apply_patch: "apply patches using codex format", search_files: "search files", list_files: "list files", @@ -272,7 +275,7 @@ export const TOOL_GROUPS: Record = { }, edit: { tools: ["apply_diff", "write_to_file", "generate_image"], - customTools: ["search_and_replace", "search_replace", "apply_patch"], + customTools: ["search_and_replace", "search_replace", "edit_file", "apply_patch"], }, browser: { tools: ["browser_action"], @@ -310,9 +313,7 @@ export const ALWAYS_AVAILABLE_TOOLS: ToolName[] = [ * To add a new alias, simply add an entry here. No other files need to be modified. */ export const TOOL_ALIASES: Record = { - edit_file: "apply_diff", write_file: "write_to_file", - temp_edit_file: "search_and_replace", } as const export type DiffResult =