Skip to content

Commit a754c96

Browse files
wukathcopybara-github
authored andcommitted
fix: Improve logic for checking if a MCP session is disconnected
Currently logic to check for a disconnected session only checks for certain headers but doesn't detect all cases, leading to situations where it tries to connect to a session that is down. This adds logic so that we ping the server to check if it is disconnected. Fixes #3321. Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 832460068
1 parent 29fea7e commit a754c96

File tree

3 files changed

+72
-12
lines changed

3 files changed

+72
-12
lines changed

src/google/adk/tools/mcp_tool/mcp_session_manager.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from mcp.client.sse import sse_client
3838
from mcp.client.stdio import stdio_client
3939
from mcp.client.streamable_http import streamablehttp_client
40+
from mcp.types import EmptyResult
4041
except ImportError as e:
4142

4243
if sys.version_info < (3, 10):
@@ -241,7 +242,7 @@ def _merge_headers(
241242

242243
return base_headers
243244

244-
def _is_session_disconnected(self, session: ClientSession) -> bool:
245+
async def _is_session_disconnected(self, session: ClientSession) -> bool:
245246
"""Checks if a session is disconnected or closed.
246247
247248
Args:
@@ -250,7 +251,24 @@ def _is_session_disconnected(self, session: ClientSession) -> bool:
250251
Returns:
251252
True if the session is disconnected, False otherwise.
252253
"""
253-
return session._read_stream._closed or session._write_stream._closed
254+
if session._read_stream._closed or session._write_stream._closed:
255+
return True
256+
257+
try:
258+
response = await asyncio.wait_for(session.send_ping(), timeout=5.0)
259+
if not isinstance(response, EmptyResult):
260+
logger.info(
261+
'Session ping returns illegal response %s, treating as'
262+
' disconnected',
263+
response,
264+
)
265+
return True
266+
return False
267+
except Exception as e:
268+
logger.info(
269+
'Session ping failed with error %s, treating as disconnected', e
270+
)
271+
return True
254272

255273
def _create_client(self, merged_headers: Optional[Dict[str, str]] = None):
256274
"""Creates an MCP client based on the connection parameters.
@@ -325,7 +343,7 @@ async def create_session(
325343
session, exit_stack = self._sessions[session_key]
326344

327345
# Check if the existing session is still connected
328-
if not self._is_session_disconnected(session):
346+
if not await self._is_session_disconnected(session):
329347
# Session is still good, return it
330348
return session
331349
else:

src/google/adk/tools/mcp_tool/mcp_toolset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,10 @@ async def get_tools(
175175
else None
176176
)
177177
# Get session from session manager
178-
session = await self._mcp_session_manager.create_session(headers=headers)
178+
try:
179+
session = await self._mcp_session_manager.create_session(headers=headers)
180+
except Exception as e:
181+
raise ConnectionError(f"Failed to create MCP session") from e
179182

180183
# Fetch available tools from the MCP server
181184
timeout_in_seconds = (

tests/unittests/tools/mcp_tool/test_mcp_session_manager.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class DummyClass:
5454
# Import real MCP classes
5555
try:
5656
from mcp import StdioServerParameters
57+
from mcp.types import EmptyResult
5758
except ImportError:
5859
# Create a mock if MCP is not available
5960
class StdioServerParameters:
@@ -62,6 +63,9 @@ def __init__(self, command="test_command", args=None):
6263
self.command = command
6364
self.args = args or []
6465

66+
class EmptyResult:
67+
pass
68+
6569

6670
class MockClientSession:
6771
"""Mock ClientSession for testing."""
@@ -72,6 +76,7 @@ def __init__(self):
7276
self._read_stream._closed = False
7377
self._write_stream._closed = False
7478
self.initialize = AsyncMock()
79+
self.send_ping = AsyncMock()
7580

7681

7782
class MockAsyncExitStack:
@@ -206,19 +211,52 @@ def test_merge_headers_sse(self):
206211
}
207212
assert merged == expected
208213

209-
def test_is_session_disconnected(self):
210-
"""Test session disconnection detection."""
214+
@pytest.mark.asyncio
215+
async def test_is_session_disconnected_when_connected(self):
216+
"""Test session disconnection detection when session is connected."""
211217
manager = MCPSessionManager(self.mock_stdio_connection_params)
218+
session = MockClientSession()
219+
session.send_ping.return_value = EmptyResult()
220+
assert not await manager._is_session_disconnected(session)
221+
session.send_ping.assert_called_once()
222+
223+
@pytest.mark.asyncio
224+
async def test_is_session_disconnected_read_stream_closed(self):
225+
"""Test session disconnection detection when read stream is closed."""
226+
manager = MCPSessionManager(self.mock_stdio_connection_params)
227+
session = MockClientSession()
228+
session.send_ping.return_value = EmptyResult()
229+
session._read_stream._closed = True
230+
assert await manager._is_session_disconnected(session)
231+
session.send_ping.assert_not_called()
212232

213-
# Create mock session
233+
@pytest.mark.asyncio
234+
async def test_is_session_disconnected_write_stream_closed(self):
235+
"""Test session disconnection detection when write stream is closed."""
236+
manager = MCPSessionManager(self.mock_stdio_connection_params)
214237
session = MockClientSession()
238+
session.send_ping.return_value = EmptyResult()
239+
session._write_stream._closed = True
240+
assert await manager._is_session_disconnected(session)
241+
session.send_ping.assert_not_called()
215242

216-
# Not disconnected
217-
assert not manager._is_session_disconnected(session)
243+
@pytest.mark.asyncio
244+
async def test_is_session_disconnected_ping_fails(self):
245+
"""Test session disconnection detection when ping fails."""
246+
manager = MCPSessionManager(self.mock_stdio_connection_params)
247+
session = MockClientSession()
248+
session.send_ping.side_effect = Exception("Ping failed")
249+
assert await manager._is_session_disconnected(session)
250+
session.send_ping.assert_called_once()
218251

219-
# Disconnected - read stream closed
220-
session._read_stream._closed = True
221-
assert manager._is_session_disconnected(session)
252+
@pytest.mark.asyncio
253+
async def test_is_session_disconnected_ping_returns_wrong_result(self):
254+
"""Test session disconnection detection when ping returns wrong result."""
255+
manager = MCPSessionManager(self.mock_stdio_connection_params)
256+
session = MockClientSession()
257+
session.send_ping.return_value = "Wrong result"
258+
assert await manager._is_session_disconnected(session)
259+
session.send_ping.assert_called_once()
222260

223261
@pytest.mark.asyncio
224262
async def test_create_session_stdio_new(self):
@@ -271,6 +309,7 @@ async def test_create_session_reuse_existing(self):
271309
# Session is connected
272310
existing_session._read_stream._closed = False
273311
existing_session._write_stream._closed = False
312+
existing_session.send_ping.return_value = EmptyResult()
274313

275314
session = await manager.create_session()
276315

0 commit comments

Comments
 (0)