Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
// npx vitest src/core/assistant-message/__tests__/presentAssistantMessage-unknown-tool.spec.ts

import { describe, it, expect, beforeEach, vi } from "vitest"
import { presentAssistantMessage } from "../presentAssistantMessage"

// Mock dependencies
vi.mock("../../task/Task")
vi.mock("../../tools/validateToolUse", () => ({
validateToolUse: vi.fn(),
}))
vi.mock("@roo-code/telemetry", () => ({
TelemetryService: {
instance: {
captureToolUsage: vi.fn(),
captureConsecutiveMistakeError: vi.fn(),
},
},
}))

describe("presentAssistantMessage - Unknown Tool Handling", () => {
let mockTask: any

beforeEach(() => {
// Create a mock Task with minimal properties needed for testing
mockTask = {
taskId: "test-task-id",
instanceId: "test-instance",
abort: false,
presentAssistantMessageLocked: false,
presentAssistantMessageHasPendingUpdates: false,
currentStreamingContentIndex: 0,
assistantMessageContent: [],
userMessageContent: [],
didCompleteReadingStream: false,
didRejectTool: false,
didAlreadyUseTool: false,
diffEnabled: false,
consecutiveMistakeCount: 0,
clineMessages: [],
api: {
getModel: () => ({ id: "test-model", info: {} }),
},
browserSession: {
closeBrowser: vi.fn().mockResolvedValue(undefined),
},
recordToolUsage: vi.fn(),
recordToolError: vi.fn(),
toolRepetitionDetector: {
check: vi.fn().mockReturnValue({ allowExecution: true }),
},
providerRef: {
deref: () => ({
getState: vi.fn().mockResolvedValue({
mode: "code",
customModes: [],
}),
}),
},
say: vi.fn().mockResolvedValue(undefined),
ask: vi.fn().mockResolvedValue({ response: "yesButtonClicked" }),
}
})

it("should return error for unknown tool in native protocol", async () => {
// Set up a tool_use block with an unknown tool name and an ID (native protocol)
const toolCallId = "tool_call_unknown_123"
mockTask.assistantMessageContent = [
{
type: "tool_use",
id: toolCallId, // ID indicates native protocol
name: "nonexistent_tool",
params: { some: "param" },
partial: false,
},
]

// Execute presentAssistantMessage
await presentAssistantMessage(mockTask)

// Verify that a tool_result with error was pushed
const toolResult = mockTask.userMessageContent.find(
(item: any) => item.type === "tool_result" && item.tool_use_id === toolCallId,
)

expect(toolResult).toBeDefined()
expect(toolResult.tool_use_id).toBe(toolCallId)
// The error is wrapped in JSON by formatResponse.toolError
expect(toolResult.content).toContain("nonexistent_tool")
expect(toolResult.content).toContain("does not exist")
expect(toolResult.content).toContain("error")

// Verify consecutiveMistakeCount was incremented
expect(mockTask.consecutiveMistakeCount).toBe(1)

// Verify recordToolError was called
expect(mockTask.recordToolError).toHaveBeenCalledWith(
"nonexistent_tool",
expect.stringContaining("Unknown tool"),
)

// Verify error message was shown to user (uses i18n key)
expect(mockTask.say).toHaveBeenCalledWith("error", "unknownToolError")
})

it("should return error for unknown tool in XML protocol", async () => {
// Set up a tool_use block with an unknown tool name WITHOUT an ID (XML protocol)
mockTask.assistantMessageContent = [
{
type: "tool_use",
// No ID = XML protocol
name: "fake_tool_that_does_not_exist",
params: { param1: "value1" },
partial: false,
},
]

// Execute presentAssistantMessage
await presentAssistantMessage(mockTask)

// For XML protocol, error is pushed as text blocks
const textBlocks = mockTask.userMessageContent.filter((item: any) => item.type === "text")

// There should be text blocks with error message
expect(textBlocks.length).toBeGreaterThan(0)
const hasErrorMessage = textBlocks.some(
(block: any) =>
block.text?.includes("fake_tool_that_does_not_exist") && block.text?.includes("does not exist"),
)
expect(hasErrorMessage).toBe(true)

// Verify consecutiveMistakeCount was incremented
expect(mockTask.consecutiveMistakeCount).toBe(1)

// Verify recordToolError was called
expect(mockTask.recordToolError).toHaveBeenCalled()

// Verify error message was shown to user (uses i18n key)
expect(mockTask.say).toHaveBeenCalledWith("error", "unknownToolError")
})

it("should handle unknown tool without freezing (native protocol)", async () => {
// This test ensures the extension doesn't freeze when an unknown tool is called
const toolCallId = "tool_call_freeze_test"
mockTask.assistantMessageContent = [
{
type: "tool_use",
id: toolCallId, // Native protocol
name: "this_tool_definitely_does_not_exist",
params: {},
partial: false,
},
]

// The test will timeout if the extension freezes
const timeoutPromise = new Promise<boolean>((_, reject) => {
setTimeout(() => reject(new Error("Test timed out - extension likely froze")), 5000)
})

const resultPromise = presentAssistantMessage(mockTask).then(() => true)

// Race between the function completing and the timeout
const completed = await Promise.race([resultPromise, timeoutPromise])
expect(completed).toBe(true)

// Verify a tool_result was pushed (critical for API not to freeze)
const toolResult = mockTask.userMessageContent.find(
(item: any) => item.type === "tool_result" && item.tool_use_id === toolCallId,
)
expect(toolResult).toBeDefined()
})

it("should increment consecutiveMistakeCount for unknown tools", async () => {
// Test with multiple unknown tools to ensure mistake count increments
const toolCallId = "tool_call_mistake_test"
mockTask.assistantMessageContent = [
{
type: "tool_use",
id: toolCallId,
name: "unknown_tool_1",
params: {},
partial: false,
},
]

expect(mockTask.consecutiveMistakeCount).toBe(0)

await presentAssistantMessage(mockTask)

expect(mockTask.consecutiveMistakeCount).toBe(1)
})

it("should set userMessageContentReady after handling unknown tool", async () => {
const toolCallId = "tool_call_ready_test"
mockTask.assistantMessageContent = [
{
type: "tool_use",
id: toolCallId,
name: "unknown_tool",
params: {},
partial: false,
},
]

mockTask.didCompleteReadingStream = true
mockTask.userMessageContentReady = false

await presentAssistantMessage(mockTask)

// userMessageContentReady should be set after processing
expect(mockTask.userMessageContentReady).toBe(true)
})

it("should still work with didAlreadyUseTool flag for unknown tool", async () => {
const toolCallId = "tool_call_already_used_test"
mockTask.assistantMessageContent = [
{
type: "tool_use",
id: toolCallId,
name: "unknown_tool",
params: {},
partial: false,
},
]

mockTask.didAlreadyUseTool = true

await presentAssistantMessage(mockTask)

// When didAlreadyUseTool is true, should send error tool_result
const toolResult = mockTask.userMessageContent.find(
(item: any) => item.type === "tool_result" && item.tool_use_id === toolCallId,
)

expect(toolResult).toBeDefined()
expect(toolResult.is_error).toBe(true)
expect(toolResult.content).toContain("was not executed because a tool has already been used")
})

it("should still work with didRejectTool flag for unknown tool", async () => {
const toolCallId = "tool_call_rejected_test"
mockTask.assistantMessageContent = [
{
type: "tool_use",
id: toolCallId,
name: "unknown_tool",
params: {},
partial: false,
},
]

mockTask.didRejectTool = true

await presentAssistantMessage(mockTask)

// When didRejectTool is true, should send error tool_result
const toolResult = mockTask.userMessageContent.find(
(item: any) => item.type === "tool_result" && item.tool_use_id === toolCallId,
)

expect(toolResult).toBeDefined()
expect(toolResult.is_error).toBe(true)
expect(toolResult.content).toContain("due to user rejecting a previous tool")
})
})
102 changes: 77 additions & 25 deletions src/core/assistant-message/presentAssistantMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { TelemetryService } from "@roo-code/telemetry"
import { defaultModeSlug, getModeBySlug } from "../../shared/modes"
import type { ToolParamName, ToolResponse, ToolUse, McpToolUse } from "../../shared/tools"
import { Package } from "../../shared/package"
import { t } from "../../i18n"

