Skip to content

Prevent stdio connection hang for missing server path. #401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 38 additions & 7 deletions src/mcp/client/stdio/__init__.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@

import anyio
import anyio.lowlevel
from anyio.abc import Process
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from anyio.streams.text import TextReceiveStream
from pydantic import BaseModel, Field
@@ -38,6 +39,10 @@
)


class ProcessTerminatedEarlyError(Exception):
"""Raised when a process terminates unexpectedly."""


def get_default_environment() -> dict[str, str]:
"""
Returns a default environment object including only environment variables deemed
@@ -110,7 +115,7 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
command = _get_executable_command(server.command)

# Open process with stderr piped for capture
process = await _create_platform_compatible_process(
process: Process = await _create_platform_compatible_process(
command=command,
args=server.args,
env=(
@@ -122,7 +127,7 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
cwd=server.cwd,
)

async def stdout_reader():
async def stdout_reader(done_event: anyio.Event):
assert process.stdout, "Opened process is missing stdout"

try:
@@ -146,6 +151,7 @@ async def stdout_reader():
await read_stream_writer.send(message)
except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint()
done_event.set()

async def stdin_writer():
assert process.stdin, "Opened process is missing stdin"
@@ -163,20 +169,45 @@ async def stdin_writer():
except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint()

process_error: str | None = None

async with (
anyio.create_task_group() as tg,
process,
):
tg.start_soon(stdout_reader)
stdout_done_event = anyio.Event()
tg.start_soon(stdout_reader, stdout_done_event)
tg.start_soon(stdin_writer)
try:
yield read_stream, write_stream
if stdout_done_event.is_set():
# The stdout reader exited before the calling code stopped listening
# (e.g. because of process error)
# Give the process a chance to exit if it was the reason for crashing
# so we can get exit code
with anyio.move_on_after(0.1) as scope:
await process.wait()
process_error = f"Process exited with code {process.returncode}."
if scope.cancelled_caught:
process_error = (
"Stdout reader exited (process did not exit immediately)."
)
finally:
await read_stream.aclose()
await write_stream.aclose()
await read_stream_writer.aclose()
await write_stream_reader.aclose()
# Clean up process to prevent any dangling orphaned processes
if sys.platform == "win32":
await terminate_windows_process(process)
else:
process.terminate()
if process.returncode is None:
if sys.platform == "win32":
await terminate_windows_process(process)
else:
process.terminate()

if process_error:
# Raise outside the task group so that the error is not wrapped in an
# ExceptionGroup
raise ProcessTerminatedEarlyError(process_error)


def _get_executable_command(command: str) -> str:
33 changes: 32 additions & 1 deletion tests/client/test_stdio.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import shutil

import pytest
from anyio import fail_after

from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.client.stdio import (
ProcessTerminatedEarlyError,
StdioServerParameters,
stdio_client,
)
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse

tee: str = shutil.which("tee") # type: ignore
uv: str = shutil.which("uv") # type: ignore


@pytest.mark.anyio
@@ -41,3 +47,28 @@ async def test_stdio_client():
assert read_messages[1] == JSONRPCMessage(
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
)


@pytest.mark.anyio
@pytest.mark.skipif(uv is None, reason="could not find uv command")
async def test_stdio_client_bad_path():
"""Check that the connection doesn't hang if process errors."""
server_parameters = StdioServerParameters(
command="uv", args=["run", "non-existent-file.py"]
)

with pytest.raises(ProcessTerminatedEarlyError):
try:
with fail_after(1):
async with stdio_client(server_parameters) as (
read_stream,
_,
):
# Try waiting for read_stream so that we don't exit before the
# process fails.
async with read_stream:
async for message in read_stream:
if isinstance(message, Exception):
raise message
except TimeoutError:
pytest.fail("The connection hung.")