Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 65 additions & 12 deletions packages/opencode/src/mcp/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,22 @@ export namespace MCP {
)

type MCPClient = Client
type ToolsCacheEntry = {
seq: number
tools: MCPToolDef[]
}

function cloneToolDefs(tools: MCPToolDef[]) {
return tools.map((tool) => ({ ...tool }))
}

function invalidateToolsCache(state: {
toolsSeq: Record<string, number>
toolsCache: Record<string, ToolsCacheEntry>
}, serverName: string) {
state.toolsSeq[serverName] = (state.toolsSeq[serverName] ?? 0) + 1
delete state.toolsCache[serverName]
}

export const Status = z
.discriminatedUnion("status", [
Expand Down Expand Up @@ -159,6 +175,13 @@ export namespace MCP {
const config = cfg.mcp ?? {}
const clients: Record<string, MCPClient> = {}
const status: Record<string, Status> = {}
const toolsCache: Record<string, ToolsCacheEntry> = {}
const toolsSeq: Record<string, number> = {}
const unsubscribeToolsChanged = Bus.subscribe(ToolsChanged, (event) => {
const serverName = event.properties.server
toolsSeq[serverName] = (toolsSeq[serverName] ?? 0) + 1
delete toolsCache[serverName]
})

await Promise.all(
Object.entries(config).map(async ([key, mcp]) => {
Expand Down Expand Up @@ -186,9 +209,13 @@ export namespace MCP {
return {
status,
clients,
toolsCache,
toolsSeq,
unsubscribeToolsChanged,
}
},
async (state) => {
state.unsubscribeToolsChanged?.()
await Promise.all(
Object.values(state.clients).map((client) =>
client.close().catch((error) => {
Expand All @@ -198,6 +225,12 @@ export namespace MCP {
}),
),
)
for (const key of Object.keys(state.toolsCache)) {
delete state.toolsCache[key]
}
for (const key of Object.keys(state.toolsSeq)) {
delete state.toolsSeq[key]
}
pendingOAuthTransports.clear()
},
)
Expand Down Expand Up @@ -256,18 +289,21 @@ export namespace MCP {
error: "unknown error",
}
s.status[name] = status
invalidateToolsCache(s, name)
return {
status,
}
}
if (!result.mcpClient) {
s.status[name] = result.status
invalidateToolsCache(s, name)
return {
status: s.status,
}
}
s.clients[name] = result.mcpClient
s.status[name] = result.status
invalidateToolsCache(s, name)

return {
status: s.status,
Expand Down Expand Up @@ -525,6 +561,7 @@ export namespace MCP {
if (result.mcpClient) {
s.clients[name] = result.mcpClient
}
invalidateToolsCache(s, name)
}

export async function disconnect(name: string) {
Expand All @@ -537,6 +574,7 @@ export namespace MCP {
delete s.clients[name]
}
s.status[name] = { status: "disabled" }
invalidateToolsCache(s, name)
}

export async function tools() {
Expand All @@ -550,20 +588,35 @@ export namespace MCP {
continue
}

const toolsResult = await client.listTools().catch((e) => {
log.error("failed to get tools", { clientName, error: e.message })
const failedStatus = {
status: "failed" as const,
error: e instanceof Error ? e.message : String(e),
const currentSeq = s.toolsSeq[clientName] ?? 0
const cached = s.toolsCache[clientName]
let toolDefs: MCPToolDef[] | undefined

if (cached && cached.seq === currentSeq) {
toolDefs = cloneToolDefs(cached.tools)
} else {
const toolsResult = await client.listTools().catch((e) => {
log.error("failed to get tools", { clientName, error: e.message })
const failedStatus = {
status: "failed" as const,
error: e instanceof Error ? e.message : String(e),
}
s.status[clientName] = failedStatus
delete s.clients[clientName]
invalidateToolsCache(s, clientName)
return undefined
})
if (!toolsResult) {
continue
}
toolDefs = toolsResult.tools
s.toolsCache[clientName] = {
seq: currentSeq,
tools: toolsResult.tools,
}
s.status[clientName] = failedStatus
delete s.clients[clientName]
return undefined
})
if (!toolsResult) {
continue
}
for (const mcpTool of toolsResult.tools) {

for (const mcpTool of toolDefs) {
const sanitizedClientName = clientName.replace(/[^a-zA-Z0-9_-]/g, "_")
const sanitizedToolName = mcpTool.name.replace(/[^a-zA-Z0-9_-]/g, "_")
result[sanitizedClientName + "_" + sanitizedToolName] = await convertMcpTool(mcpTool, client)
Expand Down
143 changes: 143 additions & 0 deletions packages/opencode/test/mcp/tools-cache.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import { test, expect, beforeEach, mock } from "bun:test"

let listToolsCalls = 0
let listToolsResponses: Array<
Array<{
name: string
description?: string
inputSchema: Record<string, unknown>
}>
> = []

class MockClient {
async connect() {
return
}

async close() {
return
}

setNotificationHandler() {
return
}

async listTools() {
const tools = listToolsResponses[listToolsCalls] ?? []
listToolsCalls += 1
return { tools }
}

async callTool() {
throw new Error("not implemented")
}
}

mock.module("@modelcontextprotocol/sdk/client/index.js", () => ({
Client: MockClient,
}))

mock.module("@modelcontextprotocol/sdk/client/streamableHttp.js", () => ({
StreamableHTTPClientTransport: class MockStreamableHTTP {
constructor() {
return
}
},
}))

mock.module("@modelcontextprotocol/sdk/client/sse.js", () => ({
SSEClientTransport: class MockSSE {
constructor() {
return
}
},
}))

beforeEach(() => {
listToolsCalls = 0
listToolsResponses = []
})

const { MCP } = await import("../../src/mcp/index")
const { Bus } = await import("../../src/bus")
const { Instance } = await import("../../src/project/instance")
const { tmpdir } = await import("../fixture/fixture")

test("MCP.tools caches listTools and invalidates on ToolsChanged", async () => {
listToolsResponses = [
[
{
name: "alpha",
description: "alpha tool",
inputSchema: { type: "object", properties: {} },
},
],
[
{
name: "beta",
description: "beta tool",
inputSchema: { type: "object", properties: {} },
},
],
]

await using tmp = await tmpdir()

await Instance.provide({
directory: tmp.path,
fn: async () => {
await MCP.add("test", {
type: "remote",
url: "https://example.com/mcp",
})

// Ignore the listTools call in create()
listToolsCalls = 0

const tools1 = await MCP.tools()
expect(listToolsCalls).toBe(1)
expect(Object.keys(tools1)).toEqual(["test_alpha"])

const tools2 = await MCP.tools()
expect(listToolsCalls).toBe(1)
expect(Object.keys(tools2)).toEqual(["test_alpha"])

await Bus.publish(MCP.ToolsChanged, { server: "test" })

const tools3 = await MCP.tools()
expect(listToolsCalls).toBe(2)
expect(Object.keys(tools3)).toEqual(["test_beta"])

await Instance.dispose()
},
})
})

test("MCP.tools caches empty tools as valid", async () => {
listToolsResponses = [[]]

await using tmp = await tmpdir()

await Instance.provide({
directory: tmp.path,
fn: async () => {
await MCP.add("empty", {
type: "remote",
url: "https://example.com/mcp",
})

// Ignore the listTools call in create()
listToolsCalls = 0

const tools1 = await MCP.tools()
expect(listToolsCalls).toBe(1)
expect(Object.keys(tools1)).toEqual([])

const tools2 = await MCP.tools()
expect(listToolsCalls).toBe(1)
expect(Object.keys(tools2)).toEqual([])

await Instance.dispose()
},
})
})
Loading