From 3593523d93fdd225ecdcb785e18c9920a94bc004 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 26 May 2025 17:40:09 +0200 Subject: [PATCH] feat: Add MCP client --- README.md | 81 ++++ src/strands_agents_builder/strands.py | 33 +- src/strands_agents_builder/utils/mcp_utils.py | 206 ++++++++++ tests/test_strands_mcp_integration.py | 225 +++++++++++ tests/utils/test_mcp_utils.py | 366 ++++++++++++++++++ 5 files changed, 910 insertions(+), 1 deletion(-) create mode 100644 src/strands_agents_builder/utils/mcp_utils.py create mode 100644 tests/test_strands_mcp_integration.py create mode 100644 tests/utils/test_mcp_utils.py diff --git a/README.md b/README.md index 351bbd4..f185f58 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,9 @@ cat agent-spec.txt | strands "Build a specialized agent based on these specifica # Use with knowledge base to extend existing tools strands --kb YOUR_KB_ID "Load my previous calculator tool and enhance it with scientific functions" + +# Connect to MCP servers for extended capabilities +strands --mcp-config '[{"transport": "stdio", "command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem"]}]' ``` ## Features @@ -33,6 +36,7 @@ strands --kb YOUR_KB_ID "Load my previous calculator tool and enhance it with sc - šŸŖ„ Nested agent capabilities with tool delegation - šŸ”§ Dynamic tool loading for extending functionality - šŸ–„ļø Environment variable management and customization +- šŸ”Œ MCP (Model Context Protocol) client for connecting to external tools and services ## Integrated Tools @@ -50,6 +54,7 @@ Strands comes with a comprehensive set of built-in tools: - **strands**: Create nested agent instances with specialized capabilities - **dialog**: Create interactive dialog interfaces - **use_aws**: Make AWS API calls through boto3 +- **mcp_client**: Connect to and interact with MCP (Model Context Protocol) servers ## Knowledge Base Integration @@ -166,6 +171,82 @@ You can then use it with strands by running: $ strands --model-provider custom_model --model-config ``` +## MCP (Model Context Protocol) Integration + +Strands now supports connecting to MCP servers, allowing you to extend your agent's capabilities with external tools and services. MCP provides a standardized way to connect AI assistants to various data sources and tools. + +### Quick Start with MCP + +```bash +# Connect to a single MCP server +strands --mcp-config '[{"transport": "stdio", "command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/files"]}]' + +# Load MCP config from a file +strands --mcp-config mcp_config.json + +# Use environment variable for default config +export STRANDS_MCP_CONFIG_PATH=~/.config/mcp/servers.json +strands +``` + +### MCP Configuration Format + +Create an `mcp_config.json` file: + +```json +[ + { + "connection_id": "filesystem", + "transport": "stdio", + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/home/user/documents"], + "auto_load_tools": true + }, + { + "connection_id": "github", + "transport": "stdio", + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-github"], + "env": { + "GITHUB_TOKEN": "your-github-token" + } + }, + { + "connection_id": "web_api", + "transport": "sse", + "server_url": "http://localhost:8080/mcp" + } +] +``` + +### Supported MCP Transports + +- **stdio**: Connect to MCP servers via standard input/output +- **sse**: Connect to MCP servers via Server-Sent Events (HTTP) + +### Using MCP Tools + +Once connected, MCP tools are automatically loaded and available to your agent: + +```bash +# List available MCP connections +strands "Show me all active MCP connections" + +# Use MCP filesystem tools +strands "List all files in my documents folder" + +# Use MCP GitHub tools +strands "Create a new issue in my repository about improving documentation" +``` + +### MCP Connection Management + +Strands automatically: +- Connects to all configured MCP servers on startup +- Loads available tools from each server (if `auto_load_tools` is true) +- Disconnects cleanly when exiting +- Shows connection status during initialization + ## Custom System Prompts ```bash diff --git a/src/strands_agents_builder/strands.py b/src/strands_agents_builder/strands.py index 4b67e52..17ba594 100644 --- a/src/strands_agents_builder/strands.py +++ b/src/strands_agents_builder/strands.py @@ -20,6 +20,7 @@ image_reader, journal, load_tool, + mcp_client, nova_reels, python_repl, retrieve, @@ -33,7 +34,7 @@ from strands_tools.utils.user_input import get_user_input from strands_agents_builder.handlers.callback_handler import callback_handler -from strands_agents_builder.utils import model_utils +from strands_agents_builder.utils import mcp_utils, model_utils from strands_agents_builder.utils.kb_utils import load_system_prompt, store_conversation_in_kb from strands_agents_builder.utils.welcome_utils import render_goodbye_message, render_welcome_message @@ -69,6 +70,12 @@ def main(): default="{}", help="Model config as JSON string or path", ) + parser.add_argument( + "--mcp-config", + type=mcp_utils.load_config, + default="[]", + help="MCP config as JSON string or path to JSON file. Can specify multiple MCP servers.", + ) args = parser.parse_args() # Get knowledge_base_id from args or environment variable @@ -103,6 +110,7 @@ def main(): store_in_kb, strand, welcome, + mcp_client, ] agent = Agent( @@ -112,6 +120,21 @@ def main(): callback_handler=callback_handler, ) + # Initialize MCP connections if configured + if args.mcp_config: + print("\nšŸ”Œ Initializing MCP connections...") + mcp_results = mcp_utils.initialize_mcp_connections(args.mcp_config, agent) + + # Summary of MCP initialization + successful = sum(1 for success in mcp_results.values() if success) + total = len(mcp_results) + if successful == total: + print(f"āœ… All {total} MCP connection(s) initialized successfully\n") + elif successful > 0: + print(f"āš ļø {successful}/{total} MCP connection(s) initialized\n") + else: + print("āŒ Failed to initialize any MCP connections\n") + # Process query or enter interactive mode if args.query: query = " ".join(args.query) @@ -135,6 +158,10 @@ def main(): try: user_input = get_user_input("\n~ ") if user_input.lower() in ["exit", "quit"]: + # Disconnect all MCP connections before exiting + if args.mcp_config: + print("\nšŸ”Œ Disconnecting MCP connections...") + mcp_utils.disconnect_all(agent) render_goodbye_message() break if user_input.startswith("!"): @@ -174,6 +201,10 @@ def main(): # Store conversation in knowledge base store_conversation_in_kb(agent, user_input, response, knowledge_base_id) except (KeyboardInterrupt, EOFError): + # Disconnect all MCP connections before exiting + if args.mcp_config: + print("\n\nšŸ”Œ Disconnecting MCP connections...") + mcp_utils.disconnect_all(agent) render_goodbye_message() break except Exception as e: diff --git a/src/strands_agents_builder/utils/mcp_utils.py b/src/strands_agents_builder/utils/mcp_utils.py new file mode 100644 index 0000000..e31a57d --- /dev/null +++ b/src/strands_agents_builder/utils/mcp_utils.py @@ -0,0 +1,206 @@ +"""Utilities for loading and managing MCP clients in strands.""" + +import json +import os +from typing import Any, Dict, List + + +def load_config(config: str) -> List[Dict[str, Any]]: + """Load MCP configuration from a JSON string or file. + + The configuration should be a list of MCP server configurations. + Each server configuration should contain: + - connection_id: Unique identifier for the connection (optional, auto-generated if missing) + - transport: Transport type (stdio or sse) + - command: Command for stdio transport (required for stdio) + - args: Arguments for stdio command (optional) + - server_url: URL for SSE transport (required for sse) + - auto_load_tools: Whether to automatically load tools (default: True) + + Args: + config: A JSON string or path to a JSON file containing MCP configurations. + If empty string or '[]', checks STRANDS_MCP_CONFIG_PATH environment variable. + + Returns: + List of parsed MCP server configurations. + + Examples: + # From JSON file + config = load_config("mcp_config.json") + + # From JSON string (connection_id is optional) + config = load_config('[{"transport": "stdio", "command": "node", "args": ["server.js"]}]') + """ + if not config or config == "[]": + # Check for default config path in environment + default_path = os.getenv("STRANDS_MCP_CONFIG_PATH") + if default_path and os.path.exists(default_path): + config = default_path + else: + return [] + + if config.endswith(".json"): + with open(config) as fp: + data = json.load(fp) + else: + data = json.loads(config) + + # Handle Amazon Q MCP format + if isinstance(data, dict) and "mcpServers" in data: + servers = [] + for server_id, server_config in data["mcpServers"].items(): + # Skip disabled servers + if server_config.get("disabled", False): + continue + + # Convert to our format + converted = { + "connection_id": server_id, + "transport": "stdio", # Amazon Q format uses stdio by default + "command": server_config.get("command"), + "args": server_config.get("args", []), + "auto_load_tools": True, + } + + # Add environment variables if present + if "env" in server_config: + converted["env"] = server_config["env"] + + servers.append(converted) + data = servers + + # Ensure it's a list + if isinstance(data, dict): + data = [data] + + return data + + +def initialize_mcp_connections(configs: List[Dict[str, Any]], agent) -> Dict[str, bool]: + """Initialize MCP connections based on provided configurations. + + Args: + configs: List of MCP server configurations. + agent: The strands agent instance to use for MCP client calls. + + Returns: + Dictionary mapping connection_id to success status. + """ + results = {} + + for i, config in enumerate(configs): + connection_id = config.get("connection_id") + if not connection_id: + # Auto-generate connection ID based on transport and command/url + transport = config.get("transport", "stdio") + if transport == "stdio": + command = config.get("command", "unknown") + # Use the command name as basis for ID + base_name = os.path.basename(command).replace(".", "_") + connection_id = f"mcp_{base_name}_{i}" + else: # sse + server_url = config.get("server_url", "") + # Extract hostname or use index + try: + from urllib.parse import urlparse + + parsed = urlparse(server_url) + host = parsed.hostname or "server" + host = host.replace(".", "_").replace("-", "_") + connection_id = f"mcp_{host}_{i}" + except Exception: + connection_id = f"mcp_sse_{i}" + + print(f"šŸ“ Auto-generated connection_id: {connection_id}") + config["connection_id"] = connection_id + + try: + # Connect to the MCP server + connect_params = {"action": "connect", "connection_id": connection_id, "kwargs": {}} + + # Add transport-specific parameters + if "transport" in config: + connect_params["transport"] = config["transport"] + + if config.get("transport") == "stdio": + if "command" in config: + connect_params["command"] = config["command"] + if "args" in config: + connect_params["args"] = config["args"] + if "env" in config: + connect_params["env"] = config["env"] + elif config.get("transport") == "sse": + if "server_url" in config: + connect_params["server_url"] = config["server_url"] + + # Connect to the server + result = agent.tool.mcp_client(**connect_params) + + if result.get("status") == "success": + print(f"āœ“ Connected to MCP server: {connection_id}") + + # Auto-load tools if specified (default: True) + if config.get("auto_load_tools", True): + load_result = agent.tool.mcp_client(action="load_tools", connection_id=connection_id, kwargs={}) + if load_result.get("status") == "success": + print(f" āœ“ Loaded tools from {connection_id}") + else: + print(f" āœ— Failed to load tools from {connection_id}") + + results[connection_id] = True + else: + print(f"āœ— Failed to connect to MCP server: {connection_id}") + error_msg = result.get("content", [{}])[0].get("text", "Unknown error") + print(f" Error: {error_msg}") + results[connection_id] = False + + except Exception as e: + print(f"āœ— Error connecting to MCP server {connection_id}: {str(e)}") + results[connection_id] = False + + return results + + +def list_active_connections(agent) -> List[str]: + """List all active MCP connections. + + Args: + agent: The strands agent instance to use for MCP client calls. + + Returns: + List of active connection IDs. + """ + try: + result = agent.tool.mcp_client(action="list_connections", kwargs={}) + if result.get("status") == "success": + # Parse the response to extract connection IDs + content = result.get("content", [{}])[0].get("text", "") + if "No active MCP connections" in content: + return [] + + # Extract connection IDs from the formatted output + connections = [] + lines = content.split("\n") + for line in lines: + if "Connection:" in line: + conn_id = line.split("Connection:")[1].strip() + connections.append(conn_id) + return connections + return [] + except Exception: + return [] + + +def disconnect_all(agent) -> None: + """Disconnect all active MCP connections. + + Args: + agent: The strands agent instance to use for MCP client calls. + """ + connections = list_active_connections(agent) + for connection_id in connections: + try: + agent.tool.mcp_client(action="disconnect", connection_id=connection_id, kwargs={}) + print(f"āœ“ Disconnected from MCP server: {connection_id}") + except Exception as e: + print(f"āœ— Error disconnecting from {connection_id}: {str(e)}") diff --git a/tests/test_strands_mcp_integration.py b/tests/test_strands_mcp_integration.py new file mode 100644 index 0000000..7ec0d0f --- /dev/null +++ b/tests/test_strands_mcp_integration.py @@ -0,0 +1,225 @@ +"""Tests for MCP integration in strands.py.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from strands_agents_builder import strands + + +class TestMCPIntegrationInStrands: + """Tests for MCP integration in the main strands module.""" + + @pytest.fixture + def mock_dependencies(self): + """Mock all dependencies for strands main.""" + with ( + patch.object(strands, "model_utils") as mock_model_utils, + patch.object(strands, "mcp_utils") as mock_mcp_utils, + patch.object(strands, "render_welcome_message") as mock_welcome, + patch.object(strands, "render_goodbye_message") as mock_goodbye, + patch.object(strands, "get_user_input") as mock_input, + patch.object(strands, "store_conversation_in_kb") as mock_store, + patch.object(strands, "load_system_prompt") as mock_load_prompt, + patch.object(strands, "Agent") as mock_agent_class, + ): + # Setup mock model + mock_model = MagicMock() + mock_model_utils.get_model.return_value = mock_model + + # Setup mock system prompt + mock_load_prompt.return_value = "Test system prompt" + + # Setup mock agent + mock_agent = MagicMock() + mock_agent_class.return_value = mock_agent + mock_agent.tool = MagicMock() + mock_agent.tool.mcp_client = MagicMock() + mock_agent.tool.welcome = MagicMock(return_value={"status": "success", "content": [{"text": "Welcome"}]}) + + yield { + "model_utils": mock_model_utils, + "mcp_utils": mock_mcp_utils, + "render_welcome_message": mock_welcome, + "render_goodbye_message": mock_goodbye, + "get_user_input": mock_input, + "store_conversation": mock_store, + "load_system_prompt": mock_load_prompt, + "Agent": mock_agent_class, + "agent": mock_agent, + } + + def test_mcp_config_argument_parsing(self, mock_dependencies): + """Test that --mcp-config argument is parsed correctly.""" + test_config = [{"transport": "stdio", "command": "test"}] + + # Mock mcp_utils.load_config to return test config + mock_dependencies["mcp_utils"].load_config.return_value = test_config + + # Mock user input to exit immediately + mock_dependencies["get_user_input"].return_value = "exit" + + # Test with JSON string + with patch("sys.argv", ["strands", "--mcp-config", json.dumps(test_config)]): + strands.main() + + # Verify load_config was called with the JSON string + mock_dependencies["mcp_utils"].load_config.assert_called_with(json.dumps(test_config)) + + def test_mcp_config_from_file(self, mock_dependencies, tmp_path): + """Test loading MCP config from file.""" + test_config = [{"transport": "sse", "server_url": "http://localhost:8080"}] + config_file = tmp_path / "mcp_config.json" + config_file.write_text(json.dumps(test_config)) + + mock_dependencies["mcp_utils"].load_config.return_value = test_config + mock_dependencies["get_user_input"].return_value = "exit" + + with patch("sys.argv", ["strands", "--mcp-config", str(config_file)]): + strands.main() + + mock_dependencies["mcp_utils"].load_config.assert_called_with(str(config_file)) + + def test_mcp_initialization_success(self, mock_dependencies): + """Test successful MCP initialization.""" + test_config = [ + {"connection_id": "server1", "transport": "stdio", "command": "test1"}, + {"connection_id": "server2", "transport": "stdio", "command": "test2"}, + ] + + mock_dependencies["mcp_utils"].load_config.return_value = test_config + mock_dependencies["mcp_utils"].initialize_mcp_connections.return_value = {"server1": True, "server2": True} + mock_dependencies["get_user_input"].return_value = "exit" + + with patch("sys.argv", ["strands", "--mcp-config", json.dumps(test_config)]): + strands.main() + + # Verify initialization was called + mock_dependencies["mcp_utils"].initialize_mcp_connections.assert_called_once_with( + test_config, mock_dependencies["agent"] + ) + + def test_mcp_initialization_partial_success(self, mock_dependencies): + """Test partial MCP initialization success.""" + test_config = [ + {"connection_id": "server1", "transport": "stdio", "command": "test1"}, + {"connection_id": "server2", "transport": "stdio", "command": "test2"}, + ] + + mock_dependencies["mcp_utils"].load_config.return_value = test_config + mock_dependencies["mcp_utils"].initialize_mcp_connections.return_value = {"server1": True, "server2": False} + mock_dependencies["get_user_input"].return_value = "exit" + + with patch("sys.argv", ["strands", "--mcp-config", json.dumps(test_config)]): + with patch("builtins.print") as mock_print: + strands.main() + + # Check that partial success message was printed + print_calls = [str(call) for call in mock_print.call_args_list] + assert any("1/2 MCP connection(s) initialized" in str(call) for call in print_calls) + + def test_mcp_initialization_all_failed(self, mock_dependencies): + """Test when all MCP connections fail.""" + test_config = [{"connection_id": "server1", "transport": "stdio", "command": "test"}] + + mock_dependencies["mcp_utils"].load_config.return_value = test_config + mock_dependencies["mcp_utils"].initialize_mcp_connections.return_value = {"server1": False} + mock_dependencies["get_user_input"].return_value = "exit" + + with patch("sys.argv", ["strands", "--mcp-config", json.dumps(test_config)]): + with patch("builtins.print") as mock_print: + strands.main() + + # Check that failure message was printed + print_calls = [str(call) for call in mock_print.call_args_list] + assert any("Failed to initialize any MCP connections" in str(call) for call in print_calls) + + def test_mcp_disconnect_on_exit(self, mock_dependencies): + """Test that MCP connections are disconnected on exit.""" + test_config = [{"connection_id": "server1", "transport": "stdio", "command": "test"}] + + mock_dependencies["mcp_utils"].load_config.return_value = test_config + mock_dependencies["mcp_utils"].initialize_mcp_connections.return_value = {"server1": True} + mock_dependencies["get_user_input"].return_value = "exit" + + with patch("sys.argv", ["strands", "--mcp-config", json.dumps(test_config)]): + strands.main() + + # Verify disconnect_all was called + mock_dependencies["mcp_utils"].disconnect_all.assert_called_once_with(mock_dependencies["agent"]) + + def test_mcp_disconnect_on_keyboard_interrupt(self, mock_dependencies): + """Test that MCP connections are disconnected on KeyboardInterrupt.""" + test_config = [{"connection_id": "server1", "transport": "stdio", "command": "test"}] + + mock_dependencies["mcp_utils"].load_config.return_value = test_config + mock_dependencies["mcp_utils"].initialize_mcp_connections.return_value = {"server1": True} + mock_dependencies["get_user_input"].side_effect = KeyboardInterrupt() + + with patch("sys.argv", ["strands", "--mcp-config", json.dumps(test_config)]): + strands.main() + + # Verify disconnect_all was called + mock_dependencies["mcp_utils"].disconnect_all.assert_called_once_with(mock_dependencies["agent"]) + + def test_no_mcp_config(self, mock_dependencies): + """Test running without MCP config.""" + mock_dependencies["mcp_utils"].load_config.return_value = [] + mock_dependencies["get_user_input"].return_value = "exit" + + with patch("sys.argv", ["strands"]): + strands.main() + + # Verify MCP initialization was not called + mock_dependencies["mcp_utils"].initialize_mcp_connections.assert_not_called() + mock_dependencies["mcp_utils"].disconnect_all.assert_not_called() + + def test_mcp_config_with_query_mode(self, mock_dependencies): + """Test MCP config works with query mode (non-interactive).""" + test_config = [{"connection_id": "server1", "transport": "stdio", "command": "test"}] + + mock_dependencies["mcp_utils"].load_config.return_value = test_config + mock_dependencies["mcp_utils"].initialize_mcp_connections.return_value = {"server1": True} + + # Mock agent response - agent is called directly in query mode + mock_dependencies["agent"].return_value = {"message": "Test response"} + + with patch("sys.argv", ["strands", "--mcp-config", json.dumps(test_config), "Test query"]): + strands.main() + + # Verify MCP was initialized + mock_dependencies["mcp_utils"].initialize_mcp_connections.assert_called_once() + + # Verify agent was called with the query + mock_dependencies["agent"].assert_called_once_with("Test query") + + # Verify no disconnect_all in query mode (non-interactive) + mock_dependencies["mcp_utils"].disconnect_all.assert_not_called() + + def test_mcp_config_load_with_default_empty_list(self, mock_dependencies): + """Test that default MCP config is empty list.""" + # Don't provide --mcp-config argument + mock_dependencies["get_user_input"].return_value = "exit" + + with patch("sys.argv", ["strands"]): + strands.main() + + # Verify load_config was called with default "[]" + mock_dependencies["mcp_utils"].load_config.assert_called_with("[]") + + def test_mcp_tool_added_to_agent(self, mock_dependencies): + """Test that mcp_client tool is added to agent tools.""" + mock_dependencies["get_user_input"].return_value = "exit" + + with patch("sys.argv", ["strands"]): + strands.main() + + # Get the tools argument passed to Agent + agent_call = mock_dependencies["Agent"].call_args + tools = agent_call[1]["tools"] + + # Verify mcp_client is in the tools list + from strands_tools import mcp_client + + assert mcp_client in tools diff --git a/tests/utils/test_mcp_utils.py b/tests/utils/test_mcp_utils.py new file mode 100644 index 0000000..5c4fe8e --- /dev/null +++ b/tests/utils/test_mcp_utils.py @@ -0,0 +1,366 @@ +"""Tests for mcp_utils module.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from strands_agents_builder.utils import mcp_utils + + +class TestLoadConfig: + """Tests for the load_config function.""" + + def test_load_config_empty_string(self): + """Test loading config with empty string returns empty list.""" + result = mcp_utils.load_config("") + assert result == [] + + def test_load_config_empty_brackets(self): + """Test loading config with empty brackets returns empty list.""" + result = mcp_utils.load_config("[]") + assert result == [] + + def test_load_config_from_json_string(self): + """Test loading config from JSON string.""" + config_str = '[{"transport": "stdio", "command": "node", "args": ["server.js"]}]' + result = mcp_utils.load_config(config_str) + assert len(result) == 1 + assert result[0]["transport"] == "stdio" + assert result[0]["command"] == "node" + assert result[0]["args"] == ["server.js"] + + def test_load_config_from_json_file(self, tmp_path): + """Test loading config from JSON file.""" + config_data = [ + { + "connection_id": "test_server", + "transport": "stdio", + "command": "python", + "args": ["mcp_server.py"], + "auto_load_tools": True, + } + ] + config_file = tmp_path / "mcp_config.json" + config_file.write_text(json.dumps(config_data)) + + result = mcp_utils.load_config(str(config_file)) + assert len(result) == 1 + assert result[0]["connection_id"] == "test_server" + assert result[0]["transport"] == "stdio" + assert result[0]["command"] == "python" + + def test_load_config_single_dict_to_list(self): + """Test that single dict config is converted to list.""" + config_str = '{"transport": "sse", "server_url": "http://localhost:8080"}' + result = mcp_utils.load_config(config_str) + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["transport"] == "sse" + assert result[0]["server_url"] == "http://localhost:8080" + + def test_load_config_from_env_path(self, tmp_path, monkeypatch): + """Test loading config from STRANDS_MCP_CONFIG_PATH environment variable.""" + config_data = [{"transport": "stdio", "command": "test"}] + config_file = tmp_path / "env_config.json" + config_file.write_text(json.dumps(config_data)) + + monkeypatch.setenv("STRANDS_MCP_CONFIG_PATH", str(config_file)) + result = mcp_utils.load_config("") + assert len(result) == 1 + assert result[0]["command"] == "test" + + def test_load_config_amazon_q_format(self): + """Test loading config in Amazon Q MCP format.""" + config_str = json.dumps( + { + "mcpServers": { + "filesystem": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "path/to/files"], + "disabled": False, + }, + "github": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-github"], + "env": {"GITHUB_TOKEN": "test_token"}, + "disabled": False, + }, + "disabled_server": {"command": "test", "disabled": True}, + } + } + ) + + result = mcp_utils.load_config(config_str) + assert len(result) == 2 # Only enabled servers + + # Check filesystem server + filesystem = next(s for s in result if s["connection_id"] == "filesystem") + assert filesystem["transport"] == "stdio" + assert filesystem["command"] == "npx" + assert filesystem["args"] == ["-y", "@modelcontextprotocol/server-filesystem", "path/to/files"] + assert filesystem["auto_load_tools"] is True + + # Check github server with env + github = next(s for s in result if s["connection_id"] == "github") + assert github["env"] == {"GITHUB_TOKEN": "test_token"} + + # Ensure disabled server is not included + assert not any(s["connection_id"] == "disabled_server" for s in result) + + +class TestInitializeMCPConnections: + """Tests for the initialize_mcp_connections function.""" + + @pytest.fixture + def mock_agent(self): + """Create a mock agent with mcp_client tool.""" + agent = MagicMock() + agent.tool = MagicMock() + agent.tool.mcp_client = MagicMock() + return agent + + def test_initialize_single_connection_success(self, mock_agent): + """Test successful initialization of a single MCP connection.""" + configs = [ + { + "connection_id": "test_server", + "transport": "stdio", + "command": "node", + "args": ["server.js"], + "auto_load_tools": True, + } + ] + + # Mock successful responses + mock_agent.tool.mcp_client.side_effect = [ + {"status": "success", "content": [{"text": "Connected"}]}, # connect + {"status": "success", "content": [{"text": "Tools loaded"}]}, # load_tools + ] + + results = mcp_utils.initialize_mcp_connections(configs, mock_agent) + + assert results == {"test_server": True} + assert mock_agent.tool.mcp_client.call_count == 2 + + # Check connect call + connect_call = mock_agent.tool.mcp_client.call_args_list[0] + assert connect_call[1]["action"] == "connect" + assert connect_call[1]["connection_id"] == "test_server" + assert connect_call[1]["transport"] == "stdio" + assert connect_call[1]["command"] == "node" + assert connect_call[1]["args"] == ["server.js"] + + def test_initialize_connection_failure(self, mock_agent): + """Test failed initialization of MCP connection.""" + configs = [ + { + "connection_id": "failing_server", + "transport": "stdio", + "command": "invalid", + } + ] + + mock_agent.tool.mcp_client.return_value = {"status": "error", "content": [{"text": "Connection failed"}]} + + results = mcp_utils.initialize_mcp_connections(configs, mock_agent) + + assert results == {"failing_server": False} + assert mock_agent.tool.mcp_client.call_count == 1 + + def test_initialize_auto_generate_connection_id_stdio(self, mock_agent): + """Test auto-generation of connection ID for stdio transport.""" + configs = [{"transport": "stdio", "command": "python", "args": ["mcp_server.py"]}] + + mock_agent.tool.mcp_client.return_value = {"status": "success", "content": [{"text": "Connected"}]} + + results = mcp_utils.initialize_mcp_connections(configs, mock_agent) + + # Check that connection_id was auto-generated + assert len(results) == 1 + connection_id = list(results.keys())[0] + assert connection_id.startswith("mcp_python_") + assert results[connection_id] is True + + def test_initialize_auto_generate_connection_id_sse(self, mock_agent): + """Test auto-generation of connection ID for SSE transport.""" + configs = [{"transport": "sse", "server_url": "http://example.com/mcp"}] + + mock_agent.tool.mcp_client.return_value = {"status": "success", "content": [{"text": "Connected"}]} + + results = mcp_utils.initialize_mcp_connections(configs, mock_agent) + + # Check that connection_id was auto-generated from hostname + assert len(results) == 1 + connection_id = list(results.keys())[0] + assert connection_id.startswith("mcp_example_") + assert results[connection_id] is True + + def test_initialize_multiple_connections(self, mock_agent): + """Test initialization of multiple MCP connections.""" + configs = [ + {"connection_id": "server1", "transport": "stdio", "command": "node"}, + {"connection_id": "server2", "transport": "sse", "server_url": "http://localhost:8080"}, + ] + + # First server succeeds, second fails + mock_agent.tool.mcp_client.side_effect = [ + {"status": "success", "content": [{"text": "Connected"}]}, # server1 connect + {"status": "error", "content": [{"text": "Failed"}]}, # server2 connect + ] + + results = mcp_utils.initialize_mcp_connections(configs, mock_agent) + + assert results == {"server1": True, "server2": False} + + def test_initialize_with_env_variables(self, mock_agent): + """Test initialization with environment variables.""" + configs = [ + { + "connection_id": "github_server", + "transport": "stdio", + "command": "npx", + "args": ["@modelcontextprotocol/server-github"], + "env": {"GITHUB_TOKEN": "test_token"}, + } + ] + + mock_agent.tool.mcp_client.return_value = {"status": "success", "content": [{"text": "Connected"}]} + + results = mcp_utils.initialize_mcp_connections(configs, mock_agent) + + assert results == {"github_server": True} + + # Check env was passed + connect_call = mock_agent.tool.mcp_client.call_args_list[0] + assert connect_call[1]["env"] == {"GITHUB_TOKEN": "test_token"} + + def test_initialize_auto_load_tools_false(self, mock_agent): + """Test initialization with auto_load_tools set to False.""" + configs = [{"connection_id": "no_autoload", "transport": "stdio", "command": "test", "auto_load_tools": False}] + + mock_agent.tool.mcp_client.return_value = {"status": "success", "content": [{"text": "Connected"}]} + + results = mcp_utils.initialize_mcp_connections(configs, mock_agent) + + assert results == {"no_autoload": True} + # Should only call connect, not load_tools + assert mock_agent.tool.mcp_client.call_count == 1 + + def test_initialize_exception_handling(self, mock_agent): + """Test exception handling during initialization.""" + configs = [{"connection_id": "exception_server", "transport": "stdio", "command": "test"}] + + mock_agent.tool.mcp_client.side_effect = Exception("Test exception") + + results = mcp_utils.initialize_mcp_connections(configs, mock_agent) + + assert results == {"exception_server": False} + + +class TestListActiveConnections: + """Tests for the list_active_connections function.""" + + @pytest.fixture + def mock_agent(self): + """Create a mock agent with mcp_client tool.""" + agent = MagicMock() + agent.tool = MagicMock() + agent.tool.mcp_client = MagicMock() + return agent + + def test_list_active_connections_success(self, mock_agent): + """Test listing active connections successfully.""" + mock_agent.tool.mcp_client.return_value = { + "status": "success", + "content": [{"text": "Active MCP connections:\n\nConnection: server1\nConnection: server2"}], + } + + result = mcp_utils.list_active_connections(mock_agent) + + assert result == ["server1", "server2"] + mock_agent.tool.mcp_client.assert_called_once_with(action="list_connections", kwargs={}) + + def test_list_active_connections_empty(self, mock_agent): + """Test listing when no active connections.""" + mock_agent.tool.mcp_client.return_value = { + "status": "success", + "content": [{"text": "No active MCP connections"}], + } + + result = mcp_utils.list_active_connections(mock_agent) + + assert result == [] + + def test_list_active_connections_error(self, mock_agent): + """Test listing connections when error occurs.""" + mock_agent.tool.mcp_client.return_value = { + "status": "error", + "content": [{"text": "Error listing connections"}], + } + + result = mcp_utils.list_active_connections(mock_agent) + + assert result == [] + + def test_list_active_connections_exception(self, mock_agent): + """Test listing connections when exception is raised.""" + mock_agent.tool.mcp_client.side_effect = Exception("Test exception") + + result = mcp_utils.list_active_connections(mock_agent) + + assert result == [] + + +class TestDisconnectAll: + """Tests for the disconnect_all function.""" + + @pytest.fixture + def mock_agent(self): + """Create a mock agent with mcp_client tool.""" + agent = MagicMock() + agent.tool = MagicMock() + agent.tool.mcp_client = MagicMock() + return agent + + def test_disconnect_all_success(self, mock_agent): + """Test disconnecting all connections successfully.""" + # Mock list_active_connections to return two connections + with patch.object(mcp_utils, "list_active_connections", return_value=["server1", "server2"]): + # Mock successful disconnections + mock_agent.tool.mcp_client.return_value = {"status": "success", "content": [{"text": "Disconnected"}]} + + mcp_utils.disconnect_all(mock_agent) + + # Should call disconnect for each connection + assert mock_agent.tool.mcp_client.call_count == 2 + + # Check disconnect calls + calls = mock_agent.tool.mcp_client.call_args_list + assert calls[0][1]["action"] == "disconnect" + assert calls[0][1]["connection_id"] == "server1" + assert calls[1][1]["action"] == "disconnect" + assert calls[1][1]["connection_id"] == "server2" + + def test_disconnect_all_no_connections(self, mock_agent): + """Test disconnecting when no active connections.""" + with patch.object(mcp_utils, "list_active_connections", return_value=[]): + mcp_utils.disconnect_all(mock_agent) + + # Should not call disconnect + mock_agent.tool.mcp_client.assert_not_called() + + def test_disconnect_all_with_errors(self, mock_agent): + """Test disconnecting with some errors.""" + with patch.object(mcp_utils, "list_active_connections", return_value=["server1", "server2"]): + # First disconnect succeeds, second raises exception + mock_agent.tool.mcp_client.side_effect = [ + {"status": "success", "content": [{"text": "Disconnected"}]}, + Exception("Connection error"), + ] + + # Should not raise exception + mcp_utils.disconnect_all(mock_agent) + + # Should still attempt both disconnections + assert mock_agent.tool.mcp_client.call_count == 2