Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions packages/core/src/tools/mcp-client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -749,9 +749,9 @@ describe('mcp-client', () => {
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
);
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue({
close: vi.fn(),
} as unknown as SdkClientStdioLib.StdioClientTransport);
const mockedToolRegistry = {
registerTool: vi.fn(),
unregisterTool: vi.fn(),
Expand Down Expand Up @@ -1888,7 +1888,7 @@ describe('connectToMcpServer with OAuth', () => {
EMPTY_CONFIG,
);

expect(client).toBe(mockedClient);
expect(client.client).toBe(mockedClient);
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();

Expand Down Expand Up @@ -1934,7 +1934,7 @@ describe('connectToMcpServer with OAuth', () => {
EMPTY_CONFIG,
);

expect(client).toBe(mockedClient);
expect(client.client).toBe(mockedClient);
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
expect(OAuthUtils.discoverOAuthConfig).toHaveBeenCalledWith(serverUrl);
Expand Down Expand Up @@ -2029,7 +2029,7 @@ describe('connectToMcpServer - HTTP→SSE fallback', () => {
EMPTY_CONFIG,
);

expect(client).toBe(mockedClient);
expect(client.client).toBe(mockedClient);
// First HTTP attempt fails, second SSE attempt succeeds
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
});
Expand Down Expand Up @@ -2070,7 +2070,7 @@ describe('connectToMcpServer - HTTP→SSE fallback', () => {
EMPTY_CONFIG,
);

expect(client).toBe(mockedClient);
expect(client.client).toBe(mockedClient);
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
});
});
Expand Down Expand Up @@ -2155,7 +2155,7 @@ describe('connectToMcpServer - OAuth with transport fallback', () => {
EMPTY_CONFIG,
);

expect(client).toBe(mockedClient);
expect(client.client).toBe(mockedClient);
expect(mockedClient.connect).toHaveBeenCalledTimes(3);
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
});
Expand Down
68 changes: 51 additions & 17 deletions packages/core/src/tools/mcp-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -144,19 +144,21 @@ export class McpClient {
}
this.updateStatus(MCPServerStatus.CONNECTING);
try {
this.client = await connectToMcpServer(
const { client, transport } = await connectToMcpServer(
this.clientVersion,
this.serverName,
this.serverConfig,
this.debugMode,
this.workspaceContext,
this.cliConfig.sanitizationConfig,
);
this.client = client;
this.transport = transport;

this.registerNotificationHandlers();

const originalOnError = this.client.onerror;
this.client.onerror = (error) => {
this.client.onerror = async (error) => {
if (this.status !== MCPServerStatus.CONNECTED) {
return;
}
Expand All @@ -167,6 +169,14 @@ export class McpClient {
error,
);
this.updateStatus(MCPServerStatus.DISCONNECTED);
// Close transport to prevent memory leaks
if (this.transport) {
try {
await this.transport.close();
} catch {
// Ignore errors when closing transport on error
}
}
};
this.updateStatus(MCPServerStatus.CONNECTED);
} catch (error) {
Expand Down Expand Up @@ -909,19 +919,30 @@ export async function connectAndDiscover(
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);

let mcpClient: Client | undefined;
let transport: Transport | undefined;
try {
mcpClient = await connectToMcpServer(
const result = await connectToMcpServer(
clientVersion,
mcpServerName,
mcpServerConfig,
debugMode,
workspaceContext,
cliConfig.sanitizationConfig,
);
mcpClient = result.client;
transport = result.transport;

mcpClient.onerror = (error) => {
mcpClient.onerror = async (error) => {
coreEvents.emitFeedback('error', `MCP ERROR (${mcpServerName}):`, error);
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
// Close transport to prevent memory leaks
if (transport) {
try {
await transport.close();
} catch {
// Ignore errors when closing transport on error
}
}
};

// Attempt to discover both prompts and tools
Expand Down Expand Up @@ -1302,16 +1323,18 @@ function createSSETransportWithAuth(
* @param client The MCP client to connect
* @param config The MCP server configuration
* @param accessToken Optional OAuth access token for authentication
* @returns The transport used for connection
*/
async function connectWithSSETransport(
client: Client,
config: MCPServerConfig,
accessToken?: string | null,
): Promise<void> {
): Promise<Transport> {
const transport = createSSETransportWithAuth(config, accessToken);
await client.connect(transport, {
timeout: config.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
});
return transport;
}

/**
Expand Down Expand Up @@ -1341,24 +1364,29 @@ async function showAuthRequiredMessage(serverName: string): Promise<never> {
* @param config The MCP server configuration
* @param accessToken The OAuth access token to use
* @param httpReturned404 Whether the HTTP transport returned 404 (indicating SSE-only server)
* @returns The transport used for connection
*/
async function retryWithOAuth(
client: Client,
serverName: string,
config: MCPServerConfig,
accessToken: string,
httpReturned404: boolean,
): Promise<void> {
): Promise<Transport> {
if (httpReturned404) {
// HTTP returned 404, only try SSE
debugLogger.log(
`Retrying SSE connection to '${serverName}' with OAuth token...`,
);
await connectWithSSETransport(client, config, accessToken);
const transport = await connectWithSSETransport(
client,
config,
accessToken,
);
debugLogger.log(
`Successfully connected to '${serverName}' using SSE with OAuth.`,
);
return;
return transport;
}

// HTTP returned 401, try HTTP with OAuth first
Expand All @@ -1382,6 +1410,7 @@ async function retryWithOAuth(
debugLogger.log(
`Successfully connected to '${serverName}' using HTTP with OAuth.`,
);
return httpTransport;
} catch (httpError) {
await httpTransport.close();

Expand All @@ -1393,10 +1422,15 @@ async function retryWithOAuth(
!config.httpUrl
) {
debugLogger.log(`HTTP with OAuth returned 404, trying SSE with OAuth...`);
await connectWithSSETransport(client, config, accessToken);
const sseTransport = await connectWithSSETransport(
client,
config,
accessToken,
);
debugLogger.log(
`Successfully connected to '${serverName}' using SSE with OAuth.`,
);
return sseTransport;
} else {
throw httpError;
}
Expand All @@ -1410,7 +1444,7 @@ async function retryWithOAuth(
*
* @param mcpServerName The name of the MCP server, used for logging and identification.
* @param mcpServerConfig The configuration specifying how to connect to the server.
* @returns A promise that resolves to a connected MCP `Client` instance.
* @returns A promise that resolves to a connected MCP `Client` instance and its transport.
* @throws An error if the connection fails or the configuration is invalid.
*/
export async function connectToMcpServer(
Expand All @@ -1420,7 +1454,7 @@ export async function connectToMcpServer(
debugMode: boolean,
workspaceContext: WorkspaceContext,
sanitizationConfig: EnvironmentSanitizationConfig,
): Promise<Client> {
): Promise<{ client: Client; transport: Transport }> {
const mcpClient = new Client(
{
name: 'gemini-cli-mcp-client',
Expand Down Expand Up @@ -1492,7 +1526,7 @@ export async function connectToMcpServer(
await mcpClient.connect(transport, {
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
});
return mcpClient;
return { client: mcpClient, transport };
} catch (error) {
await transport.close();
firstAttemptError = error as Error;
Expand Down Expand Up @@ -1523,7 +1557,7 @@ export async function connectToMcpServer(
try {
// Try SSE with stored OAuth token if available
// This ensures that SSE fallback works for authenticated servers
await connectWithSSETransport(
const sseTransport = await connectWithSSETransport(
mcpClient,
mcpServerConfig,
await getStoredOAuthToken(mcpServerName),
Expand All @@ -1532,7 +1566,7 @@ export async function connectToMcpServer(
debugLogger.log(
`MCP server '${mcpServerName}': Successfully connected using SSE transport.`,
);
return mcpClient;
return { client: mcpClient, transport: sseTransport };
} catch (sseFallbackError) {
sseError = sseFallbackError as Error;

Expand Down Expand Up @@ -1639,14 +1673,14 @@ export async function connectToMcpServer(
);
}

await retryWithOAuth(
const oauthTransport = await retryWithOAuth(
mcpClient,
mcpServerName,
mcpServerConfig,
accessToken,
httpReturned404,
);
return mcpClient;
return { client: mcpClient, transport: oauthTransport };
} else {
throw new Error(
`Failed to handle automatic OAuth for server '${mcpServerName}'`,
Expand Down Expand Up @@ -1727,7 +1761,7 @@ export async function connectToMcpServer(
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
});
// Connection successful with OAuth
return mcpClient;
return { client: mcpClient, transport: oauthTransport };
} else {
throw new Error(
`OAuth configuration failed for '${mcpServerName}'. Please authenticate manually with /mcp auth ${mcpServerName}`,
Expand Down
Loading