Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions packages/core/src/tools/mcp-client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -901,9 +901,9 @@ describe('mcp-client', () => {
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
);
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue({
close: vi.fn(),
} as unknown as SdkClientStdioLib.StdioClientTransport);
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
const mockedToolRegistry = {
registerTool: vi.fn(),
unregisterTool: vi.fn(),
Expand Down Expand Up @@ -1971,7 +1971,7 @@ describe('connectToMcpServer with OAuth', () => {
EMPTY_CONFIG,
);

expect(client.client).toBe(mockedClient);
expect(client).toBe(mockedClient);
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();

Expand Down Expand Up @@ -2017,7 +2017,7 @@ describe('connectToMcpServer with OAuth', () => {
EMPTY_CONFIG,
);

expect(client.client).toBe(mockedClient);
expect(client).toBe(mockedClient);
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
expect(OAuthUtils.discoverOAuthConfig).toHaveBeenCalledWith(serverUrl);
Expand Down Expand Up @@ -2112,7 +2112,7 @@ describe('connectToMcpServer - HTTP→SSE fallback', () => {
EMPTY_CONFIG,
);

expect(client.client).toBe(mockedClient);
expect(client).toBe(mockedClient);
// First HTTP attempt fails, second SSE attempt succeeds
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
});
Expand Down Expand Up @@ -2153,7 +2153,7 @@ describe('connectToMcpServer - HTTP→SSE fallback', () => {
EMPTY_CONFIG,
);

expect(client.client).toBe(mockedClient);
expect(client).toBe(mockedClient);
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
});
});
Expand Down Expand Up @@ -2238,7 +2238,7 @@ describe('connectToMcpServer - OAuth with transport fallback', () => {
EMPTY_CONFIG,
);

expect(client.client).toBe(mockedClient);
expect(client).toBe(mockedClient);
expect(mockedClient.connect).toHaveBeenCalledTimes(3);
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
});
Expand Down
68 changes: 17 additions & 51 deletions packages/core/src/tools/mcp-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -142,21 +142,19 @@ export class McpClient {
}
this.updateStatus(MCPServerStatus.CONNECTING);
try {
const { client, transport } = await connectToMcpServer(
this.client = await connectToMcpServer(
this.clientVersion,
this.serverName,
this.serverConfig,
this.debugMode,
this.workspaceContext,
this.cliConfig.sanitizationConfig,
);
Comment on lines +145 to 152
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The removal of the this.transport = transport assignment leaves the transport property of the McpClient class (defined at line 112) permanently uninitialized. This breaks the cleanup logic in the disconnect() method (lines 213-215), which still attempts to close this.transport. Since connectToMcpServer now only returns the Client instance, the direct reference to the transport is lost. To maintain correctness, the transport property should be removed from the class, and the disconnect() method should be updated to rely solely on this.client.close() (which internally closes the transport).

this.client = client;
this.transport = transport;

this.registerNotificationHandlers();

const originalOnError = this.client.onerror;
this.client.onerror = async (error) => {
this.client.onerror = (error) => {
if (this.status !== MCPServerStatus.CONNECTED) {
return;
}
Expand All @@ -167,14 +165,6 @@ export class McpClient {
error,
);
this.updateStatus(MCPServerStatus.DISCONNECTED);
// Close transport to prevent memory leaks
if (this.transport) {
try {
await this.transport.close();
} catch {
// Ignore errors when closing transport on error
}
}
};
this.updateStatus(MCPServerStatus.CONNECTED);
} catch (error) {
Expand Down Expand Up @@ -923,30 +913,19 @@ export async function connectAndDiscover(
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);

let mcpClient: Client | undefined;
let transport: Transport | undefined;
try {
const result = await connectToMcpServer(
mcpClient = await connectToMcpServer(
clientVersion,
mcpServerName,
mcpServerConfig,
debugMode,
workspaceContext,
cliConfig.sanitizationConfig,
);
mcpClient = result.client;
transport = result.transport;

mcpClient.onerror = async (error) => {
mcpClient.onerror = (error) => {
coreEvents.emitFeedback('error', `MCP ERROR (${mcpServerName}):`, error);
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
// Close transport to prevent memory leaks
if (transport) {
try {
await transport.close();
} catch {
// Ignore errors when closing transport on error
}
}
};
Comment on lines +926 to 929
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the McpClient class, the onerror handler in connectAndDiscover no longer performs cleanup. While the catch block at the end of the function handles errors during the discovery process, an asynchronous error triggered via onerror should also ensure the client is closed to prevent leaking resources like subprocesses or network connections.

Suggested change
mcpClient.onerror = (error) => {
coreEvents.emitFeedback('error', `MCP ERROR (${mcpServerName}):`, error);
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
// Close transport to prevent memory leaks
if (transport) {
try {
await transport.close();
} catch {
// Ignore errors when closing transport on error
}
}
};
mcpClient.onerror = (error) => {
coreEvents.emitFeedback('error', `MCP ERROR (${mcpServerName}):`, error);
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
mcpClient?.close().catch(() => {});
};


// Attempt to discover both prompts and tools
Expand Down Expand Up @@ -1344,18 +1323,16 @@ function createSSETransportWithAuth(
* @param client The MCP client to connect
* @param config The MCP server configuration
* @param accessToken Optional OAuth access token for authentication
* @returns The transport used for connection
*/
async function connectWithSSETransport(
client: Client,
config: MCPServerConfig,
accessToken?: string | null,
): Promise<Transport> {
): Promise<void> {
const transport = createSSETransportWithAuth(config, accessToken);
await client.connect(transport, {
timeout: config.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
});
return transport;
}

/**
Expand Down Expand Up @@ -1385,29 +1362,24 @@ async function showAuthRequiredMessage(serverName: string): Promise<never> {
* @param config The MCP server configuration
* @param accessToken The OAuth access token to use
* @param httpReturned404 Whether the HTTP transport returned 404 (indicating SSE-only server)
* @returns The transport used for connection
*/
async function retryWithOAuth(
client: Client,
serverName: string,
config: MCPServerConfig,
accessToken: string,
httpReturned404: boolean,
): Promise<Transport> {
): Promise<void> {
if (httpReturned404) {
// HTTP returned 404, only try SSE
debugLogger.log(
`Retrying SSE connection to '${serverName}' with OAuth token...`,
);
const transport = await connectWithSSETransport(
client,
config,
accessToken,
);
await connectWithSSETransport(client, config, accessToken);
debugLogger.log(
`Successfully connected to '${serverName}' using SSE with OAuth.`,
);
return transport;
return;
}

// HTTP returned 401, try HTTP with OAuth first
Expand All @@ -1431,7 +1403,6 @@ async function retryWithOAuth(
debugLogger.log(
`Successfully connected to '${serverName}' using HTTP with OAuth.`,
);
return httpTransport;
} catch (httpError) {
await httpTransport.close();

Expand All @@ -1443,15 +1414,10 @@ async function retryWithOAuth(
!config.httpUrl
) {
debugLogger.log(`HTTP with OAuth returned 404, trying SSE with OAuth...`);
const sseTransport = await connectWithSSETransport(
client,
config,
accessToken,
);
await connectWithSSETransport(client, config, accessToken);
debugLogger.log(
`Successfully connected to '${serverName}' using SSE with OAuth.`,
);
return sseTransport;
} else {
throw httpError;
}
Expand All @@ -1465,7 +1431,7 @@ async function retryWithOAuth(
*
* @param mcpServerName The name of the MCP server, used for logging and identification.
* @param mcpServerConfig The configuration specifying how to connect to the server.
* @returns A promise that resolves to a connected MCP `Client` instance and its transport.
* @returns A promise that resolves to a connected MCP `Client` instance.
* @throws An error if the connection fails or the configuration is invalid.
*/
export async function connectToMcpServer(
Expand All @@ -1475,7 +1441,7 @@ export async function connectToMcpServer(
debugMode: boolean,
workspaceContext: WorkspaceContext,
sanitizationConfig: EnvironmentSanitizationConfig,
): Promise<{ client: Client; transport: Transport }> {
): Promise<Client> {
const mcpClient = new Client(
{
name: 'gemini-cli-mcp-client',
Expand Down Expand Up @@ -1547,7 +1513,7 @@ export async function connectToMcpServer(
await mcpClient.connect(transport, {
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
});
return { client: mcpClient, transport };
return mcpClient;
} catch (error) {
await transport.close();
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
Expand Down Expand Up @@ -1579,7 +1545,7 @@ export async function connectToMcpServer(
try {
// Try SSE with stored OAuth token if available
// This ensures that SSE fallback works for authenticated servers
const sseTransport = await connectWithSSETransport(
await connectWithSSETransport(
mcpClient,
mcpServerConfig,
await getStoredOAuthToken(mcpServerName),
Expand All @@ -1588,7 +1554,7 @@ export async function connectToMcpServer(
debugLogger.log(
`MCP server '${mcpServerName}': Successfully connected using SSE transport.`,
);
return { client: mcpClient, transport: sseTransport };
return mcpClient;
} catch (sseFallbackError) {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
sseError = sseFallbackError as Error;
Expand Down Expand Up @@ -1696,14 +1662,14 @@ export async function connectToMcpServer(
);
}

const oauthTransport = await retryWithOAuth(
await retryWithOAuth(
mcpClient,
mcpServerName,
mcpServerConfig,
accessToken,
httpReturned404,
);
return { client: mcpClient, transport: oauthTransport };
return mcpClient;
} else {
throw new Error(
`Failed to handle automatic OAuth for server '${mcpServerName}'`,
Expand Down Expand Up @@ -1784,7 +1750,7 @@ export async function connectToMcpServer(
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
});
// Connection successful with OAuth
return { client: mcpClient, transport: oauthTransport };
return mcpClient;
} else {
throw new Error(
`OAuth configuration failed for '${mcpServerName}'. Please authenticate manually with /mcp auth ${mcpServerName}`,
Expand Down
Loading