diff --git a/src/services/mcp/McpHub.ts b/src/services/mcp/McpHub.ts index 3d02ce25f19..63efd52ba4d 100644 --- a/src/services/mcp/McpHub.ts +++ b/src/services/mcp/McpHub.ts @@ -69,6 +69,7 @@ const BaseConfigSchema = z.object({ timeout: z.number().min(1).max(3600).optional().default(60), alwaysAllow: z.array(z.string()).default([]), watchPaths: z.array(z.string()).optional(), // paths to watch for changes and restart server + allowedTools: z.array(z.string()).optional(), // whitelist: if specified, only these tools are enabled disabledTools: z.array(z.string()).default([]), }) @@ -981,6 +982,7 @@ export class McpHub { const actualSource = connection.server.source || "global" let configPath: string let alwaysAllowConfig: string[] = [] + let allowedToolsList: string[] | undefined = undefined let disabledToolsList: string[] = [] // Read from the appropriate config file based on the actual source @@ -1002,6 +1004,7 @@ export class McpHub { } if (serverConfigData) { alwaysAllowConfig = serverConfigData.mcpServers?.[serverName]?.alwaysAllow || [] + allowedToolsList = serverConfigData.mcpServers?.[serverName]?.allowedTools disabledToolsList = serverConfigData.mcpServers?.[serverName]?.disabledTools || [] } } catch (error) { @@ -1013,11 +1016,17 @@ export class McpHub { const hasWildcard = alwaysAllowConfig.includes("*") // Mark tools as always allowed and enabled for prompt based on settings - const tools = (response?.tools || []).map((tool) => ({ - ...tool, - alwaysAllow: hasWildcard || alwaysAllowConfig.includes(tool.name), - enabledForPrompt: !disabledToolsList.includes(tool.name), - })) + // If allowedTools whitelist is specified, only those tools are enabled. + // Then disabledTools blacklist further filters from the allowed set. + const tools = (response?.tools || []).map((tool) => { + const isWhitelisted = allowedToolsList === undefined || allowedToolsList.includes(tool.name) + const isBlacklisted = disabledToolsList.includes(tool.name) + return { + ...tool, + alwaysAllow: hasWildcard || alwaysAllowConfig.includes(tool.name), + enabledForPrompt: isWhitelisted && !isBlacklisted, + } + }) return tools } catch (error) { diff --git a/src/services/mcp/__tests__/McpHub.spec.ts b/src/services/mcp/__tests__/McpHub.spec.ts index 3f06627cc17..93ca576e027 100644 --- a/src/services/mcp/__tests__/McpHub.spec.ts +++ b/src/services/mcp/__tests__/McpHub.spec.ts @@ -1053,6 +1053,183 @@ describe("McpHub", () => { }) }) + describe("allowedTools whitelist", () => { + it("should enable only whitelisted tools when allowedTools is specified", async () => { + const mockConfig = { + mcpServers: { + "test-server": { + type: "stdio", + command: "node", + args: ["test.js"], + allowedTools: ["tool1", "tool3"], + }, + }, + } + + vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(mockConfig)) + + const mockConnection: ConnectedMcpConnection = { + type: "connected", + server: { + name: "test-server", + type: "stdio", + command: "node", + args: ["test.js"], + source: "global", + } as any, + client: { + request: vi.fn().mockResolvedValue({ + tools: [ + { name: "tool1", description: "Tool 1" }, + { name: "tool2", description: "Tool 2" }, + { name: "tool3", description: "Tool 3" }, + ], + }), + } as any, + transport: {} as any, + } + mcpHub.connections = [mockConnection] + + const tools = await mcpHub["fetchToolsList"]("test-server", "global") + + expect(tools.length).toBe(3) + expect(tools[0].enabledForPrompt).toBe(true) // tool1 is whitelisted + expect(tools[1].enabledForPrompt).toBe(false) // tool2 is NOT whitelisted + expect(tools[2].enabledForPrompt).toBe(true) // tool3 is whitelisted + }) + + it("should allow all tools when allowedTools is not specified (backward compatibility)", async () => { + const mockConfig = { + mcpServers: { + "test-server": { + type: "stdio", + command: "node", + args: ["test.js"], + // no allowedTools field + }, + }, + } + + vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(mockConfig)) + + const mockConnection: ConnectedMcpConnection = { + type: "connected", + server: { + name: "test-server", + type: "stdio", + command: "node", + args: ["test.js"], + source: "global", + } as any, + client: { + request: vi.fn().mockResolvedValue({ + tools: [ + { name: "tool1", description: "Tool 1" }, + { name: "tool2", description: "Tool 2" }, + ], + }), + } as any, + transport: {} as any, + } + mcpHub.connections = [mockConnection] + + const tools = await mcpHub["fetchToolsList"]("test-server", "global") + + expect(tools.length).toBe(2) + expect(tools[0].enabledForPrompt).toBe(true) // all tools allowed + expect(tools[1].enabledForPrompt).toBe(true) // all tools allowed + }) + + it("should apply disabledTools on top of allowedTools whitelist", async () => { + const mockConfig = { + mcpServers: { + "test-server": { + type: "stdio", + command: "node", + args: ["test.js"], + allowedTools: ["tool1", "tool2", "tool3"], + disabledTools: ["tool2"], + }, + }, + } + + vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(mockConfig)) + + const mockConnection: ConnectedMcpConnection = { + type: "connected", + server: { + name: "test-server", + type: "stdio", + command: "node", + args: ["test.js"], + source: "global", + } as any, + client: { + request: vi.fn().mockResolvedValue({ + tools: [ + { name: "tool1", description: "Tool 1" }, + { name: "tool2", description: "Tool 2" }, + { name: "tool3", description: "Tool 3" }, + { name: "tool4", description: "Tool 4" }, + ], + }), + } as any, + transport: {} as any, + } + mcpHub.connections = [mockConnection] + + const tools = await mcpHub["fetchToolsList"]("test-server", "global") + + expect(tools.length).toBe(4) + expect(tools[0].enabledForPrompt).toBe(true) // tool1: whitelisted, not blacklisted + expect(tools[1].enabledForPrompt).toBe(false) // tool2: whitelisted but blacklisted + expect(tools[2].enabledForPrompt).toBe(true) // tool3: whitelisted, not blacklisted + expect(tools[3].enabledForPrompt).toBe(false) // tool4: not whitelisted + }) + + it("should disable all tools when allowedTools is an empty array", async () => { + const mockConfig = { + mcpServers: { + "test-server": { + type: "stdio", + command: "node", + args: ["test.js"], + allowedTools: [], + }, + }, + } + + vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(mockConfig)) + + const mockConnection: ConnectedMcpConnection = { + type: "connected", + server: { + name: "test-server", + type: "stdio", + command: "node", + args: ["test.js"], + source: "global", + } as any, + client: { + request: vi.fn().mockResolvedValue({ + tools: [ + { name: "tool1", description: "Tool 1" }, + { name: "tool2", description: "Tool 2" }, + ], + }), + } as any, + transport: {} as any, + } + mcpHub.connections = [mockConnection] + + const tools = await mcpHub["fetchToolsList"]("test-server", "global") + + expect(tools.length).toBe(2) + expect(tools[0].enabledForPrompt).toBe(false) // empty whitelist means nothing allowed + expect(tools[1].enabledForPrompt).toBe(false) // empty whitelist means nothing allowed + }) + }) + describe("toggleToolEnabledForPrompt", () => { it("should add tool to disabledTools list when enabling", async () => { const mockConfig = {