diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 9f6348c6c9..5f22d73294 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1792,9 +1792,11 @@ async def __aenter__(self) -> Self: """ async with self._enter_lock: if self._entered_count == 0: - self._exit_stack = AsyncExitStack() - toolset = self._get_toolset() - await self._exit_stack.enter_async_context(toolset) + async with AsyncExitStack() as exit_stack: + toolset = self._get_toolset() + await exit_stack.enter_async_context(toolset) + + self._exit_stack = exit_stack.pop_all() self._entered_count += 1 return self diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index c84f4b10bc..77d53f0800 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -201,25 +201,24 @@ async def __aenter__(self) -> Self: """ async with self._enter_lock: if self._running_count == 0: - self._exit_stack = AsyncExitStack() - - self._read_stream, self._write_stream = await self._exit_stack.enter_async_context( - self.client_streams() - ) - client = ClientSession( - read_stream=self._read_stream, - write_stream=self._write_stream, - sampling_callback=self._sampling_callback if self.allow_sampling else None, - logging_callback=self.log_handler, - read_timeout_seconds=timedelta(seconds=self.read_timeout), - ) - self._client = await self._exit_stack.enter_async_context(client) - - with anyio.fail_after(self.timeout): - await self._client.initialize() - - if log_level := self.log_level: - await self._client.set_logging_level(log_level) + async with AsyncExitStack() as exit_stack: + self._read_stream, self._write_stream = await exit_stack.enter_async_context(self.client_streams()) + client = ClientSession( + read_stream=self._read_stream, + write_stream=self._write_stream, + sampling_callback=self._sampling_callback if self.allow_sampling else None, + logging_callback=self.log_handler, + read_timeout_seconds=timedelta(seconds=self.read_timeout), + ) + self._client = await exit_stack.enter_async_context(client) + + with anyio.fail_after(self.timeout): + await self._client.initialize() + + if log_level := self.log_level: + await self._client.set_logging_level(log_level) + + self._exit_stack = exit_stack.pop_all() self._running_count += 1 return self diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py index 4b1511fae1..d2ddaa1258 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py @@ -43,9 +43,10 @@ def __post_init__(self): async def __aenter__(self) -> Self: async with self._enter_lock: if self._entered_count == 0: - self._exit_stack = AsyncExitStack() - for toolset in self.toolsets: - await self._exit_stack.enter_async_context(toolset) + async with AsyncExitStack() as exit_stack: + for toolset in self.toolsets: + await exit_stack.enter_async_context(toolset) + self._exit_stack = exit_stack.pop_all() self._entered_count += 1 return self diff --git a/tests/test_mcp.py b/tests/test_mcp.py index de77b3587e..1021b31512 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -91,6 +91,20 @@ async def test_reentrant_context_manager(): pass +async def test_context_manager_initialization_error() -> None: + """Test if streams are closed if client fails to initialize.""" + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + from mcp.client.session import ClientSession + + with patch.object(ClientSession, 'initialize', side_effect=Exception): + with pytest.raises(Exception): + async with server: + pass + + assert server._read_stream._closed # pyright: ignore[reportPrivateUsage] + assert server._write_stream._closed # pyright: ignore[reportPrivateUsage] + + async def test_stdio_server_with_tool_prefix(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], tool_prefix='foo') async with server: diff --git a/tests/test_toolsets.py b/tests/test_toolsets.py index eac0dc78a7..f188d3141a 100644 --- a/tests/test_toolsets.py +++ b/tests/test_toolsets.py @@ -3,6 +3,7 @@ import re from dataclasses import dataclass, replace from typing import TypeVar +from unittest.mock import AsyncMock import pytest from inline_snapshot import snapshot @@ -469,3 +470,27 @@ async def test_context_manager(): async with toolset: assert server1.is_running assert server2.is_running + + +class InitializationError(Exception): + pass + + +async def test_context_manager_failed_initialization(): + """Test if MCP servers stop if any MCP server fails to initialize.""" + try: + from pydantic_ai.mcp import MCPServerStdio + except ImportError: # pragma: lax no cover + pytest.skip('mcp is not installed') + + server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + server2 = AsyncMock() + server2.__aenter__.side_effect = InitializationError + + toolset = CombinedToolset([server1, server2]) + + with pytest.raises(InitializationError): + async with toolset: + pass + + assert server1.is_running is False