diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index 780d40df891..016fecc97ff 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -880,6 +880,13 @@ export const webviewMessageHandler = async ( case "mcpEnabled": const mcpEnabled = message.bool ?? true await updateGlobalState("mcpEnabled", mcpEnabled) + + // Delegate MCP enable/disable logic to McpHub + const mcpHubInstance = provider.getMcpHub() + if (mcpHubInstance) { + await mcpHubInstance.handleMcpEnabledChange(mcpEnabled) + } + await provider.postStateToWebview() break case "enableMcpServerCreation": diff --git a/src/services/mcp/McpHub.ts b/src/services/mcp/McpHub.ts index 10a74712ef0..6d512b3f284 100644 --- a/src/services/mcp/McpHub.ts +++ b/src/services/mcp/McpHub.ts @@ -33,12 +33,29 @@ import { fileExistsAtPath } from "../../utils/fs" import { arePathsEqual } from "../../utils/path" import { injectVariables } from "../../utils/config" -export type McpConnection = { +// Discriminated union for connection states +export type ConnectedMcpConnection = { + type: "connected" server: McpServer client: Client transport: StdioClientTransport | SSEClientTransport | StreamableHTTPClientTransport } +export type DisconnectedMcpConnection = { + type: "disconnected" + server: McpServer + client: null + transport: null +} + +export type McpConnection = ConnectedMcpConnection | DisconnectedMcpConnection + +// Enum for disable reasons +export enum DisableReason { + MCP_DISABLED = "mcpDisabled", + SERVER_DISABLED = "serverDisabled", +} + // Base configuration schema for common settings const BaseConfigSchema = z.object({ disabled: z.boolean().optional(), @@ -497,6 +514,7 @@ export class McpHub { const result = McpSettingsSchema.safeParse(config) if (result.success) { + // Pass all servers including disabled ones - they'll be handled in updateServerConnections await this.updateServerConnections(result.data.mcpServers || {}, source, false) } else { const errorMessages = result.error.errors @@ -552,6 +570,49 @@ export class McpHub { await this.initializeMcpServers("project") } + /** + * Creates a placeholder connection for disabled servers or when MCP is globally disabled + * @param name The server name + * @param config The server configuration + * @param source The source of the server (global or project) + * @param reason The reason for creating a placeholder (mcpDisabled or serverDisabled) + * @returns A placeholder DisconnectedMcpConnection object + */ + private createPlaceholderConnection( + name: string, + config: z.infer, + source: "global" | "project", + reason: DisableReason, + ): DisconnectedMcpConnection { + return { + type: "disconnected", + server: { + name, + config: JSON.stringify(config), + status: "disconnected", + disabled: reason === DisableReason.SERVER_DISABLED ? true : config.disabled, + source, + projectPath: source === "project" ? vscode.workspace.workspaceFolders?.[0]?.uri.fsPath : undefined, + errorHistory: [], + }, + client: null, + transport: null, + } + } + + /** + * Checks if MCP is globally enabled + * @returns Promise indicating if MCP is enabled + */ + private async isMcpEnabled(): Promise { + const provider = this.providerRef.deref() + if (!provider) { + return true // Default to enabled if provider is not available + } + const state = await provider.getState() + return state.mcpEnabled ?? true + } + private async connectToServer( name: string, config: z.infer, @@ -560,6 +621,26 @@ export class McpHub { // Remove existing connection if it exists with the same source await this.deleteConnection(name, source) + // Check if MCP is globally enabled + const mcpEnabled = await this.isMcpEnabled() + if (!mcpEnabled) { + // Still create a connection object to track the server, but don't actually connect + const connection = this.createPlaceholderConnection(name, config, source, DisableReason.MCP_DISABLED) + this.connections.push(connection) + return + } + + // Skip connecting to disabled servers + if (config.disabled) { + // Still create a connection object to track the server, but don't actually connect + const connection = this.createPlaceholderConnection(name, config, source, DisableReason.SERVER_DISABLED) + this.connections.push(connection) + return + } + + // Set up file watchers for enabled servers + this.setupFileWatcher(name, config, source) + try { const client = new Client( { @@ -733,7 +814,9 @@ export class McpHub { transport.start = async () => {} } - const connection: McpConnection = { + // Create a connected connection + const connection: ConnectedMcpConnection = { + type: "connected", server: { name, config: JSON.stringify(configInjected), @@ -826,8 +909,8 @@ export class McpHub { // Use the helper method to find the connection const connection = this.findConnection(serverName, source) - if (!connection) { - throw new Error(`Server ${serverName} not found`) + if (!connection || connection.type !== "connected") { + return [] } const response = await connection.client.request({ method: "tools/list" }, ListToolsResultSchema) @@ -881,7 +964,7 @@ export class McpHub { private async fetchResourcesList(serverName: string, source?: "global" | "project"): Promise { try { const connection = this.findConnection(serverName, source) - if (!connection) { + if (!connection || connection.type !== "connected") { return [] } const response = await connection.client.request({ method: "resources/list" }, ListResourcesResultSchema) @@ -898,7 +981,7 @@ export class McpHub { ): Promise { try { const connection = this.findConnection(serverName, source) - if (!connection) { + if (!connection || connection.type !== "connected") { return [] } const response = await connection.client.request( @@ -913,6 +996,9 @@ export class McpHub { } async deleteConnection(name: string, source?: "global" | "project"): Promise { + // Clean up file watchers for this server + this.removeFileWatchersForServer(name) + // If source is provided, only delete connections from that source const connections = source ? this.connections.filter((conn) => conn.server.name === name && conn.server.source === source) @@ -920,8 +1006,10 @@ export class McpHub { for (const connection of connections) { try { - await connection.transport.close() - await connection.client.close() + if (connection.type === "connected") { + await connection.transport.close() + await connection.client.close() + } } catch (error) { console.error(`Failed to close transport for ${name}:`, error) } @@ -975,7 +1063,10 @@ export class McpHub { if (!currentConnection) { // New server try { - this.setupFileWatcher(name, validatedConfig, source) + // Only setup file watcher for enabled servers + if (!validatedConfig.disabled) { + this.setupFileWatcher(name, validatedConfig, source) + } await this.connectToServer(name, validatedConfig, source) } catch (error) { this.showErrorMessage(`Failed to connect to new MCP server ${name}`, error) @@ -983,7 +1074,10 @@ export class McpHub { } else if (!deepEqual(JSON.parse(currentConnection.server.config), config)) { // Existing server with changed config try { - this.setupFileWatcher(name, validatedConfig, source) + // Only setup file watcher for enabled servers + if (!validatedConfig.disabled) { + this.setupFileWatcher(name, validatedConfig, source) + } await this.deleteConnection(name, source) await this.connectToServer(name, validatedConfig, source) } catch (error) { @@ -1066,10 +1160,21 @@ export class McpHub { this.fileWatchers.clear() } + private removeFileWatchersForServer(serverName: string) { + const watchers = this.fileWatchers.get(serverName) + if (watchers) { + watchers.forEach((watcher) => watcher.close()) + this.fileWatchers.delete(serverName) + } + } + async restartConnection(serverName: string, source?: "global" | "project"): Promise { this.isConnecting = true - const provider = this.providerRef.deref() - if (!provider) { + + // Check if MCP is globally enabled + const mcpEnabled = await this.isMcpEnabled() + if (!mcpEnabled) { + this.isConnecting = false return } @@ -1111,6 +1216,23 @@ export class McpHub { return } + // Check if MCP is globally enabled + const mcpEnabled = await this.isMcpEnabled() + if (!mcpEnabled) { + // Clear all existing connections + const existingConnections = [...this.connections] + for (const conn of existingConnections) { + await this.deleteConnection(conn.server.name, conn.server.source) + } + + // Still initialize servers to track them, but they won't connect + await this.initializeMcpServers("global") + await this.initializeMcpServers("project") + + await this.notifyWebviewOfServerChanges() + return + } + this.isConnecting = true vscode.window.showInformationMessage(t("mcp:info.refreshing_all")) @@ -1257,8 +1379,21 @@ export class McpHub { try { connection.server.disabled = disabled - // Only refresh capabilities if connected - if (connection.server.status === "connected") { + // If disabling a connected server, disconnect it + if (disabled && connection.server.status === "connected") { + // Clean up file watchers when disabling + this.removeFileWatchersForServer(serverName) + await this.deleteConnection(serverName, serverSource) + // Re-add as a disabled connection + await this.connectToServer(serverName, JSON.parse(connection.server.config), serverSource) + } else if (!disabled && connection.server.status === "disconnected") { + // If enabling a disabled server, connect it + const config = JSON.parse(connection.server.config) + await this.deleteConnection(serverName, serverSource) + // When re-enabling, file watchers will be set up in connectToServer + await this.connectToServer(serverName, config, serverSource) + } else if (connection.server.status === "connected") { + // Only refresh capabilities if connected connection.server.tools = await this.fetchToolsList(serverName, serverSource) connection.server.resources = await this.fetchResourcesList(serverName, serverSource) connection.server.resourceTemplates = await this.fetchResourceTemplatesList( @@ -1439,7 +1574,7 @@ export class McpHub { async readResource(serverName: string, uri: string, source?: "global" | "project"): Promise { const connection = this.findConnection(serverName, source) - if (!connection) { + if (!connection || connection.type !== "connected") { throw new Error(`No connection found for server: ${serverName}${source ? ` with source ${source}` : ""}`) } if (connection.server.disabled) { @@ -1463,7 +1598,7 @@ export class McpHub { source?: "global" | "project", ): Promise { const connection = this.findConnection(serverName, source) - if (!connection) { + if (!connection || connection.type !== "connected") { throw new Error( `No connection found for server: ${serverName}${source ? ` with source ${source}` : ""}. Please make sure to use MCP servers available under 'Connected MCP Servers'.`, ) @@ -1609,6 +1744,64 @@ export class McpHub { } } + /** + * Handles enabling/disabling MCP globally + * @param enabled Whether MCP should be enabled or disabled + * @returns Promise + */ + async handleMcpEnabledChange(enabled: boolean): Promise { + if (!enabled) { + // If MCP is being disabled, disconnect all servers with error handling + const existingConnections = [...this.connections] + const disconnectionErrors: Array<{ serverName: string; error: string }> = [] + + for (const conn of existingConnections) { + try { + await this.deleteConnection(conn.server.name, conn.server.source) + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error) + disconnectionErrors.push({ + serverName: conn.server.name, + error: errorMessage, + }) + console.error(`Failed to disconnect MCP server ${conn.server.name}: ${errorMessage}`) + } + } + + // If there were errors, notify the user + if (disconnectionErrors.length > 0) { + const errorSummary = disconnectionErrors.map((e) => `${e.serverName}: ${e.error}`).join("\n") + vscode.window.showWarningMessage( + t("mcp:errors.disconnect_servers_partial", { + count: disconnectionErrors.length, + errors: errorSummary, + }) || + `Failed to disconnect ${disconnectionErrors.length} MCP server(s). Check the output for details.`, + ) + } + + // Re-initialize servers to track them in disconnected state + try { + await this.refreshAllConnections() + } catch (error) { + console.error(`Failed to refresh MCP connections after disabling: ${error}`) + vscode.window.showErrorMessage( + t("mcp:errors.refresh_after_disable") || "Failed to refresh MCP connections after disabling", + ) + } + } else { + // If MCP is being enabled, reconnect all servers + try { + await this.refreshAllConnections() + } catch (error) { + console.error(`Failed to refresh MCP connections after enabling: ${error}`) + vscode.window.showErrorMessage( + t("mcp:errors.refresh_after_enable") || "Failed to refresh MCP connections after enabling", + ) + } + } + } + async dispose(): Promise { // Prevent multiple disposals if (this.isDisposed) { diff --git a/src/services/mcp/__tests__/McpHub.spec.ts b/src/services/mcp/__tests__/McpHub.spec.ts index 7dc7f00c045..ebce2d5b2a0 100644 --- a/src/services/mcp/__tests__/McpHub.spec.ts +++ b/src/services/mcp/__tests__/McpHub.spec.ts @@ -1,7 +1,7 @@ -import type { McpHub as McpHubType, McpConnection } from "../McpHub" +import type { McpHub as McpHubType, McpConnection, ConnectedMcpConnection, DisconnectedMcpConnection } from "../McpHub" import type { ClineProvider } from "../../../core/webview/ClineProvider" import type { ExtensionContext, Uri } from "vscode" -import { ServerConfigSchema, McpHub } from "../McpHub" +import { ServerConfigSchema, McpHub, DisableReason } from "../McpHub" import fs from "fs/promises" import { vi, Mock } from "vitest" @@ -33,11 +33,15 @@ vi.mock("fs/promises", () => ({ mkdir: vi.fn().mockResolvedValue(undefined), })) +// Import safeWriteJson to use in mocks +import { safeWriteJson } from "../../../utils/safeWriteJson" + // Mock safeWriteJson vi.mock("../../../utils/safeWriteJson", () => ({ safeWriteJson: vi.fn(async (filePath, data) => { // Instead of trying to write to the file system, just call fs.writeFile mock // This avoids the complex file locking and temp file operations + const fs = await import("fs/promises") return fs.writeFile(filePath, JSON.stringify(data), "utf8") }), })) @@ -79,6 +83,16 @@ vi.mock("@modelcontextprotocol/sdk/client/index.js", () => ({ Client: vi.fn(), })) +// Mock chokidar +vi.mock("chokidar", () => ({ + default: { + watch: vi.fn().mockReturnValue({ + on: vi.fn().mockReturnThis(), + close: vi.fn(), + }), + }, +})) + describe("McpHub", () => { let mcpHub: McpHubType let mockProvider: Partial @@ -108,6 +122,7 @@ describe("McpHub", () => { ensureSettingsDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"), ensureMcpServersDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"), postMessageToWebview: vi.fn(), + getState: vi.fn().mockResolvedValue({ mcpEnabled: true }), context: { subscriptions: [], workspaceState: {} as any, @@ -140,31 +155,612 @@ describe("McpHub", () => { } as ExtensionContext, } - // Mock fs.readFile for initial settings - vi.mocked(fs.readFile).mockResolvedValue( - JSON.stringify({ - mcpServers: { - "test-server": { - type: "stdio", - command: "node", - args: ["test.js"], - alwaysAllow: ["allowed-tool"], - disabledTools: ["disabled-tool"], - }, + // Mock fs.readFile for initial settings + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "test-server": { + type: "stdio", + command: "node", + args: ["test.js"], + alwaysAllow: ["allowed-tool"], + disabledTools: ["disabled-tool"], + }, + }, + }), + ) + + mcpHub = new McpHub(mockProvider as ClineProvider) + }) + + afterEach(() => { + // Restore original console methods + console.error = originalConsoleError + // Restore original platform + if (originalPlatform) { + Object.defineProperty(process, "platform", originalPlatform) + } + }) + + describe("Discriminated union type handling", () => { + it("should create connected connections with proper type", async () => { + // Mock StdioClientTransport + const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js") + const StdioClientTransport = stdioModule.StdioClientTransport as ReturnType + + const mockTransport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + stderr: { + on: vi.fn(), + }, + onerror: null, + onclose: null, + } + + StdioClientTransport.mockImplementation(() => mockTransport) + + // Mock Client + const clientModule = await import("@modelcontextprotocol/sdk/client/index.js") + const Client = clientModule.Client as ReturnType + + const mockClient = { + connect: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + getInstructions: vi.fn().mockReturnValue("test instructions"), + request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }), + } + + Client.mockImplementation(() => mockClient) + + // Mock the config file read + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "union-test-server": { + command: "node", + args: ["test.js"], + }, + }, + }), + ) + + // Create McpHub and let it initialize + const mcpHub = new McpHub(mockProvider as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Find the connection + const connection = mcpHub.connections.find((conn) => conn.server.name === "union-test-server") + expect(connection).toBeDefined() + + // Type guard check - connected connections should have client and transport + if (connection && connection.type === "connected") { + expect(connection.client).toBeDefined() + expect(connection.transport).toBeDefined() + expect(connection.server.status).toBe("connected") + } else { + throw new Error("Connection should be of type 'connected'") + } + }) + + it("should create disconnected connections for disabled servers", async () => { + // Mock the config file read with a disabled server + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "disabled-union-server": { + command: "node", + args: ["test.js"], + disabled: true, + }, + }, + }), + ) + + // Create McpHub and let it initialize + const mcpHub = new McpHub(mockProvider as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Find the connection + const connection = mcpHub.connections.find((conn) => conn.server.name === "disabled-union-server") + expect(connection).toBeDefined() + + // Type guard check - disconnected connections should have null client and transport + if (connection && connection.type === "disconnected") { + expect(connection.client).toBeNull() + expect(connection.transport).toBeNull() + expect(connection.server.status).toBe("disconnected") + expect(connection.server.disabled).toBe(true) + } else { + throw new Error("Connection should be of type 'disconnected'") + } + }) + + it("should handle type narrowing correctly in callTool", async () => { + // Mock fs.readFile to return empty config so no servers are initialized + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: {}, + }), + ) + + // Create a mock McpHub instance + const mcpHub = new McpHub(mockProvider as ClineProvider) + + // Wait for initialization + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Clear any connections that might have been created + mcpHub.connections = [] + + // Directly set up a connected connection + const connectedConnection: ConnectedMcpConnection = { + type: "connected", + server: { + name: "test-server", + config: JSON.stringify({ command: "node", args: ["test.js"] }), + status: "connected", + source: "global", + errorHistory: [], + } as any, + client: { + request: vi.fn().mockResolvedValue({ result: "success" }), + } as any, + transport: {} as any, + } + + // Add the connected connection + mcpHub.connections = [connectedConnection] + + // Call tool should work with connected server + const result = await mcpHub.callTool("test-server", "test-tool", {}) + expect(result).toEqual({ result: "success" }) + expect(connectedConnection.client.request).toHaveBeenCalled() + + // Now test with a disconnected connection + const disconnectedConnection: DisconnectedMcpConnection = { + type: "disconnected", + server: { + name: "disabled-server", + config: JSON.stringify({ command: "node", args: ["test.js"], disabled: true }), + status: "disconnected", + disabled: true, + source: "global", + errorHistory: [], + } as any, + client: null, + transport: null, + } + + // Replace connections with disconnected one + mcpHub.connections = [disconnectedConnection] + + // Call tool should fail with disconnected server + await expect(mcpHub.callTool("disabled-server", "test-tool", {})).rejects.toThrow( + "No connection found for server: disabled-server", + ) + }) + }) + + describe("File watcher cleanup", () => { + it("should clean up file watchers when server is disabled", async () => { + // Get the mocked chokidar + const chokidar = (await import("chokidar")).default + const mockWatcher = { + on: vi.fn().mockReturnThis(), + close: vi.fn(), + } + vi.mocked(chokidar.watch).mockReturnValue(mockWatcher as any) + + // Mock StdioClientTransport + const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js") + const StdioClientTransport = stdioModule.StdioClientTransport as ReturnType + + const mockTransport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + stderr: { + on: vi.fn(), + }, + onerror: null, + onclose: null, + } + + StdioClientTransport.mockImplementation(() => mockTransport) + + // Mock Client + const clientModule = await import("@modelcontextprotocol/sdk/client/index.js") + const Client = clientModule.Client as ReturnType + + const mockClient = { + connect: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + getInstructions: vi.fn().mockReturnValue("test instructions"), + request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }), + } + + Client.mockImplementation(() => mockClient) + + // Create server with watchPaths + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "watcher-test-server": { + command: "node", + args: ["test.js"], + watchPaths: ["/path/to/watch"], + }, + }, + }), + ) + + const mcpHub = new McpHub(mockProvider as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Verify watcher was created + expect(chokidar.watch).toHaveBeenCalledWith(["/path/to/watch"], expect.any(Object)) + + // Now disable the server + await mcpHub.toggleServerDisabled("watcher-test-server", true) + + // Verify watcher was closed + expect(mockWatcher.close).toHaveBeenCalled() + }) + + it("should clean up all file watchers when server is deleted", async () => { + // Get the mocked chokidar + const chokidar = (await import("chokidar")).default + const mockWatcher1 = { + on: vi.fn().mockReturnThis(), + close: vi.fn(), + } + const mockWatcher2 = { + on: vi.fn().mockReturnThis(), + close: vi.fn(), + } + + // Return different watchers for different paths + let watcherIndex = 0 + vi.mocked(chokidar.watch).mockImplementation(() => { + return (watcherIndex++ === 0 ? mockWatcher1 : mockWatcher2) as any + }) + + // Mock StdioClientTransport + const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js") + const StdioClientTransport = stdioModule.StdioClientTransport as ReturnType + + const mockTransport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + stderr: { + on: vi.fn(), + }, + onerror: null, + onclose: null, + } + + StdioClientTransport.mockImplementation(() => mockTransport) + + // Mock Client + const clientModule = await import("@modelcontextprotocol/sdk/client/index.js") + const Client = clientModule.Client as ReturnType + + const mockClient = { + connect: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + getInstructions: vi.fn().mockReturnValue("test instructions"), + request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }), + } + + Client.mockImplementation(() => mockClient) + + // Create server with multiple watchPaths + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "multi-watcher-server": { + command: "node", + args: ["test.js", "build/index.js"], // This will create a watcher for build/index.js + watchPaths: ["/path/to/watch1", "/path/to/watch2"], + }, + }, + }), + ) + + const mcpHub = new McpHub(mockProvider as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Verify watchers were created + expect(chokidar.watch).toHaveBeenCalled() + + // Delete the connection (this should clean up all watchers) + await mcpHub.deleteConnection("multi-watcher-server") + + // Verify all watchers were closed + expect(mockWatcher1.close).toHaveBeenCalled() + expect(mockWatcher2.close).toHaveBeenCalled() + }) + + it("should not create file watchers for disabled servers on initialization", async () => { + // Get the mocked chokidar + const chokidar = (await import("chokidar")).default + + // Create disabled server with watchPaths + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "disabled-watcher-server": { + command: "node", + args: ["test.js"], + watchPaths: ["/path/to/watch"], + disabled: true, + }, + }, + }), + ) + + vi.mocked(chokidar.watch).mockClear() + + const mcpHub = new McpHub(mockProvider as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Verify no watcher was created for disabled server + expect(chokidar.watch).not.toHaveBeenCalled() + }) + }) + + describe("DisableReason enum usage", () => { + it("should use MCP_DISABLED reason when MCP is globally disabled", async () => { + // Mock provider with mcpEnabled: false + mockProvider.getState = vi.fn().mockResolvedValue({ mcpEnabled: false }) + + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "mcp-disabled-server": { + command: "node", + args: ["test.js"], + }, + }, + }), + ) + + const mcpHub = new McpHub(mockProvider as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Find the connection + const connection = mcpHub.connections.find((conn) => conn.server.name === "mcp-disabled-server") + expect(connection).toBeDefined() + expect(connection?.type).toBe("disconnected") + expect(connection?.server.status).toBe("disconnected") + + // The server should not be marked as disabled individually + expect(connection?.server.disabled).toBeUndefined() + }) + + it("should use SERVER_DISABLED reason when server is individually disabled", async () => { + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "server-disabled-server": { + command: "node", + args: ["test.js"], + disabled: true, + }, + }, + }), + ) + + const mcpHub = new McpHub(mockProvider as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Find the connection + const connection = mcpHub.connections.find((conn) => conn.server.name === "server-disabled-server") + expect(connection).toBeDefined() + expect(connection?.type).toBe("disconnected") + expect(connection?.server.status).toBe("disconnected") + expect(connection?.server.disabled).toBe(true) + }) + + it("should handle both disable reasons correctly", async () => { + // First test with MCP globally disabled + mockProvider.getState = vi.fn().mockResolvedValue({ mcpEnabled: false }) + + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "both-reasons-server": { + command: "node", + args: ["test.js"], + disabled: true, // Server is also individually disabled + }, + }, + }), + ) + + const mcpHub = new McpHub(mockProvider as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Find the connection + const connection = mcpHub.connections.find((conn) => conn.server.name === "both-reasons-server") + expect(connection).toBeDefined() + expect(connection?.type).toBe("disconnected") + + // When MCP is globally disabled, it takes precedence + // The server's individual disabled state should be preserved + expect(connection?.server.disabled).toBe(true) + }) + }) + + describe("Null safety improvements", () => { + it("should handle null client safely in disconnected connections", async () => { + // Mock fs.readFile to return a disabled server config + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "null-safety-server": { + command: "node", + args: ["test.js"], + disabled: true, + }, + }, + }), + ) + + const mcpHub = new McpHub(mockProvider as ClineProvider) + + // Wait for initialization + await new Promise((resolve) => setTimeout(resolve, 100)) + + // The server should be created as a disconnected connection with null client/transport + const connection = mcpHub.connections.find((conn) => conn.server.name === "null-safety-server") + expect(connection).toBeDefined() + expect(connection?.type).toBe("disconnected") + + // Type guard to ensure it's a disconnected connection + if (connection?.type === "disconnected") { + expect(connection.client).toBeNull() + expect(connection.transport).toBeNull() + } + + // Try to call tool on disconnected server + await expect(mcpHub.callTool("null-safety-server", "test-tool", {})).rejects.toThrow( + "No connection found for server: null-safety-server", + ) + + // Try to read resource on disconnected server + await expect(mcpHub.readResource("null-safety-server", "test-uri")).rejects.toThrow( + "No connection found for server: null-safety-server", + ) + }) + + it("should handle connection type checks safely", async () => { + // Mock StdioClientTransport + const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js") + const StdioClientTransport = stdioModule.StdioClientTransport as ReturnType + + const mockTransport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + stderr: { + on: vi.fn(), + }, + onerror: null, + onclose: null, + } + + StdioClientTransport.mockImplementation(() => mockTransport) + + // Mock Client + const clientModule = await import("@modelcontextprotocol/sdk/client/index.js") + const Client = clientModule.Client as ReturnType + + const mockClient = { + connect: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + getInstructions: vi.fn().mockReturnValue("test instructions"), + request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }), + } + + Client.mockImplementation(() => mockClient) + + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "type-check-server": { + command: "node", + args: ["test.js"], + }, + }, + }), + ) + + const mcpHub = new McpHub(mockProvider as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Get the connection + const connection = mcpHub.connections.find((conn) => conn.server.name === "type-check-server") + expect(connection).toBeDefined() + + // Safe type checking + if (connection?.type === "connected") { + expect(connection.client).toBeDefined() + expect(connection.transport).toBeDefined() + } else if (connection?.type === "disconnected") { + expect(connection.client).toBeNull() + expect(connection.transport).toBeNull() + } + }) + + it("should handle missing connections safely", async () => { + const mcpHub = new McpHub(mockProvider as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Try operations on non-existent server + await expect(mcpHub.callTool("non-existent-server", "test-tool", {})).rejects.toThrow( + "No connection found for server: non-existent-server", + ) + + await expect(mcpHub.readResource("non-existent-server", "test-uri")).rejects.toThrow( + "No connection found for server: non-existent-server", + ) + }) + + it("should handle connection deletion safely", async () => { + // Mock StdioClientTransport + const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js") + const StdioClientTransport = stdioModule.StdioClientTransport as ReturnType + + const mockTransport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + stderr: { + on: vi.fn(), }, - }), - ) + onerror: null, + onclose: null, + } - mcpHub = new McpHub(mockProvider as ClineProvider) - }) + StdioClientTransport.mockImplementation(() => mockTransport) - afterEach(() => { - // Restore original console methods - console.error = originalConsoleError - // Restore original platform - if (originalPlatform) { - Object.defineProperty(process, "platform", originalPlatform) - } + // Mock Client + const clientModule = await import("@modelcontextprotocol/sdk/client/index.js") + const Client = clientModule.Client as ReturnType + + const mockClient = { + connect: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + getInstructions: vi.fn().mockReturnValue("test instructions"), + request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }), + } + + Client.mockImplementation(() => mockClient) + + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "delete-safety-server": { + command: "node", + args: ["test.js"], + }, + }, + }), + ) + + const mcpHub = new McpHub(mockProvider as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Delete the connection + await mcpHub.deleteConnection("delete-safety-server") + + // Verify connection is removed + const connection = mcpHub.connections.find((conn) => conn.server.name === "delete-safety-server") + expect(connection).toBeUndefined() + + // Verify transport and client were closed + expect(mockTransport.close).toHaveBeenCalled() + expect(mockClient.close).toHaveBeenCalled() + }) }) describe("toggleToolAlwaysAllow", () => { @@ -184,7 +780,8 @@ describe("McpHub", () => { vi.mocked(fs.readFile).mockResolvedValueOnce(JSON.stringify(mockConfig)) // Set up mock connection without alwaysAllow - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", type: "stdio", @@ -232,7 +829,8 @@ describe("McpHub", () => { vi.mocked(fs.readFile).mockResolvedValueOnce(JSON.stringify(mockConfig)) // Set up mock connection - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", type: "stdio", @@ -280,7 +878,8 @@ describe("McpHub", () => { vi.mocked(fs.readFile).mockResolvedValueOnce(JSON.stringify(mockConfig)) // Set up mock connection - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", type: "stdio", @@ -325,7 +924,8 @@ describe("McpHub", () => { } // Set up mock connection - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", config: "test-server-config", @@ -372,7 +972,8 @@ describe("McpHub", () => { } // Set up mock connection - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", config: "test-server-config", @@ -418,7 +1019,8 @@ describe("McpHub", () => { } // Set up mock connection - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", config: "test-server-config", @@ -468,7 +1070,8 @@ describe("McpHub", () => { vi.mocked(fs.readFile).mockResolvedValueOnce(JSON.stringify(mockConfig)) // Set up mock connection - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", type: "stdio", @@ -500,6 +1103,7 @@ describe("McpHub", () => { it("should filter out disabled servers from getServers", () => { const mockConnections: McpConnection[] = [ { + type: "connected", server: { name: "enabled-server", config: "{}", @@ -508,17 +1112,18 @@ describe("McpHub", () => { }, client: {} as any, transport: {} as any, - }, + } as ConnectedMcpConnection, { + type: "disconnected", server: { name: "disabled-server", config: "{}", - status: "connected", + status: "disconnected", disabled: true, }, - client: {} as any, - transport: {} as any, - }, + client: null, + transport: null, + } as DisconnectedMcpConnection, ] mcpHub.connections = mockConnections @@ -529,44 +1134,64 @@ describe("McpHub", () => { }) it("should prevent calling tools on disabled servers", async () => { - const mockConnection: McpConnection = { - server: { - name: "disabled-server", - config: "{}", - status: "connected", - disabled: true, - }, - client: { - request: vi.fn().mockResolvedValue({ result: "success" }), - } as any, - transport: {} as any, - } + // Mock fs.readFile to return a disabled server config + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "disabled-server": { + command: "node", + args: ["test.js"], + disabled: true, + }, + }, + }), + ) - mcpHub.connections = [mockConnection] + const mcpHub = new McpHub(mockProvider as ClineProvider) + + // Wait for initialization + await new Promise((resolve) => setTimeout(resolve, 100)) + // The server should be created as a disconnected connection + const connection = mcpHub.connections.find((conn) => conn.server.name === "disabled-server") + expect(connection).toBeDefined() + expect(connection?.type).toBe("disconnected") + expect(connection?.server.disabled).toBe(true) + + // Try to call tool on disabled server await expect(mcpHub.callTool("disabled-server", "some-tool", {})).rejects.toThrow( - 'Server "disabled-server" is disabled and cannot be used', + "No connection found for server: disabled-server", ) }) it("should prevent reading resources from disabled servers", async () => { - const mockConnection: McpConnection = { - server: { - name: "disabled-server", - config: "{}", - status: "connected", - disabled: true, - }, - client: { - request: vi.fn(), - } as any, - transport: {} as any, - } + // Mock fs.readFile to return a disabled server config + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "disabled-server": { + command: "node", + args: ["test.js"], + disabled: true, + }, + }, + }), + ) - mcpHub.connections = [mockConnection] + const mcpHub = new McpHub(mockProvider as ClineProvider) + + // Wait for initialization + await new Promise((resolve) => setTimeout(resolve, 100)) + // The server should be created as a disconnected connection + const connection = mcpHub.connections.find((conn) => conn.server.name === "disabled-server") + expect(connection).toBeDefined() + expect(connection?.type).toBe("disconnected") + expect(connection?.server.disabled).toBe(true) + + // Try to read resource from disabled server await expect(mcpHub.readResource("disabled-server", "some/uri")).rejects.toThrow( - 'Server "disabled-server" is disabled', + "No connection found for server: disabled-server", ) }) }) @@ -574,7 +1199,8 @@ describe("McpHub", () => { describe("callTool", () => { it("should execute tool successfully", async () => { // Mock the connection with a minimal client implementation - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", config: JSON.stringify({}), @@ -595,7 +1221,7 @@ describe("McpHub", () => { await mcpHub.callTool("test-server", "some-tool", {}) // Verify the request was made with correct parameters - expect(mockConnection.client.request).toHaveBeenCalledWith( + expect(mockConnection.client!.request).toHaveBeenCalledWith( { method: "tools/call", params: { @@ -637,7 +1263,8 @@ describe("McpHub", () => { }) it("should use default timeout of 60 seconds if not specified", async () => { - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", config: JSON.stringify({ type: "stdio", command: "test" }), // No timeout specified @@ -652,7 +1279,7 @@ describe("McpHub", () => { mcpHub.connections = [mockConnection] await mcpHub.callTool("test-server", "test-tool") - expect(mockConnection.client.request).toHaveBeenCalledWith( + expect(mockConnection.client!.request).toHaveBeenCalledWith( expect.anything(), expect.anything(), expect.objectContaining({ timeout: 60000 }), // 60 seconds in milliseconds @@ -660,7 +1287,8 @@ describe("McpHub", () => { }) it("should apply configured timeout to tool calls", async () => { - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", config: JSON.stringify({ type: "stdio", command: "test", timeout: 120 }), // 2 minutes @@ -675,7 +1303,7 @@ describe("McpHub", () => { mcpHub.connections = [mockConnection] await mcpHub.callTool("test-server", "test-tool") - expect(mockConnection.client.request).toHaveBeenCalledWith( + expect(mockConnection.client!.request).toHaveBeenCalledWith( expect.anything(), expect.anything(), expect.objectContaining({ timeout: 120000 }), // 120 seconds in milliseconds @@ -700,7 +1328,8 @@ describe("McpHub", () => { vi.mocked(fs.readFile).mockResolvedValueOnce(JSON.stringify(mockConfig)) // Set up mock connection - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", type: "stdio", @@ -745,7 +1374,8 @@ describe("McpHub", () => { vi.mocked(fs.readFile).mockResolvedValueOnce(JSON.stringify(mockConfig)) // Set up mock connection before updating - const mockConnectionInitial: McpConnection = { + const mockConnectionInitial: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", type: "stdio", @@ -768,7 +1398,8 @@ describe("McpHub", () => { expect(fs.writeFile).toHaveBeenCalled() // Setup connection with invalid timeout - const mockConnectionInvalid: McpConnection = { + const mockConnectionInvalid: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", config: JSON.stringify({ @@ -791,7 +1422,7 @@ describe("McpHub", () => { await mcpHub.callTool("test-server", "test-tool") // Verify default timeout was used - expect(mockConnectionInvalid.client.request).toHaveBeenCalledWith( + expect(mockConnectionInvalid.client!.request).toHaveBeenCalledWith( expect.anything(), expect.anything(), expect.objectContaining({ timeout: 60000 }), // Default 60 seconds @@ -813,7 +1444,8 @@ describe("McpHub", () => { vi.mocked(fs.readFile).mockResolvedValueOnce(JSON.stringify(mockConfig)) // Set up mock connection - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", type: "stdio", @@ -852,7 +1484,8 @@ describe("McpHub", () => { vi.mocked(fs.readFile).mockResolvedValueOnce(JSON.stringify(mockConfig)) // Set up mock connection - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", type: "stdio", @@ -877,6 +1510,291 @@ describe("McpHub", () => { }) }) + describe("MCP global enable/disable", () => { + beforeEach(() => { + // Clear all mocks before each test + vi.clearAllMocks() + }) + + it("should disconnect all servers when MCP is toggled from enabled to disabled", async () => { + // Mock StdioClientTransport + const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js") + const StdioClientTransport = stdioModule.StdioClientTransport as ReturnType + + const mockTransport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + stderr: { + on: vi.fn(), + }, + onerror: null, + onclose: null, + } + + StdioClientTransport.mockImplementation(() => mockTransport) + + // Mock Client + const clientModule = await import("@modelcontextprotocol/sdk/client/index.js") + const Client = clientModule.Client as ReturnType + + const mockClient = { + connect: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + getInstructions: vi.fn().mockReturnValue("test instructions"), + request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }), + } + + Client.mockImplementation(() => mockClient) + + // Start with MCP enabled + mockProvider.getState = vi.fn().mockResolvedValue({ mcpEnabled: true }) + + // Mock the config file read + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "toggle-test-server": { + command: "node", + args: ["test.js"], + }, + }, + }), + ) + + // Create McpHub and let it initialize with MCP enabled + const mcpHub = new McpHub(mockProvider as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Verify server is connected + const connectedServer = mcpHub.connections.find((conn) => conn.server.name === "toggle-test-server") + expect(connectedServer).toBeDefined() + expect(connectedServer!.server.status).toBe("connected") + expect(connectedServer!.client).toBeDefined() + expect(connectedServer!.transport).toBeDefined() + + // Now simulate toggling MCP to disabled + mockProvider.getState = vi.fn().mockResolvedValue({ mcpEnabled: false }) + + // Manually trigger what would happen when MCP is disabled + // (normally this would be triggered by the webview message handler) + const existingConnections = [...mcpHub.connections] + for (const conn of existingConnections) { + await mcpHub.deleteConnection(conn.server.name, conn.server.source) + } + await mcpHub.refreshAllConnections() + + // Verify server is now tracked but disconnected + const disconnectedServer = mcpHub.connections.find((conn) => conn.server.name === "toggle-test-server") + expect(disconnectedServer).toBeDefined() + expect(disconnectedServer!.server.status).toBe("disconnected") + expect(disconnectedServer!.client).toBeNull() + expect(disconnectedServer!.transport).toBeNull() + + // Verify close was called on the original client and transport + expect(mockClient.close).toHaveBeenCalled() + expect(mockTransport.close).toHaveBeenCalled() + }) + + it("should not connect to servers when MCP is globally disabled", async () => { + // Mock provider with mcpEnabled: false + const disabledMockProvider = { + ensureSettingsDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"), + ensureMcpServersDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"), + postMessageToWebview: vi.fn(), + getState: vi.fn().mockResolvedValue({ mcpEnabled: false }), + context: mockProvider.context, + } + + // Mock the config file read with a different server name to avoid conflicts + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "disabled-test-server": { + command: "node", + args: ["test.js"], + }, + }, + }), + ) + + // Create a new McpHub instance with disabled MCP + const mcpHub = new McpHub(disabledMockProvider as unknown as ClineProvider) + + // Wait for initialization + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Find the disabled-test-server + const disabledServer = mcpHub.connections.find((conn) => conn.server.name === "disabled-test-server") + + // Verify that the server is tracked but not connected + expect(disabledServer).toBeDefined() + expect(disabledServer!.server.status).toBe("disconnected") + expect(disabledServer!.client).toBeNull() + expect(disabledServer!.transport).toBeNull() + }) + + it("should connect to servers when MCP is globally enabled", async () => { + // Clear all mocks + vi.clearAllMocks() + + // Mock StdioClientTransport + const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js") + const StdioClientTransport = stdioModule.StdioClientTransport as ReturnType + + const mockTransport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + stderr: { + on: vi.fn(), + }, + onerror: null, + onclose: null, + } + + StdioClientTransport.mockImplementation(() => mockTransport) + + // Mock Client + const clientModule = await import("@modelcontextprotocol/sdk/client/index.js") + const Client = clientModule.Client as ReturnType + + Client.mockImplementation(() => ({ + connect: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + getInstructions: vi.fn().mockReturnValue("test instructions"), + request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }), + })) + + // Mock provider with mcpEnabled: true + const enabledMockProvider = { + ensureSettingsDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"), + ensureMcpServersDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"), + postMessageToWebview: vi.fn(), + getState: vi.fn().mockResolvedValue({ mcpEnabled: true }), + context: mockProvider.context, + } + + // Mock the config file read with a different server name + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "enabled-test-server": { + command: "node", + args: ["test.js"], + }, + }, + }), + ) + + // Create a new McpHub instance with enabled MCP + const mcpHub = new McpHub(enabledMockProvider as unknown as ClineProvider) + + // Wait for initialization + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Find the enabled-test-server + const enabledServer = mcpHub.connections.find((conn) => conn.server.name === "enabled-test-server") + + // Verify that the server is connected + expect(enabledServer).toBeDefined() + expect(enabledServer!.server.status).toBe("connected") + expect(enabledServer!.client).toBeDefined() + expect(enabledServer!.transport).toBeDefined() + + // Verify StdioClientTransport was called + expect(StdioClientTransport).toHaveBeenCalled() + }) + + it("should handle refreshAllConnections when MCP is disabled", async () => { + // Mock provider with mcpEnabled: false + const disabledMockProvider = { + ensureSettingsDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"), + ensureMcpServersDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"), + postMessageToWebview: vi.fn(), + getState: vi.fn().mockResolvedValue({ mcpEnabled: false }), + context: mockProvider.context, + } + + // Mock the config file read + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "refresh-test-server": { + command: "node", + args: ["test.js"], + }, + }, + }), + ) + + // Create McpHub with disabled MCP + const mcpHub = new McpHub(disabledMockProvider as unknown as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Clear previous calls + vi.clearAllMocks() + + // Call refreshAllConnections + await mcpHub.refreshAllConnections() + + // Verify that servers are tracked but not connected + const server = mcpHub.connections.find((conn) => conn.server.name === "refresh-test-server") + expect(server).toBeDefined() + expect(server!.server.status).toBe("disconnected") + expect(server!.client).toBeNull() + expect(server!.transport).toBeNull() + + // Verify postMessageToWebview was called to update the UI + expect(disabledMockProvider.postMessageToWebview).toHaveBeenCalledWith( + expect.objectContaining({ + type: "mcpServers", + }), + ) + }) + + it("should skip restarting connection when MCP is disabled", async () => { + // Mock provider with mcpEnabled: false + const disabledMockProvider = { + ensureSettingsDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"), + ensureMcpServersDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"), + postMessageToWebview: vi.fn(), + getState: vi.fn().mockResolvedValue({ mcpEnabled: false }), + context: mockProvider.context, + } + + // Mock the config file read + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "restart-test-server": { + command: "node", + args: ["test.js"], + }, + }, + }), + ) + + // Create McpHub with disabled MCP + const mcpHub = new McpHub(disabledMockProvider as unknown as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Set isConnecting to false to ensure it's properly reset + mcpHub.isConnecting = false + + // Try to restart a connection + await mcpHub.restartConnection("restart-test-server") + + // Verify that isConnecting was reset to false + expect(mcpHub.isConnecting).toBe(false) + + // Verify that the server remains disconnected + const server = mcpHub.connections.find((conn) => conn.server.name === "restart-test-server") + expect(server).toBeDefined() + expect(server!.server.status).toBe("disconnected") + expect(server!.client).toBeNull() + expect(server!.transport).toBeNull() + }) + }) + describe("Windows command wrapping", () => { let StdioClientTransport: ReturnType let Client: ReturnType diff --git a/webview-ui/src/components/mcp/McpView.tsx b/webview-ui/src/components/mcp/McpView.tsx index 0873bde1957..21ad1c26525 100644 --- a/webview-ui/src/components/mcp/McpView.tsx +++ b/webview-ui/src/components/mcp/McpView.tsx @@ -206,6 +206,9 @@ const ServerRow = ({ server, alwaysAllowMcp }: { server: McpServer; alwaysAllowM return configTimeout ?? 60 // Default 1 minute (60 seconds) }) + // Computed property to check if server is expandable + const isExpandable = server.status === "connected" && !server.disabled + const timeoutOptions = [ { value: 15, label: t("mcp:networkTimeout.options.15seconds") }, { value: 30, label: t("mcp:networkTimeout.options.30seconds") }, @@ -218,6 +221,11 @@ const ServerRow = ({ server, alwaysAllowMcp }: { server: McpServer; alwaysAllowM ] const getStatusColor = () => { + // Disabled servers should always show grey regardless of connection status + if (server.disabled) { + return "var(--vscode-descriptionForeground)" + } + switch (server.status) { case "connected": return "var(--vscode-testing-iconPassed)" @@ -229,7 +237,8 @@ const ServerRow = ({ server, alwaysAllowMcp }: { server: McpServer; alwaysAllowM } const handleRowClick = () => { - if (server.status === "connected") { + // Only allow expansion for connected and enabled servers + if (isExpandable) { setIsExpanded(!isExpanded) } } @@ -270,12 +279,12 @@ const ServerRow = ({ server, alwaysAllowMcp }: { server: McpServer; alwaysAllowM alignItems: "center", padding: "8px", background: "var(--vscode-textCodeBlock-background)", - cursor: server.status === "connected" ? "pointer" : "default", - borderRadius: isExpanded || server.status === "connected" ? "4px" : "4px 4px 0 0", + cursor: isExpandable ? "pointer" : "default", + borderRadius: isExpanded || isExpandable ? "4px" : "4px 4px 0 0", opacity: server.disabled ? 0.6 : 1, }} onClick={handleRowClick}> - {server.status === "connected" && ( + {isExpandable && ( - {server.status === "connected" ? ( - isExpanded && ( -
- - - {t("mcp:tabs.tools")} ({server.tools?.length || 0}) - - - {t("mcp:tabs.resources")} ( - {[...(server.resourceTemplates || []), ...(server.resources || [])].length || 0}) - - {server.instructions && ( - {t("mcp:instructions")} - )} - - {t("mcp:tabs.errors")} ({server.errorHistory?.length || 0}) - - - - {server.tools && server.tools.length > 0 ? ( -
- {server.tools.map((tool) => ( - - ))} -
- ) : ( -
- {t("mcp:emptyState.noTools")} -
+ {isExpandable + ? isExpanded && ( +
+ + + {t("mcp:tabs.tools")} ({server.tools?.length || 0}) + + + {t("mcp:tabs.resources")} ( + {[...(server.resourceTemplates || []), ...(server.resources || [])].length || 0}) + + {server.instructions && ( + {t("mcp:instructions")} )} - + + {t("mcp:tabs.errors")} ({server.errorHistory?.length || 0}) + - - {(server.resources && server.resources.length > 0) || - (server.resourceTemplates && server.resourceTemplates.length > 0) ? ( -
- {[...(server.resourceTemplates || []), ...(server.resources || [])].map( - (item) => ( - + {server.tools && server.tools.length > 0 ? ( +
+ {server.tools.map((tool) => ( + - ), - )} -
- ) : ( -
- {t("mcp:emptyState.noResources")} -
- )} - + ))} +
+ ) : ( +
+ {t("mcp:emptyState.noTools")} +
+ )} +
- {server.instructions && ( - -
-
- {server.instructions} + + {(server.resources && server.resources.length > 0) || + (server.resourceTemplates && server.resourceTemplates.length > 0) ? ( +
+ {[...(server.resourceTemplates || []), ...(server.resources || [])].map( + (item) => ( + + ), + )} +
+ ) : ( +
+ {t("mcp:emptyState.noResources")}
-
+ )} - )} - - {server.errorHistory && server.errorHistory.length > 0 ? ( -
- {[...server.errorHistory] - .sort((a, b) => b.timestamp - a.timestamp) - .map((error, index) => ( - - ))} -
- ) : ( -
- {t("mcp:emptyState.noErrors")} -
+ {server.instructions && ( + +
+
+ {server.instructions} +
+
+
)} -
- - {/* Network Timeout */} -
+ + {server.errorHistory && server.errorHistory.length > 0 ? ( +
+ {[...server.errorHistory] + .sort((a, b) => b.timestamp - a.timestamp) + .map((error, index) => ( + + ))} +
+ ) : ( +
+ {t("mcp:emptyState.noErrors")} +
+ )} +
+ + + {/* Network Timeout */} +
+
+ {t("mcp:networkTimeout.label")} + +
+ + {t("mcp:networkTimeout.description")} + +
+
+ ) + : // Only show error UI for non-disabled servers + !server.disabled && ( +
- {t("mcp:networkTimeout.label")} -
- - {t("mcp:networkTimeout.description")} - + + {server.status === "connecting" + ? t("mcp:serverStatus.retrying") + : t("mcp:serverStatus.retryConnection")} +
-
- ) - ) : ( -
-
- {server.error && - server.error.split("\n").map((item, index) => ( - - {index > 0 &&
} - {item} -
- ))} -
- - {server.status === "connecting" - ? t("mcp:serverStatus.retrying") - : t("mcp:serverStatus.retryConnection")} - -
- )} + )} {/* Delete Confirmation Dialog */}