Skip to content

Commit a48a1a9

Browse files
wukathcopybara-github
authored andcommitted
fix: Improve logic for checking if a MCP session is disconnected
Currently logic to check for a disconnected session only checks for certain headers but doesn't detect all cases, leading to situations where it tries to connect to a session that is down. This adds logic so that we ping the server to check if it is disconnected. Fixes #3321. Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 833487825
1 parent a5ac1d5 commit a48a1a9

File tree

3 files changed

+12
-72
lines changed

3 files changed

+12
-72
lines changed

src/google/adk/tools/mcp_tool/mcp_session_manager.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from mcp.client.sse import sse_client
3838
from mcp.client.stdio import stdio_client
3939
from mcp.client.streamable_http import streamablehttp_client
40-
from mcp.types import EmptyResult
4140
except ImportError as e:
4241

4342
if sys.version_info < (3, 10):
@@ -242,7 +241,7 @@ def _merge_headers(
242241

243242
return base_headers
244243

245-
async def _is_session_disconnected(self, session: ClientSession) -> bool:
244+
def _is_session_disconnected(self, session: ClientSession) -> bool:
246245
"""Checks if a session is disconnected or closed.
247246
248247
Args:
@@ -251,24 +250,7 @@ async def _is_session_disconnected(self, session: ClientSession) -> bool:
251250
Returns:
252251
True if the session is disconnected, False otherwise.
253252
"""
254-
if session._read_stream._closed or session._write_stream._closed:
255-
return True
256-
257-
try:
258-
response = await asyncio.wait_for(session.send_ping(), timeout=5.0)
259-
if not isinstance(response, EmptyResult):
260-
logger.info(
261-
'Session ping returns illegal response %s, treating as'
262-
' disconnected',
263-
response,
264-
)
265-
return True
266-
return False
267-
except Exception as e:
268-
logger.info(
269-
'Session ping failed with error %s, treating as disconnected', e
270-
)
271-
return True
253+
return session._read_stream._closed or session._write_stream._closed
272254

273255
def _create_client(self, merged_headers: Optional[Dict[str, str]] = None):
274256
"""Creates an MCP client based on the connection parameters.
@@ -343,7 +325,7 @@ async def create_session(
343325
session, exit_stack = self._sessions[session_key]
344326

345327
# Check if the existing session is still connected
346-
if not await self._is_session_disconnected(session):
328+
if not self._is_session_disconnected(session):
347329
# Session is still good, return it
348330
return session
349331
else:

src/google/adk/tools/mcp_tool/mcp_toolset.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,7 @@ async def get_tools(
175175
else None
176176
)
177177
# Get session from session manager
178-
try:
179-
session = await self._mcp_session_manager.create_session(headers=headers)
180-
except Exception as e:
181-
raise ConnectionError(f"Failed to create MCP session") from e
178+
session = await self._mcp_session_manager.create_session(headers=headers)
182179

183180
# Fetch available tools from the MCP server
184181
timeout_in_seconds = (

tests/unittests/tools/mcp_tool/test_mcp_session_manager.py

Lines changed: 8 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ class DummyClass:
5454
# Import real MCP classes
5555
try:
5656
from mcp import StdioServerParameters
57-
from mcp.types import EmptyResult
5857
except ImportError:
5958
# Create a mock if MCP is not available
6059
class StdioServerParameters:
@@ -63,9 +62,6 @@ def __init__(self, command="test_command", args=None):
6362
self.command = command
6463
self.args = args or []
6564

66-
class EmptyResult:
67-
pass
68-
6965

7066
class MockClientSession:
7167
"""Mock ClientSession for testing."""
@@ -76,7 +72,6 @@ def __init__(self):
7672
self._read_stream._closed = False
7773
self._write_stream._closed = False
7874
self.initialize = AsyncMock()
79-
self.send_ping = AsyncMock()
8075

8176

8277
class MockAsyncExitStack:
@@ -211,52 +206,19 @@ def test_merge_headers_sse(self):
211206
}
212207
assert merged == expected
213208

214-
@pytest.mark.asyncio
215-
async def test_is_session_disconnected_when_connected(self):
216-
"""Test session disconnection detection when session is connected."""
209+
def test_is_session_disconnected(self):
210+
"""Test session disconnection detection."""
217211
manager = MCPSessionManager(self.mock_stdio_connection_params)
218-
session = MockClientSession()
219-
session.send_ping.return_value = EmptyResult()
220-
assert not await manager._is_session_disconnected(session)
221-
session.send_ping.assert_called_once()
222-
223-
@pytest.mark.asyncio
224-
async def test_is_session_disconnected_read_stream_closed(self):
225-
"""Test session disconnection detection when read stream is closed."""
226-
manager = MCPSessionManager(self.mock_stdio_connection_params)
227-
session = MockClientSession()
228-
session.send_ping.return_value = EmptyResult()
229-
session._read_stream._closed = True
230-
assert await manager._is_session_disconnected(session)
231-
session.send_ping.assert_not_called()
232212

233-
@pytest.mark.asyncio
234-
async def test_is_session_disconnected_write_stream_closed(self):
235-
"""Test session disconnection detection when write stream is closed."""
236-
manager = MCPSessionManager(self.mock_stdio_connection_params)
213+
# Create mock session
237214
session = MockClientSession()
238-
session.send_ping.return_value = EmptyResult()
239-
session._write_stream._closed = True
240-
assert await manager._is_session_disconnected(session)
241-
session.send_ping.assert_not_called()
242215

243-
@pytest.mark.asyncio
244-
async def test_is_session_disconnected_ping_fails(self):
245-
"""Test session disconnection detection when ping fails."""
246-
manager = MCPSessionManager(self.mock_stdio_connection_params)
247-
session = MockClientSession()
248-
session.send_ping.side_effect = Exception("Ping failed")
249-
assert await manager._is_session_disconnected(session)
250-
session.send_ping.assert_called_once()
216+
# Not disconnected
217+
assert not manager._is_session_disconnected(session)
251218

252-
@pytest.mark.asyncio
253-
async def test_is_session_disconnected_ping_returns_wrong_result(self):
254-
"""Test session disconnection detection when ping returns wrong result."""
255-
manager = MCPSessionManager(self.mock_stdio_connection_params)
256-
session = MockClientSession()
257-
session.send_ping.return_value = "Wrong result"
258-
assert await manager._is_session_disconnected(session)
259-
session.send_ping.assert_called_once()
219+
# Disconnected - read stream closed
220+
session._read_stream._closed = True
221+
assert manager._is_session_disconnected(session)
260222

261223
@pytest.mark.asyncio
262224
async def test_create_session_stdio_new(self):
@@ -309,7 +271,6 @@ async def test_create_session_reuse_existing(self):
309271
# Session is connected
310272
existing_session._read_stream._closed = False
311273
existing_session._write_stream._closed = False
312-
existing_session.send_ping.return_value = EmptyResult()
313274

314275
session = await manager.create_session()
315276

0 commit comments

Comments
 (0)