diff --git a/crates/mcp-client/Cargo.toml b/crates/mcp-client/Cargo.toml index 976613eca6ee..f64d885650e5 100644 --- a/crates/mcp-client/Cargo.toml +++ b/crates/mcp-client/Cargo.toml @@ -6,18 +6,24 @@ edition = "2021" [dependencies] mcp-core = { path = "../mcp-core" } tokio = { version = "1", features = ["full"] } -reqwest = { version = "0.12.9", default-features = false, features = ["json", "stream", "rustls-tls"] } -reqwest-eventsource = "0.6.0" +reqwest = { version = "0.11", default-features = false, features = ["json", "stream", "rustls-tls"] } +eventsource-client = "0.12.0" +futures = "0.3" futures-util = "0.3" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" clap = { version = "4.5", features = ["derive"] } async-trait = "0.1.83" url = "2.5.4" +thiserror = "1.0" anyhow = "1.0" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } tokio-retry = "0.3" +tower = { version = "0.4", features = ["timeout", "util"] } +tower-service = "0.3" +tokio-util = { version = "0.7", features = ["io-util", "io"] } +rand = "0.8" [dev-dependencies] warp = "0.3" diff --git a/crates/mcp-client/README.md b/crates/mcp-client/README.md index 32e0c4c32fb8..a43c4c21002a 100644 --- a/crates/mcp-client/README.md +++ b/crates/mcp-client/README.md @@ -1,13 +1,11 @@ -## Testing stdio +## Testing stdio transport ```bash -cargo run -p mcp_client -- --mode git -cargo run -p mcp_client -- --mode echo - -cargo run -p mcp_client --bin stdio +cargo run -p mcp-client --example stdio ``` -## Testing SSE +## Testing SSE transport + +1. Start the MCP server in one terminal: `fastmcp run -t sse echo.py` +2. Run the client example in new terminal: `cargo run -p mcp-client --example sse` -1. Start the MCP server: `fastmcp run -t sse echo.py` -2. Run the client: `cargo run -p mcp_client --bin sse` diff --git a/crates/mcp-client/examples/clients.rs b/crates/mcp-client/examples/clients.rs new file mode 100644 index 000000000000..926f13e18fb4 --- /dev/null +++ b/crates/mcp-client/examples/clients.rs @@ -0,0 +1,144 @@ +use mcp_client::{ + client::{ClientCapabilities, ClientInfo, McpClient, McpClientImpl}, + service::TransportService, + transport::{SseTransport, StdioTransport}, +}; +use rand::Rng; +use rand::SeedableRng; +use std::sync::Arc; +use std::time::Duration; +use tower::ServiceBuilder; +use tracing_subscriber::EnvFilter; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Initialize logging + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::from_default_env().add_directive("mcp_client=debug".parse().unwrap()), + ) + .init(); + + // Create two separate clients with stdio transport + let client1 = create_stdio_client("client1", "1.0.0")?; + let client2 = create_stdio_client("client2", "1.0.0")?; + let client3 = create_sse_client("client3", "1.0.0")?; + + // Initialize both clients + let mut clients: Vec> = Vec::new(); + clients.push(client1); + clients.push(client2); + clients.push(client3); + + // Initialize all clients + for (i, client) in clients.iter_mut().enumerate() { + let info = ClientInfo { + name: format!("example-client-{}", i + 1), + version: "1.0.0".to_string(), + }; + let capabilities = ClientCapabilities::default(); + + println!("\nInitializing client {}", i + 1); + let init_result = client.initialize(info, capabilities).await?; + println!("Client {} initialized: {:?}", i + 1, init_result); + } + + // List tools for all clients + for (i, client) in clients.iter_mut().enumerate() { + let tools = client.list_tools().await?; + println!("\nClient {} tools: {:?}", i + 1, tools); + } + + println!("\n\n----------------------------------\n\n"); + + // Wrap clients in Arc before spawning tasks + let clients = Arc::new(clients); + let mut handles = vec![]; + + for i in 0..20 { + let clients = Arc::clone(&clients); + let handle = tokio::spawn(async move { + // let mut rng = rand::thread_rng(); + let mut rng = rand::rngs::StdRng::from_entropy(); + tokio::time::sleep(Duration::from_millis(rng.gen_range(5..50))).await; + + // Randomly select an operation + match rng.gen_range(0..4) { + 0 => { + println!("\n{i}: Listing tools for client 1 (stdio)"); + match clients[0].list_tools().await { + Ok(tools) => { + println!(" {i}: -> Got tools, first one: {:?}", tools.tools.first()) + } + Err(e) => println!(" {i}: -> Error: {}", e), + } + } + 1 => { + println!("\n{i}: Listing tools for client 3 (sse)"); + match clients[2].list_tools().await { + Ok(tools) => { + println!(" {i}: -> Got tools, first one: {:?}", tools.tools.first()) + } + Err(e) => println!(" {i}: -> Error: {}", e), + } + } + 2 => { + println!("\n{i}: Calling tool for client 2 (stdio)"); + match clients[1] + .call_tool("git_status", serde_json::json!({ "repo_path": "." })) + .await + { + Ok(result) => println!( + " {i}: -> Tool execution result, is_error: {:?}", + result.is_error + ), + Err(e) => println!(" {i}: -> Error: {}", e), + } + } + 3 => { + println!("\n{i}: Calling tool for client 3 (sse)"); + match clients[2] + .call_tool( + "echo_tool", + serde_json::json!({ "message": "Client with SSE transport - calling a tool" }), + ) + .await + { + Ok(result) => println!(" {i}: -> Tool execution result, is_error: {:?}", result.is_error), + Err(e) => println!(" {i}: -> Error: {}", e), + } + } + _ => unreachable!(), + } + Ok::<(), Box>(()) + }); + handles.push(handle); + } + + // Wait for all tasks to complete + for handle in handles { + handle.await.unwrap().unwrap(); + } + + Ok(()) +} + +fn create_stdio_client( + _name: &str, + _version: &str, +) -> Result, Box> { + let transport = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()]); + // TODO: Add timeout middleware + let service = ServiceBuilder::new().service(TransportService::new(transport)); + Ok(Box::new(McpClientImpl::new(service))) +} + +fn create_sse_client( + _name: &str, + _version: &str, +) -> Result, Box> { + let transport = SseTransport::new("http://localhost:8000/sse"); + // TODO: Add timeout middleware + let service = ServiceBuilder::new().service(TransportService::new(transport)); + Ok(Box::new(McpClientImpl::new(service))) +} diff --git a/crates/mcp-client/examples/sse.rs b/crates/mcp-client/examples/sse.rs new file mode 100644 index 000000000000..961aecafb2d4 --- /dev/null +++ b/crates/mcp-client/examples/sse.rs @@ -0,0 +1,59 @@ +use anyhow::Result; +use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientImpl}; +use mcp_client::{service::TransportService, transport::SseTransport}; +use std::time::Duration; +use tower::ServiceBuilder; +use tracing_subscriber::EnvFilter; + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize logging + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::from_default_env() + .add_directive("mcp_client=debug".parse().unwrap()) + .add_directive("eventsource_client=debug".parse().unwrap()), + ) + .init(); + + // Create the base transport + let transport = SseTransport::new("http://localhost:8000/sse"); + + // Build service + // TODO: Add timeout middleware + let service = ServiceBuilder::new().service(TransportService::new(transport)); + + // Create client + let client = McpClientImpl::new(service); + println!("Client created\n"); + + // Initialize + let server_info = client + .initialize( + ClientInfo { + name: "test-client".into(), + version: "1.0.0".into(), + }, + ClientCapabilities::default(), + ) + .await?; + println!("Connected to server: {server_info:?}\n"); + + // Sleep for 100ms to allow the server to start - surprisingly this is required! + tokio::time::sleep(Duration::from_millis(100)).await; + + // List tools + let tools = client.list_tools().await?; + println!("Available tools: {tools:?}\n"); + + // Call tool + let tool_result = client + .call_tool( + "echo_tool", + serde_json::json!({ "message": "Client with SSE transport - calling a tool" }), + ) + .await?; + println!("Tool result: {tool_result:?}"); + + Ok(()) +} diff --git a/crates/mcp-client/examples/stdio.rs b/crates/mcp-client/examples/stdio.rs new file mode 100644 index 000000000000..0816e0e29b2b --- /dev/null +++ b/crates/mcp-client/examples/stdio.rs @@ -0,0 +1,53 @@ +use anyhow::Result; +use mcp_client::client::{ + ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientImpl, +}; +use mcp_client::{service::TransportService, transport::StdioTransport}; +use tower::ServiceBuilder; +use tracing_subscriber::EnvFilter; + +#[tokio::main] +async fn main() -> Result<(), ClientError> { + // Initialize logging + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::from_default_env() + .add_directive("mcp_client=debug".parse().unwrap()) + .add_directive("eventsource_client=debug".parse().unwrap()), + ) + .init(); + + // Create the transport + let transport = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()]); + + // Build service + // TODO: Add timeout middleware + let service = ServiceBuilder::new().service(TransportService::new(transport)); + + // Create client + let client = McpClientImpl::new(service); + + // Initialize + let server_info = client + .initialize( + ClientInfo { + name: "test-client".into(), + version: "1.0.0".into(), + }, + ClientCapabilities::default(), + ) + .await?; + println!("Connected to server: {server_info:?}\n"); + + // List tools + let tools = client.list_tools().await?; + println!("Available tools: {tools:?}\n"); + + // Call tool 'git_status' with arguments = {"repo_path": "."} + let tool_result = client + .call_tool("git_status", serde_json::json!({ "repo_path": "." })) + .await?; + println!("Tool result: {tool_result:?}\n"); + + Ok(()) +} diff --git a/crates/mcp-client/examples/stdio_integration.rs b/crates/mcp-client/examples/stdio_integration.rs new file mode 100644 index 000000000000..58226711382f --- /dev/null +++ b/crates/mcp-client/examples/stdio_integration.rs @@ -0,0 +1,73 @@ +// This example shows how to use the mcp-client crate to interact with a server that has a simple counter tool. +// The server is started by running `cargo run -p mcp-server` in the root of the mcp-server crate. +use anyhow::Result; +use mcp_client::client::{ + ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientImpl, +}; +use mcp_client::{service::TransportService, transport::StdioTransport}; +use tower::ServiceBuilder; +use tracing_subscriber::EnvFilter; + +#[tokio::main] +async fn main() -> Result<(), ClientError> { + // Initialize logging + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::from_default_env() + .add_directive("mcp_client=debug".parse().unwrap()) + .add_directive("eventsource_client=debug".parse().unwrap()), + ) + .init(); + + // Create the transport + let transport = StdioTransport::new( + "cargo", + vec!["run", "-p", "mcp-server"] + .into_iter() + .map(|s| s.to_string()) + .collect(), + ); + + // Build service + // TODO: Add timeout middleware + let service = ServiceBuilder::new().service(TransportService::new(transport)); + + // Create client + let client = McpClientImpl::new(service); + + // Initialize + let server_info = client + .initialize( + ClientInfo { + name: "test-client".into(), + version: "1.0.0".into(), + }, + ClientCapabilities::default(), + ) + .await?; + println!("Connected to server: {server_info:?}\n"); + + // List tools + let tools = client.list_tools().await?; + println!("Available tools: {tools:?}\n"); + + // Call tool 'increment' tool 3 times + for _ in 0..3 { + let increment_result = client.call_tool("increment", serde_json::json!({})).await?; + println!("Tool result for 'increment': {increment_result:?}\n"); + } + + // Call tool 'get_value' + let get_value_result = client.call_tool("get_value", serde_json::json!({})).await?; + println!("Tool result for 'get_value': {get_value_result:?}\n"); + + // Call tool 'decrement' once + let decrement_result = client.call_tool("decrement", serde_json::json!({})).await?; + println!("Tool result for 'decrement': {decrement_result:?}\n"); + + // Call tool 'get_value' + let get_value_result = client.call_tool("get_value", serde_json::json!({})).await?; + println!("Tool result for 'get_value': {get_value_result:?}\n"); + + Ok(()) +} diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs new file mode 100644 index 000000000000..2aeb69b5e1c0 --- /dev/null +++ b/crates/mcp-client/src/client.rs @@ -0,0 +1,216 @@ +use std::sync::atomic::{AtomicU64, Ordering}; + +use mcp_core::protocol::{ + CallToolResult, InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcNotification, + JsonRpcRequest, JsonRpcResponse, ListResourcesResult, ListToolsResult, ReadResourceResult, +}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use thiserror::Error; +use tokio::sync::Mutex; +use tower::ServiceExt; // for Service::ready() + +/// Error type for MCP client operations. +#[derive(Debug, Error)] +pub enum Error { + #[error("Service error: {0}")] + Service(#[from] super::service::ServiceError), + + #[error("RPC error: code={code}, message={message}")] + RpcError { code: i32, message: String }, + + #[error("Serialization error: {0}")] + Serialization(#[from] serde_json::Error), + + #[error("Unexpected response from server")] + UnexpectedResponse, + + #[error("Timeout or service not ready")] + NotReady, +} + +#[derive(Serialize, Deserialize)] +pub struct ClientInfo { + pub name: String, + pub version: String, +} + +#[derive(Serialize, Deserialize, Default)] +pub struct ClientCapabilities { + // Add fields as needed. For now, empty capabilities are fine. +} + +#[derive(Serialize, Deserialize)] +pub struct InitializeParams { + #[serde(rename = "protocolVersion")] + pub protocol_version: String, + pub capabilities: ClientCapabilities, + #[serde(rename = "clientInfo")] + pub client_info: ClientInfo, +} + +/// The MCP client trait defining the interface for MCP operations. +#[async_trait::async_trait] +pub trait McpClient: Send + Sync + 'static { + /// Initialize the connection with the server. + async fn initialize( + &self, + info: ClientInfo, + capabilities: ClientCapabilities, + ) -> Result; + + /// List available resources. + async fn list_resources(&self) -> Result; + + /// Read a resource's content. + async fn read_resource(&self, uri: &str) -> Result; + + /// List available tools. + async fn list_tools(&self) -> Result; + + /// Call a specific tool with arguments. + async fn call_tool(&self, name: &str, arguments: Value) -> Result; +} + +/// Standard implementation of the MCP client that sends requests via the provided service. +pub struct McpClientImpl { + service: Mutex, + next_id: AtomicU64, +} + +impl McpClientImpl +where + S: tower::Service< + JsonRpcMessage, + Response = JsonRpcMessage, + Error = super::service::ServiceError, + > + Send, + S::Future: Send, +{ + pub fn new(service: S) -> Self { + Self { + service: Mutex::new(service), + next_id: AtomicU64::new(1), + } + } + + /// Send a JSON-RPC request and check we don't get an error response. + async fn send_request(&self, method: &str, params: Value) -> Result + where + R: for<'de> Deserialize<'de>, + { + let mut service = self.service.lock().await; + service.ready().await.map_err(|_| Error::NotReady)?; + + let id = self.next_id.fetch_add(1, Ordering::SeqCst); + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: Some(id), + method: method.to_string(), + params: Some(params), + }); + + let response_msg = service.call(request).await?; + + match response_msg { + JsonRpcMessage::Response(JsonRpcResponse { + id, result, error, .. + }) => { + // Verify id matches + if id != Some(self.next_id.load(Ordering::SeqCst) - 1) { + return Err(Error::UnexpectedResponse); + } + if let Some(err) = error { + Err(Error::RpcError { + code: err.code, + message: err.message, + }) + } else if let Some(r) = result { + Ok(serde_json::from_value(r)?) + } else { + Err(Error::UnexpectedResponse) + } + } + JsonRpcMessage::Error(JsonRpcError { id, error, .. }) => { + if id != Some(self.next_id.load(Ordering::SeqCst) - 1) { + return Err(Error::UnexpectedResponse); + } + Err(Error::RpcError { + code: error.code, + message: error.message, + }) + } + _ => { + // Requests/notifications not expected as a response + Err(Error::UnexpectedResponse) + } + } + } + + /// Send a JSON-RPC notification. + async fn send_notification(&self, method: &str, params: Value) -> Result<(), Error> { + let mut service = self.service.lock().await; + service.ready().await.map_err(|_| Error::NotReady)?; + + let notification = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: method.to_string(), + params: Some(params), + }); + + service.call(notification).await?; + Ok(()) + } +} + +#[async_trait::async_trait] +impl McpClient for McpClientImpl +where + S: tower::Service< + JsonRpcMessage, + Response = JsonRpcMessage, + Error = super::service::ServiceError, + > + Send + + Sync + + 'static, + S::Future: Send, +{ + async fn initialize( + &self, + info: ClientInfo, + capabilities: ClientCapabilities, + ) -> Result { + let params = InitializeParams { + protocol_version: "1.0.0".into(), + client_info: info, + capabilities, + }; + let result: InitializeResult = self + .send_request("initialize", serde_json::to_value(params)?) + .await?; + + self.send_notification("notifications/initialized", serde_json::json!({})) + .await?; + + Ok(result) + } + + async fn list_resources(&self) -> Result { + self.send_request("resources/list", serde_json::json!({})) + .await + } + + async fn read_resource(&self, uri: &str) -> Result { + let params = serde_json::json!({ "uri": uri }); + self.send_request("resources/read", params).await + } + + async fn list_tools(&self) -> Result { + self.send_request("tools/list", serde_json::json!({})).await + } + + async fn call_tool(&self, name: &str, arguments: Value) -> Result { + let params = serde_json::json!({ "name": name, "arguments": arguments }); + self.send_request("tools/call", params).await + } +} diff --git a/crates/mcp-client/src/lib.rs b/crates/mcp-client/src/lib.rs index 3172f2944f96..c2ab27500d91 100644 --- a/crates/mcp-client/src/lib.rs +++ b/crates/mcp-client/src/lib.rs @@ -1,4 +1,3 @@ -pub mod session; -pub mod sse_transport; -pub mod stdio_transport; +pub mod client; +pub mod service; pub mod transport; diff --git a/crates/mcp-client/src/main.rs b/crates/mcp-client/src/main.rs deleted file mode 100644 index 70495fffd04a..000000000000 --- a/crates/mcp-client/src/main.rs +++ /dev/null @@ -1,91 +0,0 @@ -use anyhow::{anyhow, Result}; -use clap::Parser; -use mcp_client::{ - session::Session, - sse_transport::{SseTransport, SseTransportParams}, - stdio_transport::{StdioServerParams, StdioTransport}, - transport::Transport, -}; -use serde_json::json; -use tracing_subscriber::EnvFilter; - -#[derive(Parser)] -#[command(author, version, about, long_about = None)] -struct Args { - /// Mode to run in: "git" or "echo" - #[arg(short, long, default_value = "git")] - mode: String, -} - -#[tokio::main] -async fn main() -> Result<()> { - // Initialize logging - tracing_subscriber::fmt() - .with_env_filter( - EnvFilter::from_default_env() - .add_directive("mcp_client=debug".parse().unwrap()) - .add_directive("reqwest_eventsource=debug".parse().unwrap()), - ) - .init(); - - let args = Args::parse(); - println!("Args - mode: {}", args.mode); - - // Create session based on mode - let transport: Box = match args.mode.as_str() { - "git" => Box::new(StdioTransport { - params: StdioServerParams { - command: "uvx".into(), - args: vec!["mcp-server-git".into()], - env: None, - }, - }), - "echo" => Box::new(SseTransport { - params: SseTransportParams { - url: "http://localhost:8000/sse".into(), - headers: None, - }, - }), - _ => return Err(anyhow!("Invalid mode. Use 'git' or 'echo'")), - }; - - let (read_stream, write_stream) = transport.connect().await?; - let mut session = Session::new(read_stream, write_stream).await?; - - // Initialize the connection - let init_result = session.initialize().await?; - println!("Initialized: {:?}", init_result); - - // List tools - let tools = session.list_tools().await?; - println!("Tools: {:?}", tools); - - if args.mode == "echo" { - // Call a tool (replace with actual tool name and arguments) - let call_result = session - .call_tool("echo_tool", Some(json!({"message": "Hello, world!"}))) - .await?; - println!("Call tool result: {:?}", call_result); - - // List available resources - let resources = session.list_resources().await?; - println!("Resources: {:?}", resources); - - // Read a resource (replace with actual URI) - if let Some(resource) = resources.resources.first() { - let read_result = session.read_resource(&resource.uri).await?; - println!("Read resource result: {:?}", read_result); - } - } else { - // Call a tool (replace with actual tool name and arguments) - let call_result = session - .call_tool("git_status", Some(json!({"repo_path": "."}))) - .await?; - println!("Call tool result: {:?}", call_result); - } - - session.shutdown().await?; - println!("Done!"); - - Ok(()) -} diff --git a/crates/mcp-client/src/service.rs b/crates/mcp-client/src/service.rs new file mode 100644 index 000000000000..7ebe40ee25bf --- /dev/null +++ b/crates/mcp-client/src/service.rs @@ -0,0 +1,147 @@ +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::sync::{mpsc, Mutex}; +use tower::Service; + +use crate::transport::{Error as TransportError, MessageRouter, Transport}; +use mcp_core::protocol::JsonRpcMessage; + +#[derive(Debug, thiserror::Error)] +pub enum ServiceError { + #[error("Transport error: {0}")] + Transport(#[from] TransportError), + + #[error("Serialization error: {0}")] + Serialization(#[from] serde_json::Error), + + #[error("Request timed out")] + Timeout(#[from] tower::timeout::error::Elapsed), + + #[error("Transport not initialized")] + NotInitialized, + + #[error("Transport already initialized")] + AlreadyInitialized, + + #[error("Other error: {0}")] + Other(String), + + #[error("Unexpected server response")] + UnexpectedResponse, +} + +struct TransportServiceInner { + transport: Arc, + router: Mutex>, + initialized: AtomicBool, +} + +impl TransportServiceInner { + async fn ensure_initialized(&self) -> Result { + if !self.initialized.load(Ordering::SeqCst) { + let mut router_guard = self.router.lock().await; + + // Double-check after acquiring lock + if !self.initialized.load(Ordering::SeqCst) { + // Start the transport + let transport_tx = self + .transport + .start() + .await + .map_err(ServiceError::Transport)?; + + // Create shutdown channel + let (shutdown_tx, _shutdown_rx) = mpsc::channel(1); + + // Create and store the router + let router = MessageRouter::new(transport_tx, shutdown_tx); + *router_guard = Some(router); + + self.initialized.store(true, Ordering::SeqCst); + } + } + + // Get a clone of the router + Ok(self + .router + .lock() + .await + .as_ref() + .ok_or(ServiceError::NotInitialized)? + .clone()) + } +} + +/// A Tower `Service` implementation that uses a `Transport` to send/receive JsonRpcMessages. +pub struct TransportService { + inner: Arc>, +} + +impl TransportService { + pub fn new(transport: T) -> Self { + Self { + inner: Arc::new(TransportServiceInner { + transport: Arc::new(transport), + router: Mutex::new(None), + initialized: AtomicBool::new(false), + }), + } + } +} + +impl Clone for TransportService { + fn clone(&self) -> Self { + Self { + inner: Arc::clone(&self.inner), + } + } +} + +impl Service for TransportService { + type Response = JsonRpcMessage; + type Error = ServiceError; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + // Always ready since we do lazy initialization in call() + Poll::Ready(Ok(())) + } + + fn call(&mut self, message: JsonRpcMessage) -> Self::Future { + let inner = Arc::clone(&self.inner); + + Box::pin(async move { + // Ensure transport is initialized + let router = inner.ensure_initialized().await?; + + match message { + JsonRpcMessage::Notification(notification) => { + router + .send_notification(notification) + .await + .map_err(ServiceError::Transport)?; + Ok(JsonRpcMessage::Nil) + } + JsonRpcMessage::Request(request) => router + .send_request(request) + .await + .map_err(ServiceError::Transport), + _ => Err(ServiceError::Other("Invalid message type".to_string())), + } + }) + } +} + +// https://spec.modelcontextprotocol.io/specification/basic/lifecycle/#shutdown +// impl Drop for TransportServiceInner { +// fn drop(&mut self) { +// if self.initialized.load(Ordering::SeqCst) { +// // Best effort cleanup in sync context +// // We can't create a new runtime here, so we'll just log a warning +// tracing::warn!("TransportService dropped while initialized - resources may leak"); +// } +// } +// } diff --git a/crates/mcp-client/src/session.rs b/crates/mcp-client/src/session.rs deleted file mode 100644 index 1946cea36549..000000000000 --- a/crates/mcp-client/src/session.rs +++ /dev/null @@ -1,544 +0,0 @@ -use crate::transport::{ReadStream, WriteStream}; -use anyhow::{anyhow, Context, Result}; -use mcp_core::protocol::*; -use serde::de::DeserializeOwned; -use serde_json::{json, Value}; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::Arc; -use tokio::sync::mpsc; -use tokio::sync::Mutex; -use tracing::debug; - -struct OutgoingMessage { - message: JsonRpcMessage, - response_tx: mpsc::Sender>>, -} - -pub struct Session { - request_tx: mpsc::Sender, - id_counter: AtomicU64, - shutdown_tx: mpsc::Sender<()>, - background_task: Arc>>>, - is_closed: Arc, -} - -impl Session { - pub async fn new(read_stream: ReadStream, write_stream: WriteStream) -> Result { - let (request_tx, mut request_rx) = mpsc::channel::(32); - let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1); - let is_closed = Arc::new(std::sync::atomic::AtomicBool::new(false)); - let is_closed_clone = is_closed.clone(); - - // Spawn the background task - let background_task = Arc::new(Mutex::new(Some(tokio::spawn({ - async move { - let mut pending_requests: Vec<( - u64, - mpsc::Sender>>, - )> = Vec::new(); - let mut read_stream = read_stream; - let write_stream = write_stream; - - loop { - tokio::select! { - // Handle shutdown signal - Some(()) = shutdown_rx.recv() => { - // Notify all pending requests of shutdown - for (_, tx) in pending_requests { - let _ = tx.send(Err(anyhow!("Session shutdown"))).await; - } - break; - } - - // Handle outgoing messages - Some(outgoing) = request_rx.recv() => { - // If session is closed, reject new messages - if is_closed_clone.load(Ordering::SeqCst) { - let _ = outgoing.response_tx.send(Err(anyhow!("Session is closed"))).await; - continue; - } - - // Send the message - if let Err(e) = write_stream.send(outgoing.message.clone()).await { - debug!("Write error occurred: {}", e); - // let _ = outgoing.response_tx.send(Err(e.into())).await; - // On write error, mark session as closed - is_closed_clone.store(true, Ordering::SeqCst); - break; - } - - // For requests, store the response channel for later - if let JsonRpcMessage::Request(request) = outgoing.message { - if let Some(id) = request.id { - pending_requests.push((id, outgoing.response_tx)); - } - } else { - // For notifications, just confirm success - let _ = outgoing.response_tx.send(Ok(None)).await; - } - } - - // Handle incoming messages - Some(message_result) = read_stream.recv() => { - match message_result { - Ok(JsonRpcMessage::Response(response)) => { - if let Some(id) = response.id { - if let Some(pos) = pending_requests.iter().position(|(req_id, _)| *req_id == id) { - let (_, tx) = pending_requests.remove(pos); - let _ = tx.send(Ok(Some(response))).await; - } - } - } - Ok(JsonRpcMessage::Notification(_)) => { - // Handle incoming notifications if needed - } - Ok(_) => { - eprintln!("Unexpected message type"); - } - Err(e) => { - // On transport error, notify all pending requests and shutdown - eprintln!("Transport error: {}", e); - for (_, tx) in pending_requests { - let _ = tx.send(Err(anyhow!("{}", e))).await; - } - - // Mark session as closed - is_closed_clone.store(true, Ordering::SeqCst); - break; - } - } - } - } - } - } - })))); - - Ok(Self { - request_tx, - id_counter: AtomicU64::new(1), - shutdown_tx, - background_task, - is_closed, - }) - } - - pub async fn shutdown(&self) -> Result<()> { - // Mark session as closed - self.is_closed.store(true, Ordering::SeqCst); - - // Send shutdown signal - self.shutdown_tx - .send(()) - .await - .map_err(|e| anyhow!("Failed to shutdown session: {}", e))?; - - // Wait for background task to complete - if let Some(task) = self.background_task.lock().await.take() { - task.await - .map_err(|e| anyhow!("Background task failed: {}", e))?; - } - - Ok(()) - } - - async fn send_message(&self, message: JsonRpcMessage) -> Result> { - // Check if session is closed - if self.is_closed.load(Ordering::SeqCst) { - return Err(anyhow!("Session is closed")); - } - - let (response_tx, mut response_rx) = mpsc::channel(1); - - self.request_tx - .send(OutgoingMessage { - message, - response_tx, - }) - .await - .context("Failed to send message")?; - - response_rx - .recv() - .await - .context("Failed to receive response")? - } - - async fn rpc_call( - &self, - method: &str, - params: Option, - ) -> Result { - // Check if session is closed - if self.is_closed.load(Ordering::SeqCst) { - return Err(anyhow!("Session is closed")); - } - - let id = self.id_counter.fetch_add(1, Ordering::SeqCst); - let request = JsonRpcRequest { - jsonrpc: "2.0".to_string(), - id: Some(id), - method: method.to_string(), - params, - }; - - let response = self - .send_message(JsonRpcMessage::Request(request)) - .await? - .context("Expected response for request")?; - - match (response.error, response.result) { - (Some(error), _) => Err(anyhow!("RPC Error {}: {}", error.code, error.message)), - (_, Some(result)) => { - serde_json::from_value(result).context("Failed to deserialize result") - } - (None, None) => Err(anyhow!("No result in response")), - } - } - - async fn send_notification(&self, method: &str, params: Option) -> Result<()> { - // Check if session is closed - if self.is_closed.load(Ordering::SeqCst) { - return Err(anyhow!("Session is closed")); - } - - let notification = JsonRpcNotification { - jsonrpc: "2.0".to_string(), - method: method.to_string(), - params, - }; - - self.send_message(JsonRpcMessage::Notification(notification)) - .await?; - - Ok(()) - } - - pub async fn initialize(&mut self) -> Result { - let params = json!({ - "protocolVersion": "2024-11-05", - "capabilities": { - "sampling": null, - "experimental": null, - "roots": { - "listChanged": true - } - }, - "clientInfo": { - "name": "RustMCPClient", - "version": "0.1.0" - } - }); - - let result: InitializeResult = self.rpc_call("initialize", Some(params)).await?; - self.send_notification("notifications/initialized", None) - .await?; - Ok(result) - } - - pub async fn list_resources(&self) -> Result { - self.rpc_call("resources/list", Some(json!({}))).await - } - - pub async fn read_resource(&self, uri: &str) -> Result { - self.rpc_call("resources/read", Some(json!({ "uri": uri }))) - .await - } - - pub async fn list_tools(&self) -> Result { - self.rpc_call("tools/list", Some(json!({}))).await - } - - pub async fn call_tool(&self, name: &str, arguments: Option) -> Result { - self.rpc_call( - "tools/call", - Some(json!({ - "name": name, - "arguments": arguments.unwrap_or_else(|| json!({})), - })), - ) - .await - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::transport::{ReadStream, Transport, WriteStream}; - use anyhow::{anyhow, Result}; - use async_trait::async_trait; - use std::sync::atomic::Ordering; - use std::time::Duration; - use tokio::sync::mpsc; - use tokio::time::timeout; - - // Mock transport that simulates errors - struct MockTransport { - error_mode: ErrorMode, - } - - #[derive(Clone)] - enum ErrorMode { - ReadError, - WriteError, - ProcessTermination, - Nil, - } - - #[async_trait] - impl Transport for MockTransport { - async fn connect(&self) -> Result<(ReadStream, WriteStream)> { - let (tx_read, rx_read) = mpsc::channel(100); - let (tx_write, mut rx_write) = mpsc::channel(100); - - let error_mode = self.error_mode.clone(); - - tokio::spawn(async move { - // For WriteError, don't wait for any writes, just drop the receiver to force an immediate failure. - // This ensures that the first attempt to send by the Session fails. - match error_mode { - ErrorMode::ReadError => { - // Wait a bit for the request to be sent and then send the error - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - let _ = tx_read.send(Err(anyhow!("Simulated read error"))).await; - } - ErrorMode::WriteError => { - // Immediately drop the rx_write side - drop(rx_write); - } - ErrorMode::ProcessTermination => { - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - let _ = tx_read.send(Err(anyhow!("Child process terminated"))).await; - } - ErrorMode::Nil => { - // Test with initialize and then list_resources - while let Some(message) = rx_write.recv().await { - match message { - JsonRpcMessage::Request(req) => { - // Send a successful response for initialization or other calls - if req.method == "initialize" { - let response = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id: req.id, - result: Some(json!({ - "protocolVersion": "2024-11-05", - "capabilities": { "resources": { "listChanged": false } }, - "serverInfo": { "name": "MockServer", "version": "1.0.0" } - })), - error: None, - }); - let _ = tx_read.send(Ok(response)).await; - } else if req.method == "resources/list" { - let response = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id: req.id, - result: Some( - json!({ "resources": [{ "uri": "file://res1", "name": "res1" }, { "uri": "file://res2", "name": "res2" }] }), - ), - error: None, - }); - let _ = tx_read.send(Ok(response)).await; - } else { - // Default success for other calls - let response = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id: req.id, - result: Some(json!({ "ok": true })), - error: None, - }); - let _ = tx_read.send(Ok(response)).await; - } - } - JsonRpcMessage::Notification(_notif) => { - // For notifications, no response is required. - } - _ => {} - } - } - } - } - }); - - Ok((rx_read, tx_write)) - } - } - - #[tokio::test] - async fn test_session_can_initialize_and_list_resources() -> Result<()> { - let transport = MockTransport { - error_mode: ErrorMode::Nil, - }; - - let (read_stream, write_stream) = transport.connect().await?; - let mut session = Session::new(read_stream, write_stream).await?; - - // Initialize the session - let init_result = session.initialize().await?; - assert_eq!(init_result.protocol_version, "2024-11-05"); - assert_eq!( - init_result.capabilities.resources.unwrap().list_changed, - Some(false) - ); - - // Now list resources - let list_result = session.list_resources().await?; - assert_eq!( - list_result - .resources - .iter() - .map(|r| &r.name) - .collect::>(), - vec!["res1", "res2"] - ); - - // Make another call - just to verify multiple calls work fine - let _: serde_json::Value = session.rpc_call("someMethod", Some(json!({}))).await?; - Ok(()) - } - - #[tokio::test] - async fn test_read_error_terminates_session() { - let transport = MockTransport { - error_mode: ErrorMode::ReadError, - }; - - let (read_stream, write_stream) = transport.connect().await.unwrap(); - let session = Session::new(read_stream, write_stream).await.unwrap(); - - // // Introduce a brief delay to ensure the request is fully sent and pending before the error occurs - // tokio::time::sleep(std::time::Duration::from_millis(20)).await; - - // Try to make an RPC call - should fail due to transport error - let result = session.list_resources().await; - assert!(result.is_err()); - - // Print the actual error message for debugging - let err = result.unwrap_err(); - println!("Actual error: {}", err); - assert!(err.to_string().contains("Simulated read error")); - - // Verify session is marked as closed - assert!( - session.is_closed.load(Ordering::SeqCst), - "Session did not close after error" - ); - - // Subsequent calls should fail immediately - let result = session.list_tools().await; - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Session is closed")); - } - - #[tokio::test] - async fn test_write_error_terminates_session() { - let transport = MockTransport { - error_mode: ErrorMode::WriteError, - }; - - let (read_stream, write_stream) = transport.connect().await.unwrap(); - let session = Session::new(read_stream, write_stream).await.unwrap(); - - // Try to make an RPC call - should fail due to transport error - let result = session.list_resources().await; - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Failed to receive response")); - - // Verify session is marked as closed - assert!(session.is_closed.load(Ordering::SeqCst)); - println!("First call made"); - - // Subsequent calls should fail immediately - let result = session.list_tools().await; - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Session is closed")); - } - - #[tokio::test] - async fn test_process_termination_terminates_session() { - let transport = MockTransport { - error_mode: ErrorMode::ProcessTermination, - }; - - let (read_stream, write_stream) = transport.connect().await.unwrap(); - let session = Session::new(read_stream, write_stream).await.unwrap(); - - // Try to make an RPC call - should fail due to process termination - let result = session.list_resources().await; - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Child process terminated")); - - // Verify session is marked as closed - assert!(session.is_closed.load(Ordering::SeqCst)); - - // Subsequent calls should fail immediately - let result = session.list_tools().await; - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Session is closed")); - } - - #[tokio::test] - async fn test_session_cleanup_on_drop() { - let transport = MockTransport { - error_mode: ErrorMode::ProcessTermination, - }; - - let (read_stream, write_stream) = transport.connect().await.unwrap(); - let session = Session::new(read_stream, write_stream).await.unwrap(); - - // Get a clone of the background task handle - let background_task = session.background_task.clone(); - - // Drop the session - drop(session); - - // Verify that the background task completes - let timeout_result = timeout(Duration::from_secs(1), async { - if let Some(task) = background_task.lock().await.take() { - task.await.unwrap(); - } - }) - .await; - - assert!(timeout_result.is_ok(), "Background task did not complete"); - } - - #[tokio::test] - async fn test_explicit_shutdown() -> Result<()> { - let transport = MockTransport { - error_mode: ErrorMode::Nil, - }; - - let (read_stream, write_stream) = transport.connect().await?; - let session = Session::new(read_stream, write_stream).await?; - - // Verify we can make calls before shutdown - let _: serde_json::Value = session.rpc_call("someMethod", Some(json!({}))).await?; - - // Shutdown the session - session.shutdown().await?; - - // Verify calls fail after shutdown - let result = session.list_resources().await; - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Session is closed")); - - Ok(()) - } -} diff --git a/crates/mcp-client/src/sse_transport.rs b/crates/mcp-client/src/sse_transport.rs deleted file mode 100644 index bc5ea885224b..000000000000 --- a/crates/mcp-client/src/sse_transport.rs +++ /dev/null @@ -1,229 +0,0 @@ -use crate::transport::{ReadStream, Transport, WriteStream}; -use anyhow::{anyhow, Context, Result}; -use async_trait::async_trait; -use futures_util::StreamExt; -use mcp_core::protocol::JsonRpcMessage; -use reqwest::{Client, Url}; -use reqwest_eventsource::{Event, EventSource}; -use std::sync::Arc; -use tokio::sync::{mpsc, Mutex}; -use tokio_retry::{ - strategy::{jitter, ExponentialBackoff}, - Retry, -}; -use tracing::{debug, error, info, warn}; - -pub struct SseTransportParams { - pub url: String, - pub headers: Option, -} - -pub struct SseTransport { - pub params: SseTransportParams, -} - -// Helper function to send a POST request with retry logic -async fn send_with_retry( - client: &Client, - endpoint: &str, - json: serde_json::Value, -) -> Result { - // Create retry strategy with exponential backoff - let retry_strategy = ExponentialBackoff::from_millis(100) // Start with 100ms - .factor(2) // Double the delay each time - .map(jitter) // Add randomness to prevent thundering herd - .take(3); // Maximum of 3 retries (4 attempts total) - - Retry::spawn(retry_strategy, || async { - let response = client.post(endpoint).json(&json).send().await?; - - // If we get a 5xx error or specific connection errors, we should retry - if response.status().is_server_error() - || matches!(response.error_for_status_ref(), Err(e) if e.is_connect()) - { - return Err(anyhow!("Server error: {}", response.status())); - } - - Ok(response) - }) - .await -} - -#[async_trait] -impl Transport for SseTransport { - async fn connect(&self) -> Result<(ReadStream, WriteStream)> { - info!("Connecting to SSE endpoint: {}", self.params.url); - let (tx_read, rx_read) = mpsc::channel(100); - let (tx_write, mut rx_write) = mpsc::channel(100); - - let client = Client::new(); - let base_url = Url::parse(&self.params.url).context("Failed to parse SSE URL")?; - - // Create the event source request - let mut request_builder = client.get(base_url.clone()); - if let Some(headers) = &self.params.headers { - request_builder = headers - .iter() - .fold(request_builder, |req, (key, value)| req.header(key, value)); - } - - let event_source = EventSource::new(request_builder)?; - let client_for_post = client.clone(); - - // Shared state for the endpoint URL - let endpoint_url = Arc::new(Mutex::new(None::)); - let endpoint_url_reader = endpoint_url.clone(); - - // Spawn the SSE reader task - tokio::spawn({ - let tx_read = tx_read.clone(); - let base_url = base_url.clone(); - async move { - info!("Starting SSE reader task"); - let mut stream = event_source; - let mut got_endpoint = false; - - while let Some(event) = stream.next().await { - match event { - Ok(Event::Open) => { - info!("SSE connection opened"); - } - Ok(Event::Message(message)) => { - debug!("Received SSE event: {} - {}", message.event, message.data); - match message.event.as_str() { - "endpoint" => { - // Handle endpoint event - let endpoint = message.data; - info!("Received endpoint URL: {}", endpoint); - - // Join with base URL if relative - let endpoint_url_full = if endpoint.starts_with('/') { - match base_url.join(&endpoint) { - Ok(url) => url, - Err(e) => { - error!("Failed to join endpoint URL: {}", e); - let _ = tx_read.send(Err(e.into())).await; - break; - } - } - } else { - match Url::parse(&endpoint) { - Ok(url) => url, - Err(e) => { - error!("Failed to parse endpoint URL: {}", e); - let _ = tx_read.send(Err(e.into())).await; - break; - } - } - }; - - // Validate endpoint URL has same origin (scheme and host) - if base_url.scheme() != endpoint_url_full.scheme() - || base_url.host_str() != endpoint_url_full.host_str() - || base_url.port() != endpoint_url_full.port() - { - let error = format!( - "Endpoint origin does not match connection origin: {}", - endpoint_url_full - ); - error!("{}", error); - let _ = tx_read.send(Err(anyhow!(error))).await; - break; - } - - let endpoint_str = endpoint_url_full.to_string(); - info!("Using full endpoint URL: {}", endpoint_str); - let mut endpoint_guard = endpoint_url.lock().await; - *endpoint_guard = Some(endpoint_str); - got_endpoint = true; - debug!("Endpoint URL set successfully"); - } - "message" => { - if !got_endpoint { - warn!("Received message before endpoint URL"); - continue; - } - // Handle message event - match serde_json::from_str::(&message.data) { - Ok(msg) => { - debug!("Received server message: {:?}", msg); - if tx_read.send(Ok(msg)).await.is_err() { - error!("Failed to send message to read channel"); - break; - } - } - Err(e) => { - error!("Error parsing server message: {}", e); - if tx_read.send(Err(e.into())).await.is_err() { - error!("Failed to send error to read channel"); - break; - } - } - } - } - _ => { - debug!("Ignoring unknown event type: {}", message.event); - } - } - } - Err(e) => { - error!("SSE error: {}", e); - let _ = tx_read.send(Err(e.into())).await; - break; - } - } - } - info!("SSE reader task ended"); - } - }); - - // Spawn the writer task - tokio::spawn(async move { - info!("Starting writer task"); - // Wait for the endpoint URL before processing messages - let mut endpoint = None; - while endpoint.is_none() { - let guard = endpoint_url_reader.lock().await; - if let Some(url) = guard.as_ref() { - endpoint = Some(url.clone()); - break; - } - drop(guard); - debug!("Waiting for endpoint URL..."); - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - } - - let endpoint = endpoint.unwrap(); - info!("Starting post writer with endpoint URL: {}", endpoint); - - while let Some(message) = rx_write.recv().await { - match serde_json::to_value(&message) { - Ok(json) => { - debug!("Sending client message: {:?}", json); - match send_with_retry(&client_for_post, &endpoint, json).await { - Ok(response) => { - if !response.status().is_success() { - let status = response.status(); - let text = response.text().await.unwrap_or_default(); - error!("Server returned error status {}: {}", status, text); - } else { - debug!("Message sent successfully: {}", response.status()); - } - } - Err(e) => { - error!("Failed to send message after retries: {}", e); - } - } - } - Err(e) => { - error!("Failed to serialize message: {}", e); - } - } - } - info!("Writer task ended"); - }); - - info!("SSE transport connected"); - Ok((rx_read, tx_write)) - } -} diff --git a/crates/mcp-client/src/stdio_transport.rs b/crates/mcp-client/src/stdio_transport.rs deleted file mode 100644 index 3b965ab83a41..000000000000 --- a/crates/mcp-client/src/stdio_transport.rs +++ /dev/null @@ -1,148 +0,0 @@ -use crate::transport::{ReadStream, Transport, WriteStream}; -use anyhow::{anyhow, Context, Result}; -use async_trait::async_trait; -use mcp_core::protocol::*; -use std::process::Stdio; -use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; -use tokio::process::{Child, Command}; -use tokio::sync::mpsc; - -pub struct StdioServerParams { - pub command: String, - pub args: Vec, - pub env: Option>, -} - -pub struct StdioTransport { - pub params: StdioServerParams, -} - -impl StdioTransport { - fn get_default_environment() -> std::collections::HashMap { - let default_vars = if cfg!(windows) { - vec!["APPDATA", "PATH", "TEMP", "USERNAME"] // Simplified list - } else { - vec!["HOME", "PATH", "SHELL", "USER"] // Simplified list - }; - - std::env::vars() - .filter(|(key, value)| default_vars.contains(&key.as_str()) && !value.starts_with("()")) - .collect() - } - - async fn monitor_child(mut child: Child, tx_read: mpsc::Sender>) { - match child.wait().await { - Ok(status) => { - let msg = if status.success() { - format!("Child process terminated normally with status: {}", status) - } else { - format!("Child process terminated with error status: {}", status) - }; - let _ = tx_read.send(Err(anyhow!(msg))).await; - } - Err(e) => { - let _ = tx_read - .send(Err(anyhow!("Child process error: {}", e))) - .await; - } - } - } -} - -#[async_trait] -impl Transport for StdioTransport { - async fn connect(&self) -> Result<(ReadStream, WriteStream)> { - let mut child = Command::new(&self.params.command) - .args(&self.params.args) - .env_clear() - .envs( - self.params - .env - .clone() - .unwrap_or_else(Self::get_default_environment), - ) - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::inherit()) - .spawn() - .context("Failed to spawn child process")?; - - let stdin = child.stdin.take().context("Failed to get stdin handle")?; - let stdout = child.stdout.take().context("Failed to get stdout handle")?; - - let (tx_read, rx_read) = mpsc::channel(100); - let (tx_write, mut rx_write) = mpsc::channel(100); - - // Clone tx_read for the child monitor - let tx_read_monitor = tx_read.clone(); - - // Spawn child process monitor - tokio::spawn(Self::monitor_child(child, tx_read_monitor)); - - // Spawn stdout reader task - let stdout_reader = BufReader::new(stdout); - tokio::spawn(async move { - let mut lines = stdout_reader.lines(); - while let Ok(Some(line)) = lines.next_line().await { - match serde_json::from_str::(&line) { - Ok(msg) => { - if tx_read.send(Ok(msg)).await.is_err() { - break; - } - } - Err(e) => { - let _ = tx_read.send(Err(e.into())).await; - } - } - } - }); - - // Spawn stdin writer task - let mut stdin = stdin; - tokio::spawn(async move { - while let Some(message) = rx_write.recv().await { - let json = serde_json::to_string(&message).expect("Failed to serialize message"); - if stdin - .write_all(format!("{}\n", json).as_bytes()) - .await - .is_err() - { - break; - } - } - }); - - Ok((rx_read, tx_write)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::time::Duration; - use tokio::time::timeout; - - #[tokio::test] - async fn test_process_termination() { - let transport = StdioTransport { - params: StdioServerParams { - command: "sleep".to_string(), - args: vec!["0.3".to_string()], - env: None, - }, - }; - let (mut rx, _tx) = transport.connect().await.unwrap(); - - // Try to receive a message - should get an error about process termination - match timeout(Duration::from_secs(1), rx.recv()).await { - Ok(Some(Err(e))) => { - assert!( - e.to_string().contains("Child process terminated normally"), - "Expected process termination error, got: {}", - e - ); - } - _ => panic!("Expected error, got a different message"), - } - } -} diff --git a/crates/mcp-client/src/transport.rs b/crates/mcp-client/src/transport.rs deleted file mode 100644 index 2ccca05e92c1..000000000000 --- a/crates/mcp-client/src/transport.rs +++ /dev/null @@ -1,14 +0,0 @@ -use anyhow::Result; -use async_trait::async_trait; -use mcp_core::protocol::JsonRpcMessage; -use tokio::sync::mpsc::{Receiver, Sender}; - -// Stream types for consistent interface -pub type ReadStream = Receiver>; -pub type WriteStream = Sender; - -// Common trait for transport implementations -#[async_trait] -pub trait Transport { - async fn connect(&self) -> Result<(ReadStream, WriteStream)>; -} diff --git a/crates/mcp-client/src/transport/mod.rs b/crates/mcp-client/src/transport/mod.rs new file mode 100644 index 000000000000..aa4030690dd9 --- /dev/null +++ b/crates/mcp-client/src/transport/mod.rs @@ -0,0 +1,114 @@ +use async_trait::async_trait; +use mcp_core::protocol::{JsonRpcMessage, JsonRpcNotification, JsonRpcRequest}; +use thiserror::Error; +use tokio::sync::{mpsc, oneshot}; + +/// A generic error type for transport operations. +#[derive(Debug, Error)] +pub enum Error { + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + + #[error("Transport was not connected or is already closed")] + NotConnected, + + #[error("Invalid URL provided")] + InvalidUrl, + + #[error("Connection timeout")] + Timeout, + + #[error("Failed to send message")] + SendFailed, + + #[error("Channel closed")] + ChannelClosed, + + #[error("Serialization error: {0}")] + Serialization(#[from] serde_json::Error), + + #[error("HTTP error: {status} - {message}")] + HttpError { status: u16, message: String }, + + #[error("SSE connection error: {0}")] + SseConnection(String), + + #[error("Connection closed by server")] + ConnectionClosed, + + #[error("Unexpected transport error: {0}")] + Other(String), +} + +/// A message that can be sent through the transport +#[derive(Debug)] +pub struct TransportMessage { + /// The JSON-RPC message to send + pub message: JsonRpcMessage, + /// Channel to receive the response on (None for notifications) + pub response_tx: Option>>, +} + +/// A generic asynchronous transport trait with channel-based communication +#[async_trait] +pub trait Transport: Send + Sync + 'static { + /// Start the transport and establish the underlying connection. + /// Returns channels for sending messages and receiving errors. + async fn start(&self) -> Result, Error>; + + /// Close the transport and free any resources. + async fn close(&self) -> Result<(), Error>; +} + +pub mod stdio; +pub use stdio::StdioTransport; + +pub mod sse; +pub use sse::SseTransport; + +/// A router that handles message distribution for a transport +#[derive(Clone)] +pub struct MessageRouter { + transport_tx: mpsc::Sender, + // shutdown_tx is unused, but we'll probably need it for shutdown + #[allow(dead_code)] + shutdown_tx: mpsc::Sender<()>, +} + +impl MessageRouter { + pub fn new( + transport_tx: mpsc::Sender, + shutdown_tx: mpsc::Sender<()>, + ) -> Self { + Self { + transport_tx, + shutdown_tx, + } + } + + /// Send a message and wait for a response + pub async fn send_request(&self, request: JsonRpcRequest) -> Result { + let (response_tx, response_rx) = oneshot::channel(); + + self.transport_tx + .send(TransportMessage { + message: JsonRpcMessage::Request(request), + response_tx: Some(response_tx), + }) + .await + .map_err(|_| Error::ChannelClosed)?; + + response_rx.await.map_err(|_| Error::ChannelClosed)? + } + + /// Send a notification (no response expected) + pub async fn send_notification(&self, notification: JsonRpcNotification) -> Result<(), Error> { + self.transport_tx + .send(TransportMessage { + message: JsonRpcMessage::Notification(notification), + response_tx: None, + }) + .await + .map_err(|_| Error::ChannelClosed) + } +} diff --git a/crates/mcp-client/src/transport/sse.rs b/crates/mcp-client/src/transport/sse.rs new file mode 100644 index 000000000000..6bfd857a4c88 --- /dev/null +++ b/crates/mcp-client/src/transport/sse.rs @@ -0,0 +1,214 @@ +use async_trait::async_trait; +use eventsource_client::{Client, SSE}; +use futures::TryStreamExt; +use reqwest::Client as HttpClient; +use std::sync::Arc; +use tokio::sync::{mpsc, oneshot, Mutex}; +use tokio::task::JoinHandle; +use tracing::warn; + +use super::{Error, Transport, TransportMessage}; +use mcp_core::protocol::JsonRpcMessage; + +/// A transport implementation that uses Server-Sent Events (SSE) for receiving messages +/// and HTTP POST for sending messages. +pub struct SseTransport { + sse_url: String, + http_client: HttpClient, + post_endpoint: Arc>>, + sse_handle: Arc>>>, + pending_requests: Arc< + Mutex>>>, + >, +} + +impl SseTransport { + /// Create a new SSE transport with the given SSE endpoint URL + pub fn new>(sse_url: S) -> Self { + Self { + sse_url: sse_url.into(), + http_client: HttpClient::new(), + post_endpoint: Arc::new(Mutex::new(None)), + sse_handle: Arc::new(Mutex::new(None)), + pending_requests: Arc::new(Mutex::new(std::collections::HashMap::new())), + } + } + + async fn handle_message( + message: JsonRpcMessage, + pending_requests: Arc< + Mutex< + std::collections::HashMap>>, + >, + >, + ) { + if let JsonRpcMessage::Response(response) = &message { + if let Some(id) = &response.id { + if let Some(tx) = pending_requests.lock().await.remove(&id.to_string()) { + let _ = tx.send(Ok(message)); + } + } + } + } + + async fn process_messages( + mut message_rx: mpsc::Receiver, + http_client: HttpClient, + post_endpoint: Arc>>, + sse_url: String, + pending_requests: Arc< + Mutex< + std::collections::HashMap>>, + >, + >, + ) { + // Set up SSE client + let client = match eventsource_client::ClientBuilder::for_url(&sse_url) { + Ok(builder) => builder.build(), + Err(e) => { + // Properly handle initial connection error + let mut pending = pending_requests.lock().await; + for (_, tx) in pending.drain() { + let _ = tx.send(Err(Error::SseConnection(e.to_string()))); + } + return; + } + }; + + let mut stream = client.stream(); + + // First, wait for the endpoint event + while let Ok(Some(event)) = stream.try_next().await { + match event { + SSE::Event(event) if event.event_type == "endpoint" => { + let base_url = sse_url.trim_end_matches('/').trim_end_matches("sse"); + let endpoint_path = event.data.trim_start_matches('/'); + let post_url = format!("{}{}", base_url, endpoint_path); + println!("Endpoint for POST requests: {}", post_url); + *post_endpoint.lock().await = Some(post_url); + break; + } + _ => continue, + } + } + + // Now handle all subsequent messages + let message_handler = tokio::spawn({ + let pending_requests = pending_requests.clone(); + async move { + while let Ok(Some(event)) = stream.try_next().await { + match event { + SSE::Event(event) if event.event_type == "message" => { + if let Ok(message) = serde_json::from_str::(&event.data) + { + Self::handle_message(message, pending_requests.clone()).await; + } + } + _ => continue, + } + } + } + }); + + // Process outgoing messages + while let Some(transport_msg) = message_rx.recv().await { + let post_url = match post_endpoint.lock().await.as_ref() { + Some(url) => url.clone(), + None => { + if let Some(response_tx) = transport_msg.response_tx { + let _ = response_tx.send(Err(Error::NotConnected)); + } + continue; + } + }; + + // Serialize message first + let message_str = match serde_json::to_string(&transport_msg.message) { + Ok(s) => s, + Err(e) => { + if let Some(response_tx) = transport_msg.response_tx { + let _ = response_tx.send(Err(Error::Serialization(e))); + } + continue; + } + }; + + // Store response channel if this is a request + if let Some(response_tx) = transport_msg.response_tx { + if let JsonRpcMessage::Request(request) = &transport_msg.message { + if let Some(id) = &request.id { + pending_requests + .lock() + .await + .insert(id.to_string(), response_tx); + } + } + } + + // Send message via HTTP POST + match http_client + .post(&post_url) + .header("Content-Type", "application/json") + .body(message_str) + .send() + .await + { + Ok(response) => { + if !response.status().is_success() { + let error = Error::HttpError { + status: response.status().as_u16(), + message: response.status().to_string(), + }; + // We don't handle the error directly as it will come through SSE, + // but we log it for debugging purposes + warn!("HTTP request failed with error: {}", error); + } + } + Err(e) => { + let error = Error::Other(format!("HTTP request failed: {}", e)); + // Transport errors will also be communicated through the SSE channel + warn!("HTTP request failed with error: {}", error); + } + } + } + + // Clean up + message_handler.abort(); + } +} + +#[async_trait] +impl Transport for SseTransport { + async fn start(&self) -> Result, Error> { + let (message_tx, message_rx) = mpsc::channel(32); + + let http_client = self.http_client.clone(); + let post_endpoint = self.post_endpoint.clone(); + let sse_url = self.sse_url.clone(); + let pending_requests = self.pending_requests.clone(); + + let handle = tokio::spawn(Self::process_messages( + message_rx, + http_client, + post_endpoint, + sse_url, + pending_requests, + )); + + *self.sse_handle.lock().await = Some(handle); + + Ok(message_tx) + } + + async fn close(&self) -> Result<(), Error> { + // Abort the SSE handler task + if let Some(handle) = self.sse_handle.lock().await.take() { + handle.abort(); + } + + // Clear any pending requests + self.pending_requests.lock().await.clear(); + + Ok(()) + } +} diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs new file mode 100644 index 000000000000..70db75a6d6db --- /dev/null +++ b/crates/mcp-client/src/transport/stdio.rs @@ -0,0 +1,201 @@ +use std::sync::Arc; +use tokio::process::{Child, ChildStdin, ChildStdout, Command}; + +use async_trait::async_trait; +use mcp_core::protocol::JsonRpcMessage; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::sync::{mpsc, oneshot, Mutex}; +use tokio::task::JoinHandle; + +use super::{Error, Transport, TransportMessage}; + +/// A `StdioTransport` uses a child process's stdin/stdout as a communication channel. +/// +/// It uses channels for message passing and handles responses asynchronously through a background task. +pub struct StdioTransport { + command: String, + args: Vec, + process: Arc>>, + reader_handle: Arc>>>, + pending_requests: Arc< + Mutex>>>, + >, +} + +impl StdioTransport { + /// Create a new `StdioTransport` configured to run the given command with arguments. + pub fn new>(command: S, args: Vec) -> Self { + Self { + command: command.into(), + args, + process: Arc::new(Mutex::new(None)), + reader_handle: Arc::new(Mutex::new(None)), + pending_requests: Arc::new(Mutex::new(std::collections::HashMap::new())), + } + } + + async fn spawn_process(&self) -> Result<(ChildStdin, ChildStdout), Error> { + let mut child = Command::new(&self.command) + .args(&self.args) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::inherit()) + .kill_on_drop(true) + .spawn()?; + + let stdin = child + .stdin + .take() + .ok_or(Error::Other("Failed to get stdin".into()))?; + let stdout = child + .stdout + .take() + .ok_or(Error::Other("Failed to get stdout".into()))?; + + *self.process.lock().await = Some(child); + + Ok((stdin, stdout)) + } + + async fn handle_message( + message: JsonRpcMessage, + pending_requests: Arc< + Mutex< + std::collections::HashMap>>, + >, + >, + ) { + if let JsonRpcMessage::Response(response) = &message { + if let Some(id) = &response.id { + if let Some(tx) = pending_requests.lock().await.remove(&id.to_string()) { + let _ = tx.send(Ok(message)); + } + } + } + } + + async fn process_messages( + mut message_rx: mpsc::Receiver, + mut stdin: ChildStdin, + stdout: ChildStdout, + pending_requests: Arc< + Mutex< + std::collections::HashMap>>, + >, + >, + ) { + // Set up async reader for stdout + let mut reader = BufReader::new(stdout); + + // Spawn stdout reader task + let pending_clone = pending_requests.clone(); + let reader_handle = tokio::spawn(async move { + let mut line = String::new(); + loop { + line.clear(); + match reader.read_line(&mut line).await { + Ok(0) => break, // EOF + Ok(_) => { + if let Ok(message) = serde_json::from_str::(&line) { + Self::handle_message(message, pending_clone.clone()).await; + } + } + Err(e) => { + eprintln!("Error reading line: {}", e); + break; + } + } + } + }); + + // Process incoming messages + while let Some(transport_msg) = message_rx.recv().await { + let message_str = match serde_json::to_string(&transport_msg.message) { + Ok(s) => s, + Err(e) => { + if let Some(tx) = transport_msg.response_tx { + let _ = tx.send(Err(Error::Serialization(e))); + } + continue; + } + }; + + // Store response channel if this is a request + if let Some(response_tx) = transport_msg.response_tx { + if let JsonRpcMessage::Request(request) = &transport_msg.message { + if let Some(id) = &request.id { + pending_requests + .lock() + .await + .insert(id.to_string(), response_tx); + } + } + } + + // Write message to stdin + if let Err(_) = stdin + .write_all(format!("{}\n", message_str).as_bytes()) + .await + { + // Break with a specific error indicating write failure + let mut pending = pending_requests.lock().await; + for (_, tx) in pending.drain() { + let _ = tx.send(Err(Error::SendFailed)); + } + break; + } + if let Err(_) = stdin.flush().await { + // Break with a specific error indicating connection issues + let mut pending = pending_requests.lock().await; + for (_, tx) in pending.drain() { + let _ = tx.send(Err(Error::ConnectionClosed)); + } + break; + } + } + + // Clean up + reader_handle.abort(); + } +} + +#[async_trait] +impl Transport for StdioTransport { + async fn start(&self) -> Result, Error> { + let (stdin, stdout) = self.spawn_process().await?; + + let (message_tx, message_rx) = mpsc::channel(32); + + let pending_requests = self.pending_requests.clone(); + let handle = tokio::spawn(Self::process_messages( + message_rx, + stdin, + stdout, + pending_requests, + )); + + *self.reader_handle.lock().await = Some(handle); + + Ok(message_tx) + } + + async fn close(&self) -> Result<(), Error> { + // Kill the process + if let Some(mut process) = self.process.lock().await.take() { + let _ = process.kill().await; + } + + // Abort the reader task + if let Some(handle) = self.reader_handle.lock().await.take() { + handle.abort(); + let _ = handle.await; + } + + // Clear any pending requests + self.pending_requests.lock().await.clear(); + + Ok(()) + } +} + +// No Drop implementation needed - we'll handle cleanup in the TransportService diff --git a/crates/mcp-core/src/protocol.rs b/crates/mcp-core/src/protocol.rs index e178c41177c3..a98d8d79480f 100644 --- a/crates/mcp-core/src/protocol.rs +++ b/crates/mcp-core/src/protocol.rs @@ -47,6 +47,7 @@ pub enum JsonRpcMessage { Response(JsonRpcResponse), Notification(JsonRpcNotification), Error(JsonRpcError), + Nil, // used to respond to notifications } #[derive(Debug, Serialize, Deserialize)] @@ -54,7 +55,8 @@ struct JsonRpcRaw { jsonrpc: String, #[serde(skip_serializing_if = "Option::is_none")] id: Option, - method: String, + #[serde(skip_serializing_if = "Option::is_none")] + method: Option, #[serde(skip_serializing_if = "Option::is_none")] params: Option, #[serde(skip_serializing_if = "Option::is_none")] @@ -86,22 +88,34 @@ impl TryFrom for JsonRpcMessage { })); } - // If the method starts with "notifications/", it's a notification - if raw.method.starts_with("notifications/") { - return Ok(JsonRpcMessage::Notification(JsonRpcNotification { + // If we have a method, it's either a notification or request + if let Some(method) = raw.method { + if method.starts_with("notifications/") { + return Ok(JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: raw.jsonrpc, + method, + params: raw.params, + })); + } + + return Ok(JsonRpcMessage::Request(JsonRpcRequest { jsonrpc: raw.jsonrpc, - method: raw.method, + id: raw.id, + method, params: raw.params, })); } - // Otherwise it's a request - Ok(JsonRpcMessage::Request(JsonRpcRequest { - jsonrpc: raw.jsonrpc, - id: raw.id, - method: raw.method, - params: raw.params, - })) + // If we have no method and no result/error, it's a nil response + if raw.id.is_none() && raw.result.is_none() && raw.error.is_none() { + return Ok(JsonRpcMessage::Nil); + } + + // If we get here, something is wrong with the message + Err(format!( + "Invalid JSON-RPC message format: id={:?}, method={:?}, result={:?}, error={:?}", + raw.id, raw.method, raw.result, raw.error + )) } } diff --git a/crates/mcp-server/src/lib.rs b/crates/mcp-server/src/lib.rs index b054a98c59a4..4aa02130d0e7 100644 --- a/crates/mcp-server/src/lib.rs +++ b/crates/mcp-server/src/lib.rs @@ -191,8 +191,9 @@ where } JsonRpcMessage::Response(_) | JsonRpcMessage::Notification(_) + | JsonRpcMessage::Nil | JsonRpcMessage::Error(_) => { - // Ignore responses and notifications for now + // Ignore responses, notifications and nil messages for now continue; } }