From bd097bfce5e4d651dd3a43071b8249527b414c14 Mon Sep 17 00:00:00 2001 From: Tim Child Date: Mon, 31 Mar 2025 13:19:16 -0700 Subject: [PATCH 1/4] add test that checks if stdio connection hangs with bad connection params --- tests/client/test_stdio.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 95747ffd..5ceaca15 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -1,6 +1,7 @@ import shutil import pytest +from anyio import fail_after from mcp.client.stdio import StdioServerParameters, stdio_client from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse @@ -41,3 +42,18 @@ async def test_stdio_client(): assert read_messages[1] == JSONRPCMessage( root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}) ) + + +@pytest.mark.anyio +async def test_stdio_client_bad_path(): + """Check that the connection doesn't hang if process errors.""" + server_parameters = StdioServerParameters( + command="uv", args=["run", "non-existent-file.py"] + ) + + try: + with fail_after(1): + async with stdio_client(server_parameters) as (read_stream, write_stream): + pass + except TimeoutError: + pytest.fail("The connection hung.") From 8558120eac0ed2eaf16aa71e49b4344bbb152a2e Mon Sep 17 00:00:00 2001 From: Tim Child Date: Mon, 31 Mar 2025 14:18:09 -0700 Subject: [PATCH 2/4] fix process hanging on bad stdio connection params --- src/mcp/client/stdio/__init__.py | 31 ++++++++++++++++++++++++++----- tests/client/test_stdio.py | 29 ++++++++++++++++++++++------- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index 83de57a2..3f0aff65 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -6,6 +6,7 @@ import anyio import anyio.lowlevel +from anyio.abc import Process from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.text import TextReceiveStream from pydantic import BaseModel, Field @@ -38,6 +39,10 @@ ) +class ProcessTerminatedEarlyError(Exception): + """Raised when a process terminates unexpectedly.""" + + def get_default_environment() -> dict[str, str]: """ Returns a default environment object including only environment variables deemed @@ -110,7 +115,7 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder command = _get_executable_command(server.command) # Open process with stderr piped for capture - process = await _create_platform_compatible_process( + process: Process = await _create_platform_compatible_process( command=command, args=server.args, env=( @@ -163,20 +168,36 @@ async def stdin_writer(): except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() + process_error: str | None = None + async with ( anyio.create_task_group() as tg, process, ): tg.start_soon(stdout_reader) tg.start_soon(stdin_writer) + # tg.start_soon(monitor_process, tg.cancel_scope) try: yield read_stream, write_stream finally: - # Clean up process to prevent any dangling orphaned processes - if sys.platform == "win32": - await terminate_windows_process(process) + await read_stream.aclose() + await write_stream.aclose() + await read_stream_writer.aclose() + await write_stream_reader.aclose() + + if process.returncode is not None and process.returncode != 0: + process_error = f"Process exited with code {process.returncode}." else: - process.terminate() + # Clean up process to prevent any dangling orphaned processes + if sys.platform == "win32": + await terminate_windows_process(process) + else: + process.terminate() + + if process_error: + # Raise outside the task group so that the error is not wrapped in an + # ExceptionGroup + raise ProcessTerminatedEarlyError(process_error) def _get_executable_command(command: str) -> str: diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 5ceaca15..2799f838 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -3,7 +3,11 @@ import pytest from anyio import fail_after -from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.client.stdio import ( + ProcessTerminatedEarlyError, + StdioServerParameters, + stdio_client, +) from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse tee: str = shutil.which("tee") # type: ignore @@ -51,9 +55,20 @@ async def test_stdio_client_bad_path(): command="uv", args=["run", "non-existent-file.py"] ) - try: - with fail_after(1): - async with stdio_client(server_parameters) as (read_stream, write_stream): - pass - except TimeoutError: - pytest.fail("The connection hung.") + with pytest.raises(ProcessTerminatedEarlyError): + try: + with fail_after(1): + async with stdio_client(server_parameters) as ( + read_stream, + _, + ): + # Try waiting for read_stream so that we don't exit before the + # process fails. + async with read_stream: + async for message in read_stream: + if isinstance(message, Exception): + raise message + + pass + except TimeoutError: + pytest.fail("The connection hung.") From be41b81c52dda1cdaed6ea34874b09b527dbf387 Mon Sep 17 00:00:00 2001 From: Tim Child Date: Mon, 31 Mar 2025 14:19:43 -0700 Subject: [PATCH 3/4] make sure test only runs if `uv` available --- tests/client/test_stdio.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 2799f838..2dbc38b4 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -11,6 +11,7 @@ from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse tee: str = shutil.which("tee") # type: ignore +uv: str = shutil.which("uv") # type: ignore @pytest.mark.anyio @@ -49,6 +50,7 @@ async def test_stdio_client(): @pytest.mark.anyio +@pytest.mark.skipif(uv is None, reason="could not find uv command") async def test_stdio_client_bad_path(): """Check that the connection doesn't hang if process errors.""" server_parameters = StdioServerParameters( From 90f224d6e19521e08246c3362bfa4a1bc7b0f080 Mon Sep 17 00:00:00 2001 From: Tim Child Date: Mon, 7 Apr 2025 11:59:48 -0700 Subject: [PATCH 4/4] fix detection of failed process --- src/mcp/client/stdio/__init__.py | 26 ++++++++++++++++++-------- tests/client/test_stdio.py | 2 -- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index 3f0aff65..d51d5c12 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -127,7 +127,7 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder cwd=server.cwd, ) - async def stdout_reader(): + async def stdout_reader(done_event: anyio.Event): assert process.stdout, "Opened process is missing stdout" try: @@ -151,6 +151,7 @@ async def stdout_reader(): await read_stream_writer.send(message) except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() + done_event.set() async def stdin_writer(): assert process.stdin, "Opened process is missing stdin" @@ -174,21 +175,30 @@ async def stdin_writer(): anyio.create_task_group() as tg, process, ): - tg.start_soon(stdout_reader) + stdout_done_event = anyio.Event() + tg.start_soon(stdout_reader, stdout_done_event) tg.start_soon(stdin_writer) - # tg.start_soon(monitor_process, tg.cancel_scope) try: yield read_stream, write_stream + if stdout_done_event.is_set(): + # The stdout reader exited before the calling code stopped listening + # (e.g. because of process error) + # Give the process a chance to exit if it was the reason for crashing + # so we can get exit code + with anyio.move_on_after(0.1) as scope: + await process.wait() + process_error = f"Process exited with code {process.returncode}." + if scope.cancelled_caught: + process_error = ( + "Stdout reader exited (process did not exit immediately)." + ) finally: await read_stream.aclose() await write_stream.aclose() await read_stream_writer.aclose() await write_stream_reader.aclose() - - if process.returncode is not None and process.returncode != 0: - process_error = f"Process exited with code {process.returncode}." - else: - # Clean up process to prevent any dangling orphaned processes + # Clean up process to prevent any dangling orphaned processes + if process.returncode is None: if sys.platform == "win32": await terminate_windows_process(process) else: diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 2dbc38b4..ae968974 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -70,7 +70,5 @@ async def test_stdio_client_bad_path(): async for message in read_stream: if isinstance(message, Exception): raise message - - pass except TimeoutError: pytest.fail("The connection hung.")