diff --git a/packages/opencode/src/mcp/index.ts b/packages/opencode/src/mcp/index.ts index fb4a8d4cf81..98d75242a7b 100644 --- a/packages/opencode/src/mcp/index.ts +++ b/packages/opencode/src/mcp/index.ts @@ -54,6 +54,22 @@ export namespace MCP { ) type MCPClient = Client + type ToolsCacheEntry = { + seq: number + tools: MCPToolDef[] + } + + function cloneToolDefs(tools: MCPToolDef[]) { + return tools.map((tool) => ({ ...tool })) + } + + function invalidateToolsCache(state: { + toolsSeq: Record + toolsCache: Record + }, serverName: string) { + state.toolsSeq[serverName] = (state.toolsSeq[serverName] ?? 0) + 1 + delete state.toolsCache[serverName] + } export const Status = z .discriminatedUnion("status", [ @@ -159,6 +175,13 @@ export namespace MCP { const config = cfg.mcp ?? {} const clients: Record = {} const status: Record = {} + const toolsCache: Record = {} + const toolsSeq: Record = {} + const unsubscribeToolsChanged = Bus.subscribe(ToolsChanged, (event) => { + const serverName = event.properties.server + toolsSeq[serverName] = (toolsSeq[serverName] ?? 0) + 1 + delete toolsCache[serverName] + }) await Promise.all( Object.entries(config).map(async ([key, mcp]) => { @@ -186,9 +209,13 @@ export namespace MCP { return { status, clients, + toolsCache, + toolsSeq, + unsubscribeToolsChanged, } }, async (state) => { + state.unsubscribeToolsChanged?.() await Promise.all( Object.values(state.clients).map((client) => client.close().catch((error) => { @@ -198,6 +225,12 @@ export namespace MCP { }), ), ) + for (const key of Object.keys(state.toolsCache)) { + delete state.toolsCache[key] + } + for (const key of Object.keys(state.toolsSeq)) { + delete state.toolsSeq[key] + } pendingOAuthTransports.clear() }, ) @@ -256,18 +289,21 @@ export namespace MCP { error: "unknown error", } s.status[name] = status + invalidateToolsCache(s, name) return { status, } } if (!result.mcpClient) { s.status[name] = result.status + invalidateToolsCache(s, name) return { status: s.status, } } s.clients[name] = result.mcpClient s.status[name] = result.status + invalidateToolsCache(s, name) return { status: s.status, @@ -525,6 +561,7 @@ export namespace MCP { if (result.mcpClient) { s.clients[name] = result.mcpClient } + invalidateToolsCache(s, name) } export async function disconnect(name: string) { @@ -537,6 +574,7 @@ export namespace MCP { delete s.clients[name] } s.status[name] = { status: "disabled" } + invalidateToolsCache(s, name) } export async function tools() { @@ -550,20 +588,35 @@ export namespace MCP { continue } - const toolsResult = await client.listTools().catch((e) => { - log.error("failed to get tools", { clientName, error: e.message }) - const failedStatus = { - status: "failed" as const, - error: e instanceof Error ? e.message : String(e), + const currentSeq = s.toolsSeq[clientName] ?? 0 + const cached = s.toolsCache[clientName] + let toolDefs: MCPToolDef[] | undefined + + if (cached && cached.seq === currentSeq) { + toolDefs = cloneToolDefs(cached.tools) + } else { + const toolsResult = await client.listTools().catch((e) => { + log.error("failed to get tools", { clientName, error: e.message }) + const failedStatus = { + status: "failed" as const, + error: e instanceof Error ? e.message : String(e), + } + s.status[clientName] = failedStatus + delete s.clients[clientName] + invalidateToolsCache(s, clientName) + return undefined + }) + if (!toolsResult) { + continue + } + toolDefs = toolsResult.tools + s.toolsCache[clientName] = { + seq: currentSeq, + tools: toolsResult.tools, } - s.status[clientName] = failedStatus - delete s.clients[clientName] - return undefined - }) - if (!toolsResult) { - continue } - for (const mcpTool of toolsResult.tools) { + + for (const mcpTool of toolDefs) { const sanitizedClientName = clientName.replace(/[^a-zA-Z0-9_-]/g, "_") const sanitizedToolName = mcpTool.name.replace(/[^a-zA-Z0-9_-]/g, "_") result[sanitizedClientName + "_" + sanitizedToolName] = await convertMcpTool(mcpTool, client) diff --git a/packages/opencode/test/mcp/tools-cache.test.ts b/packages/opencode/test/mcp/tools-cache.test.ts new file mode 100644 index 00000000000..35e97d80349 --- /dev/null +++ b/packages/opencode/test/mcp/tools-cache.test.ts @@ -0,0 +1,143 @@ +import { test, expect, beforeEach, mock } from "bun:test" + +let listToolsCalls = 0 +let listToolsResponses: Array< + Array<{ + name: string + description?: string + inputSchema: Record + }> +> = [] + +class MockClient { + async connect() { + return + } + + async close() { + return + } + + setNotificationHandler() { + return + } + + async listTools() { + const tools = listToolsResponses[listToolsCalls] ?? [] + listToolsCalls += 1 + return { tools } + } + + async callTool() { + throw new Error("not implemented") + } +} + +mock.module("@modelcontextprotocol/sdk/client/index.js", () => ({ + Client: MockClient, +})) + +mock.module("@modelcontextprotocol/sdk/client/streamableHttp.js", () => ({ + StreamableHTTPClientTransport: class MockStreamableHTTP { + constructor() { + return + } + }, +})) + +mock.module("@modelcontextprotocol/sdk/client/sse.js", () => ({ + SSEClientTransport: class MockSSE { + constructor() { + return + } + }, +})) + +beforeEach(() => { + listToolsCalls = 0 + listToolsResponses = [] +}) + +const { MCP } = await import("../../src/mcp/index") +const { Bus } = await import("../../src/bus") +const { Instance } = await import("../../src/project/instance") +const { tmpdir } = await import("../fixture/fixture") + +test("MCP.tools caches listTools and invalidates on ToolsChanged", async () => { + listToolsResponses = [ + [ + { + name: "alpha", + description: "alpha tool", + inputSchema: { type: "object", properties: {} }, + }, + ], + [ + { + name: "beta", + description: "beta tool", + inputSchema: { type: "object", properties: {} }, + }, + ], + ] + + await using tmp = await tmpdir() + + await Instance.provide({ + directory: tmp.path, + fn: async () => { + await MCP.add("test", { + type: "remote", + url: "https://example.com/mcp", + }) + + // Ignore the listTools call in create() + listToolsCalls = 0 + + const tools1 = await MCP.tools() + expect(listToolsCalls).toBe(1) + expect(Object.keys(tools1)).toEqual(["test_alpha"]) + + const tools2 = await MCP.tools() + expect(listToolsCalls).toBe(1) + expect(Object.keys(tools2)).toEqual(["test_alpha"]) + + await Bus.publish(MCP.ToolsChanged, { server: "test" }) + + const tools3 = await MCP.tools() + expect(listToolsCalls).toBe(2) + expect(Object.keys(tools3)).toEqual(["test_beta"]) + + await Instance.dispose() + }, + }) +}) + +test("MCP.tools caches empty tools as valid", async () => { + listToolsResponses = [[]] + + await using tmp = await tmpdir() + + await Instance.provide({ + directory: tmp.path, + fn: async () => { + await MCP.add("empty", { + type: "remote", + url: "https://example.com/mcp", + }) + + // Ignore the listTools call in create() + listToolsCalls = 0 + + const tools1 = await MCP.tools() + expect(listToolsCalls).toBe(1) + expect(Object.keys(tools1)).toEqual([]) + + const tools2 = await MCP.tools() + expect(listToolsCalls).toBe(1) + expect(Object.keys(tools2)).toEqual([]) + + await Instance.dispose() + }, + }) +})