import { fetchInstructionsTool } from "../tools/FetchInstructionsTool"
import { listFilesTool } from "../tools/ListFilesTool"
Expand Down Expand Up @@ -333,7 +334,11 @@ export async function presentAssistantMessage(cline: Task) {
await cline.say("text", content, undefined, block.partial)
break
}
case "tool_use":
case "tool_use": {
// Fetch state early so it's available for toolDescription and validation
const state = await cline.providerRef.deref()?.getState()
const { mode, customModes, experiments: stateExperiments, apiConfiguration } = state ?? {}

const toolDescription = (): string => {
switch (block.name) {
case "execute_command":
Expand Down Expand Up @@ -675,30 +680,46 @@ export async function presentAssistantMessage(cline: Task) {
TelemetryService.instance.captureToolUsage(cline.taskId, block.name, toolProtocol)
}

// Validate tool use before execution.
const {
mode,
customModes,
experiments: stateExperiments,
apiConfiguration,
} = (await cline.providerRef.deref()?.getState()) ?? {}
const modelInfo = cline.api.getModel()
const includedTools = modelInfo?.info?.includedTools

try {
validateToolUse(
block.name as ToolName,
mode ?? defaultModeSlug,
customModes ?? [],
{ apply_diff: cline.diffEnabled },
block.params,
stateExperiments,
includedTools,
)
} catch (error) {
cline.consecutiveMistakeCount++
pushToolResult(formatResponse.toolError(error.message, toolProtocol))
break
// Validate tool use before execution - ONLY for complete (non-partial) blocks.
// Validating partial blocks would cause validation errors to be thrown repeatedly
// during streaming, pushing multiple tool_results for the same tool_use_id and
// potentially causing the stream to appear frozen.
if (!block.partial) {
const modelInfo = cline.api.getModel()
const includedTools = modelInfo?.info?.includedTools

try {
validateToolUse(
block.name as ToolName,
mode ?? defaultModeSlug,
customModes ?? [],
{ apply_diff: cline.diffEnabled },
block.params,
stateExperiments,
includedTools,
)
} catch (error) {
cline.consecutiveMistakeCount++
// For validation errors (unknown tool, tool not allowed for mode), we need to:
// 1. Send a tool_result with the error (required for native protocol)
// 2. NOT set didAlreadyUseTool = true (the tool was never executed, just failed validation)
// This prevents the stream from being interrupted with "Response interrupted by tool use result"
// which would cause the extension to appear to hang
const errorContent = formatResponse.toolError(error.message, toolProtocol)
if (toolProtocol === TOOL_PROTOCOL.NATIVE && toolCallId) {
// For native protocol, push tool_result directly without setting didAlreadyUseTool
cline.userMessageContent.push({
type: "tool_result",
tool_use_id: toolCallId,
content: typeof errorContent === "string" ? errorContent : "(validation error)",
is_error: true,
} as Anthropic.ToolResultBlockParam)
} else {
// For XML protocol, use the standard pushToolResult
pushToolResult(errorContent)
}
break
}
}

// Check for identical consecutive tool calls.
Expand Down Expand Up @@ -995,9 +1016,40 @@ export async function presentAssistantMessage(cline: Task) {
toolProtocol,
})
break
default: {
// Handle unknown/invalid tool names
// This is critical for native protocol where every tool_use MUST have a tool_result
// Note: This case should rarely be reached since validateToolUse now checks for unknown tools

// CRITICAL: Don't process partial blocks for unknown tools - just let them stream in.
// If we try to show errors for partial blocks, we'd show the error on every streaming chunk,
// creating a loop that appears to freeze the extension. Only handle complete blocks.
if (block.partial) {
break
}

const errorMessage = `Unknown tool "${block.name}". This tool does not exist. Please use one of the available tools.`
cline.consecutiveMistakeCount++
cline.recordToolError(block.name as ToolName, errorMessage)
await cline.say("error", t("tools:unknownToolError", { toolName: block.name }))
// Push tool_result directly for native protocol WITHOUT setting didAlreadyUseTool
// This prevents the stream from being interrupted with "Response interrupted by tool use result"
if (toolProtocol === TOOL_PROTOCOL.NATIVE && toolCallId) {
cline.userMessageContent.push({
type: "tool_result",
tool_use_id: toolCallId,
content: formatResponse.toolError(errorMessage, toolProtocol),
is_error: true,
} as Anthropic.ToolResultBlockParam)
} else {
pushToolResult(formatResponse.toolError(errorMessage, toolProtocol))
}
break
}
}

break
}
}

// Seeing out of bounds is fine, it means that the next too call is being
Expand Down
Loading
Loading