diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 05fd3ce3..535432bc 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -14,6 +14,7 @@ from mcp.shared.exceptions import McpError from mcp.types import ( + CONNECTION_CLOSED, CancelledNotification, ClientNotification, ClientRequest, @@ -359,6 +360,13 @@ async def _receive_loop(self) -> None: ) ) + # after the read stream is closed, we need to send errors + # to any pending requests + for id, stream in self._response_streams.items(): + error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") + await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) + await stream.aclose() + async def _received_request( self, responder: RequestResponder[ReceiveRequestT, SendResultT] ) -> None: diff --git a/src/mcp/types.py b/src/mcp/types.py index bd71d51f..57028442 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -140,6 +140,10 @@ class JSONRPCResponse(BaseModel): model_config = ConfigDict(extra="allow") +# SDK error codes +CONNECTION_CLOSED = -32000 +# REQUEST_TIMEOUT = -32001 # the typescript sdk uses this + # Standard JSON-RPC error codes PARSE_ERROR = -32700 INVALID_REQUEST = -32600 diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 59cb30c8..eb4e004a 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -7,7 +7,10 @@ from mcp.client.session import ClientSession from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import McpError -from mcp.shared.memory import create_connected_server_and_client_session +from mcp.shared.memory import ( + create_client_server_memory_streams, + create_connected_server_and_client_session, +) from mcp.types import ( CancelledNotification, CancelledNotificationParams, @@ -124,3 +127,57 @@ async def make_request(client_session): # Give cancellation time to process with anyio.fail_after(1): await ev_cancelled.wait() + + +@pytest.mark.anyio +async def test_connection_closed(): + """ + Test that pending requests are cancelled when the connection is closed remotely. + """ + + ev_closed = anyio.Event() + ev_response = anyio.Event() + + async with create_client_server_memory_streams() as ( + client_streams, + server_streams, + ): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def make_request(client_session): + """Send a request in a separate task""" + nonlocal ev_response + try: + # any request will do + await client_session.initialize() + pytest.fail("Request should have errored") + except McpError as e: + # Expected - request errored + assert "Connection closed" in str(e) + ev_response.set() + + async def mock_server(): + """Wait for a request, then close the connection""" + nonlocal ev_closed + # Wait for a request + await server_read.receive() + # Close the connection, as if the server exited + server_write.close() + server_read.close() + ev_closed.set() + + async with ( + anyio.create_task_group() as tg, + ClientSession( + read_stream=client_read, + write_stream=client_write, + ) as client_session, + ): + tg.start_soon(make_request, client_session) + tg.start_soon(mock_server) + + with anyio.fail_after(1): + await ev_closed.wait() + with anyio.fail_after(1): + await ev_response.wait()