From 5d55bc3486f977764a49b9f4f00be22d37b22b4e Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Mon, 16 Dec 2024 23:26:38 -0500 Subject: [PATCH 01/21] mcp-client using tower service trait --- crates/mcp-client/Cargo.toml | 3 + crates/mcp-client/src/client.rs | 163 +++++++ crates/mcp-client/src/lib.rs | 5 +- crates/mcp-client/src/main.rs | 117 ++--- crates/mcp-client/src/service.rs | 74 +++ crates/mcp-client/src/session.rs | 544 ----------------------- crates/mcp-client/src/sse_transport.rs | 229 ---------- crates/mcp-client/src/stdio_transport.rs | 198 --------- crates/mcp-client/src/transport.rs | 14 - crates/mcp-client/src/transport/mod.rs | 42 ++ crates/mcp-client/src/transport/stdio.rs | 94 ++++ crates/mcp-core/src/tool.rs | 1 + 12 files changed, 426 insertions(+), 1058 deletions(-) create mode 100644 crates/mcp-client/src/client.rs create mode 100644 crates/mcp-client/src/service.rs delete mode 100644 crates/mcp-client/src/session.rs delete mode 100644 crates/mcp-client/src/sse_transport.rs delete mode 100644 crates/mcp-client/src/stdio_transport.rs delete mode 100644 crates/mcp-client/src/transport.rs create mode 100644 crates/mcp-client/src/transport/mod.rs create mode 100644 crates/mcp-client/src/transport/stdio.rs diff --git a/crates/mcp-client/Cargo.toml b/crates/mcp-client/Cargo.toml index 6de3d390332e..a666e6522b53 100644 --- a/crates/mcp-client/Cargo.toml +++ b/crates/mcp-client/Cargo.toml @@ -14,10 +14,13 @@ serde_json = "1.0" clap = { version = "4.5", features = ["derive"] } async-trait = "0.1.83" url = "2.5.4" +thiserror = "1.0" anyhow = "1.0" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } tokio-retry = "0.3" +tower = { version = "0.4", features = ["timeout", "util"] } +tower-service = "0.3" [dev-dependencies] warp = "0.3" diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs new file mode 100644 index 000000000000..0cdab4391dd4 --- /dev/null +++ b/crates/mcp-client/src/client.rs @@ -0,0 +1,163 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use thiserror::Error; +use tower::ServiceExt; // for Service::ready() + +use mcp_core::protocol::{ + InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, + ListResourcesResult, ReadResourceResult, +}; + +/// Error type for MCP client operations. +#[derive(Debug, Error)] +pub enum Error { + #[error("Service error: {0}")] + Service(#[from] super::service::ServiceError), + + #[error("RPC error: code={code}, message={message}")] + RpcError { code: i32, message: String }, + + #[error("Serialization error: {0}")] + Serialization(#[from] serde_json::Error), + + #[error("Unexpected response from server")] + UnexpectedResponse, + + #[error("Timeout or service not ready")] + NotReady, +} + +#[derive(Serialize, Deserialize)] +pub struct ClientInfo { + pub name: String, + pub version: String, +} + +#[derive(Serialize, Deserialize, Default)] +pub struct ClientCapabilities { + // Add fields as needed. For now, empty capabilities are fine. +} + +#[derive(Serialize, Deserialize)] +pub struct InitializeParams { + #[serde(rename = "protocolVersion")] + pub protocol_version: String, + pub capabilities: ClientCapabilities, + #[serde(rename = "clientInfo")] + pub client_info: ClientInfo, +} + +/// The MCP client that sends requests via the provided service. +pub struct McpClient { + service: S, + next_id: u64, +} + +impl McpClient +where + S: tower::Service< + JsonRpcRequest, + Response = JsonRpcMessage, + Error = super::service::ServiceError, + > + Send, + S::Future: Send, +{ + pub fn new(service: S) -> Self { + Self { + service, + next_id: 1, + } + } + + /// Send a JSON-RPC request and wait for a response. + async fn send_message(&mut self, method: &str, params: Value) -> Result + where + T: for<'de> Deserialize<'de>, + { + self.service.ready().await.map_err(|_| Error::NotReady)?; + + let request = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: Some(self.next_id), + method: method.to_string(), + params: Some(params), + }; + + self.next_id += 1; + + let response_msg = self.service.call(request).await?; + + match response_msg { + JsonRpcMessage::Response(JsonRpcResponse { + id, result, error, .. + }) => { + // Verify id matches + if id != Some(self.next_id - 1) { + return Err(Error::UnexpectedResponse); + } + if let Some(err) = error { + Err(Error::RpcError { + code: err.code, + message: err.message, + }) + } else if let Some(r) = result { + Ok(serde_json::from_value(r)?) + } else { + Err(Error::UnexpectedResponse) + } + } + JsonRpcMessage::Error(JsonRpcError { id, error, .. }) => { + if id != Some(self.next_id - 1) { + return Err(Error::UnexpectedResponse); + } + Err(Error::RpcError { + code: error.code, + message: error.message, + }) + } + _ => { + // Requests/notifications not expected as a response + Err(Error::UnexpectedResponse) + } + } + } + + // /// Send a JSON-RPC notification. + // pub async fn send_notification(&self, method: &str, params: Value) -> Result<(), Error> { + // let notification = mcp_core::protocol::JsonRpcNotification { + // jsonrpc: "2.0".to_string(), + // method: method.to_string(), + // params: Some(params), + // }; + // let msg = serde_json::to_string(¬ification)?; + // let mut transport = self.transport.lock().await; + // transport.send(msg).await + // } + + /// Initialize the connection with the server. + pub async fn initialize( + &mut self, + info: ClientInfo, + capabilities: ClientCapabilities, + ) -> Result { + let params = InitializeParams { + protocol_version: "1.0.0".into(), + client_info: info, + capabilities, + }; + self.send_message("initialize", serde_json::to_value(params)?) + .await + } + + /// List available resources. + pub async fn list_resources(&mut self) -> Result { + self.send_message("resources/list", serde_json::json!({})) + .await + } + + /// Read a resource's content. + pub async fn read_resource(&mut self, uri: &str) -> Result { + let params = serde_json::json!({ "uri": uri }); + self.send_message("resources/read", params).await + } +} diff --git a/crates/mcp-client/src/lib.rs b/crates/mcp-client/src/lib.rs index 3172f2944f96..c2ab27500d91 100644 --- a/crates/mcp-client/src/lib.rs +++ b/crates/mcp-client/src/lib.rs @@ -1,4 +1,3 @@ -pub mod session; -pub mod sse_transport; -pub mod stdio_transport; +pub mod client; +pub mod service; pub mod transport; diff --git a/crates/mcp-client/src/main.rs b/crates/mcp-client/src/main.rs index 70495fffd04a..bd4490fd0543 100644 --- a/crates/mcp-client/src/main.rs +++ b/crates/mcp-client/src/main.rs @@ -1,24 +1,31 @@ -use anyhow::{anyhow, Result}; -use clap::Parser; -use mcp_client::{ - session::Session, - sse_transport::{SseTransport, SseTransportParams}, - stdio_transport::{StdioServerParams, StdioTransport}, - transport::Transport, -}; -use serde_json::json; +use anyhow::Result; +use mcp_client::client::{ClientCapabilities, ClientInfo, Error as ClientError, McpClient}; +use mcp_client::{service::TransportService, transport::StdioTransport}; +use tower::ServiceBuilder; use tracing_subscriber::EnvFilter; -#[derive(Parser)] -#[command(author, version, about, long_about = None)] -struct Args { - /// Mode to run in: "git" or "echo" - #[arg(short, long, default_value = "git")] - mode: String, -} +// use mcp_client::{ +// service::{ServiceError}, +// transport::{Error as TransportError}, +// }; +// use std::time::Duration; +// use tower::timeout::error::Elapsed; + +// fn convert_box_error(err: Box) -> ServiceError { +// if let Some(elapsed) = err.downcast_ref::() { +// ServiceError::Transport(TransportError::Io( +// std::io::Error::new( +// std::io::ErrorKind::TimedOut, +// format!("Timeout elapsed: {}", elapsed), +// ), +// )) +// } else { +// ServiceError::Other(err.to_string()) +// } +// } #[tokio::main] -async fn main() -> Result<()> { +async fn main() -> Result<(), ClientError> { // Initialize logging tracing_subscriber::fmt() .with_env_filter( @@ -28,64 +35,34 @@ async fn main() -> Result<()> { ) .init(); - let args = Args::parse(); - println!("Args - mode: {}", args.mode); - - // Create session based on mode - let transport: Box = match args.mode.as_str() { - "git" => Box::new(StdioTransport { - params: StdioServerParams { - command: "uvx".into(), - args: vec!["mcp-server-git".into()], - env: None, - }, - }), - "echo" => Box::new(SseTransport { - params: SseTransportParams { - url: "http://localhost:8000/sse".into(), - headers: None, - }, - }), - _ => return Err(anyhow!("Invalid mode. Use 'git' or 'echo'")), - }; + // Create the base transport + let transport = StdioTransport::new("uvx", ["mcp-server-git"]); - let (read_stream, write_stream) = transport.connect().await?; - let mut session = Session::new(read_stream, write_stream).await?; + // Build service with middleware + let service = ServiceBuilder::new().service(TransportService::new(transport)); - // Initialize the connection - let init_result = session.initialize().await?; - println!("Initialized: {:?}", init_result); + // Create client + let mut client = McpClient::new(service); - // List tools - let tools = session.list_tools().await?; - println!("Tools: {:?}", tools); - - if args.mode == "echo" { - // Call a tool (replace with actual tool name and arguments) - let call_result = session - .call_tool("echo_tool", Some(json!({"message": "Hello, world!"}))) - .await?; - println!("Call tool result: {:?}", call_result); - - // List available resources - let resources = session.list_resources().await?; - println!("Resources: {:?}", resources); + // Initialize + let server_info = client + .initialize( + ClientInfo { + name: "test-client".into(), + version: "1.0.0".into(), + }, + ClientCapabilities::default(), + ) + .await?; + println!("Connected to server: {server_info:?}"); - // Read a resource (replace with actual URI) - if let Some(resource) = resources.resources.first() { - let read_result = session.read_resource(&resource.uri).await?; - println!("Read resource result: {:?}", read_result); - } - } else { - // Call a tool (replace with actual tool name and arguments) - let call_result = session - .call_tool("git_status", Some(json!({"repo_path": "."}))) - .await?; - println!("Call tool result: {:?}", call_result); - } + // List resources + let resources = client.list_resources().await?; + println!("Available resources: {resources:?}"); - session.shutdown().await?; - println!("Done!"); + // Read a resource + let content = client.read_resource("file:///example.txt".into()).await?; + println!("Content: {content:?}"); Ok(()) } diff --git a/crates/mcp-client/src/service.rs b/crates/mcp-client/src/service.rs new file mode 100644 index 000000000000..f422bc9244d5 --- /dev/null +++ b/crates/mcp-client/src/service.rs @@ -0,0 +1,74 @@ +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::sync::Mutex; +use tower::Service; + +use crate::transport::{Error as TransportError, Transport}; +use mcp_core::protocol::{JsonRpcMessage, JsonRpcRequest}; + +#[derive(Debug, thiserror::Error)] +pub enum ServiceError { + #[error("Transport error: {0}")] + Transport(#[from] TransportError), + + #[error("Serialization error: {0}")] + Serialization(#[from] serde_json::Error), + + #[error("Other error: {0}")] + Other(String), + + #[error("Unexpected server response")] + UnexpectedResponse, +} + +/// A Tower `Service` implementation that uses a `Transport` to send/receive JsonRpcRequests and JsonRpcMessages. +pub struct TransportService { + transport: Arc>, + initialized: AtomicBool, +} + +impl TransportService { + pub fn new(transport: T) -> Self { + Self { + transport: Arc::new(Mutex::new(transport)), + initialized: AtomicBool::new(false), + } + } +} + +impl Service for TransportService { + type Response = JsonRpcMessage; + type Error = ServiceError; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + // Always ready. We do on-demand initialization in call(). + Poll::Ready(Ok(())) + } + + fn call(&mut self, request: JsonRpcRequest) -> Self::Future { + let transport = Arc::clone(&self.transport); + let started = self.initialized.load(Ordering::SeqCst); + + Box::pin(async move { + let mut transport = transport.lock().await; + + // Initialize (start) transport on the first call. + if !started { + transport.start().await?; + } + + // Serialize request to JSON line + let msg = serde_json::to_string(&request)?; + transport.send(msg).await?; + + let line = transport.receive().await?; + let response_msg: JsonRpcMessage = serde_json::from_str(&line)?; + + Ok(response_msg) + }) + } +} diff --git a/crates/mcp-client/src/session.rs b/crates/mcp-client/src/session.rs deleted file mode 100644 index 1946cea36549..000000000000 --- a/crates/mcp-client/src/session.rs +++ /dev/null @@ -1,544 +0,0 @@ -use crate::transport::{ReadStream, WriteStream}; -use anyhow::{anyhow, Context, Result}; -use mcp_core::protocol::*; -use serde::de::DeserializeOwned; -use serde_json::{json, Value}; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::Arc; -use tokio::sync::mpsc; -use tokio::sync::Mutex; -use tracing::debug; - -struct OutgoingMessage { - message: JsonRpcMessage, - response_tx: mpsc::Sender>>, -} - -pub struct Session { - request_tx: mpsc::Sender, - id_counter: AtomicU64, - shutdown_tx: mpsc::Sender<()>, - background_task: Arc>>>, - is_closed: Arc, -} - -impl Session { - pub async fn new(read_stream: ReadStream, write_stream: WriteStream) -> Result { - let (request_tx, mut request_rx) = mpsc::channel::(32); - let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1); - let is_closed = Arc::new(std::sync::atomic::AtomicBool::new(false)); - let is_closed_clone = is_closed.clone(); - - // Spawn the background task - let background_task = Arc::new(Mutex::new(Some(tokio::spawn({ - async move { - let mut pending_requests: Vec<( - u64, - mpsc::Sender>>, - )> = Vec::new(); - let mut read_stream = read_stream; - let write_stream = write_stream; - - loop { - tokio::select! { - // Handle shutdown signal - Some(()) = shutdown_rx.recv() => { - // Notify all pending requests of shutdown - for (_, tx) in pending_requests { - let _ = tx.send(Err(anyhow!("Session shutdown"))).await; - } - break; - } - - // Handle outgoing messages - Some(outgoing) = request_rx.recv() => { - // If session is closed, reject new messages - if is_closed_clone.load(Ordering::SeqCst) { - let _ = outgoing.response_tx.send(Err(anyhow!("Session is closed"))).await; - continue; - } - - // Send the message - if let Err(e) = write_stream.send(outgoing.message.clone()).await { - debug!("Write error occurred: {}", e); - // let _ = outgoing.response_tx.send(Err(e.into())).await; - // On write error, mark session as closed - is_closed_clone.store(true, Ordering::SeqCst); - break; - } - - // For requests, store the response channel for later - if let JsonRpcMessage::Request(request) = outgoing.message { - if let Some(id) = request.id { - pending_requests.push((id, outgoing.response_tx)); - } - } else { - // For notifications, just confirm success - let _ = outgoing.response_tx.send(Ok(None)).await; - } - } - - // Handle incoming messages - Some(message_result) = read_stream.recv() => { - match message_result { - Ok(JsonRpcMessage::Response(response)) => { - if let Some(id) = response.id { - if let Some(pos) = pending_requests.iter().position(|(req_id, _)| *req_id == id) { - let (_, tx) = pending_requests.remove(pos); - let _ = tx.send(Ok(Some(response))).await; - } - } - } - Ok(JsonRpcMessage::Notification(_)) => { - // Handle incoming notifications if needed - } - Ok(_) => { - eprintln!("Unexpected message type"); - } - Err(e) => { - // On transport error, notify all pending requests and shutdown - eprintln!("Transport error: {}", e); - for (_, tx) in pending_requests { - let _ = tx.send(Err(anyhow!("{}", e))).await; - } - - // Mark session as closed - is_closed_clone.store(true, Ordering::SeqCst); - break; - } - } - } - } - } - } - })))); - - Ok(Self { - request_tx, - id_counter: AtomicU64::new(1), - shutdown_tx, - background_task, - is_closed, - }) - } - - pub async fn shutdown(&self) -> Result<()> { - // Mark session as closed - self.is_closed.store(true, Ordering::SeqCst); - - // Send shutdown signal - self.shutdown_tx - .send(()) - .await - .map_err(|e| anyhow!("Failed to shutdown session: {}", e))?; - - // Wait for background task to complete - if let Some(task) = self.background_task.lock().await.take() { - task.await - .map_err(|e| anyhow!("Background task failed: {}", e))?; - } - - Ok(()) - } - - async fn send_message(&self, message: JsonRpcMessage) -> Result> { - // Check if session is closed - if self.is_closed.load(Ordering::SeqCst) { - return Err(anyhow!("Session is closed")); - } - - let (response_tx, mut response_rx) = mpsc::channel(1); - - self.request_tx - .send(OutgoingMessage { - message, - response_tx, - }) - .await - .context("Failed to send message")?; - - response_rx - .recv() - .await - .context("Failed to receive response")? - } - - async fn rpc_call( - &self, - method: &str, - params: Option, - ) -> Result { - // Check if session is closed - if self.is_closed.load(Ordering::SeqCst) { - return Err(anyhow!("Session is closed")); - } - - let id = self.id_counter.fetch_add(1, Ordering::SeqCst); - let request = JsonRpcRequest { - jsonrpc: "2.0".to_string(), - id: Some(id), - method: method.to_string(), - params, - }; - - let response = self - .send_message(JsonRpcMessage::Request(request)) - .await? - .context("Expected response for request")?; - - match (response.error, response.result) { - (Some(error), _) => Err(anyhow!("RPC Error {}: {}", error.code, error.message)), - (_, Some(result)) => { - serde_json::from_value(result).context("Failed to deserialize result") - } - (None, None) => Err(anyhow!("No result in response")), - } - } - - async fn send_notification(&self, method: &str, params: Option) -> Result<()> { - // Check if session is closed - if self.is_closed.load(Ordering::SeqCst) { - return Err(anyhow!("Session is closed")); - } - - let notification = JsonRpcNotification { - jsonrpc: "2.0".to_string(), - method: method.to_string(), - params, - }; - - self.send_message(JsonRpcMessage::Notification(notification)) - .await?; - - Ok(()) - } - - pub async fn initialize(&mut self) -> Result { - let params = json!({ - "protocolVersion": "2024-11-05", - "capabilities": { - "sampling": null, - "experimental": null, - "roots": { - "listChanged": true - } - }, - "clientInfo": { - "name": "RustMCPClient", - "version": "0.1.0" - } - }); - - let result: InitializeResult = self.rpc_call("initialize", Some(params)).await?; - self.send_notification("notifications/initialized", None) - .await?; - Ok(result) - } - - pub async fn list_resources(&self) -> Result { - self.rpc_call("resources/list", Some(json!({}))).await - } - - pub async fn read_resource(&self, uri: &str) -> Result { - self.rpc_call("resources/read", Some(json!({ "uri": uri }))) - .await - } - - pub async fn list_tools(&self) -> Result { - self.rpc_call("tools/list", Some(json!({}))).await - } - - pub async fn call_tool(&self, name: &str, arguments: Option) -> Result { - self.rpc_call( - "tools/call", - Some(json!({ - "name": name, - "arguments": arguments.unwrap_or_else(|| json!({})), - })), - ) - .await - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::transport::{ReadStream, Transport, WriteStream}; - use anyhow::{anyhow, Result}; - use async_trait::async_trait; - use std::sync::atomic::Ordering; - use std::time::Duration; - use tokio::sync::mpsc; - use tokio::time::timeout; - - // Mock transport that simulates errors - struct MockTransport { - error_mode: ErrorMode, - } - - #[derive(Clone)] - enum ErrorMode { - ReadError, - WriteError, - ProcessTermination, - Nil, - } - - #[async_trait] - impl Transport for MockTransport { - async fn connect(&self) -> Result<(ReadStream, WriteStream)> { - let (tx_read, rx_read) = mpsc::channel(100); - let (tx_write, mut rx_write) = mpsc::channel(100); - - let error_mode = self.error_mode.clone(); - - tokio::spawn(async move { - // For WriteError, don't wait for any writes, just drop the receiver to force an immediate failure. - // This ensures that the first attempt to send by the Session fails. - match error_mode { - ErrorMode::ReadError => { - // Wait a bit for the request to be sent and then send the error - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - let _ = tx_read.send(Err(anyhow!("Simulated read error"))).await; - } - ErrorMode::WriteError => { - // Immediately drop the rx_write side - drop(rx_write); - } - ErrorMode::ProcessTermination => { - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - let _ = tx_read.send(Err(anyhow!("Child process terminated"))).await; - } - ErrorMode::Nil => { - // Test with initialize and then list_resources - while let Some(message) = rx_write.recv().await { - match message { - JsonRpcMessage::Request(req) => { - // Send a successful response for initialization or other calls - if req.method == "initialize" { - let response = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id: req.id, - result: Some(json!({ - "protocolVersion": "2024-11-05", - "capabilities": { "resources": { "listChanged": false } }, - "serverInfo": { "name": "MockServer", "version": "1.0.0" } - })), - error: None, - }); - let _ = tx_read.send(Ok(response)).await; - } else if req.method == "resources/list" { - let response = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id: req.id, - result: Some( - json!({ "resources": [{ "uri": "file://res1", "name": "res1" }, { "uri": "file://res2", "name": "res2" }] }), - ), - error: None, - }); - let _ = tx_read.send(Ok(response)).await; - } else { - // Default success for other calls - let response = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id: req.id, - result: Some(json!({ "ok": true })), - error: None, - }); - let _ = tx_read.send(Ok(response)).await; - } - } - JsonRpcMessage::Notification(_notif) => { - // For notifications, no response is required. - } - _ => {} - } - } - } - } - }); - - Ok((rx_read, tx_write)) - } - } - - #[tokio::test] - async fn test_session_can_initialize_and_list_resources() -> Result<()> { - let transport = MockTransport { - error_mode: ErrorMode::Nil, - }; - - let (read_stream, write_stream) = transport.connect().await?; - let mut session = Session::new(read_stream, write_stream).await?; - - // Initialize the session - let init_result = session.initialize().await?; - assert_eq!(init_result.protocol_version, "2024-11-05"); - assert_eq!( - init_result.capabilities.resources.unwrap().list_changed, - Some(false) - ); - - // Now list resources - let list_result = session.list_resources().await?; - assert_eq!( - list_result - .resources - .iter() - .map(|r| &r.name) - .collect::>(), - vec!["res1", "res2"] - ); - - // Make another call - just to verify multiple calls work fine - let _: serde_json::Value = session.rpc_call("someMethod", Some(json!({}))).await?; - Ok(()) - } - - #[tokio::test] - async fn test_read_error_terminates_session() { - let transport = MockTransport { - error_mode: ErrorMode::ReadError, - }; - - let (read_stream, write_stream) = transport.connect().await.unwrap(); - let session = Session::new(read_stream, write_stream).await.unwrap(); - - // // Introduce a brief delay to ensure the request is fully sent and pending before the error occurs - // tokio::time::sleep(std::time::Duration::from_millis(20)).await; - - // Try to make an RPC call - should fail due to transport error - let result = session.list_resources().await; - assert!(result.is_err()); - - // Print the actual error message for debugging - let err = result.unwrap_err(); - println!("Actual error: {}", err); - assert!(err.to_string().contains("Simulated read error")); - - // Verify session is marked as closed - assert!( - session.is_closed.load(Ordering::SeqCst), - "Session did not close after error" - ); - - // Subsequent calls should fail immediately - let result = session.list_tools().await; - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Session is closed")); - } - - #[tokio::test] - async fn test_write_error_terminates_session() { - let transport = MockTransport { - error_mode: ErrorMode::WriteError, - }; - - let (read_stream, write_stream) = transport.connect().await.unwrap(); - let session = Session::new(read_stream, write_stream).await.unwrap(); - - // Try to make an RPC call - should fail due to transport error - let result = session.list_resources().await; - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Failed to receive response")); - - // Verify session is marked as closed - assert!(session.is_closed.load(Ordering::SeqCst)); - println!("First call made"); - - // Subsequent calls should fail immediately - let result = session.list_tools().await; - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Session is closed")); - } - - #[tokio::test] - async fn test_process_termination_terminates_session() { - let transport = MockTransport { - error_mode: ErrorMode::ProcessTermination, - }; - - let (read_stream, write_stream) = transport.connect().await.unwrap(); - let session = Session::new(read_stream, write_stream).await.unwrap(); - - // Try to make an RPC call - should fail due to process termination - let result = session.list_resources().await; - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Child process terminated")); - - // Verify session is marked as closed - assert!(session.is_closed.load(Ordering::SeqCst)); - - // Subsequent calls should fail immediately - let result = session.list_tools().await; - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Session is closed")); - } - - #[tokio::test] - async fn test_session_cleanup_on_drop() { - let transport = MockTransport { - error_mode: ErrorMode::ProcessTermination, - }; - - let (read_stream, write_stream) = transport.connect().await.unwrap(); - let session = Session::new(read_stream, write_stream).await.unwrap(); - - // Get a clone of the background task handle - let background_task = session.background_task.clone(); - - // Drop the session - drop(session); - - // Verify that the background task completes - let timeout_result = timeout(Duration::from_secs(1), async { - if let Some(task) = background_task.lock().await.take() { - task.await.unwrap(); - } - }) - .await; - - assert!(timeout_result.is_ok(), "Background task did not complete"); - } - - #[tokio::test] - async fn test_explicit_shutdown() -> Result<()> { - let transport = MockTransport { - error_mode: ErrorMode::Nil, - }; - - let (read_stream, write_stream) = transport.connect().await?; - let session = Session::new(read_stream, write_stream).await?; - - // Verify we can make calls before shutdown - let _: serde_json::Value = session.rpc_call("someMethod", Some(json!({}))).await?; - - // Shutdown the session - session.shutdown().await?; - - // Verify calls fail after shutdown - let result = session.list_resources().await; - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Session is closed")); - - Ok(()) - } -} diff --git a/crates/mcp-client/src/sse_transport.rs b/crates/mcp-client/src/sse_transport.rs deleted file mode 100644 index bc5ea885224b..000000000000 --- a/crates/mcp-client/src/sse_transport.rs +++ /dev/null @@ -1,229 +0,0 @@ -use crate::transport::{ReadStream, Transport, WriteStream}; -use anyhow::{anyhow, Context, Result}; -use async_trait::async_trait; -use futures_util::StreamExt; -use mcp_core::protocol::JsonRpcMessage; -use reqwest::{Client, Url}; -use reqwest_eventsource::{Event, EventSource}; -use std::sync::Arc; -use tokio::sync::{mpsc, Mutex}; -use tokio_retry::{ - strategy::{jitter, ExponentialBackoff}, - Retry, -}; -use tracing::{debug, error, info, warn}; - -pub struct SseTransportParams { - pub url: String, - pub headers: Option, -} - -pub struct SseTransport { - pub params: SseTransportParams, -} - -// Helper function to send a POST request with retry logic -async fn send_with_retry( - client: &Client, - endpoint: &str, - json: serde_json::Value, -) -> Result { - // Create retry strategy with exponential backoff - let retry_strategy = ExponentialBackoff::from_millis(100) // Start with 100ms - .factor(2) // Double the delay each time - .map(jitter) // Add randomness to prevent thundering herd - .take(3); // Maximum of 3 retries (4 attempts total) - - Retry::spawn(retry_strategy, || async { - let response = client.post(endpoint).json(&json).send().await?; - - // If we get a 5xx error or specific connection errors, we should retry - if response.status().is_server_error() - || matches!(response.error_for_status_ref(), Err(e) if e.is_connect()) - { - return Err(anyhow!("Server error: {}", response.status())); - } - - Ok(response) - }) - .await -} - -#[async_trait] -impl Transport for SseTransport { - async fn connect(&self) -> Result<(ReadStream, WriteStream)> { - info!("Connecting to SSE endpoint: {}", self.params.url); - let (tx_read, rx_read) = mpsc::channel(100); - let (tx_write, mut rx_write) = mpsc::channel(100); - - let client = Client::new(); - let base_url = Url::parse(&self.params.url).context("Failed to parse SSE URL")?; - - // Create the event source request - let mut request_builder = client.get(base_url.clone()); - if let Some(headers) = &self.params.headers { - request_builder = headers - .iter() - .fold(request_builder, |req, (key, value)| req.header(key, value)); - } - - let event_source = EventSource::new(request_builder)?; - let client_for_post = client.clone(); - - // Shared state for the endpoint URL - let endpoint_url = Arc::new(Mutex::new(None::)); - let endpoint_url_reader = endpoint_url.clone(); - - // Spawn the SSE reader task - tokio::spawn({ - let tx_read = tx_read.clone(); - let base_url = base_url.clone(); - async move { - info!("Starting SSE reader task"); - let mut stream = event_source; - let mut got_endpoint = false; - - while let Some(event) = stream.next().await { - match event { - Ok(Event::Open) => { - info!("SSE connection opened"); - } - Ok(Event::Message(message)) => { - debug!("Received SSE event: {} - {}", message.event, message.data); - match message.event.as_str() { - "endpoint" => { - // Handle endpoint event - let endpoint = message.data; - info!("Received endpoint URL: {}", endpoint); - - // Join with base URL if relative - let endpoint_url_full = if endpoint.starts_with('/') { - match base_url.join(&endpoint) { - Ok(url) => url, - Err(e) => { - error!("Failed to join endpoint URL: {}", e); - let _ = tx_read.send(Err(e.into())).await; - break; - } - } - } else { - match Url::parse(&endpoint) { - Ok(url) => url, - Err(e) => { - error!("Failed to parse endpoint URL: {}", e); - let _ = tx_read.send(Err(e.into())).await; - break; - } - } - }; - - // Validate endpoint URL has same origin (scheme and host) - if base_url.scheme() != endpoint_url_full.scheme() - || base_url.host_str() != endpoint_url_full.host_str() - || base_url.port() != endpoint_url_full.port() - { - let error = format!( - "Endpoint origin does not match connection origin: {}", - endpoint_url_full - ); - error!("{}", error); - let _ = tx_read.send(Err(anyhow!(error))).await; - break; - } - - let endpoint_str = endpoint_url_full.to_string(); - info!("Using full endpoint URL: {}", endpoint_str); - let mut endpoint_guard = endpoint_url.lock().await; - *endpoint_guard = Some(endpoint_str); - got_endpoint = true; - debug!("Endpoint URL set successfully"); - } - "message" => { - if !got_endpoint { - warn!("Received message before endpoint URL"); - continue; - } - // Handle message event - match serde_json::from_str::(&message.data) { - Ok(msg) => { - debug!("Received server message: {:?}", msg); - if tx_read.send(Ok(msg)).await.is_err() { - error!("Failed to send message to read channel"); - break; - } - } - Err(e) => { - error!("Error parsing server message: {}", e); - if tx_read.send(Err(e.into())).await.is_err() { - error!("Failed to send error to read channel"); - break; - } - } - } - } - _ => { - debug!("Ignoring unknown event type: {}", message.event); - } - } - } - Err(e) => { - error!("SSE error: {}", e); - let _ = tx_read.send(Err(e.into())).await; - break; - } - } - } - info!("SSE reader task ended"); - } - }); - - // Spawn the writer task - tokio::spawn(async move { - info!("Starting writer task"); - // Wait for the endpoint URL before processing messages - let mut endpoint = None; - while endpoint.is_none() { - let guard = endpoint_url_reader.lock().await; - if let Some(url) = guard.as_ref() { - endpoint = Some(url.clone()); - break; - } - drop(guard); - debug!("Waiting for endpoint URL..."); - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - } - - let endpoint = endpoint.unwrap(); - info!("Starting post writer with endpoint URL: {}", endpoint); - - while let Some(message) = rx_write.recv().await { - match serde_json::to_value(&message) { - Ok(json) => { - debug!("Sending client message: {:?}", json); - match send_with_retry(&client_for_post, &endpoint, json).await { - Ok(response) => { - if !response.status().is_success() { - let status = response.status(); - let text = response.text().await.unwrap_or_default(); - error!("Server returned error status {}: {}", status, text); - } else { - debug!("Message sent successfully: {}", response.status()); - } - } - Err(e) => { - error!("Failed to send message after retries: {}", e); - } - } - } - Err(e) => { - error!("Failed to serialize message: {}", e); - } - } - } - info!("Writer task ended"); - }); - - info!("SSE transport connected"); - Ok((rx_read, tx_write)) - } -} diff --git a/crates/mcp-client/src/stdio_transport.rs b/crates/mcp-client/src/stdio_transport.rs deleted file mode 100644 index 67f0fbf12053..000000000000 --- a/crates/mcp-client/src/stdio_transport.rs +++ /dev/null @@ -1,198 +0,0 @@ -use crate::transport::{ReadStream, Transport, WriteStream}; -use anyhow::{anyhow, Context, Result}; -use async_trait::async_trait; -use mcp_core::protocol::*; -use std::process::Stdio; -use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; -use tokio::process::{Child, Command}; -use tokio::sync::mpsc; - -pub struct StdioServerParams { - pub command: String, - pub args: Vec, - pub env: Option>, -} - -pub struct StdioTransport { - pub params: StdioServerParams, -} - -impl StdioTransport { - fn get_default_environment() -> std::collections::HashMap { - let default_vars = if cfg!(windows) { - vec!["APPDATA", "PATH", "TEMP", "USERNAME"] // Simplified list - } else { - vec!["HOME", "PATH", "SHELL", "USER"] // Simplified list - }; - - std::env::vars() - .filter(|(key, value)| default_vars.contains(&key.as_str()) && !value.starts_with("()")) - .collect() - } - - async fn monitor_child(mut child: Child, tx_read: mpsc::Sender>) { - match child.wait().await { - Ok(status) => { - let msg = if status.success() { - format!("Child process terminated normally with status: {}", status) - } else { - format!("Child process terminated with error status: {}", status) - }; - let _ = tx_read.send(Err(anyhow!(msg))).await; - } - Err(e) => { - let _ = tx_read - .send(Err(anyhow!("Child process error: {}", e))) - .await; - } - } - } -} - -#[async_trait] -impl Transport for StdioTransport { - async fn connect(&self) -> Result<(ReadStream, WriteStream)> { - let mut child = Command::new(&self.params.command) - .args(&self.params.args) - .env_clear() - .envs( - self.params - .env - .clone() - .unwrap_or_else(Self::get_default_environment), - ) - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::inherit()) - .spawn() - .context("Failed to spawn child process")?; - - let stdin = child.stdin.take().context("Failed to get stdin handle")?; - let stdout = child.stdout.take().context("Failed to get stdout handle")?; - - let (tx_read, rx_read) = mpsc::channel(100); - let (tx_write, mut rx_write) = mpsc::channel(100); - - // Clone tx_read for the child monitor - let tx_read_monitor = tx_read.clone(); - - // Spawn child process monitor - tokio::spawn(Self::monitor_child(child, tx_read_monitor)); - - // Spawn stdout reader task - let stdout_reader = BufReader::new(stdout); - tokio::spawn(async move { - let mut lines = stdout_reader.lines(); - while let Ok(Some(line)) = lines.next_line().await { - match serde_json::from_str::(&line) { - Ok(msg) => { - if tx_read.send(Ok(msg)).await.is_err() { - break; - } - } - Err(e) => { - let _ = tx_read.send(Err(e.into())).await; - } - } - } - }); - - // Spawn stdin writer task - let mut stdin = stdin; - tokio::spawn(async move { - while let Some(message) = rx_write.recv().await { - let json = serde_json::to_string(&message).expect("Failed to serialize message"); - if stdin - .write_all(format!("{}\n", json).as_bytes()) - .await - .is_err() - { - break; - } - } - }); - - Ok((rx_read, tx_write)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - use std::time::Duration; - use tokio::time::timeout; - - #[tokio::test] - async fn test_stdio_transport() { - let transport = StdioTransport { - params: StdioServerParams { - command: "tee".to_string(), // tee will echo back what it receives - args: vec![], - env: None, - }, - }; - - let (mut rx, tx) = transport.connect().await.unwrap(); - - // Create test messages - let request = JsonRpcMessage::Request(JsonRpcRequest { - jsonrpc: "2.0".to_string(), - id: Some(1), - method: "ping".to_string(), - params: None, - }); - - let response = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id: Some(2), - result: Some(json!({})), - error: None, - }); - - // Send messages - tx.send(request.clone()).await.unwrap(); - tx.send(response.clone()).await.unwrap(); - - // Receive and verify messages - let mut read_messages = Vec::new(); - - // Use timeout to avoid hanging if messages aren't received - for _ in 0..2 { - match timeout(Duration::from_secs(1), rx.recv()).await { - Ok(Some(Ok(msg))) => read_messages.push(msg), - Ok(Some(Err(e))) => panic!("Received error: {}", e), - Ok(None) => break, - Err(_) => panic!("Timeout waiting for message"), - } - } - - assert_eq!(read_messages.len(), 2, "Expected 2 messages"); - assert_eq!(read_messages[0], request); - assert_eq!(read_messages[1], response); - } - - #[tokio::test] - async fn test_process_termination() { - let transport = StdioTransport { - params: StdioServerParams { - command: "sleep".to_string(), - args: vec!["0.3".to_string()], - env: None, - }, - }; - let (mut rx, _tx) = transport.connect().await.unwrap(); - - // Try to receive a message - should get an error about process termination - match timeout(Duration::from_secs(1), rx.recv()).await { - Ok(Some(Err(e))) => { - assert!( - e.to_string().contains("Child process terminated normally"), - "Expected process termination error, got: {}", - e - ); - } - _ => panic!("Expected error, got a different message"), - } - } -} diff --git a/crates/mcp-client/src/transport.rs b/crates/mcp-client/src/transport.rs deleted file mode 100644 index 2ccca05e92c1..000000000000 --- a/crates/mcp-client/src/transport.rs +++ /dev/null @@ -1,14 +0,0 @@ -use anyhow::Result; -use async_trait::async_trait; -use mcp_core::protocol::JsonRpcMessage; -use tokio::sync::mpsc::{Receiver, Sender}; - -// Stream types for consistent interface -pub type ReadStream = Receiver>; -pub type WriteStream = Sender; - -// Common trait for transport implementations -#[async_trait] -pub trait Transport { - async fn connect(&self) -> Result<(ReadStream, WriteStream)>; -} diff --git a/crates/mcp-client/src/transport/mod.rs b/crates/mcp-client/src/transport/mod.rs new file mode 100644 index 000000000000..55e17d10a3fc --- /dev/null +++ b/crates/mcp-client/src/transport/mod.rs @@ -0,0 +1,42 @@ +use async_trait::async_trait; +use thiserror::Error; + +/// A generic error type for transport operations. +#[derive(Debug, Error)] +pub enum Error { + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + + #[error("Transport was not connected or is already closed")] + NotConnected, + + #[error("Unexpected transport error: {0}")] + Other(String), +} + +/// A generic asynchronous transport trait. +/// +/// Implementations are expected to handle: +/// - starting the underlying communication channel (e.g., launching a child process, connecting a socket) +/// - sending JSON-RPC messages as strings +/// - receiving JSON-RPC messages as strings +/// - closing the transport cleanly +#[async_trait] +pub trait Transport: Send + 'static { + /// Start the transport and establish the underlying connection. + async fn start(&mut self) -> Result<(), Error>; + + /// Send a raw JSON-encoded message through the transport. + async fn send(&mut self, msg: String) -> Result<(), Error>; + + /// Receive a raw JSON-encoded message from the transport. + /// + /// This should return a single line representing one JSON message. + async fn receive(&mut self) -> Result; + + /// Close the transport and free any resources. + async fn close(&mut self) -> Result<(), Error>; +} + +pub mod stdio; +pub use stdio::StdioTransport; diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs new file mode 100644 index 000000000000..cb223b4f66c2 --- /dev/null +++ b/crates/mcp-client/src/transport/stdio.rs @@ -0,0 +1,94 @@ +use super::{Error, Transport}; +use async_trait::async_trait; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::process::{Child, ChildStdin, ChildStdout, Command}; + +/// A `StdioTransport` uses a child process’s stdin/stdout as a communication channel. +/// +/// It starts the specified command with arguments and uses its stdin/stdout to send/receive +/// JSON-RPC messages line by line. This is useful for running MCP servers as subprocesses. +pub struct StdioTransport { + command: String, + args: Vec, + child: Option, + stdin: Option, + stdout: Option>, +} + +impl StdioTransport { + /// Create a new `StdioTransport` configured to run the given command with arguments. + /// + /// The transport will not start until `start()` is called. + pub fn new(command: S, args: I) -> Self + where + S: Into, + I: IntoIterator, + { + Self { + command: command.into(), + args: args.into_iter().map(Into::into).collect(), + child: None, + stdin: None, + stdout: None, + } + } +} + +#[async_trait] +impl Transport for StdioTransport { + async fn start(&mut self) -> Result<(), Error> { + if self.child.is_some() { + return Ok(()); // Already started + } + + let mut cmd = Command::new(&self.command); + cmd.args(&self.args) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::inherit()); + + let mut child = cmd.spawn()?; + + let stdin = child.stdin.take().ok_or(Error::NotConnected)?; + let stdout = child.stdout.take().ok_or(Error::NotConnected)?; + + self.stdin = Some(stdin); + self.stdout = Some(BufReader::new(stdout)); + self.child = Some(child); + + Ok(()) + } + + async fn send(&mut self, msg: String) -> Result<(), Error> { + let stdin = self.stdin.as_mut().ok_or(Error::NotConnected)?; + // Write the message followed by a newline + stdin.write_all(msg.as_bytes()).await?; + stdin.write_all(b"\n").await?; + stdin.flush().await?; + Ok(()) + } + + async fn receive(&mut self) -> Result { + let stdout = self.stdout.as_mut().ok_or(Error::NotConnected)?; + let mut line = String::new(); + let n = stdout.read_line(&mut line).await?; + if n == 0 { + // End of stream + return Err(Error::NotConnected); + } + Ok(line) + } + + async fn close(&mut self) -> Result<(), Error> { + // Drop stdin to signal EOF + self.stdin.take(); + self.stdout.take(); + + if let Some(mut child) = self.child.take() { + // Wait for child to exit + let _status = child.wait().await?; + } + + Ok(()) + } +} diff --git a/crates/mcp-core/src/tool.rs b/crates/mcp-core/src/tool.rs index 6401b9632671..adb99ce12fca 100644 --- a/crates/mcp-core/src/tool.rs +++ b/crates/mcp-core/src/tool.rs @@ -5,6 +5,7 @@ use serde_json::Value; /// A tool that can be used by a model. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] pub struct Tool { /// The name of the tool pub name: String, From 9ac0f5ee42f7c022edccde4adfb1f2f03d5f1c38 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Tue, 17 Dec 2024 00:15:22 -0500 Subject: [PATCH 02/21] working: send initialized notification during initialization --- crates/mcp-client/src/client.rs | 60 ++++++++++++++++-------- crates/mcp-client/src/main.rs | 42 ++++------------- crates/mcp-client/src/service.rs | 7 ++- crates/mcp-client/src/transport/mod.rs | 33 +++++++++++-- crates/mcp-client/src/transport/stdio.rs | 47 +++++++++++-------- crates/mcp-core/src/protocol.rs | 1 + 6 files changed, 114 insertions(+), 76 deletions(-) diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index 0cdab4391dd4..3cb33277c3a6 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -5,8 +5,12 @@ use tower::ServiceExt; // for Service::ready() use mcp_core::protocol::{ InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, - ListResourcesResult, ReadResourceResult, + ListResourcesResult, ListToolsResult, ReadResourceResult, }; +use std::sync::Arc; +use tokio::sync::Mutex; + +use crate::transport::{Error as TransportError, Transport}; /// Error type for MCP client operations. #[derive(Debug, Error)] @@ -48,12 +52,13 @@ pub struct InitializeParams { } /// The MCP client that sends requests via the provided service. -pub struct McpClient { +pub struct McpClient { service: S, + transport: Arc>, next_id: u64, } -impl McpClient +impl McpClient where S: tower::Service< JsonRpcRequest, @@ -61,18 +66,20 @@ where Error = super::service::ServiceError, > + Send, S::Future: Send, + T: Transport, { - pub fn new(service: S) -> Self { + pub fn new(service: S, transport: Arc>) -> Self { Self { service, + transport, next_id: 1, } } /// Send a JSON-RPC request and wait for a response. - async fn send_message(&mut self, method: &str, params: Value) -> Result + async fn send_message(&mut self, method: &str, params: Value) -> Result where - T: for<'de> Deserialize<'de>, + R: for<'de> Deserialize<'de>, { self.service.ready().await.map_err(|_| Error::NotReady)?; @@ -122,17 +129,21 @@ where } } - // /// Send a JSON-RPC notification. - // pub async fn send_notification(&self, method: &str, params: Value) -> Result<(), Error> { - // let notification = mcp_core::protocol::JsonRpcNotification { - // jsonrpc: "2.0".to_string(), - // method: method.to_string(), - // params: Some(params), - // }; - // let msg = serde_json::to_string(¬ification)?; - // let mut transport = self.transport.lock().await; - // transport.send(msg).await - // } + /// Send a JSON-RPC notification. + pub async fn send_notification(&self, method: &str, params: Value) -> Result<(), Error> { + let notification = mcp_core::protocol::JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: method.to_string(), + params: Some(params), + }; + let msg = serde_json::to_string(¬ification)?; + let transport = self.transport.lock().await; + // transport.send(msg).await + transport + .send(msg) + .await + .map_err(|e: TransportError| Error::Service(e.into())) + } /// Initialize the connection with the server. pub async fn initialize( @@ -145,8 +156,14 @@ where client_info: info, capabilities, }; - self.send_message("initialize", serde_json::to_value(params)?) - .await + let result: InitializeResult = self + .send_message("initialize", serde_json::to_value(params)?) + .await?; + + self.send_notification("notifications/initialized", serde_json::json!({})) + .await?; + + Ok(result) } /// List available resources. @@ -160,4 +177,9 @@ where let params = serde_json::json!({ "uri": uri }); self.send_message("resources/read", params).await } + + /// List tools + pub async fn list_tools(&mut self) -> Result { + self.send_message("tools/list", serde_json::json!({})).await + } } diff --git a/crates/mcp-client/src/main.rs b/crates/mcp-client/src/main.rs index bd4490fd0543..cdf86fe9b3c4 100644 --- a/crates/mcp-client/src/main.rs +++ b/crates/mcp-client/src/main.rs @@ -1,29 +1,11 @@ use anyhow::Result; use mcp_client::client::{ClientCapabilities, ClientInfo, Error as ClientError, McpClient}; use mcp_client::{service::TransportService, transport::StdioTransport}; +use std::sync::Arc; +use tokio::sync::Mutex; use tower::ServiceBuilder; use tracing_subscriber::EnvFilter; -// use mcp_client::{ -// service::{ServiceError}, -// transport::{Error as TransportError}, -// }; -// use std::time::Duration; -// use tower::timeout::error::Elapsed; - -// fn convert_box_error(err: Box) -> ServiceError { -// if let Some(elapsed) = err.downcast_ref::() { -// ServiceError::Transport(TransportError::Io( -// std::io::Error::new( -// std::io::ErrorKind::TimedOut, -// format!("Timeout elapsed: {}", elapsed), -// ), -// )) -// } else { -// ServiceError::Other(err.to_string()) -// } -// } - #[tokio::main] async fn main() -> Result<(), ClientError> { // Initialize logging @@ -35,14 +17,14 @@ async fn main() -> Result<(), ClientError> { ) .init(); - // Create the base transport - let transport = StdioTransport::new("uvx", ["mcp-server-git"]); + // Create the base transport as Arc> + let transport = Arc::new(Mutex::new(StdioTransport::new("uvx", ["mcp-server-git"]))); // Build service with middleware - let service = ServiceBuilder::new().service(TransportService::new(transport)); + let service = ServiceBuilder::new().service(TransportService::new(Arc::clone(&transport))); // Create client - let mut client = McpClient::new(service); + let mut client = McpClient::new(service, Arc::clone(&transport)); // Initialize let server_info = client @@ -54,15 +36,11 @@ async fn main() -> Result<(), ClientError> { ClientCapabilities::default(), ) .await?; - println!("Connected to server: {server_info:?}"); - - // List resources - let resources = client.list_resources().await?; - println!("Available resources: {resources:?}"); + println!("Connected to server: {server_info:?}\n"); - // Read a resource - let content = client.read_resource("file:///example.txt".into()).await?; - println!("Content: {content:?}"); + // List tools + let tools = client.list_tools().await?; + println!("Available tools: {tools:?}"); Ok(()) } diff --git a/crates/mcp-client/src/service.rs b/crates/mcp-client/src/service.rs index f422bc9244d5..6748f896abdb 100644 --- a/crates/mcp-client/src/service.rs +++ b/crates/mcp-client/src/service.rs @@ -37,6 +37,11 @@ impl TransportService { initialized: AtomicBool::new(false), } } + + /// Provides a clone of the transport handle for external access (e.g., for sending notifications). + pub fn get_transport_handle(&self) -> Arc> { + Arc::clone(&self.transport) + } } impl Service for TransportService { @@ -54,7 +59,7 @@ impl Service for TransportService { let started = self.initialized.load(Ordering::SeqCst); Box::pin(async move { - let mut transport = transport.lock().await; + let transport = transport.lock().await; // Initialize (start) transport on the first call. if !started { diff --git a/crates/mcp-client/src/transport/mod.rs b/crates/mcp-client/src/transport/mod.rs index 55e17d10a3fc..f346af833b97 100644 --- a/crates/mcp-client/src/transport/mod.rs +++ b/crates/mcp-client/src/transport/mod.rs @@ -1,5 +1,7 @@ use async_trait::async_trait; +use std::sync::Arc; use thiserror::Error; +use tokio::sync::Mutex; /// A generic error type for transport operations. #[derive(Debug, Error)] @@ -24,18 +26,41 @@ pub enum Error { #[async_trait] pub trait Transport: Send + 'static { /// Start the transport and establish the underlying connection. - async fn start(&mut self) -> Result<(), Error>; + async fn start(&self) -> Result<(), Error>; /// Send a raw JSON-encoded message through the transport. - async fn send(&mut self, msg: String) -> Result<(), Error>; + async fn send(&self, msg: String) -> Result<(), Error>; /// Receive a raw JSON-encoded message from the transport. /// /// This should return a single line representing one JSON message. - async fn receive(&mut self) -> Result; + async fn receive(&self) -> Result; /// Close the transport and free any resources. - async fn close(&mut self) -> Result<(), Error>; + async fn close(&self) -> Result<(), Error>; +} + +#[async_trait] +impl Transport for Arc> { + async fn start(&self) -> Result<(), Error> { + let transport = self.lock().await; + transport.start().await + } + + async fn send(&self, msg: String) -> Result<(), Error> { + let transport = self.lock().await; + transport.send(msg).await + } + + async fn receive(&self) -> Result { + let transport = self.lock().await; + transport.receive().await + } + + async fn close(&self) -> Result<(), Error> { + let transport = self.lock().await; + transport.close().await + } } pub mod stdio; diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs index cb223b4f66c2..cac9275745fb 100644 --- a/crates/mcp-client/src/transport/stdio.rs +++ b/crates/mcp-client/src/transport/stdio.rs @@ -2,6 +2,7 @@ use super::{Error, Transport}; use async_trait::async_trait; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use tokio::process::{Child, ChildStdin, ChildStdout, Command}; +use tokio::sync::Mutex; /// A `StdioTransport` uses a child process’s stdin/stdout as a communication channel. /// @@ -10,9 +11,9 @@ use tokio::process::{Child, ChildStdin, ChildStdout, Command}; pub struct StdioTransport { command: String, args: Vec, - child: Option, - stdin: Option, - stdout: Option>, + child: Mutex>, + stdin: Mutex>, + stdout: Mutex>>, } impl StdioTransport { @@ -27,17 +28,17 @@ impl StdioTransport { Self { command: command.into(), args: args.into_iter().map(Into::into).collect(), - child: None, - stdin: None, - stdout: None, + child: Mutex::new(None), + stdin: Mutex::new(None), + stdout: Mutex::new(None), } } } #[async_trait] impl Transport for StdioTransport { - async fn start(&mut self) -> Result<(), Error> { - if self.child.is_some() { + async fn start(&self) -> Result<(), Error> { + if self.child.lock().await.is_some() { return Ok(()); // Already started } @@ -52,15 +53,16 @@ impl Transport for StdioTransport { let stdin = child.stdin.take().ok_or(Error::NotConnected)?; let stdout = child.stdout.take().ok_or(Error::NotConnected)?; - self.stdin = Some(stdin); - self.stdout = Some(BufReader::new(stdout)); - self.child = Some(child); + *self.stdin.lock().await = Some(stdin); + *self.stdout.lock().await = Some(BufReader::new(stdout)); + *self.child.lock().await = Some(child); Ok(()) } - async fn send(&mut self, msg: String) -> Result<(), Error> { - let stdin = self.stdin.as_mut().ok_or(Error::NotConnected)?; + async fn send(&self, msg: String) -> Result<(), Error> { + let mut stdin = self.stdin.lock().await; + let stdin = stdin.as_mut().ok_or(Error::NotConnected)?; // Write the message followed by a newline stdin.write_all(msg.as_bytes()).await?; stdin.write_all(b"\n").await?; @@ -68,8 +70,9 @@ impl Transport for StdioTransport { Ok(()) } - async fn receive(&mut self) -> Result { - let stdout = self.stdout.as_mut().ok_or(Error::NotConnected)?; + async fn receive(&self) -> Result { + let mut stdout = self.stdout.lock().await; + let stdout = stdout.as_mut().ok_or(Error::NotConnected)?; let mut line = String::new(); let n = stdout.read_line(&mut line).await?; if n == 0 { @@ -79,14 +82,18 @@ impl Transport for StdioTransport { Ok(line) } - async fn close(&mut self) -> Result<(), Error> { + async fn close(&self) -> Result<(), Error> { + let mut child = self.child.lock().await; + let mut stdin = self.stdin.lock().await; + let mut stdout = self.stdout.lock().await; + // Drop stdin to signal EOF - self.stdin.take(); - self.stdout.take(); + *stdin = None; + *stdout = None; - if let Some(mut child) = self.child.take() { + if let Some(mut c) = child.take() { // Wait for child to exit - let _status = child.wait().await?; + let _status = c.wait().await?; } Ok(()) diff --git a/crates/mcp-core/src/protocol.rs b/crates/mcp-core/src/protocol.rs index 259050a20354..4402ff135d96 100644 --- a/crates/mcp-core/src/protocol.rs +++ b/crates/mcp-core/src/protocol.rs @@ -23,6 +23,7 @@ pub struct JsonRpcResponse { pub struct JsonRpcNotification { pub jsonrpc: String, pub method: String, + #[serde(skip_serializing_if = "Option::is_none")] pub params: Option, } From a9c51dbfc98abf315f8262e304d34f8670e8ceac Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Tue, 17 Dec 2024 00:16:07 -0500 Subject: [PATCH 03/21] update README --- crates/mcp-client/README.md | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/crates/mcp-client/README.md b/crates/mcp-client/README.md index 32e0c4c32fb8..395b4ba93327 100644 --- a/crates/mcp-client/README.md +++ b/crates/mcp-client/README.md @@ -1,13 +1,5 @@ ## Testing stdio ```bash -cargo run -p mcp_client -- --mode git -cargo run -p mcp_client -- --mode echo - -cargo run -p mcp_client --bin stdio +cargo run -p mcp-client ``` - -## Testing SSE - -1. Start the MCP server: `fastmcp run -t sse echo.py` -2. Run the client: `cargo run -p mcp_client --bin sse` From 58abcf5b73b7e845541c7f8b744fb70597c60987 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Tue, 17 Dec 2024 08:23:22 -0500 Subject: [PATCH 04/21] implement Drop trait to close transport when its out of scope --- crates/mcp-client/src/service.rs | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/crates/mcp-client/src/service.rs b/crates/mcp-client/src/service.rs index 6748f896abdb..98e3bb8476a4 100644 --- a/crates/mcp-client/src/service.rs +++ b/crates/mcp-client/src/service.rs @@ -25,7 +25,7 @@ pub enum ServiceError { } /// A Tower `Service` implementation that uses a `Transport` to send/receive JsonRpcRequests and JsonRpcMessages. -pub struct TransportService { +pub struct TransportService { transport: Arc>, initialized: AtomicBool, } @@ -71,9 +71,20 @@ impl Service for TransportService { transport.send(msg).await?; let line = transport.receive().await?; - let response_msg: JsonRpcMessage = serde_json::from_str(&line)?; + let response: JsonRpcMessage = serde_json::from_str(&line)?; - Ok(response_msg) + Ok(response) }) } } + +impl Drop for TransportService { + fn drop(&mut self) { + if self.initialized.load(Ordering::SeqCst) { + // Create a new runtime for cleanup if needed + let rt = tokio::runtime::Runtime::new().unwrap(); + let transport = rt.block_on(self.transport.lock()); + let _ = rt.block_on(transport.close()); + } + } +} From 6ee7c8216911f696390bd6f3f5bcf73101d2c953 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Tue, 17 Dec 2024 08:50:17 -0500 Subject: [PATCH 05/21] add timeout middleware to the service --- crates/mcp-client/src/client.rs | 1 - crates/mcp-client/src/main.rs | 22 ++++++++++++++++++---- crates/mcp-client/src/service.rs | 3 +++ 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index 3cb33277c3a6..d568cf80fdb3 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -138,7 +138,6 @@ where }; let msg = serde_json::to_string(¬ification)?; let transport = self.transport.lock().await; - // transport.send(msg).await transport .send(msg) .await diff --git a/crates/mcp-client/src/main.rs b/crates/mcp-client/src/main.rs index cdf86fe9b3c4..90408350526f 100644 --- a/crates/mcp-client/src/main.rs +++ b/crates/mcp-client/src/main.rs @@ -1,9 +1,14 @@ use anyhow::Result; use mcp_client::client::{ClientCapabilities, ClientInfo, Error as ClientError, McpClient}; -use mcp_client::{service::TransportService, transport::StdioTransport}; +use mcp_client::{ + service::{ServiceError, TransportService}, + transport::StdioTransport, +}; use std::sync::Arc; +use std::time::Duration; use tokio::sync::Mutex; -use tower::ServiceBuilder; +use tower::timeout::TimeoutLayer; +use tower::{ServiceBuilder, ServiceExt}; use tracing_subscriber::EnvFilter; #[tokio::main] @@ -20,8 +25,17 @@ async fn main() -> Result<(), ClientError> { // Create the base transport as Arc> let transport = Arc::new(Mutex::new(StdioTransport::new("uvx", ["mcp-server-git"]))); - // Build service with middleware - let service = ServiceBuilder::new().service(TransportService::new(Arc::clone(&transport))); + // Build service with middleware including timeout + let service = ServiceBuilder::new() + .layer(TimeoutLayer::new(Duration::from_secs(30))) + .service(TransportService::new(Arc::clone(&transport))) + .map_err(|e: Box| { + if e.is::() { + ServiceError::Timeout(tower::timeout::error::Elapsed::new()) + } else { + ServiceError::Other(e.to_string()) + } + }); // Create client let mut client = McpClient::new(service, Arc::clone(&transport)); diff --git a/crates/mcp-client/src/service.rs b/crates/mcp-client/src/service.rs index 98e3bb8476a4..83a9dd67c1ef 100644 --- a/crates/mcp-client/src/service.rs +++ b/crates/mcp-client/src/service.rs @@ -17,6 +17,9 @@ pub enum ServiceError { #[error("Serialization error: {0}")] Serialization(#[from] serde_json::Error), + #[error("Request timed out")] + Timeout(#[from] tower::timeout::error::Elapsed), + #[error("Other error: {0}")] Other(String), From 1a7dde6f14592047ff3d1f2498a816128f02a468 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Tue, 17 Dec 2024 10:45:30 -0500 Subject: [PATCH 06/21] add call_tool method, move to example dir --- crates/mcp-client/README.md | 4 ++-- crates/mcp-client/{src/main.rs => examples/stdio.rs} | 8 +++++++- crates/mcp-client/src/client.rs | 9 +++++++-- 3 files changed, 16 insertions(+), 5 deletions(-) rename crates/mcp-client/{src/main.rs => examples/stdio.rs} (87%) diff --git a/crates/mcp-client/README.md b/crates/mcp-client/README.md index 395b4ba93327..05559abf7568 100644 --- a/crates/mcp-client/README.md +++ b/crates/mcp-client/README.md @@ -1,5 +1,5 @@ -## Testing stdio +## Testing ```bash -cargo run -p mcp-client +cargo run -p mcp-client --example stdio ``` diff --git a/crates/mcp-client/src/main.rs b/crates/mcp-client/examples/stdio.rs similarity index 87% rename from crates/mcp-client/src/main.rs rename to crates/mcp-client/examples/stdio.rs index 90408350526f..04f6d7f65b9f 100644 --- a/crates/mcp-client/src/main.rs +++ b/crates/mcp-client/examples/stdio.rs @@ -54,7 +54,13 @@ async fn main() -> Result<(), ClientError> { // List tools let tools = client.list_tools().await?; - println!("Available tools: {tools:?}"); + println!("Available tools: {tools:?}\n"); + + // Call tool 'git_status' wtih arguments = {"repo_path": "."} + let tool_result = client + .call_tool("git_status", serde_json::json!({ "repo_path": "." })) + .await?; + println!("Tool result: {tool_result:?}"); Ok(()) } diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index d568cf80fdb3..e864e140a468 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -4,8 +4,7 @@ use thiserror::Error; use tower::ServiceExt; // for Service::ready() use mcp_core::protocol::{ - InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, - ListResourcesResult, ListToolsResult, ReadResourceResult, + CallToolResult, InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, ListResourcesResult, ListToolsResult, ReadResourceResult }; use std::sync::Arc; use tokio::sync::Mutex; @@ -181,4 +180,10 @@ where pub async fn list_tools(&mut self) -> Result { self.send_message("tools/list", serde_json::json!({})).await } + + // Call tool + pub async fn call_tool(&mut self, name: &str, arguments: Value) -> Result { + let params = serde_json::json!({ "name": name, "arguments": arguments }); + self.send_message("tools/call", params).await + } } From 0faf107780e65411b45a10a736ad8ec8808bae8e Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Tue, 17 Dec 2024 10:56:05 -0500 Subject: [PATCH 07/21] working: add SSE transport and example --- crates/mcp-client/README.md | 8 +- crates/mcp-client/examples/sse.rs | 73 ++++++++++++ crates/mcp-client/src/client.rs | 9 +- crates/mcp-client/src/transport/mod.rs | 12 ++ crates/mcp-client/src/transport/sse.rs | 155 +++++++++++++++++++++++++ 5 files changed, 254 insertions(+), 3 deletions(-) create mode 100644 crates/mcp-client/examples/sse.rs create mode 100644 crates/mcp-client/src/transport/sse.rs diff --git a/crates/mcp-client/README.md b/crates/mcp-client/README.md index 05559abf7568..a43c4c21002a 100644 --- a/crates/mcp-client/README.md +++ b/crates/mcp-client/README.md @@ -1,5 +1,11 @@ -## Testing +## Testing stdio transport ```bash cargo run -p mcp-client --example stdio ``` + +## Testing SSE transport + +1. Start the MCP server in one terminal: `fastmcp run -t sse echo.py` +2. Run the client example in new terminal: `cargo run -p mcp-client --example sse` + diff --git a/crates/mcp-client/examples/sse.rs b/crates/mcp-client/examples/sse.rs new file mode 100644 index 000000000000..d26aeb665ed1 --- /dev/null +++ b/crates/mcp-client/examples/sse.rs @@ -0,0 +1,73 @@ +use anyhow::Result; +use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient}; +use mcp_client::{ + service::{ServiceError, TransportService}, + transport::SseTransport, +}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::Mutex; +use tower::timeout::TimeoutLayer; +use tower::{ServiceBuilder, ServiceExt}; +use tracing_subscriber::EnvFilter; + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize logging + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::from_default_env() + .add_directive("mcp_client=debug".parse().unwrap()) + .add_directive("reqwest_eventsource=debug".parse().unwrap()), + ) + .init(); + + // Create the base transport as Arc> + let transport = Arc::new(Mutex::new(SseTransport::new("http://localhost:8000/sse")?)); + + // Build service with middleware including timeout + let service = ServiceBuilder::new() + .layer(TimeoutLayer::new(Duration::from_secs(30))) + .service(TransportService::new(Arc::clone(&transport))) + .map_err(|e: Box| { + if e.is::() { + ServiceError::Timeout(tower::timeout::error::Elapsed::new()) + } else { + ServiceError::Other(e.to_string()) + } + }); + + // Create client + let mut client = McpClient::new(service, Arc::clone(&transport)); + println!("Client created\n"); + + // Initialize + let server_info = client + .initialize( + ClientInfo { + name: "test-client".into(), + version: "1.0.0".into(), + }, + ClientCapabilities::default(), + ) + .await?; + println!("Connected to server: {server_info:?}\n"); + + // Sleep for 100ms to allow the server to start - surprisingly this is required! + tokio::time::sleep(Duration::from_millis(100)).await; + + // List tools + let tools = client.list_tools().await?; + println!("Available tools: {tools:?}\n"); + + // Call tool + let tool_result = client + .call_tool( + "echo_tool", + serde_json::json!({ "message": "Client with SSE transport - calling a tool" }), + ) + .await?; + println!("Tool result: {tool_result:?}"); + + Ok(()) +} diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index e864e140a468..38ecf6c886bd 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -4,7 +4,8 @@ use thiserror::Error; use tower::ServiceExt; // for Service::ready() use mcp_core::protocol::{ - CallToolResult, InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, ListResourcesResult, ListToolsResult, ReadResourceResult + CallToolResult, InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcRequest, + JsonRpcResponse, ListResourcesResult, ListToolsResult, ReadResourceResult, }; use std::sync::Arc; use tokio::sync::Mutex; @@ -182,7 +183,11 @@ where } // Call tool - pub async fn call_tool(&mut self, name: &str, arguments: Value) -> Result { + pub async fn call_tool( + &mut self, + name: &str, + arguments: Value, + ) -> Result { let params = serde_json::json!({ "name": name, "arguments": arguments }); self.send_message("tools/call", params).await } diff --git a/crates/mcp-client/src/transport/mod.rs b/crates/mcp-client/src/transport/mod.rs index f346af833b97..12da4003a951 100644 --- a/crates/mcp-client/src/transport/mod.rs +++ b/crates/mcp-client/src/transport/mod.rs @@ -12,6 +12,15 @@ pub enum Error { #[error("Transport was not connected or is already closed")] NotConnected, + #[error("Invalid URL provided")] + InvalidUrl, + + #[error("Connection timeout")] + Timeout, + + #[error("Failed to send message")] + SendFailed, + #[error("Unexpected transport error: {0}")] Other(String), } @@ -65,3 +74,6 @@ impl Transport for Arc> { pub mod stdio; pub use stdio::StdioTransport; + +pub mod sse; +pub use sse::SseTransport; diff --git a/crates/mcp-client/src/transport/sse.rs b/crates/mcp-client/src/transport/sse.rs new file mode 100644 index 000000000000..7945197f3b86 --- /dev/null +++ b/crates/mcp-client/src/transport/sse.rs @@ -0,0 +1,155 @@ +use super::{Error, Transport}; +use async_trait::async_trait; +use futures_util::StreamExt; +use reqwest::{Client, Url}; +use reqwest_eventsource::{Event, EventSource}; +use std::sync::Arc; +use tokio::sync::{mpsc, Mutex}; +use tracing::{debug, error, info}; + +pub struct SseTransport { + connection_url: Url, + endpoint: Arc>>, + http_client: Client, + event_source: Arc>>, + message_rx: Arc>>>, + message_tx: mpsc::Sender, +} + +impl SseTransport { + pub fn new(url: &str) -> Result { + let (message_tx, message_rx) = mpsc::channel(100); + + Ok(Self { + connection_url: Url::parse(url).map_err(|_| Error::InvalidUrl)?, + endpoint: Arc::new(Mutex::new(None)), + http_client: Client::new(), + event_source: Arc::new(Mutex::new(None)), + message_rx: Arc::new(Mutex::new(Some(message_rx))), + message_tx, + }) + } +} + +/// Constructs the endpoint URL by removing "/sse" from the connection URL +/// and appending the given suffix. +fn construct_endpoint_url(base_url: &Url, url_suffix: &str) -> Result { + let trimmed_base = base_url.as_str().trim_end_matches("/sse"); + let trimmed_base = trimmed_base.trim_end_matches('/'); + let trimmed_suffix = url_suffix.trim_start_matches('/'); + let full_url = format!("{}/{}", trimmed_base, trimmed_suffix); + Url::parse(&full_url) +} + +#[async_trait] +impl Transport for SseTransport { + async fn start(&self) -> Result<(), Error> { + if self.event_source.lock().await.is_some() { + return Ok(()); + } + + let event_source = EventSource::get(self.connection_url.as_str()); + let message_tx = self.message_tx.clone(); + let endpoint = self.endpoint.clone(); + + // Store event source + *self.event_source.lock().await = Some(event_source); + + // Create a new event source for the task + let mut stream = EventSource::get(self.connection_url.as_str()); + + let connection_url = self.connection_url.clone(); + let cloned_connection_url = connection_url.clone(); + + // Spawn a task to handle incoming events + tokio::spawn(async move { + while let Some(event) = stream.next().await { + match event { + Ok(Event::Open) => { + // Connection established + info!("\nSSE connection opened"); + } + Ok(Event::Message(message)) => { + debug!("Received SSE event: {} - {}", message.event, message.data); + // Check if this is an endpoint event + if message.event == "endpoint" { + let url_suffix = &message.data; + debug!("Received endpoint URL suffix: {}", url_suffix); + match construct_endpoint_url(&cloned_connection_url, url_suffix) { + Ok(url) => { + info!("Endpoint URL: {}", url); + let mut endpoint_guard = endpoint.lock().await; + *endpoint_guard = Some(url); + } + Err(e) => { + error!("Failed to construct endpoint URL: {}", e); + // Optionally, handle the error (e.g., retry, notify, etc.) + } + } + } else { + // Regular message + // Assuming message.data is the message payload + if let Err(e) = message_tx.send(message.data).await { + error!("Failed to send message: {}", e); + } + } + } + Err(e) => { + error!("EventSource error: {}", e); + break; + } + } + } + }); + + // Wait for endpoint URL: every 100ms, check if the endpoint is set upto 30s timeout + let timeout = tokio::time::sleep(std::time::Duration::from_secs(30)); + tokio::pin!(timeout); + + loop { + tokio::select! { + _ = &mut timeout => { + return Err(Error::Timeout); + } + _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => { + let endpoint_guard = self.endpoint.lock().await; + if endpoint_guard.is_some() { + break; + } + } + } + } + + Ok(()) + } + + async fn send(&self, msg: String) -> Result<(), Error> { + let endpoint = { + let endpoint_guard = self.endpoint.lock().await; + endpoint_guard.as_ref().ok_or(Error::NotConnected)?.clone() + }; + + self.http_client + .post(endpoint) + .header("Content-Type", "application/json") + .body(msg) + .send() + .await + .map_err(|_| Error::SendFailed)?; + + Ok(()) + } + + async fn receive(&self) -> Result { + let mut rx_guard = self.message_rx.lock().await; + let rx = rx_guard.as_mut().ok_or(Error::NotConnected)?; + + rx.recv().await.ok_or(Error::NotConnected) + } + + async fn close(&self) -> Result<(), Error> { + *self.event_source.lock().await = None; + *self.endpoint.lock().await = None; + Ok(()) + } +} From c073d19461cd39768d6dd60c8a611dfb1f71d8d7 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Wed, 18 Dec 2024 15:43:47 -0500 Subject: [PATCH 08/21] Remove transport field in McpClient and let Service handle json rpc msgs --- crates/mcp-client/examples/sse.rs | 2 +- crates/mcp-client/examples/stdio.rs | 2 +- crates/mcp-client/src/client.rs | 40 +++++++++++---------------- crates/mcp-client/src/service.rs | 42 ++++++++++++++++++++--------- 4 files changed, 48 insertions(+), 38 deletions(-) diff --git a/crates/mcp-client/examples/sse.rs b/crates/mcp-client/examples/sse.rs index d26aeb665ed1..89d3560f28b2 100644 --- a/crates/mcp-client/examples/sse.rs +++ b/crates/mcp-client/examples/sse.rs @@ -38,7 +38,7 @@ async fn main() -> Result<()> { }); // Create client - let mut client = McpClient::new(service, Arc::clone(&transport)); + let mut client = McpClient::new(service); println!("Client created\n"); // Initialize diff --git a/crates/mcp-client/examples/stdio.rs b/crates/mcp-client/examples/stdio.rs index 04f6d7f65b9f..3b26af0d5122 100644 --- a/crates/mcp-client/examples/stdio.rs +++ b/crates/mcp-client/examples/stdio.rs @@ -38,7 +38,7 @@ async fn main() -> Result<(), ClientError> { }); // Create client - let mut client = McpClient::new(service, Arc::clone(&transport)); + let mut client = McpClient::new(service); // Initialize let server_info = client diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index 38ecf6c886bd..28a7a506ec6a 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -4,13 +4,9 @@ use thiserror::Error; use tower::ServiceExt; // for Service::ready() use mcp_core::protocol::{ - CallToolResult, InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcRequest, - JsonRpcResponse, ListResourcesResult, ListToolsResult, ReadResourceResult, + CallToolResult, InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcNotification, + JsonRpcRequest, JsonRpcResponse, ListResourcesResult, ListToolsResult, ReadResourceResult, }; -use std::sync::Arc; -use tokio::sync::Mutex; - -use crate::transport::{Error as TransportError, Transport}; /// Error type for MCP client operations. #[derive(Debug, Error)] @@ -52,26 +48,23 @@ pub struct InitializeParams { } /// The MCP client that sends requests via the provided service. -pub struct McpClient { +pub struct McpClient { service: S, - transport: Arc>, next_id: u64, } -impl McpClient +impl McpClient where S: tower::Service< - JsonRpcRequest, + JsonRpcMessage, Response = JsonRpcMessage, Error = super::service::ServiceError, > + Send, S::Future: Send, - T: Transport, { - pub fn new(service: S, transport: Arc>) -> Self { + pub fn new(service: S) -> Self { Self { service, - transport, next_id: 1, } } @@ -83,12 +76,12 @@ where { self.service.ready().await.map_err(|_| Error::NotReady)?; - let request = JsonRpcRequest { + let request = JsonRpcMessage::Request(JsonRpcRequest { jsonrpc: "2.0".to_string(), id: Some(self.next_id), method: method.to_string(), params: Some(params), - }; + }); self.next_id += 1; @@ -130,18 +123,17 @@ where } /// Send a JSON-RPC notification. - pub async fn send_notification(&self, method: &str, params: Value) -> Result<(), Error> { - let notification = mcp_core::protocol::JsonRpcNotification { + pub async fn send_notification(&mut self, method: &str, params: Value) -> Result<(), Error> { + self.service.ready().await.map_err(|_| Error::NotReady)?; + + let notification = JsonRpcMessage::Notification(JsonRpcNotification { jsonrpc: "2.0".to_string(), method: method.to_string(), params: Some(params), - }; - let msg = serde_json::to_string(¬ification)?; - let transport = self.transport.lock().await; - transport - .send(msg) - .await - .map_err(|e: TransportError| Error::Service(e.into())) + }); + + self.service.call(notification).await?; + Ok(()) } /// Initialize the connection with the server. diff --git a/crates/mcp-client/src/service.rs b/crates/mcp-client/src/service.rs index 83a9dd67c1ef..589f6681cee6 100644 --- a/crates/mcp-client/src/service.rs +++ b/crates/mcp-client/src/service.rs @@ -7,7 +7,7 @@ use tokio::sync::Mutex; use tower::Service; use crate::transport::{Error as TransportError, Transport}; -use mcp_core::protocol::{JsonRpcMessage, JsonRpcRequest}; +use mcp_core::protocol::{JsonRpcMessage, JsonRpcResponse}; #[derive(Debug, thiserror::Error)] pub enum ServiceError { @@ -27,7 +27,7 @@ pub enum ServiceError { UnexpectedResponse, } -/// A Tower `Service` implementation that uses a `Transport` to send/receive JsonRpcRequests and JsonRpcMessages. +/// A Tower `Service` implementation that uses a `Transport` to send/receive JsonRpcMessages and JsonRpcMessages. pub struct TransportService { transport: Arc>, initialized: AtomicBool, @@ -47,7 +47,7 @@ impl TransportService { } } -impl Service for TransportService { +impl Service for TransportService { type Response = JsonRpcMessage; type Error = ServiceError; type Future = Pin> + Send>>; @@ -57,7 +57,7 @@ impl Service for TransportService { Poll::Ready(Ok(())) } - fn call(&mut self, request: JsonRpcRequest) -> Self::Future { + fn call(&mut self, message: JsonRpcMessage) -> Self::Future { let transport = Arc::clone(&self.transport); let started = self.initialized.load(Ordering::SeqCst); @@ -69,14 +69,32 @@ impl Service for TransportService { transport.start().await?; } - // Serialize request to JSON line - let msg = serde_json::to_string(&request)?; - transport.send(msg).await?; - - let line = transport.receive().await?; - let response: JsonRpcMessage = serde_json::from_str(&line)?; - - Ok(response) + match message { + JsonRpcMessage::Notification(notification) => { + // Serialize notification + let msg = serde_json::to_string(¬ification)?; + transport.send(msg).await?; + // For notifications, the protocol does not require a response + // So we return an empty response here and this is not checked upstream + let response: JsonRpcMessage = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: None, + result: None, + error: None, + }); + + Ok(response) + } + JsonRpcMessage::Request(request) => { + // Serialize request & wait for response + let msg = serde_json::to_string(&request)?; + transport.send(msg).await?; + let line = transport.receive().await?; + let response: JsonRpcMessage = serde_json::from_str(&line)?; + Ok(response) + } + _ => return Err(ServiceError::Other("Invalid message type".to_string())), + } }) } } From 0a0efc32474c87c52732ae61e4e7f44105a5cfa3 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Wed, 18 Dec 2024 17:01:58 -0500 Subject: [PATCH 09/21] add example to create collection of clients --- crates/mcp-client/examples/clients.rs | 81 +++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 crates/mcp-client/examples/clients.rs diff --git a/crates/mcp-client/examples/clients.rs b/crates/mcp-client/examples/clients.rs new file mode 100644 index 000000000000..d6e2c6afc873 --- /dev/null +++ b/crates/mcp-client/examples/clients.rs @@ -0,0 +1,81 @@ +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::Mutex; + +use mcp_client::{ + client::{ClientCapabilities, ClientInfo, McpClient}, + service::{ServiceError, TransportService}, + transport::StdioTransport, +}; +use tower::{ServiceBuilder, ServiceExt}; +use tower::timeout::TimeoutLayer; +use tracing_subscriber::EnvFilter; +use tower::util::BoxService; +use mcp_core::protocol::JsonRpcMessage; + +// Define a type alias for the boxed service using BoxService +type BoxedService = BoxService; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Initialize logging + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::from_default_env() + .add_directive("mcp_client=debug".parse().unwrap()) + .add_directive("reqwest_eventsource=debug".parse().unwrap()), + ) + .init(); + + // Create two separate clients with stdio transport + let client1 = create_client("client1", "1.0.0")?; + let client2 = create_client("client2", "1.0.0")?; + + // Initialize both clients + let mut clients: Vec> = vec![client1, client2]; + + // Initialize all clients + for (i, client) in clients.iter_mut().enumerate() { + let info = ClientInfo { + name: format!("example-client-{}", i + 1), + version: "1.0.0".to_string(), + }; + let capabilities = ClientCapabilities::default(); + + println!("\nInitializing client {}", i + 1); + let init_result = client.initialize(info, capabilities).await?; + println!("Client {} initialized: {:?}", i + 1, init_result); + } + + // List tools for each client + for (i, client) in clients.iter_mut().enumerate() { + let tools = client.list_tools().await?; + println!("\nClient {} tools: {:?}", i + 1, tools); + } + + Ok(()) +} + +fn create_client( + _name: &str, + _version: &str, +) -> Result>, Box> { + // Create the transport + let transport = Arc::new(Mutex::new(StdioTransport::new("uvx", ["mcp-server-git"]))); + + // Build service with middleware including timeout + let service = ServiceBuilder::new() + .layer(TimeoutLayer::new(Duration::from_secs(30))) + .service(TransportService::new(Arc::clone(&transport))) + .map_err(|e: Box| { + if e.is::() { + ServiceError::Timeout(tower::timeout::error::Elapsed::new()) + } else { + ServiceError::Other(e.to_string()) + } + }) + .boxed(); // Box the service to create a BoxService + + // Create the client + Ok(McpClient::new(service)) +} From 40d5beb53a73e836fcc8d9237e943cb688c0f27c Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Thu, 19 Dec 2024 10:42:16 -0500 Subject: [PATCH 10/21] make McpClient a trait and current version (McpClientImpl) an implementation * checks out the changes from 'kalvin/mcp-client-trait' branch Co-authored-by: kalvinnchau --- crates/mcp-client/examples/clients.rs | 50 +++++++++++++++------ crates/mcp-client/examples/sse.rs | 4 +- crates/mcp-client/examples/stdio.rs | 6 ++- crates/mcp-client/src/client.rs | 62 +++++++++++++++++++-------- 4 files changed, 86 insertions(+), 36 deletions(-) diff --git a/crates/mcp-client/examples/clients.rs b/crates/mcp-client/examples/clients.rs index d6e2c6afc873..43d31c4c283c 100644 --- a/crates/mcp-client/examples/clients.rs +++ b/crates/mcp-client/examples/clients.rs @@ -3,18 +3,15 @@ use std::time::Duration; use tokio::sync::Mutex; use mcp_client::{ - client::{ClientCapabilities, ClientInfo, McpClient}, + client::{ClientCapabilities, ClientInfo, McpClient, McpClientImpl}, service::{ServiceError, TransportService}, - transport::StdioTransport, + transport::{SseTransport, StdioTransport}, }; -use tower::{ServiceBuilder, ServiceExt}; +use mcp_core::protocol::JsonRpcMessage; use tower::timeout::TimeoutLayer; -use tracing_subscriber::EnvFilter; use tower::util::BoxService; -use mcp_core::protocol::JsonRpcMessage; - -// Define a type alias for the boxed service using BoxService -type BoxedService = BoxService; +use tower::{ServiceBuilder, ServiceExt}; +use tracing_subscriber::EnvFilter; #[tokio::main] async fn main() -> Result<(), Box> { @@ -30,9 +27,13 @@ async fn main() -> Result<(), Box> { // Create two separate clients with stdio transport let client1 = create_client("client1", "1.0.0")?; let client2 = create_client("client2", "1.0.0")?; + let client3 = create_sse_client("client3", "1.0.0")?; // Initialize both clients - let mut clients: Vec> = vec![client1, client2]; + let mut clients: Vec> = Vec::new(); + clients.push(client1); + clients.push(client2); + clients.push(client3); // Initialize all clients for (i, client) in clients.iter_mut().enumerate() { @@ -59,7 +60,7 @@ async fn main() -> Result<(), Box> { fn create_client( _name: &str, _version: &str, -) -> Result>, Box> { +) -> Result, Box> { // Create the transport let transport = Arc::new(Mutex::new(StdioTransport::new("uvx", ["mcp-server-git"]))); @@ -73,9 +74,30 @@ fn create_client( } else { ServiceError::Other(e.to_string()) } - }) - .boxed(); // Box the service to create a BoxService + }); + + Ok(Box::new(McpClientImpl::new(service))) +} + +fn create_sse_client( + _name: &str, + _version: &str, +) -> Result, Box> { + let transport = Arc::new(Mutex::new( + SseTransport::new("http://localhost:8000/sse").unwrap(), + )); + + // Build service with middleware including timeout + let service = ServiceBuilder::new() + .layer(TimeoutLayer::new(Duration::from_secs(30))) + .service(TransportService::new(Arc::clone(&transport))) + .map_err(|e: Box| { + if e.is::() { + ServiceError::Timeout(tower::timeout::error::Elapsed::new()) + } else { + ServiceError::Other(e.to_string()) + } + }); - // Create the client - Ok(McpClient::new(service)) + Ok(Box::new(McpClientImpl::new(service))) } diff --git a/crates/mcp-client/examples/sse.rs b/crates/mcp-client/examples/sse.rs index 89d3560f28b2..3d7e570d4415 100644 --- a/crates/mcp-client/examples/sse.rs +++ b/crates/mcp-client/examples/sse.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient}; +use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientImpl}; use mcp_client::{ service::{ServiceError, TransportService}, transport::SseTransport, @@ -38,7 +38,7 @@ async fn main() -> Result<()> { }); // Create client - let mut client = McpClient::new(service); + let mut client = McpClientImpl::new(service); println!("Client created\n"); // Initialize diff --git a/crates/mcp-client/examples/stdio.rs b/crates/mcp-client/examples/stdio.rs index 3b26af0d5122..7300512ddafb 100644 --- a/crates/mcp-client/examples/stdio.rs +++ b/crates/mcp-client/examples/stdio.rs @@ -1,5 +1,7 @@ use anyhow::Result; -use mcp_client::client::{ClientCapabilities, ClientInfo, Error as ClientError, McpClient}; +use mcp_client::client::{ + ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientImpl, +}; use mcp_client::{ service::{ServiceError, TransportService}, transport::StdioTransport, @@ -38,7 +40,7 @@ async fn main() -> Result<(), ClientError> { }); // Create client - let mut client = McpClient::new(service); + let mut client = McpClientImpl::new(service); // Initialize let server_info = client diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index 28a7a506ec6a..270385f4e4d6 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -47,13 +47,36 @@ pub struct InitializeParams { pub client_info: ClientInfo, } -/// The MCP client that sends requests via the provided service. -pub struct McpClient { +/// The MCP client trait defining the interface for MCP operations. +#[async_trait::async_trait] +pub trait McpClient { + /// Initialize the connection with the server. + async fn initialize( + &mut self, + info: ClientInfo, + capabilities: ClientCapabilities, + ) -> Result; + + /// List available resources. + async fn list_resources(&mut self) -> Result; + + /// Read a resource's content. + async fn read_resource(&mut self, uri: &str) -> Result; + + /// List available tools. + async fn list_tools(&mut self) -> Result; + + /// Call a specific tool with arguments. + async fn call_tool(&mut self, name: &str, arguments: Value) -> Result; +} + +/// Standard implementation of the MCP client that sends requests via the provided service. +pub struct McpClientImpl { service: S, next_id: u64, } -impl McpClient +impl McpClientImpl where S: tower::Service< JsonRpcMessage, @@ -123,7 +146,7 @@ where } /// Send a JSON-RPC notification. - pub async fn send_notification(&mut self, method: &str, params: Value) -> Result<(), Error> { + async fn send_notification(&mut self, method: &str, params: Value) -> Result<(), Error> { self.service.ready().await.map_err(|_| Error::NotReady)?; let notification = JsonRpcMessage::Notification(JsonRpcNotification { @@ -135,9 +158,20 @@ where self.service.call(notification).await?; Ok(()) } +} - /// Initialize the connection with the server. - pub async fn initialize( +#[async_trait::async_trait] +impl McpClient for McpClientImpl +where + S: tower::Service< + JsonRpcMessage, + Response = JsonRpcMessage, + Error = super::service::ServiceError, + > + Send + + Sync, + S::Future: Send, +{ + async fn initialize( &mut self, info: ClientInfo, capabilities: ClientCapabilities, @@ -157,29 +191,21 @@ where Ok(result) } - /// List available resources. - pub async fn list_resources(&mut self) -> Result { + async fn list_resources(&mut self) -> Result { self.send_message("resources/list", serde_json::json!({})) .await } - /// Read a resource's content. - pub async fn read_resource(&mut self, uri: &str) -> Result { + async fn read_resource(&mut self, uri: &str) -> Result { let params = serde_json::json!({ "uri": uri }); self.send_message("resources/read", params).await } - /// List tools - pub async fn list_tools(&mut self) -> Result { + async fn list_tools(&mut self) -> Result { self.send_message("tools/list", serde_json::json!({})).await } - // Call tool - pub async fn call_tool( - &mut self, - name: &str, - arguments: Value, - ) -> Result { + async fn call_tool(&mut self, name: &str, arguments: Value) -> Result { let params = serde_json::json!({ "name": name, "arguments": arguments }); self.send_message("tools/call", params).await } From 67ea15e85f0b068dbcf687d66bc9496332e8757a Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Thu, 19 Dec 2024 10:46:47 -0500 Subject: [PATCH 11/21] Add JsonRpcMessage::Nil for responding to notifications --- crates/mcp-client/examples/clients.rs | 2 -- crates/mcp-client/src/service.rs | 11 ++--------- crates/mcp-core/src/protocol.rs | 1 + 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/crates/mcp-client/examples/clients.rs b/crates/mcp-client/examples/clients.rs index 43d31c4c283c..df8fac9c526e 100644 --- a/crates/mcp-client/examples/clients.rs +++ b/crates/mcp-client/examples/clients.rs @@ -7,9 +7,7 @@ use mcp_client::{ service::{ServiceError, TransportService}, transport::{SseTransport, StdioTransport}, }; -use mcp_core::protocol::JsonRpcMessage; use tower::timeout::TimeoutLayer; -use tower::util::BoxService; use tower::{ServiceBuilder, ServiceExt}; use tracing_subscriber::EnvFilter; diff --git a/crates/mcp-client/src/service.rs b/crates/mcp-client/src/service.rs index 589f6681cee6..76b081720c6c 100644 --- a/crates/mcp-client/src/service.rs +++ b/crates/mcp-client/src/service.rs @@ -7,7 +7,7 @@ use tokio::sync::Mutex; use tower::Service; use crate::transport::{Error as TransportError, Transport}; -use mcp_core::protocol::{JsonRpcMessage, JsonRpcResponse}; +use mcp_core::protocol::JsonRpcMessage; #[derive(Debug, thiserror::Error)] pub enum ServiceError { @@ -76,14 +76,7 @@ impl Service for TransportService { transport.send(msg).await?; // For notifications, the protocol does not require a response // So we return an empty response here and this is not checked upstream - let response: JsonRpcMessage = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id: None, - result: None, - error: None, - }); - - Ok(response) + Ok(JsonRpcMessage::Nil) } JsonRpcMessage::Request(request) => { // Serialize request & wait for response diff --git a/crates/mcp-core/src/protocol.rs b/crates/mcp-core/src/protocol.rs index 4402ff135d96..87f846fe82ab 100644 --- a/crates/mcp-core/src/protocol.rs +++ b/crates/mcp-core/src/protocol.rs @@ -41,6 +41,7 @@ pub enum JsonRpcMessage { Response(JsonRpcResponse), Notification(JsonRpcNotification), Error(JsonRpcError), + Nil, // used to respond to notifications } // Standard JSON-RPC error codes From d3d0a3b17fd3f4229f1c03ecbbe1d03872cb3464 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Thu, 19 Dec 2024 12:38:51 -0500 Subject: [PATCH 12/21] stage: transport only connects; message router uses channels to send/receive msgs --- crates/mcp-client/src/service.rs | 78 ++++--- crates/mcp-client/src/transport/mod.rs | 105 +++++---- crates/mcp-client/src/transport/sse.rs | 258 +++++++++++++---------- crates/mcp-client/src/transport/stdio.rs | 196 +++++++++++------ 4 files changed, 386 insertions(+), 251 deletions(-) diff --git a/crates/mcp-client/src/service.rs b/crates/mcp-client/src/service.rs index 76b081720c6c..cac16e3e17d0 100644 --- a/crates/mcp-client/src/service.rs +++ b/crates/mcp-client/src/service.rs @@ -3,10 +3,10 @@ use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::task::{Context, Poll}; -use tokio::sync::Mutex; +use tokio::sync::{mpsc, Mutex, oneshot}; use tower::Service; -use crate::transport::{Error as TransportError, Transport}; +use crate::transport::{Error as TransportError, Transport, TransportMessage, MessageRouter}; use mcp_core::protocol::JsonRpcMessage; #[derive(Debug, thiserror::Error)] @@ -27,30 +27,50 @@ pub enum ServiceError { UnexpectedResponse, } -/// A Tower `Service` implementation that uses a `Transport` to send/receive JsonRpcMessages and JsonRpcMessages. +/// A Tower `Service` implementation that uses a `Transport` to send/receive JsonRpcMessages. pub struct TransportService { - transport: Arc>, + transport: Arc, + router: Arc>>, initialized: AtomicBool, } impl TransportService { pub fn new(transport: T) -> Self { Self { - transport: Arc::new(Mutex::new(transport)), + transport: Arc::new(transport), + router: Arc::new(Mutex::new(None)), initialized: AtomicBool::new(false), } } - /// Provides a clone of the transport handle for external access (e.g., for sending notifications). - pub fn get_transport_handle(&self) -> Arc> { - Arc::clone(&self.transport) + async fn ensure_initialized(&self) -> Result, ServiceError> { + if !self.initialized.load(Ordering::SeqCst) { + let mut router_guard = self.router.lock().await; + + // Double-check after acquiring lock + if !self.initialized.load(Ordering::SeqCst) { + // Start the transport + let transport_tx = self.transport.start().await?; + + // Create shutdown channel + let (shutdown_tx, _shutdown_rx) = mpsc::channel(1); + + // Create and store the router + let router = MessageRouter::new(transport_tx, shutdown_tx); + *router_guard = Some(router); + + self.initialized.store(true, Ordering::SeqCst); + } + } + + Ok(Arc::new(self.router.lock().await.as_ref().unwrap().clone())) } } impl Service for TransportService { type Response = JsonRpcMessage; type Error = ServiceError; - type Future = Pin> + Send>>; + type Future = Pin> + Send>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { // Always ready. We do on-demand initialization in call(). @@ -59,34 +79,24 @@ impl Service for TransportService { fn call(&mut self, message: JsonRpcMessage) -> Self::Future { let transport = Arc::clone(&self.transport); - let started = self.initialized.load(Ordering::SeqCst); + let router = Arc::clone(&self.router); Box::pin(async move { - let transport = transport.lock().await; - - // Initialize (start) transport on the first call. - if !started { - transport.start().await?; - } + let router = match router.lock().await.as_ref() { + Some(router) => router.clone(), + None => return Err(ServiceError::Other("Transport not initialized".to_string())), + }; match message { JsonRpcMessage::Notification(notification) => { - // Serialize notification - let msg = serde_json::to_string(¬ification)?; - transport.send(msg).await?; - // For notifications, the protocol does not require a response - // So we return an empty response here and this is not checked upstream + router.send_notification(JsonRpcMessage::Notification(notification)).await?; Ok(JsonRpcMessage::Nil) } JsonRpcMessage::Request(request) => { - // Serialize request & wait for response - let msg = serde_json::to_string(&request)?; - transport.send(msg).await?; - let line = transport.receive().await?; - let response: JsonRpcMessage = serde_json::from_str(&line)?; - Ok(response) + router.send_request(JsonRpcMessage::Request(request)).await + .map_err(|e| ServiceError::Transport(e)) } - _ => return Err(ServiceError::Other("Invalid message type".to_string())), + _ => Err(ServiceError::Other("Invalid message type".to_string())), } }) } @@ -97,8 +107,14 @@ impl Drop for TransportService { if self.initialized.load(Ordering::SeqCst) { // Create a new runtime for cleanup if needed let rt = tokio::runtime::Runtime::new().unwrap(); - let transport = rt.block_on(self.transport.lock()); - let _ = rt.block_on(transport.close()); + + // Request shutdown through the router if it exists + if let Some(router) = rt.block_on(self.router.lock()).as_ref() { + let _ = rt.block_on(router.shutdown()); + } + + // Close the transport + let _ = rt.block_on(self.transport.close()); } } -} +} \ No newline at end of file diff --git a/crates/mcp-client/src/transport/mod.rs b/crates/mcp-client/src/transport/mod.rs index 12da4003a951..422329b6f5e2 100644 --- a/crates/mcp-client/src/transport/mod.rs +++ b/crates/mcp-client/src/transport/mod.rs @@ -1,7 +1,8 @@ use async_trait::async_trait; +use mcp_core::protocol::JsonRpcMessage; use std::sync::Arc; use thiserror::Error; -use tokio::sync::Mutex; +use tokio::sync::{mpsc, oneshot}; /// A generic error type for transport operations. #[derive(Debug, Error)] @@ -21,59 +22,87 @@ pub enum Error { #[error("Failed to send message")] SendFailed, + #[error("Channel closed")] + ChannelClosed, + #[error("Unexpected transport error: {0}")] Other(String), } -/// A generic asynchronous transport trait. -/// -/// Implementations are expected to handle: -/// - starting the underlying communication channel (e.g., launching a child process, connecting a socket) -/// - sending JSON-RPC messages as strings -/// - receiving JSON-RPC messages as strings -/// - closing the transport cleanly +/// A message that can be sent through the transport +#[derive(Debug)] +pub struct TransportMessage { + /// The JSON-RPC message to send + pub message: JsonRpcMessage, + /// Channel to receive the response on (None for notifications) + pub response_tx: Option>>, +} + +/// A generic asynchronous transport trait with channel-based communication #[async_trait] -pub trait Transport: Send + 'static { +pub trait Transport: Send + Sync + 'static { /// Start the transport and establish the underlying connection. - async fn start(&self) -> Result<(), Error>; - - /// Send a raw JSON-encoded message through the transport. - async fn send(&self, msg: String) -> Result<(), Error>; - - /// Receive a raw JSON-encoded message from the transport. - /// - /// This should return a single line representing one JSON message. - async fn receive(&self) -> Result; + /// Returns channels for sending messages and receiving errors. + async fn start(&self) -> Result, Error>; /// Close the transport and free any resources. async fn close(&self) -> Result<(), Error>; } -#[async_trait] -impl Transport for Arc> { - async fn start(&self) -> Result<(), Error> { - let transport = self.lock().await; - transport.start().await +pub mod stdio; +pub use stdio::StdioTransport; + +pub mod sse; +pub use sse::SseTransport; + +/// A router that handles message distribution for a transport +pub struct MessageRouter { + transport_tx: mpsc::Sender, + shutdown_tx: mpsc::Sender<()>, +} + +impl MessageRouter { + pub fn new( + transport_tx: mpsc::Sender, + shutdown_tx: mpsc::Sender<()>, + ) -> Self { + Self { + transport_tx, + shutdown_tx, + } } - async fn send(&self, msg: String) -> Result<(), Error> { - let transport = self.lock().await; - transport.send(msg).await + /// Send a message and wait for a response + pub async fn send_request(&self, request: JsonRpcMessage) -> Result { + let (response_tx, response_rx) = oneshot::channel(); + + self.transport_tx + .send(TransportMessage { + message: request, + response_tx: Some(response_tx), + }) + .await + .map_err(|_| Error::ChannelClosed)?; + + response_rx.await.map_err(|_| Error::ChannelClosed)? } - async fn receive(&self) -> Result { - let transport = self.lock().await; - transport.receive().await + /// Send a notification (no response expected) + pub async fn send_notification(&self, notification: JsonRpcMessage) -> Result<(), Error> { + self.transport_tx + .send(TransportMessage { + message: notification, + response_tx: None, + }) + .await + .map_err(|_| Error::ChannelClosed) } - async fn close(&self) -> Result<(), Error> { - let transport = self.lock().await; - transport.close().await + /// Request shutdown of the transport + pub async fn shutdown(&self) -> Result<(), Error> { + self.shutdown_tx + .send(()) + .await + .map_err(|_| Error::ChannelClosed) } } - -pub mod stdio; -pub use stdio::StdioTransport; - -pub mod sse; -pub use sse::SseTransport; diff --git a/crates/mcp-client/src/transport/sse.rs b/crates/mcp-client/src/transport/sse.rs index 7945197f3b86..072927600caf 100644 --- a/crates/mcp-client/src/transport/sse.rs +++ b/crates/mcp-client/src/transport/sse.rs @@ -1,155 +1,185 @@ -use super::{Error, Transport}; +use std::sync::Arc; use async_trait::async_trait; use futures_util::StreamExt; -use reqwest::{Client, Url}; -use reqwest_eventsource::{Event, EventSource}; -use std::sync::Arc; -use tokio::sync::{mpsc, Mutex}; -use tracing::{debug, error, info}; +use reqwest::{Client as HttpClient, Url}; +use tokio::sync::{mpsc, Mutex, oneshot}; +use tokio::task::JoinHandle; +use eventsource_client::{Client as EventSourceClient, SSE}; + +use super::{Error, Transport, TransportMessage}; +use mcp_core::protocol::JsonRpcMessage; +/// A transport implementation that uses Server-Sent Events (SSE) for receiving messages +/// and HTTP POST for sending messages. pub struct SseTransport { - connection_url: Url, - endpoint: Arc>>, - http_client: Client, - event_source: Arc>>, - message_rx: Arc>>>, - message_tx: mpsc::Sender, + sse_url: String, + http_client: HttpClient, + post_endpoint: Arc>>, + sse_handle: Arc>>>, + pending_requests: Arc>>>>, } impl SseTransport { - pub fn new(url: &str) -> Result { - let (message_tx, message_rx) = mpsc::channel(100); - - Ok(Self { - connection_url: Url::parse(url).map_err(|_| Error::InvalidUrl)?, - endpoint: Arc::new(Mutex::new(None)), - http_client: Client::new(), - event_source: Arc::new(Mutex::new(None)), - message_rx: Arc::new(Mutex::new(Some(message_rx))), - message_tx, - }) + /// Create a new SSE transport with the given SSE endpoint URL + pub fn new>(sse_url: S) -> Self { + Self { + sse_url: sse_url.into(), + http_client: HttpClient::new(), + post_endpoint: Arc::new(Mutex::new(None)), + sse_handle: Arc::new(Mutex::new(None)), + pending_requests: Arc::new(Mutex::new(std::collections::HashMap::new())), + } } -} -/// Constructs the endpoint URL by removing "/sse" from the connection URL -/// and appending the given suffix. -fn construct_endpoint_url(base_url: &Url, url_suffix: &str) -> Result { - let trimmed_base = base_url.as_str().trim_end_matches("/sse"); - let trimmed_base = trimmed_base.trim_end_matches('/'); - let trimmed_suffix = url_suffix.trim_start_matches('/'); - let full_url = format!("{}/{}", trimmed_base, trimmed_suffix); - Url::parse(&full_url) -} - -#[async_trait] -impl Transport for SseTransport { - async fn start(&self) -> Result<(), Error> { - if self.event_source.lock().await.is_some() { - return Ok(()); + async fn handle_message( + message: JsonRpcMessage, + pending_requests: Arc>>>>, + ) { + if let JsonRpcMessage::Response(response) = &message { + if let Some(tx) = pending_requests.lock().await.remove(&response.id) { + let _ = tx.send(Ok(message)); + } } + } - let event_source = EventSource::get(self.connection_url.as_str()); - let message_tx = self.message_tx.clone(); - let endpoint = self.endpoint.clone(); - - // Store event source - *self.event_source.lock().await = Some(event_source); + async fn process_messages( + mut message_rx: mpsc::Receiver, + http_client: HttpClient, + post_endpoint: Arc>>, + sse_url: String, + pending_requests: Arc>>>>, + ) { + // Set up SSE client + let client = match EventSourceClient::new(&sse_url) { + Ok(client) => client, + Err(e) => { + eprintln!("Failed to create SSE client: {}", e); + return; + } + }; - // Create a new event source for the task - let mut stream = EventSource::get(self.connection_url.as_str()); + let mut stream = client.stream(); - let connection_url = self.connection_url.clone(); - let cloned_connection_url = connection_url.clone(); + // Wait for endpoint event to get POST URL + while let Some(event) = stream.next().await { + match event { + Ok(SSE::Event(event)) if event.event_type == "endpoint" => { + if let Some(data) = event.data { + *post_endpoint.lock().await = Some(data); + break; + } + } + Ok(_) => continue, + Err(e) => { + eprintln!("SSE connection error: {}", e); + return; + } + } + } - // Spawn a task to handle incoming events - tokio::spawn(async move { + // Spawn SSE message handler + let pending_clone = pending_requests.clone(); + let sse_handle = tokio::spawn(async move { while let Some(event) = stream.next().await { match event { - Ok(Event::Open) => { - // Connection established - info!("\nSSE connection opened"); - } - Ok(Event::Message(message)) => { - debug!("Received SSE event: {} - {}", message.event, message.data); - // Check if this is an endpoint event - if message.event == "endpoint" { - let url_suffix = &message.data; - debug!("Received endpoint URL suffix: {}", url_suffix); - match construct_endpoint_url(&cloned_connection_url, url_suffix) { - Ok(url) => { - info!("Endpoint URL: {}", url); - let mut endpoint_guard = endpoint.lock().await; - *endpoint_guard = Some(url); - } - Err(e) => { - error!("Failed to construct endpoint URL: {}", e); - // Optionally, handle the error (e.g., retry, notify, etc.) - } - } - } else { - // Regular message - // Assuming message.data is the message payload - if let Err(e) = message_tx.send(message.data).await { - error!("Failed to send message: {}", e); + Ok(SSE::Event(event)) if event.event_type == "message" => { + if let Some(data) = event.data { + if let Ok(message) = serde_json::from_str::(&data) { + Self::handle_message(message, pending_clone.clone()).await; } } } + Ok(_) => continue, Err(e) => { - error!("EventSource error: {}", e); + eprintln!("SSE message error: {}", e); break; } } } }); - // Wait for endpoint URL: every 100ms, check if the endpoint is set upto 30s timeout - let timeout = tokio::time::sleep(std::time::Duration::from_secs(30)); - tokio::pin!(timeout); + // Process outgoing messages + while let Some(transport_msg) = message_rx.recv().await { + let post_url = match post_endpoint.lock().await.as_ref() { + Some(url) => url.clone(), + None => { + eprintln!("No POST endpoint available"); + continue; + } + }; - loop { - tokio::select! { - _ = &mut timeout => { - return Err(Error::Timeout); + // Store response channel if this is a request + if let Some(response_tx) = transport_msg.response_tx { + if let JsonRpcMessage::Request(request) = &transport_msg.message { + pending_requests.lock().await.insert(request.id.clone(), response_tx); } - _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => { - let endpoint_guard = self.endpoint.lock().await; - if endpoint_guard.is_some() { - break; - } + } + + // Send message via HTTP POST + let message_str = match serde_json::to_string(&transport_msg.message) { + Ok(s) => s, + Err(e) => { + eprintln!("Failed to serialize message: {}", e); + continue; } + }; + + if let Err(e) = http_client + .post(&post_url) + .header("Content-Type", "application/json") + .body(message_str) + .send() + .await + { + eprintln!("Failed to send message: {}", e); } } - Ok(()) + // Clean up + sse_handle.abort(); } +} - async fn send(&self, msg: String) -> Result<(), Error> { - let endpoint = { - let endpoint_guard = self.endpoint.lock().await; - endpoint_guard.as_ref().ok_or(Error::NotConnected)?.clone() - }; - - self.http_client - .post(endpoint) - .header("Content-Type", "application/json") - .body(msg) - .send() - .await - .map_err(|_| Error::SendFailed)?; - - Ok(()) +#[async_trait] +impl Transport for SseTransport { + async fn start(&self) -> Result, Error> { + let (message_tx, message_rx) = mpsc::channel(32); + + let http_client = self.http_client.clone(); + let post_endpoint = self.post_endpoint.clone(); + let sse_url = self.sse_url.clone(); + let pending_requests = self.pending_requests.clone(); + + let handle = tokio::spawn(Self::process_messages( + message_rx, + http_client, + post_endpoint, + sse_url, + pending_requests, + )); + + *self.sse_handle.lock().await = Some(handle); + + Ok(message_tx) } - async fn receive(&self) -> Result { - let mut rx_guard = self.message_rx.lock().await; - let rx = rx_guard.as_mut().ok_or(Error::NotConnected)?; + async fn close(&self) -> Result<(), Error> { + // Abort the SSE handler task + if let Some(handle) = self.sse_handle.lock().await.take() { + handle.abort(); + } - rx.recv().await.ok_or(Error::NotConnected) - } + // Clear any pending requests + self.pending_requests.lock().await.clear(); - async fn close(&self) -> Result<(), Error> { - *self.event_source.lock().await = None; - *self.endpoint.lock().await = None; Ok(()) } } + +impl Drop for SseTransport { + fn drop(&mut self) { + // Create a new runtime for cleanup if needed + let rt = tokio::runtime::Runtime::new().unwrap(); + let _ = rt.block_on(self.close()); + } +} \ No newline at end of file diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs index cac9275745fb..ef8bf86d6ba4 100644 --- a/crates/mcp-client/src/transport/stdio.rs +++ b/crates/mcp-client/src/transport/stdio.rs @@ -1,101 +1,161 @@ -use super::{Error, Transport}; +use std::io::{BufRead, BufReader, Write}; +use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio}; +use std::sync::Arc; + use async_trait::async_trait; -use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; -use tokio::process::{Child, ChildStdin, ChildStdout, Command}; -use tokio::sync::Mutex; +use mcp_core::protocol::JsonRpcMessage; +use tokio::io::AsyncWriteExt; +use tokio::sync::{mpsc, Mutex}; +use tokio::task::JoinHandle; + +use super::{Error, Transport, TransportMessage}; -/// A `StdioTransport` uses a child process’s stdin/stdout as a communication channel. +/// A `StdioTransport` uses a child process's stdin/stdout as a communication channel. /// -/// It starts the specified command with arguments and uses its stdin/stdout to send/receive -/// JSON-RPC messages line by line. This is useful for running MCP servers as subprocesses. +/// It uses channels for message passing and handles responses asynchronously through a background task. pub struct StdioTransport { command: String, args: Vec, - child: Mutex>, - stdin: Mutex>, - stdout: Mutex>>, + process: Arc>>, + stdin: Arc>>, + reader_handle: Arc>>>, + pending_requests: Arc>>>>, } impl StdioTransport { /// Create a new `StdioTransport` configured to run the given command with arguments. - /// - /// The transport will not start until `start()` is called. - pub fn new(command: S, args: I) -> Self - where - S: Into, - I: IntoIterator, - { + pub fn new>(command: S, args: Vec) -> Self { Self { command: command.into(), - args: args.into_iter().map(Into::into).collect(), - child: Mutex::new(None), - stdin: Mutex::new(None), - stdout: Mutex::new(None), + args, + process: Arc::new(Mutex::new(None)), + stdin: Arc::new(Mutex::new(None)), + reader_handle: Arc::new(Mutex::new(None)), + pending_requests: Arc::new(Mutex::new(std::collections::HashMap::new())), } } -} -#[async_trait] -impl Transport for StdioTransport { - async fn start(&self) -> Result<(), Error> { - if self.child.lock().await.is_some() { - return Ok(()); // Already started + async fn spawn_process(&self) -> Result<(ChildStdin, ChildStdout), Error> { + let mut child = Command::new(&self.command) + .args(&self.args) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::inherit()) + .spawn()?; + + let stdin = child.stdin.take().ok_or(Error::Other("Failed to get stdin".into()))?; + let stdout = child.stdout.take().ok_or(Error::Other("Failed to get stdout".into()))?; + + *self.process.lock().await = Some(child); + + Ok((stdin, stdout)) + } + + async fn handle_message( + message: JsonRpcMessage, + pending_requests: Arc>>>>, + ) { + if let JsonRpcMessage::Response(response) = &message { + if let Some(tx) = pending_requests.lock().await.remove(&response.id) { + let _ = tx.send(Ok(message)); + } } + } - let mut cmd = Command::new(&self.command); - cmd.args(&self.args) - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::inherit()); + async fn process_messages( + mut message_rx: mpsc::Receiver, + mut stdin: ChildStdin, + stdout: ChildStdout, + pending_requests: Arc>>>>, + ) { + // Spawn stdout reader task + let pending_clone = pending_requests.clone(); + let reader_handle = tokio::spawn(async move { + let mut reader = BufReader::new(stdout); + let mut line = String::new(); - let mut child = cmd.spawn()?; + loop { + line.clear(); + match reader.read_line(&mut line) { + Ok(0) => break, // EOF + Ok(_) => { + if let Ok(message) = serde_json::from_str::(&line) { + Self::handle_message(message, pending_clone.clone()).await; + } + } + Err(_) => break, + } + } + }); - let stdin = child.stdin.take().ok_or(Error::NotConnected)?; - let stdout = child.stdout.take().ok_or(Error::NotConnected)?; + // Process incoming messages + while let Some(transport_msg) = message_rx.recv().await { + let message_str = serde_json::to_string(&transport_msg.message) + .map_err(|e| Error::Other(format!("Serialization error: {}", e))) + .unwrap_or_default(); - *self.stdin.lock().await = Some(stdin); - *self.stdout.lock().await = Some(BufReader::new(stdout)); - *self.child.lock().await = Some(child); + // Store response channel if this is a request + if let Some(response_tx) = transport_msg.response_tx { + if let JsonRpcMessage::Request(request) = &transport_msg.message { + pending_requests.lock().await.insert(request.id.clone(), response_tx); + } + } - Ok(()) - } + // Write message to stdin + if let Err(e) = stdin.write_all(format!("{}\n", message_str).as_bytes()).await { + eprintln!("Failed to write to stdin: {}", e); + break; + } + } - async fn send(&self, msg: String) -> Result<(), Error> { - let mut stdin = self.stdin.lock().await; - let stdin = stdin.as_mut().ok_or(Error::NotConnected)?; - // Write the message followed by a newline - stdin.write_all(msg.as_bytes()).await?; - stdin.write_all(b"\n").await?; - stdin.flush().await?; - Ok(()) + // Clean up + reader_handle.abort(); } +} - async fn receive(&self) -> Result { - let mut stdout = self.stdout.lock().await; - let stdout = stdout.as_mut().ok_or(Error::NotConnected)?; - let mut line = String::new(); - let n = stdout.read_line(&mut line).await?; - if n == 0 { - // End of stream - return Err(Error::NotConnected); - } - Ok(line) +#[async_trait] +impl Transport for StdioTransport { + async fn start(&self) -> Result, Error> { + let (stdin, stdout) = self.spawn_process().await?; + *self.stdin.lock().await = Some(stdin.try_clone()?); + + let (message_tx, message_rx) = mpsc::channel(32); + + let pending_requests = self.pending_requests.clone(); + let handle = tokio::spawn(Self::process_messages( + message_rx, + stdin, + stdout, + pending_requests, + )); + + *self.reader_handle.lock().await = Some(handle); + + Ok(message_tx) } async fn close(&self) -> Result<(), Error> { - let mut child = self.child.lock().await; - let mut stdin = self.stdin.lock().await; - let mut stdout = self.stdout.lock().await; - - // Drop stdin to signal EOF - *stdin = None; - *stdout = None; + // Kill the process + if let Some(mut process) = self.process.lock().await.take() { + let _ = process.kill(); + } - if let Some(mut c) = child.take() { - // Wait for child to exit - let _status = c.wait().await?; + // Abort the reader task + if let Some(handle) = self.reader_handle.lock().await.take() { + handle.abort(); } + // Clear any pending requests + self.pending_requests.lock().await.clear(); + Ok(()) } } + +impl Drop for StdioTransport { + fn drop(&mut self) { + // Create a new runtime for cleanup if needed + let rt = tokio::runtime::Runtime::new().unwrap(); + let _ = rt.block_on(self.close()); + } +} \ No newline at end of file From 689f6fba97855606de3616510dbf0a7d77cfded2 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Thu, 19 Dec 2024 13:26:09 -0500 Subject: [PATCH 13/21] working but panicking: commented out sse; stdio example working but panics at the end --- crates/mcp-client/Cargo.toml | 5 +- crates/mcp-client/examples/clients.rs | 179 +++++------ crates/mcp-client/examples/sse.rs | 133 ++++---- crates/mcp-client/examples/stdio.rs | 37 +-- crates/mcp-client/src/service.rs | 109 ++++--- crates/mcp-client/src/transport/mod.rs | 6 +- crates/mcp-client/src/transport/sse.rs | 368 +++++++++++------------ crates/mcp-client/src/transport/stdio.rs | 90 ++++-- 8 files changed, 494 insertions(+), 433 deletions(-) diff --git a/crates/mcp-client/Cargo.toml b/crates/mcp-client/Cargo.toml index a666e6522b53..9a330c6e729b 100644 --- a/crates/mcp-client/Cargo.toml +++ b/crates/mcp-client/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" mcp-core = { path = "../mcp-core" } tokio = { version = "1", features = ["full"] } reqwest = { version = "0.11", default-features = false, features = ["json", "stream", "rustls-tls"] } -reqwest-eventsource = "0.5.0" +eventsource-client = "0.12.0" futures-util = "0.3" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" @@ -21,7 +21,8 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] } tokio-retry = "0.3" tower = { version = "0.4", features = ["timeout", "util"] } tower-service = "0.3" +tokio-util = { version = "0.7", features = ["io-util", "io"] } [dev-dependencies] warp = "0.3" -async-stream = "0.3" +async-stream = "0.3" \ No newline at end of file diff --git a/crates/mcp-client/examples/clients.rs b/crates/mcp-client/examples/clients.rs index df8fac9c526e..c3bec5fc2a54 100644 --- a/crates/mcp-client/examples/clients.rs +++ b/crates/mcp-client/examples/clients.rs @@ -1,101 +1,106 @@ -use std::sync::Arc; -use std::time::Duration; -use tokio::sync::Mutex; +// TODO: Remove this +fn main() { + println!("Hello World!"); +} -use mcp_client::{ - client::{ClientCapabilities, ClientInfo, McpClient, McpClientImpl}, - service::{ServiceError, TransportService}, - transport::{SseTransport, StdioTransport}, -}; -use tower::timeout::TimeoutLayer; -use tower::{ServiceBuilder, ServiceExt}; -use tracing_subscriber::EnvFilter; +// use std::sync::Arc; +// use std::time::Duration; +// use tokio::sync::Mutex; -#[tokio::main] -async fn main() -> Result<(), Box> { - // Initialize logging - tracing_subscriber::fmt() - .with_env_filter( - EnvFilter::from_default_env() - .add_directive("mcp_client=debug".parse().unwrap()) - .add_directive("reqwest_eventsource=debug".parse().unwrap()), - ) - .init(); +// use mcp_client::{ +// client::{ClientCapabilities, ClientInfo, McpClient, McpClientImpl}, +// service::{ServiceError, TransportService}, +// transport::{SseTransport, StdioTransport}, +// }; +// use tower::timeout::TimeoutLayer; +// use tower::{ServiceBuilder, ServiceExt}; +// use tracing_subscriber::EnvFilter; - // Create two separate clients with stdio transport - let client1 = create_client("client1", "1.0.0")?; - let client2 = create_client("client2", "1.0.0")?; - let client3 = create_sse_client("client3", "1.0.0")?; +// #[tokio::main] +// async fn main() -> Result<(), Box> { +// // Initialize logging +// tracing_subscriber::fmt() +// .with_env_filter( +// EnvFilter::from_default_env() +// .add_directive("mcp_client=debug".parse().unwrap()) +// .add_directive("reqwest_eventsource=debug".parse().unwrap()), +// ) +// .init(); - // Initialize both clients - let mut clients: Vec> = Vec::new(); - clients.push(client1); - clients.push(client2); - clients.push(client3); +// // Create two separate clients with stdio transport +// let client1 = create_client("client1", "1.0.0")?; +// let client2 = create_client("client2", "1.0.0")?; +// let client3 = create_sse_client("client3", "1.0.0")?; - // Initialize all clients - for (i, client) in clients.iter_mut().enumerate() { - let info = ClientInfo { - name: format!("example-client-{}", i + 1), - version: "1.0.0".to_string(), - }; - let capabilities = ClientCapabilities::default(); +// // Initialize both clients +// let mut clients: Vec> = Vec::new(); +// clients.push(client1); +// clients.push(client2); +// clients.push(client3); - println!("\nInitializing client {}", i + 1); - let init_result = client.initialize(info, capabilities).await?; - println!("Client {} initialized: {:?}", i + 1, init_result); - } +// // Initialize all clients +// for (i, client) in clients.iter_mut().enumerate() { +// let info = ClientInfo { +// name: format!("example-client-{}", i + 1), +// version: "1.0.0".to_string(), +// }; +// let capabilities = ClientCapabilities::default(); - // List tools for each client - for (i, client) in clients.iter_mut().enumerate() { - let tools = client.list_tools().await?; - println!("\nClient {} tools: {:?}", i + 1, tools); - } +// println!("\nInitializing client {}", i + 1); +// let init_result = client.initialize(info, capabilities).await?; +// println!("Client {} initialized: {:?}", i + 1, init_result); +// } - Ok(()) -} +// // List tools for each client +// for (i, client) in clients.iter_mut().enumerate() { +// let tools = client.list_tools().await?; +// println!("\nClient {} tools: {:?}", i + 1, tools); +// } -fn create_client( - _name: &str, - _version: &str, -) -> Result, Box> { - // Create the transport - let transport = Arc::new(Mutex::new(StdioTransport::new("uvx", ["mcp-server-git"]))); +// Ok(()) +// } - // Build service with middleware including timeout - let service = ServiceBuilder::new() - .layer(TimeoutLayer::new(Duration::from_secs(30))) - .service(TransportService::new(Arc::clone(&transport))) - .map_err(|e: Box| { - if e.is::() { - ServiceError::Timeout(tower::timeout::error::Elapsed::new()) - } else { - ServiceError::Other(e.to_string()) - } - }); +// fn create_client( +// _name: &str, +// _version: &str, +// ) -> Result, Box> { +// // Create the transport +// let transport = Arc::new(Mutex::new(StdioTransport::new("uvx", ["mcp-server-git"]))); - Ok(Box::new(McpClientImpl::new(service))) -} +// // Build service with middleware including timeout +// let service = ServiceBuilder::new() +// .layer(TimeoutLayer::new(Duration::from_secs(30))) +// .service(TransportService::new(Arc::clone(&transport))) +// .map_err(|e: Box| { +// if e.is::() { +// ServiceError::Timeout(tower::timeout::error::Elapsed::new()) +// } else { +// ServiceError::Other(e.to_string()) +// } +// }); -fn create_sse_client( - _name: &str, - _version: &str, -) -> Result, Box> { - let transport = Arc::new(Mutex::new( - SseTransport::new("http://localhost:8000/sse").unwrap(), - )); +// Ok(Box::new(McpClientImpl::new(service))) +// } - // Build service with middleware including timeout - let service = ServiceBuilder::new() - .layer(TimeoutLayer::new(Duration::from_secs(30))) - .service(TransportService::new(Arc::clone(&transport))) - .map_err(|e: Box| { - if e.is::() { - ServiceError::Timeout(tower::timeout::error::Elapsed::new()) - } else { - ServiceError::Other(e.to_string()) - } - }); +// fn create_sse_client( +// _name: &str, +// _version: &str, +// ) -> Result, Box> { +// let transport = Arc::new(Mutex::new( +// SseTransport::new("http://localhost:8000/sse").unwrap(), +// )); - Ok(Box::new(McpClientImpl::new(service))) -} +// // Build service with middleware including timeout +// let service = ServiceBuilder::new() +// .layer(TimeoutLayer::new(Duration::from_secs(30))) +// .service(TransportService::new(Arc::clone(&transport))) +// .map_err(|e: Box| { +// if e.is::() { +// ServiceError::Timeout(tower::timeout::error::Elapsed::new()) +// } else { +// ServiceError::Other(e.to_string()) +// } +// }); + +// Ok(Box::new(McpClientImpl::new(service))) +// } diff --git a/crates/mcp-client/examples/sse.rs b/crates/mcp-client/examples/sse.rs index 3d7e570d4415..5b6487a7a2fd 100644 --- a/crates/mcp-client/examples/sse.rs +++ b/crates/mcp-client/examples/sse.rs @@ -1,73 +1,78 @@ -use anyhow::Result; -use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientImpl}; -use mcp_client::{ - service::{ServiceError, TransportService}, - transport::SseTransport, -}; -use std::sync::Arc; -use std::time::Duration; -use tokio::sync::Mutex; -use tower::timeout::TimeoutLayer; -use tower::{ServiceBuilder, ServiceExt}; -use tracing_subscriber::EnvFilter; +// TODO: Remove this +fn main() { + println!("Hello World!"); +} -#[tokio::main] -async fn main() -> Result<()> { - // Initialize logging - tracing_subscriber::fmt() - .with_env_filter( - EnvFilter::from_default_env() - .add_directive("mcp_client=debug".parse().unwrap()) - .add_directive("reqwest_eventsource=debug".parse().unwrap()), - ) - .init(); +// use anyhow::Result; +// use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientImpl}; +// use mcp_client::{ +// service::{ServiceError, TransportService}, +// transport::SseTransport, +// }; +// use std::sync::Arc; +// use std::time::Duration; +// use tokio::sync::Mutex; +// use tower::timeout::TimeoutLayer; +// use tower::{ServiceBuilder, ServiceExt}; +// use tracing_subscriber::EnvFilter; - // Create the base transport as Arc> - let transport = Arc::new(Mutex::new(SseTransport::new("http://localhost:8000/sse")?)); +// #[tokio::main] +// async fn main() -> Result<()> { +// // Initialize logging +// tracing_subscriber::fmt() +// .with_env_filter( +// EnvFilter::from_default_env() +// .add_directive("mcp_client=debug".parse().unwrap()) +// .add_directive("reqwest_eventsource=debug".parse().unwrap()), +// ) +// .init(); - // Build service with middleware including timeout - let service = ServiceBuilder::new() - .layer(TimeoutLayer::new(Duration::from_secs(30))) - .service(TransportService::new(Arc::clone(&transport))) - .map_err(|e: Box| { - if e.is::() { - ServiceError::Timeout(tower::timeout::error::Elapsed::new()) - } else { - ServiceError::Other(e.to_string()) - } - }); +// // Create the base transport as Arc> +// let transport = Arc::new(Mutex::new(SseTransport::new("http://localhost:8000/sse")?)); - // Create client - let mut client = McpClientImpl::new(service); - println!("Client created\n"); +// // Build service with middleware including timeout +// let service = ServiceBuilder::new() +// .layer(TimeoutLayer::new(Duration::from_secs(30))) +// .service(TransportService::new(Arc::clone(&transport))) +// .map_err(|e: Box| { +// if e.is::() { +// ServiceError::Timeout(tower::timeout::error::Elapsed::new()) +// } else { +// ServiceError::Other(e.to_string()) +// } +// }); - // Initialize - let server_info = client - .initialize( - ClientInfo { - name: "test-client".into(), - version: "1.0.0".into(), - }, - ClientCapabilities::default(), - ) - .await?; - println!("Connected to server: {server_info:?}\n"); +// // Create client +// let mut client = McpClientImpl::new(service); +// println!("Client created\n"); - // Sleep for 100ms to allow the server to start - surprisingly this is required! - tokio::time::sleep(Duration::from_millis(100)).await; +// // Initialize +// let server_info = client +// .initialize( +// ClientInfo { +// name: "test-client".into(), +// version: "1.0.0".into(), +// }, +// ClientCapabilities::default(), +// ) +// .await?; +// println!("Connected to server: {server_info:?}\n"); - // List tools - let tools = client.list_tools().await?; - println!("Available tools: {tools:?}\n"); +// // Sleep for 100ms to allow the server to start - surprisingly this is required! +// tokio::time::sleep(Duration::from_millis(100)).await; - // Call tool - let tool_result = client - .call_tool( - "echo_tool", - serde_json::json!({ "message": "Client with SSE transport - calling a tool" }), - ) - .await?; - println!("Tool result: {tool_result:?}"); +// // List tools +// let tools = client.list_tools().await?; +// println!("Available tools: {tools:?}\n"); - Ok(()) -} +// // Call tool +// let tool_result = client +// .call_tool( +// "echo_tool", +// serde_json::json!({ "message": "Client with SSE transport - calling a tool" }), +// ) +// .await?; +// println!("Tool result: {tool_result:?}"); + +// Ok(()) +// } diff --git a/crates/mcp-client/examples/stdio.rs b/crates/mcp-client/examples/stdio.rs index 7300512ddafb..57356ff0301d 100644 --- a/crates/mcp-client/examples/stdio.rs +++ b/crates/mcp-client/examples/stdio.rs @@ -1,16 +1,8 @@ use anyhow::Result; -use mcp_client::client::{ - ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientImpl, -}; -use mcp_client::{ - service::{ServiceError, TransportService}, - transport::StdioTransport, -}; -use std::sync::Arc; -use std::time::Duration; -use tokio::sync::Mutex; -use tower::timeout::TimeoutLayer; -use tower::{ServiceBuilder, ServiceExt}; +use mcp_client::client::McpClient; +use mcp_client::client::{ClientCapabilities, ClientInfo, Error as ClientError, McpClientImpl}; +use mcp_client::{service::TransportService, transport::StdioTransport}; +use tower::ServiceBuilder; use tracing_subscriber::EnvFilter; #[tokio::main] @@ -24,20 +16,11 @@ async fn main() -> Result<(), ClientError> { ) .init(); - // Create the base transport as Arc> - let transport = Arc::new(Mutex::new(StdioTransport::new("uvx", ["mcp-server-git"]))); + // Create the transport + let transport = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()]); // Build service with middleware including timeout - let service = ServiceBuilder::new() - .layer(TimeoutLayer::new(Duration::from_secs(30))) - .service(TransportService::new(Arc::clone(&transport))) - .map_err(|e: Box| { - if e.is::() { - ServiceError::Timeout(tower::timeout::error::Elapsed::new()) - } else { - ServiceError::Other(e.to_string()) - } - }); + let service = ServiceBuilder::new().service(TransportService::new(transport)); // Create client let mut client = McpClientImpl::new(service); @@ -58,11 +41,13 @@ async fn main() -> Result<(), ClientError> { let tools = client.list_tools().await?; println!("Available tools: {tools:?}\n"); - // Call tool 'git_status' wtih arguments = {"repo_path": "."} + // Call tool 'git_status' with arguments = {"repo_path": "."} let tool_result = client .call_tool("git_status", serde_json::json!({ "repo_path": "." })) .await?; - println!("Tool result: {tool_result:?}"); + println!("Tool result: {tool_result:?}\n"); + + println!("Finishing up, will cleanup resources as we go out of scope\n"); Ok(()) } diff --git a/crates/mcp-client/src/service.rs b/crates/mcp-client/src/service.rs index cac16e3e17d0..d78f2d980192 100644 --- a/crates/mcp-client/src/service.rs +++ b/crates/mcp-client/src/service.rs @@ -3,10 +3,10 @@ use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::task::{Context, Poll}; -use tokio::sync::{mpsc, Mutex, oneshot}; +use tokio::sync::{mpsc, Mutex}; use tower::Service; -use crate::transport::{Error as TransportError, Transport, TransportMessage, MessageRouter}; +use crate::transport::{Error as TransportError, MessageRouter, Transport}; use mcp_core::protocol::JsonRpcMessage; #[derive(Debug, thiserror::Error)] @@ -20,6 +20,12 @@ pub enum ServiceError { #[error("Request timed out")] Timeout(#[from] tower::timeout::error::Elapsed), + #[error("Transport not initialized")] + NotInitialized, + + #[error("Transport already initialized")] + AlreadyInitialized, + #[error("Other error: {0}")] Other(String), @@ -27,43 +33,70 @@ pub enum ServiceError { UnexpectedResponse, } -/// A Tower `Service` implementation that uses a `Transport` to send/receive JsonRpcMessages. -pub struct TransportService { +struct TransportServiceInner { transport: Arc, - router: Arc>>, + router: Mutex>, initialized: AtomicBool, } -impl TransportService { - pub fn new(transport: T) -> Self { - Self { - transport: Arc::new(transport), - router: Arc::new(Mutex::new(None)), - initialized: AtomicBool::new(false), - } - } - - async fn ensure_initialized(&self) -> Result, ServiceError> { +impl TransportServiceInner { + async fn ensure_initialized(&self) -> Result { if !self.initialized.load(Ordering::SeqCst) { let mut router_guard = self.router.lock().await; - + // Double-check after acquiring lock if !self.initialized.load(Ordering::SeqCst) { // Start the transport - let transport_tx = self.transport.start().await?; - + let transport_tx = self + .transport + .start() + .await + .map_err(ServiceError::Transport)?; + // Create shutdown channel let (shutdown_tx, _shutdown_rx) = mpsc::channel(1); - + // Create and store the router let router = MessageRouter::new(transport_tx, shutdown_tx); *router_guard = Some(router); - + self.initialized.store(true, Ordering::SeqCst); } } - - Ok(Arc::new(self.router.lock().await.as_ref().unwrap().clone())) + + // Get a clone of the router + Ok(self + .router + .lock() + .await + .as_ref() + .ok_or(ServiceError::NotInitialized)? + .clone()) + } +} + +/// A Tower `Service` implementation that uses a `Transport` to send/receive JsonRpcMessages. +pub struct TransportService { + inner: Arc>, +} + +impl TransportService { + pub fn new(transport: T) -> Self { + Self { + inner: Arc::new(TransportServiceInner { + transport: Arc::new(transport), + router: Mutex::new(None), + initialized: AtomicBool::new(false), + }), + } + } +} + +impl Clone for TransportService { + fn clone(&self) -> Self { + Self { + inner: Arc::clone(&self.inner), + } } } @@ -73,48 +106,48 @@ impl Service for TransportService { type Future = Pin> + Send>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - // Always ready. We do on-demand initialization in call(). + // Always ready since we do lazy initialization in call() Poll::Ready(Ok(())) } fn call(&mut self, message: JsonRpcMessage) -> Self::Future { - let transport = Arc::clone(&self.transport); - let router = Arc::clone(&self.router); + let inner = Arc::clone(&self.inner); Box::pin(async move { - let router = match router.lock().await.as_ref() { - Some(router) => router.clone(), - None => return Err(ServiceError::Other("Transport not initialized".to_string())), - }; + // Ensure transport is initialized + let router = inner.ensure_initialized().await?; match message { JsonRpcMessage::Notification(notification) => { - router.send_notification(JsonRpcMessage::Notification(notification)).await?; + router + .send_notification(JsonRpcMessage::Notification(notification)) + .await + .map_err(ServiceError::Transport)?; Ok(JsonRpcMessage::Nil) } - JsonRpcMessage::Request(request) => { - router.send_request(JsonRpcMessage::Request(request)).await - .map_err(|e| ServiceError::Transport(e)) - } + JsonRpcMessage::Request(request) => router + .send_request(JsonRpcMessage::Request(request)) + .await + .map_err(ServiceError::Transport), _ => Err(ServiceError::Other("Invalid message type".to_string())), } }) } } -impl Drop for TransportService { +impl Drop for TransportServiceInner { fn drop(&mut self) { if self.initialized.load(Ordering::SeqCst) { // Create a new runtime for cleanup if needed let rt = tokio::runtime::Runtime::new().unwrap(); - + // Request shutdown through the router if it exists if let Some(router) = rt.block_on(self.router.lock()).as_ref() { let _ = rt.block_on(router.shutdown()); } - + // Close the transport let _ = rt.block_on(self.transport.close()); } } -} \ No newline at end of file +} diff --git a/crates/mcp-client/src/transport/mod.rs b/crates/mcp-client/src/transport/mod.rs index 422329b6f5e2..ea045e7f1672 100644 --- a/crates/mcp-client/src/transport/mod.rs +++ b/crates/mcp-client/src/transport/mod.rs @@ -1,6 +1,5 @@ use async_trait::async_trait; use mcp_core::protocol::JsonRpcMessage; -use std::sync::Arc; use thiserror::Error; use tokio::sync::{mpsc, oneshot}; @@ -52,10 +51,11 @@ pub trait Transport: Send + Sync + 'static { pub mod stdio; pub use stdio::StdioTransport; -pub mod sse; -pub use sse::SseTransport; +// pub mod sse; +// pub use sse::SseTransport; /// A router that handles message distribution for a transport +#[derive(Clone)] pub struct MessageRouter { transport_tx: mpsc::Sender, shutdown_tx: mpsc::Sender<()>, diff --git a/crates/mcp-client/src/transport/sse.rs b/crates/mcp-client/src/transport/sse.rs index 072927600caf..f39daab52ad1 100644 --- a/crates/mcp-client/src/transport/sse.rs +++ b/crates/mcp-client/src/transport/sse.rs @@ -1,185 +1,183 @@ -use std::sync::Arc; -use async_trait::async_trait; -use futures_util::StreamExt; -use reqwest::{Client as HttpClient, Url}; -use tokio::sync::{mpsc, Mutex, oneshot}; -use tokio::task::JoinHandle; -use eventsource_client::{Client as EventSourceClient, SSE}; - -use super::{Error, Transport, TransportMessage}; -use mcp_core::protocol::JsonRpcMessage; - -/// A transport implementation that uses Server-Sent Events (SSE) for receiving messages -/// and HTTP POST for sending messages. -pub struct SseTransport { - sse_url: String, - http_client: HttpClient, - post_endpoint: Arc>>, - sse_handle: Arc>>>, - pending_requests: Arc>>>>, -} - -impl SseTransport { - /// Create a new SSE transport with the given SSE endpoint URL - pub fn new>(sse_url: S) -> Self { - Self { - sse_url: sse_url.into(), - http_client: HttpClient::new(), - post_endpoint: Arc::new(Mutex::new(None)), - sse_handle: Arc::new(Mutex::new(None)), - pending_requests: Arc::new(Mutex::new(std::collections::HashMap::new())), - } - } - - async fn handle_message( - message: JsonRpcMessage, - pending_requests: Arc>>>>, - ) { - if let JsonRpcMessage::Response(response) = &message { - if let Some(tx) = pending_requests.lock().await.remove(&response.id) { - let _ = tx.send(Ok(message)); - } - } - } - - async fn process_messages( - mut message_rx: mpsc::Receiver, - http_client: HttpClient, - post_endpoint: Arc>>, - sse_url: String, - pending_requests: Arc>>>>, - ) { - // Set up SSE client - let client = match EventSourceClient::new(&sse_url) { - Ok(client) => client, - Err(e) => { - eprintln!("Failed to create SSE client: {}", e); - return; - } - }; - - let mut stream = client.stream(); - - // Wait for endpoint event to get POST URL - while let Some(event) = stream.next().await { - match event { - Ok(SSE::Event(event)) if event.event_type == "endpoint" => { - if let Some(data) = event.data { - *post_endpoint.lock().await = Some(data); - break; - } - } - Ok(_) => continue, - Err(e) => { - eprintln!("SSE connection error: {}", e); - return; - } - } - } - - // Spawn SSE message handler - let pending_clone = pending_requests.clone(); - let sse_handle = tokio::spawn(async move { - while let Some(event) = stream.next().await { - match event { - Ok(SSE::Event(event)) if event.event_type == "message" => { - if let Some(data) = event.data { - if let Ok(message) = serde_json::from_str::(&data) { - Self::handle_message(message, pending_clone.clone()).await; - } - } - } - Ok(_) => continue, - Err(e) => { - eprintln!("SSE message error: {}", e); - break; - } - } - } - }); - - // Process outgoing messages - while let Some(transport_msg) = message_rx.recv().await { - let post_url = match post_endpoint.lock().await.as_ref() { - Some(url) => url.clone(), - None => { - eprintln!("No POST endpoint available"); - continue; - } - }; - - // Store response channel if this is a request - if let Some(response_tx) = transport_msg.response_tx { - if let JsonRpcMessage::Request(request) = &transport_msg.message { - pending_requests.lock().await.insert(request.id.clone(), response_tx); - } - } - - // Send message via HTTP POST - let message_str = match serde_json::to_string(&transport_msg.message) { - Ok(s) => s, - Err(e) => { - eprintln!("Failed to serialize message: {}", e); - continue; - } - }; - - if let Err(e) = http_client - .post(&post_url) - .header("Content-Type", "application/json") - .body(message_str) - .send() - .await - { - eprintln!("Failed to send message: {}", e); - } - } - - // Clean up - sse_handle.abort(); - } -} - -#[async_trait] -impl Transport for SseTransport { - async fn start(&self) -> Result, Error> { - let (message_tx, message_rx) = mpsc::channel(32); - - let http_client = self.http_client.clone(); - let post_endpoint = self.post_endpoint.clone(); - let sse_url = self.sse_url.clone(); - let pending_requests = self.pending_requests.clone(); - - let handle = tokio::spawn(Self::process_messages( - message_rx, - http_client, - post_endpoint, - sse_url, - pending_requests, - )); - - *self.sse_handle.lock().await = Some(handle); - - Ok(message_tx) - } - - async fn close(&self) -> Result<(), Error> { - // Abort the SSE handler task - if let Some(handle) = self.sse_handle.lock().await.take() { - handle.abort(); - } - - // Clear any pending requests - self.pending_requests.lock().await.clear(); - - Ok(()) - } -} - -impl Drop for SseTransport { - fn drop(&mut self) { - // Create a new runtime for cleanup if needed - let rt = tokio::runtime::Runtime::new().unwrap(); - let _ = rt.block_on(self.close()); - } -} \ No newline at end of file +// use std::sync::Arc; +// use async_trait::async_trait; +// use reqwest::Client as HttpClient; +// use tokio::sync::{mpsc, Mutex, oneshot}; +// use tokio::task::JoinHandle; +// use eventsource_client::SSE; + +// use super::{Error, Transport, TransportMessage}; +// use mcp_core::protocol::JsonRpcMessage; + +// /// A transport implementation that uses Server-Sent Events (SSE) for receiving messages +// /// and HTTP POST for sending messages. +// pub struct SseTransport { +// sse_url: String, +// http_client: HttpClient, +// post_endpoint: Arc>>, +// sse_handle: Arc>>>, +// pending_requests: Arc>>>>, +// } + +// impl SseTransport { +// /// Create a new SSE transport with the given SSE endpoint URL +// pub fn new>(sse_url: S) -> Self { +// Self { +// sse_url: sse_url.into(), +// http_client: HttpClient::new(), +// post_endpoint: Arc::new(Mutex::new(None)), +// sse_handle: Arc::new(Mutex::new(None)), +// pending_requests: Arc::new(Mutex::new(std::collections::HashMap::new())), +// } +// } + +// async fn handle_message( +// message: JsonRpcMessage, +// pending_requests: Arc>>>>, +// ) { +// if let JsonRpcMessage::Response(response) = &message { +// if let Some(id) = &response.id { +// if let Some(tx) = pending_requests.lock().await.remove(&id.to_string()) { +// let _ = tx.send(Ok(message)); +// } +// } +// } +// } + +// async fn process_messages( +// mut message_rx: mpsc::Receiver, +// http_client: HttpClient, +// post_endpoint: Arc>>, +// sse_url: String, +// pending_requests: Arc>>>>, +// ) { +// // Set up SSE client +// let mut client = match eventsource_client::ClientBuilder::for_url(&sse_url) { +// Ok(builder) => builder.build(), +// Err(e) => { +// eprintln!("Failed to create SSE client: {}", e); +// return; +// } +// }; + +// // Wait for endpoint event to get POST URL +// while let Some(event) = client.next().await { +// match event { +// Ok(SSE::Event(event)) if event.event_type == "endpoint" => { +// *post_endpoint.lock().await = Some(event.data); +// break; +// } +// Ok(_) => continue, +// Err(e) => { +// eprintln!("SSE connection error: {}", e); +// return; +// } +// } +// } + +// // Spawn SSE message handler +// let pending_clone = pending_requests.clone(); +// let mut client_clone = client.clone(); +// let sse_handle = tokio::spawn(async move { +// while let Some(event) = client_clone.next().await { +// match event { +// Ok(SSE::Event(event)) if event.event_type == "message" => { +// if let Ok(message) = serde_json::from_str::(&event.data) { +// Self::handle_message(message, pending_clone.clone()).await; +// } +// } +// Ok(_) => continue, +// Err(e) => { +// eprintln!("SSE message error: {}", e); +// break; +// } +// } +// } +// }); + +// // Process outgoing messages +// while let Some(transport_msg) = message_rx.recv().await { +// let post_url = match post_endpoint.lock().await.as_ref() { +// Some(url) => url.clone(), +// None => { +// eprintln!("No POST endpoint available"); +// continue; +// } +// }; + +// // Store response channel if this is a request +// if let Some(response_tx) = transport_msg.response_tx { +// if let JsonRpcMessage::Request(request) = &transport_msg.message { +// if let Some(id) = &request.id { +// pending_requests.lock().await.insert(id.to_string(), response_tx); +// } +// } +// } + +// // Send message via HTTP POST +// let message_str = match serde_json::to_string(&transport_msg.message) { +// Ok(s) => s, +// Err(e) => { +// eprintln!("Failed to serialize message: {}", e); +// continue; +// } +// }; + +// if let Err(e) = http_client +// .post(&post_url) +// .header("Content-Type", "application/json") +// .body(message_str) +// .send() +// .await +// { +// eprintln!("Failed to send message: {}", e); +// } +// } + +// // Clean up +// sse_handle.abort(); +// } +// } + +// #[async_trait] +// impl Transport for SseTransport { +// async fn start(&self) -> Result, Error> { +// let (message_tx, message_rx) = mpsc::channel(32); + +// let http_client = self.http_client.clone(); +// let post_endpoint = self.post_endpoint.clone(); +// let sse_url = self.sse_url.clone(); +// let pending_requests = self.pending_requests.clone(); + +// let handle = tokio::spawn(Self::process_messages( +// message_rx, +// http_client, +// post_endpoint, +// sse_url, +// pending_requests, +// )); + +// *self.sse_handle.lock().await = Some(handle); + +// Ok(message_tx) +// } + +// async fn close(&self) -> Result<(), Error> { +// // Abort the SSE handler task +// if let Some(handle) = self.sse_handle.lock().await.take() { +// handle.abort(); +// } + +// // Clear any pending requests +// self.pending_requests.lock().await.clear(); + +// Ok(()) +// } +// } + +// impl Drop for SseTransport { +// fn drop(&mut self) { +// // Create a new runtime for cleanup if needed +// let rt = tokio::runtime::Runtime::new().unwrap(); +// let _ = rt.block_on(self.close()); +// } +// } diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs index ef8bf86d6ba4..90f576fce7fd 100644 --- a/crates/mcp-client/src/transport/stdio.rs +++ b/crates/mcp-client/src/transport/stdio.rs @@ -1,11 +1,10 @@ -use std::io::{BufRead, BufReader, Write}; -use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio}; use std::sync::Arc; +use tokio::process::{Child, ChildStdin, ChildStdout, Command}; use async_trait::async_trait; use mcp_core::protocol::JsonRpcMessage; -use tokio::io::AsyncWriteExt; -use tokio::sync::{mpsc, Mutex}; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::sync::{mpsc, oneshot, Mutex}; use tokio::task::JoinHandle; use super::{Error, Transport, TransportMessage}; @@ -17,9 +16,10 @@ pub struct StdioTransport { command: String, args: Vec, process: Arc>>, - stdin: Arc>>, reader_handle: Arc>>>, - pending_requests: Arc>>>>, + pending_requests: Arc< + Mutex>>>, + >, } impl StdioTransport { @@ -29,7 +29,6 @@ impl StdioTransport { command: command.into(), args, process: Arc::new(Mutex::new(None)), - stdin: Arc::new(Mutex::new(None)), reader_handle: Arc::new(Mutex::new(None)), pending_requests: Arc::new(Mutex::new(std::collections::HashMap::new())), } @@ -38,13 +37,19 @@ impl StdioTransport { async fn spawn_process(&self) -> Result<(ChildStdin, ChildStdout), Error> { let mut child = Command::new(&self.command) .args(&self.args) - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::inherit()) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::inherit()) .spawn()?; - let stdin = child.stdin.take().ok_or(Error::Other("Failed to get stdin".into()))?; - let stdout = child.stdout.take().ok_or(Error::Other("Failed to get stdout".into()))?; + let stdin = child + .stdin + .take() + .ok_or(Error::Other("Failed to get stdin".into()))?; + let stdout = child + .stdout + .take() + .ok_or(Error::Other("Failed to get stdout".into()))?; *self.process.lock().await = Some(child); @@ -53,11 +58,17 @@ impl StdioTransport { async fn handle_message( message: JsonRpcMessage, - pending_requests: Arc>>>>, + pending_requests: Arc< + Mutex< + std::collections::HashMap>>, + >, + >, ) { if let JsonRpcMessage::Response(response) = &message { - if let Some(tx) = pending_requests.lock().await.remove(&response.id) { - let _ = tx.send(Ok(message)); + if let Some(id) = &response.id { + if let Some(tx) = pending_requests.lock().await.remove(&id.to_string()) { + let _ = tx.send(Ok(message)); + } } } } @@ -66,46 +77,70 @@ impl StdioTransport { mut message_rx: mpsc::Receiver, mut stdin: ChildStdin, stdout: ChildStdout, - pending_requests: Arc>>>>, + pending_requests: Arc< + Mutex< + std::collections::HashMap>>, + >, + >, ) { + // Set up async reader for stdout + let mut reader = BufReader::new(stdout); + // Spawn stdout reader task let pending_clone = pending_requests.clone(); let reader_handle = tokio::spawn(async move { - let mut reader = BufReader::new(stdout); let mut line = String::new(); - loop { line.clear(); - match reader.read_line(&mut line) { + match reader.read_line(&mut line).await { Ok(0) => break, // EOF Ok(_) => { if let Ok(message) = serde_json::from_str::(&line) { Self::handle_message(message, pending_clone.clone()).await; } } - Err(_) => break, + Err(e) => { + eprintln!("Error reading line: {}", e); + break; + } } } }); // Process incoming messages while let Some(transport_msg) = message_rx.recv().await { - let message_str = serde_json::to_string(&transport_msg.message) - .map_err(|e| Error::Other(format!("Serialization error: {}", e))) - .unwrap_or_default(); + let message_str = match serde_json::to_string(&transport_msg.message) { + Ok(s) => s, + Err(e) => { + eprintln!("Failed to serialize message: {}", e); + continue; + } + }; // Store response channel if this is a request if let Some(response_tx) = transport_msg.response_tx { if let JsonRpcMessage::Request(request) = &transport_msg.message { - pending_requests.lock().await.insert(request.id.clone(), response_tx); + if let Some(id) = &request.id { + pending_requests + .lock() + .await + .insert(id.to_string(), response_tx); + } } } // Write message to stdin - if let Err(e) = stdin.write_all(format!("{}\n", message_str).as_bytes()).await { + if let Err(e) = stdin + .write_all(format!("{}\n", message_str).as_bytes()) + .await + { eprintln!("Failed to write to stdin: {}", e); break; } + if let Err(e) = stdin.flush().await { + eprintln!("Failed to flush stdin: {}", e); + break; + } } // Clean up @@ -117,10 +152,9 @@ impl StdioTransport { impl Transport for StdioTransport { async fn start(&self) -> Result, Error> { let (stdin, stdout) = self.spawn_process().await?; - *self.stdin.lock().await = Some(stdin.try_clone()?); let (message_tx, message_rx) = mpsc::channel(32); - + let pending_requests = self.pending_requests.clone(); let handle = tokio::spawn(Self::process_messages( message_rx, @@ -158,4 +192,4 @@ impl Drop for StdioTransport { let rt = tokio::runtime::Runtime::new().unwrap(); let _ = rt.block_on(self.close()); } -} \ No newline at end of file +} From e5268ae049ce1d38ab3212af173f5c11e214d3a7 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Thu, 19 Dec 2024 18:54:20 -0500 Subject: [PATCH 14/21] stdio transport working with async requests --- crates/mcp-client/examples/stdio.rs | 7 +++--- crates/mcp-client/src/client.rs | 14 +++++------ crates/mcp-client/src/service.rs | 30 ++++++++++-------------- crates/mcp-client/src/transport/mod.rs | 22 +++++------------ crates/mcp-client/src/transport/sse.rs | 8 ------- crates/mcp-client/src/transport/stdio.rs | 11 +++------ 6 files changed, 31 insertions(+), 61 deletions(-) diff --git a/crates/mcp-client/examples/stdio.rs b/crates/mcp-client/examples/stdio.rs index 57356ff0301d..f79082f48678 100644 --- a/crates/mcp-client/examples/stdio.rs +++ b/crates/mcp-client/examples/stdio.rs @@ -1,6 +1,7 @@ use anyhow::Result; -use mcp_client::client::McpClient; -use mcp_client::client::{ClientCapabilities, ClientInfo, Error as ClientError, McpClientImpl}; +use mcp_client::client::{ + ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientImpl, +}; use mcp_client::{service::TransportService, transport::StdioTransport}; use tower::ServiceBuilder; use tracing_subscriber::EnvFilter; @@ -47,7 +48,5 @@ async fn main() -> Result<(), ClientError> { .await?; println!("Tool result: {tool_result:?}\n"); - println!("Finishing up, will cleanup resources as we go out of scope\n"); - Ok(()) } diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index 270385f4e4d6..348977052a2f 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -92,8 +92,8 @@ where } } - /// Send a JSON-RPC request and wait for a response. - async fn send_message(&mut self, method: &str, params: Value) -> Result + /// Send a JSON-RPC request and check we don't get an error response. + async fn send_request(&mut self, method: &str, params: Value) -> Result where R: for<'de> Deserialize<'de>, { @@ -182,7 +182,7 @@ where capabilities, }; let result: InitializeResult = self - .send_message("initialize", serde_json::to_value(params)?) + .send_request("initialize", serde_json::to_value(params)?) .await?; self.send_notification("notifications/initialized", serde_json::json!({})) @@ -192,21 +192,21 @@ where } async fn list_resources(&mut self) -> Result { - self.send_message("resources/list", serde_json::json!({})) + self.send_request("resources/list", serde_json::json!({})) .await } async fn read_resource(&mut self, uri: &str) -> Result { let params = serde_json::json!({ "uri": uri }); - self.send_message("resources/read", params).await + self.send_request("resources/read", params).await } async fn list_tools(&mut self) -> Result { - self.send_message("tools/list", serde_json::json!({})).await + self.send_request("tools/list", serde_json::json!({})).await } async fn call_tool(&mut self, name: &str, arguments: Value) -> Result { let params = serde_json::json!({ "name": name, "arguments": arguments }); - self.send_message("tools/call", params).await + self.send_request("tools/call", params).await } } diff --git a/crates/mcp-client/src/service.rs b/crates/mcp-client/src/service.rs index d78f2d980192..7ebe40ee25bf 100644 --- a/crates/mcp-client/src/service.rs +++ b/crates/mcp-client/src/service.rs @@ -120,13 +120,13 @@ impl Service for TransportService { match message { JsonRpcMessage::Notification(notification) => { router - .send_notification(JsonRpcMessage::Notification(notification)) + .send_notification(notification) .await .map_err(ServiceError::Transport)?; Ok(JsonRpcMessage::Nil) } JsonRpcMessage::Request(request) => router - .send_request(JsonRpcMessage::Request(request)) + .send_request(request) .await .map_err(ServiceError::Transport), _ => Err(ServiceError::Other("Invalid message type".to_string())), @@ -135,19 +135,13 @@ impl Service for TransportService { } } -impl Drop for TransportServiceInner { - fn drop(&mut self) { - if self.initialized.load(Ordering::SeqCst) { - // Create a new runtime for cleanup if needed - let rt = tokio::runtime::Runtime::new().unwrap(); - - // Request shutdown through the router if it exists - if let Some(router) = rt.block_on(self.router.lock()).as_ref() { - let _ = rt.block_on(router.shutdown()); - } - - // Close the transport - let _ = rt.block_on(self.transport.close()); - } - } -} +// https://spec.modelcontextprotocol.io/specification/basic/lifecycle/#shutdown +// impl Drop for TransportServiceInner { +// fn drop(&mut self) { +// if self.initialized.load(Ordering::SeqCst) { +// // Best effort cleanup in sync context +// // We can't create a new runtime here, so we'll just log a warning +// tracing::warn!("TransportService dropped while initialized - resources may leak"); +// } +// } +// } diff --git a/crates/mcp-client/src/transport/mod.rs b/crates/mcp-client/src/transport/mod.rs index ea045e7f1672..2b4923dea3b5 100644 --- a/crates/mcp-client/src/transport/mod.rs +++ b/crates/mcp-client/src/transport/mod.rs @@ -1,5 +1,5 @@ use async_trait::async_trait; -use mcp_core::protocol::JsonRpcMessage; +use mcp_core::protocol::{JsonRpcMessage, JsonRpcNotification, JsonRpcRequest}; use thiserror::Error; use tokio::sync::{mpsc, oneshot}; @@ -51,13 +51,11 @@ pub trait Transport: Send + Sync + 'static { pub mod stdio; pub use stdio::StdioTransport; -// pub mod sse; -// pub use sse::SseTransport; - /// A router that handles message distribution for a transport #[derive(Clone)] pub struct MessageRouter { transport_tx: mpsc::Sender, + #[allow(dead_code)] shutdown_tx: mpsc::Sender<()>, } @@ -73,12 +71,12 @@ impl MessageRouter { } /// Send a message and wait for a response - pub async fn send_request(&self, request: JsonRpcMessage) -> Result { + pub async fn send_request(&self, request: JsonRpcRequest) -> Result { let (response_tx, response_rx) = oneshot::channel(); self.transport_tx .send(TransportMessage { - message: request, + message: JsonRpcMessage::Request(request), response_tx: Some(response_tx), }) .await @@ -88,21 +86,13 @@ impl MessageRouter { } /// Send a notification (no response expected) - pub async fn send_notification(&self, notification: JsonRpcMessage) -> Result<(), Error> { + pub async fn send_notification(&self, notification: JsonRpcNotification) -> Result<(), Error> { self.transport_tx .send(TransportMessage { - message: notification, + message: JsonRpcMessage::Notification(notification), response_tx: None, }) .await .map_err(|_| Error::ChannelClosed) } - - /// Request shutdown of the transport - pub async fn shutdown(&self) -> Result<(), Error> { - self.shutdown_tx - .send(()) - .await - .map_err(|_| Error::ChannelClosed) - } } diff --git a/crates/mcp-client/src/transport/sse.rs b/crates/mcp-client/src/transport/sse.rs index f39daab52ad1..1a9264dc1cd6 100644 --- a/crates/mcp-client/src/transport/sse.rs +++ b/crates/mcp-client/src/transport/sse.rs @@ -173,11 +173,3 @@ // Ok(()) // } // } - -// impl Drop for SseTransport { -// fn drop(&mut self) { -// // Create a new runtime for cleanup if needed -// let rt = tokio::runtime::Runtime::new().unwrap(); -// let _ = rt.block_on(self.close()); -// } -// } diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs index 90f576fce7fd..749a952cce1e 100644 --- a/crates/mcp-client/src/transport/stdio.rs +++ b/crates/mcp-client/src/transport/stdio.rs @@ -171,12 +171,13 @@ impl Transport for StdioTransport { async fn close(&self) -> Result<(), Error> { // Kill the process if let Some(mut process) = self.process.lock().await.take() { - let _ = process.kill(); + let _ = process.kill().await; } // Abort the reader task if let Some(handle) = self.reader_handle.lock().await.take() { handle.abort(); + let _ = handle.await; } // Clear any pending requests @@ -186,10 +187,4 @@ impl Transport for StdioTransport { } } -impl Drop for StdioTransport { - fn drop(&mut self) { - // Create a new runtime for cleanup if needed - let rt = tokio::runtime::Runtime::new().unwrap(); - let _ = rt.block_on(self.close()); - } -} +// No Drop implementation needed - we'll handle cleanup in the TransportService From 0db4f80aca5cb119a52ab43a3f266738e2a0817b Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Thu, 19 Dec 2024 19:19:22 -0500 Subject: [PATCH 15/21] examples are working, added TODO for timeout middleware service --- crates/mcp-client/Cargo.toml | 3 +- crates/mcp-client/examples/clients.rs | 159 +++++------ crates/mcp-client/examples/sse.rs | 135 ++++----- crates/mcp-client/examples/stdio.rs | 3 +- crates/mcp-client/src/transport/mod.rs | 3 + crates/mcp-client/src/transport/sse.rs | 364 +++++++++++++------------ 6 files changed, 316 insertions(+), 351 deletions(-) diff --git a/crates/mcp-client/Cargo.toml b/crates/mcp-client/Cargo.toml index 9a330c6e729b..221b80e2d844 100644 --- a/crates/mcp-client/Cargo.toml +++ b/crates/mcp-client/Cargo.toml @@ -8,6 +8,7 @@ mcp-core = { path = "../mcp-core" } tokio = { version = "1", features = ["full"] } reqwest = { version = "0.11", default-features = false, features = ["json", "stream", "rustls-tls"] } eventsource-client = "0.12.0" +futures = "0.3" futures-util = "0.3" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" @@ -25,4 +26,4 @@ tokio-util = { version = "0.7", features = ["io-util", "io"] } [dev-dependencies] warp = "0.3" -async-stream = "0.3" \ No newline at end of file +async-stream = "0.3" diff --git a/crates/mcp-client/examples/clients.rs b/crates/mcp-client/examples/clients.rs index c3bec5fc2a54..c91fcdfaae68 100644 --- a/crates/mcp-client/examples/clients.rs +++ b/crates/mcp-client/examples/clients.rs @@ -1,106 +1,71 @@ -// TODO: Remove this -fn main() { - println!("Hello World!"); -} - -// use std::sync::Arc; -// use std::time::Duration; -// use tokio::sync::Mutex; - -// use mcp_client::{ -// client::{ClientCapabilities, ClientInfo, McpClient, McpClientImpl}, -// service::{ServiceError, TransportService}, -// transport::{SseTransport, StdioTransport}, -// }; -// use tower::timeout::TimeoutLayer; -// use tower::{ServiceBuilder, ServiceExt}; -// use tracing_subscriber::EnvFilter; - -// #[tokio::main] -// async fn main() -> Result<(), Box> { -// // Initialize logging -// tracing_subscriber::fmt() -// .with_env_filter( -// EnvFilter::from_default_env() -// .add_directive("mcp_client=debug".parse().unwrap()) -// .add_directive("reqwest_eventsource=debug".parse().unwrap()), -// ) -// .init(); - -// // Create two separate clients with stdio transport -// let client1 = create_client("client1", "1.0.0")?; -// let client2 = create_client("client2", "1.0.0")?; -// let client3 = create_sse_client("client3", "1.0.0")?; +use mcp_client::{ + client::{ClientCapabilities, ClientInfo, McpClient, McpClientImpl}, + service::TransportService, + transport::{SseTransport, StdioTransport}, +}; +use tower::ServiceBuilder; +use tracing_subscriber::EnvFilter; -// // Initialize both clients -// let mut clients: Vec> = Vec::new(); -// clients.push(client1); -// clients.push(client2); -// clients.push(client3); +#[tokio::main] +async fn main() -> Result<(), Box> { + // Initialize logging + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::from_default_env() + .add_directive("mcp_client=debug".parse().unwrap()) + .add_directive("reqwest_eventsource=debug".parse().unwrap()), + ) + .init(); -// // Initialize all clients -// for (i, client) in clients.iter_mut().enumerate() { -// let info = ClientInfo { -// name: format!("example-client-{}", i + 1), -// version: "1.0.0".to_string(), -// }; -// let capabilities = ClientCapabilities::default(); + // Create two separate clients with stdio transport + let client1 = create_stdio_client("client1", "1.0.0")?; + let client2 = create_stdio_client("client2", "1.0.0")?; + let client3 = create_sse_client("client3", "1.0.0")?; -// println!("\nInitializing client {}", i + 1); -// let init_result = client.initialize(info, capabilities).await?; -// println!("Client {} initialized: {:?}", i + 1, init_result); -// } + // Initialize both clients + let mut clients: Vec> = Vec::new(); + clients.push(client1); + clients.push(client2); + clients.push(client3); -// // List tools for each client -// for (i, client) in clients.iter_mut().enumerate() { -// let tools = client.list_tools().await?; -// println!("\nClient {} tools: {:?}", i + 1, tools); -// } + // Initialize all clients + for (i, client) in clients.iter_mut().enumerate() { + let info = ClientInfo { + name: format!("example-client-{}", i + 1), + version: "1.0.0".to_string(), + }; + let capabilities = ClientCapabilities::default(); -// Ok(()) -// } + println!("\nInitializing client {}", i + 1); + let init_result = client.initialize(info, capabilities).await?; + println!("Client {} initialized: {:?}", i + 1, init_result); + } -// fn create_client( -// _name: &str, -// _version: &str, -// ) -> Result, Box> { -// // Create the transport -// let transport = Arc::new(Mutex::new(StdioTransport::new("uvx", ["mcp-server-git"]))); + // List tools for each client + for (i, client) in clients.iter_mut().enumerate() { + let tools = client.list_tools().await?; + println!("\nClient {} tools: {:?}", i + 1, tools); + } -// // Build service with middleware including timeout -// let service = ServiceBuilder::new() -// .layer(TimeoutLayer::new(Duration::from_secs(30))) -// .service(TransportService::new(Arc::clone(&transport))) -// .map_err(|e: Box| { -// if e.is::() { -// ServiceError::Timeout(tower::timeout::error::Elapsed::new()) -// } else { -// ServiceError::Other(e.to_string()) -// } -// }); - -// Ok(Box::new(McpClientImpl::new(service))) -// } - -// fn create_sse_client( -// _name: &str, -// _version: &str, -// ) -> Result, Box> { -// let transport = Arc::new(Mutex::new( -// SseTransport::new("http://localhost:8000/sse").unwrap(), -// )); + Ok(()) +} -// // Build service with middleware including timeout -// let service = ServiceBuilder::new() -// .layer(TimeoutLayer::new(Duration::from_secs(30))) -// .service(TransportService::new(Arc::clone(&transport))) -// .map_err(|e: Box| { -// if e.is::() { -// ServiceError::Timeout(tower::timeout::error::Elapsed::new()) -// } else { -// ServiceError::Other(e.to_string()) -// } -// }); +fn create_stdio_client( + _name: &str, + _version: &str, +) -> Result, Box> { + let transport = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()]); + // TODO: Add timeout middleware + let service = ServiceBuilder::new().service(TransportService::new(transport)); + Ok(Box::new(McpClientImpl::new(service))) +} -// Ok(Box::new(McpClientImpl::new(service))) -// } +fn create_sse_client( + _name: &str, + _version: &str, +) -> Result, Box> { + let transport = SseTransport::new("http://localhost:8000/sse"); + // TODO: Add timeout middleware + let service = ServiceBuilder::new().service(TransportService::new(transport)); + Ok(Box::new(McpClientImpl::new(service))) +} diff --git a/crates/mcp-client/examples/sse.rs b/crates/mcp-client/examples/sse.rs index 5b6487a7a2fd..4b72946936f3 100644 --- a/crates/mcp-client/examples/sse.rs +++ b/crates/mcp-client/examples/sse.rs @@ -1,78 +1,59 @@ -// TODO: Remove this -fn main() { - println!("Hello World!"); +use anyhow::Result; +use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientImpl}; +use mcp_client::{service::TransportService, transport::SseTransport}; +use std::time::Duration; +use tower::ServiceBuilder; +use tracing_subscriber::EnvFilter; + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize logging + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::from_default_env() + .add_directive("mcp_client=debug".parse().unwrap()) + .add_directive("reqwest_eventsource=debug".parse().unwrap()), + ) + .init(); + + // Create the base transport + let transport = SseTransport::new("http://localhost:8000/sse"); + + // Build service + // TODO: Add timeout middleware + let service = ServiceBuilder::new().service(TransportService::new(transport)); + + // Create client + let mut client = McpClientImpl::new(service); + println!("Client created\n"); + + // Initialize + let server_info = client + .initialize( + ClientInfo { + name: "test-client".into(), + version: "1.0.0".into(), + }, + ClientCapabilities::default(), + ) + .await?; + println!("Connected to server: {server_info:?}\n"); + + // Sleep for 100ms to allow the server to start - surprisingly this is required! + tokio::time::sleep(Duration::from_millis(100)).await; + + // List tools + let tools = client.list_tools().await?; + println!("Available tools: {tools:?}\n"); + + // Call tool + let tool_result = client + .call_tool( + "echo_tool", + serde_json::json!({ "message": "Client with SSE transport - calling a tool" }), + ) + .await?; + println!("Tool result: {tool_result:?}"); + + Ok(()) } - -// use anyhow::Result; -// use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientImpl}; -// use mcp_client::{ -// service::{ServiceError, TransportService}, -// transport::SseTransport, -// }; -// use std::sync::Arc; -// use std::time::Duration; -// use tokio::sync::Mutex; -// use tower::timeout::TimeoutLayer; -// use tower::{ServiceBuilder, ServiceExt}; -// use tracing_subscriber::EnvFilter; - -// #[tokio::main] -// async fn main() -> Result<()> { -// // Initialize logging -// tracing_subscriber::fmt() -// .with_env_filter( -// EnvFilter::from_default_env() -// .add_directive("mcp_client=debug".parse().unwrap()) -// .add_directive("reqwest_eventsource=debug".parse().unwrap()), -// ) -// .init(); - -// // Create the base transport as Arc> -// let transport = Arc::new(Mutex::new(SseTransport::new("http://localhost:8000/sse")?)); - -// // Build service with middleware including timeout -// let service = ServiceBuilder::new() -// .layer(TimeoutLayer::new(Duration::from_secs(30))) -// .service(TransportService::new(Arc::clone(&transport))) -// .map_err(|e: Box| { -// if e.is::() { -// ServiceError::Timeout(tower::timeout::error::Elapsed::new()) -// } else { -// ServiceError::Other(e.to_string()) -// } -// }); - -// // Create client -// let mut client = McpClientImpl::new(service); -// println!("Client created\n"); - -// // Initialize -// let server_info = client -// .initialize( -// ClientInfo { -// name: "test-client".into(), -// version: "1.0.0".into(), -// }, -// ClientCapabilities::default(), -// ) -// .await?; -// println!("Connected to server: {server_info:?}\n"); - -// // Sleep for 100ms to allow the server to start - surprisingly this is required! -// tokio::time::sleep(Duration::from_millis(100)).await; - -// // List tools -// let tools = client.list_tools().await?; -// println!("Available tools: {tools:?}\n"); - -// // Call tool -// let tool_result = client -// .call_tool( -// "echo_tool", -// serde_json::json!({ "message": "Client with SSE transport - calling a tool" }), -// ) -// .await?; -// println!("Tool result: {tool_result:?}"); - -// Ok(()) -// } diff --git a/crates/mcp-client/examples/stdio.rs b/crates/mcp-client/examples/stdio.rs index f79082f48678..850ee733f115 100644 --- a/crates/mcp-client/examples/stdio.rs +++ b/crates/mcp-client/examples/stdio.rs @@ -20,7 +20,8 @@ async fn main() -> Result<(), ClientError> { // Create the transport let transport = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()]); - // Build service with middleware including timeout + // Build service + // TODO: Add timeout middleware let service = ServiceBuilder::new().service(TransportService::new(transport)); // Create client diff --git a/crates/mcp-client/src/transport/mod.rs b/crates/mcp-client/src/transport/mod.rs index 2b4923dea3b5..2754d65671bb 100644 --- a/crates/mcp-client/src/transport/mod.rs +++ b/crates/mcp-client/src/transport/mod.rs @@ -51,6 +51,9 @@ pub trait Transport: Send + Sync + 'static { pub mod stdio; pub use stdio::StdioTransport; +pub mod sse; +pub use sse::SseTransport; + /// A router that handles message distribution for a transport #[derive(Clone)] pub struct MessageRouter { diff --git a/crates/mcp-client/src/transport/sse.rs b/crates/mcp-client/src/transport/sse.rs index 1a9264dc1cd6..fac0998ff633 100644 --- a/crates/mcp-client/src/transport/sse.rs +++ b/crates/mcp-client/src/transport/sse.rs @@ -1,175 +1,189 @@ -// use std::sync::Arc; -// use async_trait::async_trait; -// use reqwest::Client as HttpClient; -// use tokio::sync::{mpsc, Mutex, oneshot}; -// use tokio::task::JoinHandle; -// use eventsource_client::SSE; - -// use super::{Error, Transport, TransportMessage}; -// use mcp_core::protocol::JsonRpcMessage; - -// /// A transport implementation that uses Server-Sent Events (SSE) for receiving messages -// /// and HTTP POST for sending messages. -// pub struct SseTransport { -// sse_url: String, -// http_client: HttpClient, -// post_endpoint: Arc>>, -// sse_handle: Arc>>>, -// pending_requests: Arc>>>>, -// } - -// impl SseTransport { -// /// Create a new SSE transport with the given SSE endpoint URL -// pub fn new>(sse_url: S) -> Self { -// Self { -// sse_url: sse_url.into(), -// http_client: HttpClient::new(), -// post_endpoint: Arc::new(Mutex::new(None)), -// sse_handle: Arc::new(Mutex::new(None)), -// pending_requests: Arc::new(Mutex::new(std::collections::HashMap::new())), -// } -// } - -// async fn handle_message( -// message: JsonRpcMessage, -// pending_requests: Arc>>>>, -// ) { -// if let JsonRpcMessage::Response(response) = &message { -// if let Some(id) = &response.id { -// if let Some(tx) = pending_requests.lock().await.remove(&id.to_string()) { -// let _ = tx.send(Ok(message)); -// } -// } -// } -// } - -// async fn process_messages( -// mut message_rx: mpsc::Receiver, -// http_client: HttpClient, -// post_endpoint: Arc>>, -// sse_url: String, -// pending_requests: Arc>>>>, -// ) { -// // Set up SSE client -// let mut client = match eventsource_client::ClientBuilder::for_url(&sse_url) { -// Ok(builder) => builder.build(), -// Err(e) => { -// eprintln!("Failed to create SSE client: {}", e); -// return; -// } -// }; - -// // Wait for endpoint event to get POST URL -// while let Some(event) = client.next().await { -// match event { -// Ok(SSE::Event(event)) if event.event_type == "endpoint" => { -// *post_endpoint.lock().await = Some(event.data); -// break; -// } -// Ok(_) => continue, -// Err(e) => { -// eprintln!("SSE connection error: {}", e); -// return; -// } -// } -// } - -// // Spawn SSE message handler -// let pending_clone = pending_requests.clone(); -// let mut client_clone = client.clone(); -// let sse_handle = tokio::spawn(async move { -// while let Some(event) = client_clone.next().await { -// match event { -// Ok(SSE::Event(event)) if event.event_type == "message" => { -// if let Ok(message) = serde_json::from_str::(&event.data) { -// Self::handle_message(message, pending_clone.clone()).await; -// } -// } -// Ok(_) => continue, -// Err(e) => { -// eprintln!("SSE message error: {}", e); -// break; -// } -// } -// } -// }); - -// // Process outgoing messages -// while let Some(transport_msg) = message_rx.recv().await { -// let post_url = match post_endpoint.lock().await.as_ref() { -// Some(url) => url.clone(), -// None => { -// eprintln!("No POST endpoint available"); -// continue; -// } -// }; - -// // Store response channel if this is a request -// if let Some(response_tx) = transport_msg.response_tx { -// if let JsonRpcMessage::Request(request) = &transport_msg.message { -// if let Some(id) = &request.id { -// pending_requests.lock().await.insert(id.to_string(), response_tx); -// } -// } -// } - -// // Send message via HTTP POST -// let message_str = match serde_json::to_string(&transport_msg.message) { -// Ok(s) => s, -// Err(e) => { -// eprintln!("Failed to serialize message: {}", e); -// continue; -// } -// }; - -// if let Err(e) = http_client -// .post(&post_url) -// .header("Content-Type", "application/json") -// .body(message_str) -// .send() -// .await -// { -// eprintln!("Failed to send message: {}", e); -// } -// } - -// // Clean up -// sse_handle.abort(); -// } -// } - -// #[async_trait] -// impl Transport for SseTransport { -// async fn start(&self) -> Result, Error> { -// let (message_tx, message_rx) = mpsc::channel(32); - -// let http_client = self.http_client.clone(); -// let post_endpoint = self.post_endpoint.clone(); -// let sse_url = self.sse_url.clone(); -// let pending_requests = self.pending_requests.clone(); - -// let handle = tokio::spawn(Self::process_messages( -// message_rx, -// http_client, -// post_endpoint, -// sse_url, -// pending_requests, -// )); - -// *self.sse_handle.lock().await = Some(handle); - -// Ok(message_tx) -// } - -// async fn close(&self) -> Result<(), Error> { -// // Abort the SSE handler task -// if let Some(handle) = self.sse_handle.lock().await.take() { -// handle.abort(); -// } - -// // Clear any pending requests -// self.pending_requests.lock().await.clear(); - -// Ok(()) -// } -// } +use async_trait::async_trait; +use eventsource_client::{Client, SSE}; +use futures::TryStreamExt; +use reqwest::Client as HttpClient; +use std::sync::Arc; +use tokio::sync::{mpsc, oneshot, Mutex}; +use tokio::task::JoinHandle; + +use super::{Error, Transport, TransportMessage}; +use mcp_core::protocol::JsonRpcMessage; + +/// A transport implementation that uses Server-Sent Events (SSE) for receiving messages +/// and HTTP POST for sending messages. +pub struct SseTransport { + sse_url: String, + http_client: HttpClient, + post_endpoint: Arc>>, + sse_handle: Arc>>>, + pending_requests: Arc< + Mutex>>>, + >, +} + +impl SseTransport { + /// Create a new SSE transport with the given SSE endpoint URL + pub fn new>(sse_url: S) -> Self { + Self { + sse_url: sse_url.into(), + http_client: HttpClient::new(), + post_endpoint: Arc::new(Mutex::new(None)), + sse_handle: Arc::new(Mutex::new(None)), + pending_requests: Arc::new(Mutex::new(std::collections::HashMap::new())), + } + } + + async fn handle_message( + message: JsonRpcMessage, + pending_requests: Arc< + Mutex< + std::collections::HashMap>>, + >, + >, + ) { + if let JsonRpcMessage::Response(response) = &message { + if let Some(id) = &response.id { + if let Some(tx) = pending_requests.lock().await.remove(&id.to_string()) { + let _ = tx.send(Ok(message)); + } + } + } + } + + async fn process_messages( + mut message_rx: mpsc::Receiver, + http_client: HttpClient, + post_endpoint: Arc>>, + sse_url: String, + pending_requests: Arc< + Mutex< + std::collections::HashMap>>, + >, + >, + ) { + // Set up SSE client + let client = match eventsource_client::ClientBuilder::for_url(&sse_url) { + Ok(builder) => builder.build(), + Err(e) => { + eprintln!("Failed to create SSE client: {}", e); + return; + } + }; + + let mut stream = client.stream(); + + // First, wait for the endpoint event + while let Ok(Some(event)) = stream.try_next().await { + match event { + SSE::Event(event) if event.event_type == "endpoint" => { + let base_url = sse_url.trim_end_matches('/').trim_end_matches("sse"); + let endpoint_path = event.data.trim_start_matches('/'); + let post_url = format!("{}{}", base_url, endpoint_path); + println!("Endpoint for POST requests: {}", post_url); + *post_endpoint.lock().await = Some(post_url); + break; + } + _ => continue, + } + } + + // Now handle all subsequent messages + let message_handler = tokio::spawn({ + let pending_requests = pending_requests.clone(); + async move { + while let Ok(Some(event)) = stream.try_next().await { + match event { + SSE::Event(event) if event.event_type == "message" => { + if let Ok(message) = serde_json::from_str::(&event.data) + { + Self::handle_message(message, pending_requests.clone()).await; + } + } + _ => continue, + } + } + } + }); + + // Process outgoing messages + while let Some(transport_msg) = message_rx.recv().await { + let post_url = match post_endpoint.lock().await.as_ref() { + Some(url) => url.clone(), + None => { + eprintln!("No POST endpoint available"); + continue; + } + }; + + // Store response channel if this is a request + if let Some(response_tx) = transport_msg.response_tx { + if let JsonRpcMessage::Request(request) = &transport_msg.message { + if let Some(id) = &request.id { + pending_requests + .lock() + .await + .insert(id.to_string(), response_tx); + } + } + } + + // Send message via HTTP POST + let message_str = match serde_json::to_string(&transport_msg.message) { + Ok(s) => s, + Err(e) => { + eprintln!("Failed to serialize message: {}", e); + continue; + } + }; + + if let Err(e) = http_client + .post(&post_url) + .header("Content-Type", "application/json") + .body(message_str) + .send() + .await + { + eprintln!("Failed to send message: {}", e); + } + } + + // Clean up + message_handler.abort(); + } +} + +#[async_trait] +impl Transport for SseTransport { + async fn start(&self) -> Result, Error> { + let (message_tx, message_rx) = mpsc::channel(32); + + let http_client = self.http_client.clone(); + let post_endpoint = self.post_endpoint.clone(); + let sse_url = self.sse_url.clone(); + let pending_requests = self.pending_requests.clone(); + + let handle = tokio::spawn(Self::process_messages( + message_rx, + http_client, + post_endpoint, + sse_url, + pending_requests, + )); + + *self.sse_handle.lock().await = Some(handle); + + Ok(message_tx) + } + + async fn close(&self) -> Result<(), Error> { + // Abort the SSE handler task + if let Some(handle) = self.sse_handle.lock().await.take() { + handle.abort(); + } + + // Clear any pending requests + self.pending_requests.lock().await.clear(); + + Ok(()) + } +} From 4e23c40b8277bf6d25ee612da18f65d9dda1c0f2 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Thu, 19 Dec 2024 20:01:04 -0500 Subject: [PATCH 16/21] trying to test parallel requests from client -> servers --- crates/mcp-client/Cargo.toml | 1 + crates/mcp-client/examples/clients.rs | 79 ++++++++++++++++++++++++-- crates/mcp-client/src/client.rs | 4 +- crates/mcp-client/src/transport/mod.rs | 1 + 4 files changed, 79 insertions(+), 6 deletions(-) diff --git a/crates/mcp-client/Cargo.toml b/crates/mcp-client/Cargo.toml index 221b80e2d844..f64d885650e5 100644 --- a/crates/mcp-client/Cargo.toml +++ b/crates/mcp-client/Cargo.toml @@ -23,6 +23,7 @@ tokio-retry = "0.3" tower = { version = "0.4", features = ["timeout", "util"] } tower-service = "0.3" tokio-util = { version = "0.7", features = ["io-util", "io"] } +rand = "0.8" [dev-dependencies] warp = "0.3" diff --git a/crates/mcp-client/examples/clients.rs b/crates/mcp-client/examples/clients.rs index c91fcdfaae68..e0ad9f0e3913 100644 --- a/crates/mcp-client/examples/clients.rs +++ b/crates/mcp-client/examples/clients.rs @@ -3,6 +3,10 @@ use mcp_client::{ service::TransportService, transport::{SseTransport, StdioTransport}, }; +use rand::Rng; +use rand::SeedableRng; +use std::sync::Arc; +use std::time::Duration; use tower::ServiceBuilder; use tracing_subscriber::EnvFilter; @@ -23,7 +27,7 @@ async fn main() -> Result<(), Box> { let client3 = create_sse_client("client3", "1.0.0")?; // Initialize both clients - let mut clients: Vec> = Vec::new(); + let mut clients: Vec> = Vec::new(); clients.push(client1); clients.push(client2); clients.push(client3); @@ -41,19 +45,86 @@ async fn main() -> Result<(), Box> { println!("Client {} initialized: {:?}", i + 1, init_result); } - // List tools for each client + // List tools for all clients for (i, client) in clients.iter_mut().enumerate() { let tools = client.list_tools().await?; println!("\nClient {} tools: {:?}", i + 1, tools); } + println!("\n\n----------------------------------\n\n"); + + // // Wrap clients in Arc before spawning tasks + // let clients = Arc::new(clients); + // let mut handles = vec![]; + + // for i in 0..50 { + // let clients = Arc::clone(&clients); + // let handle = tokio::spawn(async move { + // // let mut rng = rand::thread_rng(); + // let mut rng = rand::rngs::StdRng::from_entropy(); + // tokio::time::sleep(Duration::from_millis(rng.gen_range(5..50))).await; + + // // Randomly select an operation + // match rng.gen_range(0..4) { + // 0 => { + // println!("{i}: Listing tools for client 1 (stdio)"); + // match clients[0].list_tools().await { + // Ok(tools) => println!(" -> Got {:?}", tools), + // Err(e) => println!(" -> Error: {}", e), + // } + // } + // 1 => { + // println!("{i}: Listing tools for client 3 (sse)"); + // match clients[2].list_tools().await { + // Ok(tools) => println!(" -> Got {:?}", tools), + // Err(e) => println!(" -> Error: {}", e), + // } + // } + // 2 => { + // println!("{i}: Calling tool for client 2 (stdio)"); + // match clients[1] + // .call_tool( + // "echo", + // serde_json::json!({ "message": "Hello from client 2" }), + // ) + // .await + // { + // Ok(result) => println!(" → Tool execution result: {:?}", result), + // Err(e) => println!(" → Error: {}", e), + // } + // } + // 3 => { + // println!("{i}: Calling tool for client 3 (sse)"); + // match clients[2] + // .call_tool( + // "echo_tool", + // serde_json::json!({ "message": "Client with SSE transport - calling a tool" }), + // ) + // .await + // { + // Ok(result) => println!(" → Tool execution result: {:?}", result), + // Err(e) => println!(" -> Error: {}", e), + // } + // } + // _ => unreachable!(), + // } + // Ok::<(), Box>(()) + // }); + // handles.push(handle); + // } + + // // Wait for all tasks to complete + // for handle in handles { + // handle.await.unwrap().unwrap(); + // } + Ok(()) } fn create_stdio_client( _name: &str, _version: &str, -) -> Result, Box> { +) -> Result, Box> { let transport = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()]); // TODO: Add timeout middleware let service = ServiceBuilder::new().service(TransportService::new(transport)); @@ -63,7 +134,7 @@ fn create_stdio_client( fn create_sse_client( _name: &str, _version: &str, -) -> Result, Box> { +) -> Result, Box> { let transport = SseTransport::new("http://localhost:8000/sse"); // TODO: Add timeout middleware let service = ServiceBuilder::new().service(TransportService::new(transport)); diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index 348977052a2f..2adbc785863b 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -49,7 +49,7 @@ pub struct InitializeParams { /// The MCP client trait defining the interface for MCP operations. #[async_trait::async_trait] -pub trait McpClient { +pub trait McpClient: Send + Sync + 'static { /// Initialize the connection with the server. async fn initialize( &mut self, @@ -168,7 +168,7 @@ where Response = JsonRpcMessage, Error = super::service::ServiceError, > + Send - + Sync, + + Sync + 'static, S::Future: Send, { async fn initialize( diff --git a/crates/mcp-client/src/transport/mod.rs b/crates/mcp-client/src/transport/mod.rs index 2754d65671bb..75167d6e5378 100644 --- a/crates/mcp-client/src/transport/mod.rs +++ b/crates/mcp-client/src/transport/mod.rs @@ -58,6 +58,7 @@ pub use sse::SseTransport; #[derive(Clone)] pub struct MessageRouter { transport_tx: mpsc::Sender, + // shutdown_tx is unused, but we'll probably need it for shutdown #[allow(dead_code)] shutdown_tx: mpsc::Sender<()>, } From aa77d3fbc8215b86f143ce55efa9605e1bb41827 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Thu, 19 Dec 2024 20:25:32 -0500 Subject: [PATCH 17/21] simulate async requests in example --- crates/mcp-client/examples/clients.rs | 119 +++++++++++++------------- crates/mcp-client/examples/sse.rs | 2 +- crates/mcp-client/examples/stdio.rs | 2 +- crates/mcp-client/src/client.rs | 64 +++++++------- 4 files changed, 94 insertions(+), 93 deletions(-) diff --git a/crates/mcp-client/examples/clients.rs b/crates/mcp-client/examples/clients.rs index e0ad9f0e3913..3f03ddfcfa62 100644 --- a/crates/mcp-client/examples/clients.rs +++ b/crates/mcp-client/examples/clients.rs @@ -53,70 +53,67 @@ async fn main() -> Result<(), Box> { println!("\n\n----------------------------------\n\n"); - // // Wrap clients in Arc before spawning tasks - // let clients = Arc::new(clients); - // let mut handles = vec![]; + // Wrap clients in Arc before spawning tasks + let clients = Arc::new(clients); + let mut handles = vec![]; - // for i in 0..50 { - // let clients = Arc::clone(&clients); - // let handle = tokio::spawn(async move { - // // let mut rng = rand::thread_rng(); - // let mut rng = rand::rngs::StdRng::from_entropy(); - // tokio::time::sleep(Duration::from_millis(rng.gen_range(5..50))).await; + for i in 0..10 { + let clients = Arc::clone(&clients); + let handle = tokio::spawn(async move { + // let mut rng = rand::thread_rng(); + let mut rng = rand::rngs::StdRng::from_entropy(); + tokio::time::sleep(Duration::from_millis(rng.gen_range(5..50))).await; - // // Randomly select an operation - // match rng.gen_range(0..4) { - // 0 => { - // println!("{i}: Listing tools for client 1 (stdio)"); - // match clients[0].list_tools().await { - // Ok(tools) => println!(" -> Got {:?}", tools), - // Err(e) => println!(" -> Error: {}", e), - // } - // } - // 1 => { - // println!("{i}: Listing tools for client 3 (sse)"); - // match clients[2].list_tools().await { - // Ok(tools) => println!(" -> Got {:?}", tools), - // Err(e) => println!(" -> Error: {}", e), - // } - // } - // 2 => { - // println!("{i}: Calling tool for client 2 (stdio)"); - // match clients[1] - // .call_tool( - // "echo", - // serde_json::json!({ "message": "Hello from client 2" }), - // ) - // .await - // { - // Ok(result) => println!(" → Tool execution result: {:?}", result), - // Err(e) => println!(" → Error: {}", e), - // } - // } - // 3 => { - // println!("{i}: Calling tool for client 3 (sse)"); - // match clients[2] - // .call_tool( - // "echo_tool", - // serde_json::json!({ "message": "Client with SSE transport - calling a tool" }), - // ) - // .await - // { - // Ok(result) => println!(" → Tool execution result: {:?}", result), - // Err(e) => println!(" -> Error: {}", e), - // } - // } - // _ => unreachable!(), - // } - // Ok::<(), Box>(()) - // }); - // handles.push(handle); - // } + // Randomly select an operation + match rng.gen_range(0..4) { + 0 => { + println!("\n{i}: Listing tools for client 1 (stdio)"); + match clients[0].list_tools().await { + Ok(tools) => println!(" {i}: -> Got {:?}", tools), + Err(e) => println!(" {i}: -> Error: {}", e), + } + } + 1 => { + println!("\n{i}: Listing tools for client 3 (sse)"); + match clients[2].list_tools().await { + Ok(tools) => println!(" {i}: -> Got {:?}", tools), + Err(e) => println!(" {i}: -> Error: {}", e), + } + } + 2 => { + println!("\n{i}: Calling tool for client 2 (stdio)"); + match clients[1] + .call_tool("git_status", serde_json::json!({ "repo_path": "." })) + .await + { + Ok(result) => println!(" {i}: → Tool execution result: {:?}", result), + Err(e) => println!(" {i}: → Error: {}", e), + } + } + 3 => { + println!("\n{i}: Calling tool for client 3 (sse)"); + match clients[2] + .call_tool( + "echo_tool", + serde_json::json!({ "message": "Client with SSE transport - calling a tool" }), + ) + .await + { + Ok(result) => println!(" {i}: → Tool execution result: {:?}", result), + Err(e) => println!(" {i}: → Error: {}", e), + } + } + _ => unreachable!(), + } + Ok::<(), Box>(()) + }); + handles.push(handle); + } - // // Wait for all tasks to complete - // for handle in handles { - // handle.await.unwrap().unwrap(); - // } + // Wait for all tasks to complete + for handle in handles { + handle.await.unwrap().unwrap(); + } Ok(()) } diff --git a/crates/mcp-client/examples/sse.rs b/crates/mcp-client/examples/sse.rs index 4b72946936f3..8d4dad3f4853 100644 --- a/crates/mcp-client/examples/sse.rs +++ b/crates/mcp-client/examples/sse.rs @@ -24,7 +24,7 @@ async fn main() -> Result<()> { let service = ServiceBuilder::new().service(TransportService::new(transport)); // Create client - let mut client = McpClientImpl::new(service); + let client = McpClientImpl::new(service); println!("Client created\n"); // Initialize diff --git a/crates/mcp-client/examples/stdio.rs b/crates/mcp-client/examples/stdio.rs index 850ee733f115..ead83b2a512b 100644 --- a/crates/mcp-client/examples/stdio.rs +++ b/crates/mcp-client/examples/stdio.rs @@ -25,7 +25,7 @@ async fn main() -> Result<(), ClientError> { let service = ServiceBuilder::new().service(TransportService::new(transport)); // Create client - let mut client = McpClientImpl::new(service); + let client = McpClientImpl::new(service); // Initialize let server_info = client diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index 2adbc785863b..2aeb69b5e1c0 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -1,12 +1,14 @@ -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use thiserror::Error; -use tower::ServiceExt; // for Service::ready() +use std::sync::atomic::{AtomicU64, Ordering}; use mcp_core::protocol::{ CallToolResult, InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListResourcesResult, ListToolsResult, ReadResourceResult, }; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use thiserror::Error; +use tokio::sync::Mutex; +use tower::ServiceExt; // for Service::ready() /// Error type for MCP client operations. #[derive(Debug, Error)] @@ -52,28 +54,28 @@ pub struct InitializeParams { pub trait McpClient: Send + Sync + 'static { /// Initialize the connection with the server. async fn initialize( - &mut self, + &self, info: ClientInfo, capabilities: ClientCapabilities, ) -> Result; /// List available resources. - async fn list_resources(&mut self) -> Result; + async fn list_resources(&self) -> Result; /// Read a resource's content. - async fn read_resource(&mut self, uri: &str) -> Result; + async fn read_resource(&self, uri: &str) -> Result; /// List available tools. - async fn list_tools(&mut self) -> Result; + async fn list_tools(&self) -> Result; /// Call a specific tool with arguments. - async fn call_tool(&mut self, name: &str, arguments: Value) -> Result; + async fn call_tool(&self, name: &str, arguments: Value) -> Result; } /// Standard implementation of the MCP client that sends requests via the provided service. pub struct McpClientImpl { - service: S, - next_id: u64, + service: Mutex, + next_id: AtomicU64, } impl McpClientImpl @@ -87,35 +89,35 @@ where { pub fn new(service: S) -> Self { Self { - service, - next_id: 1, + service: Mutex::new(service), + next_id: AtomicU64::new(1), } } /// Send a JSON-RPC request and check we don't get an error response. - async fn send_request(&mut self, method: &str, params: Value) -> Result + async fn send_request(&self, method: &str, params: Value) -> Result where R: for<'de> Deserialize<'de>, { - self.service.ready().await.map_err(|_| Error::NotReady)?; + let mut service = self.service.lock().await; + service.ready().await.map_err(|_| Error::NotReady)?; + let id = self.next_id.fetch_add(1, Ordering::SeqCst); let request = JsonRpcMessage::Request(JsonRpcRequest { jsonrpc: "2.0".to_string(), - id: Some(self.next_id), + id: Some(id), method: method.to_string(), params: Some(params), }); - self.next_id += 1; - - let response_msg = self.service.call(request).await?; + let response_msg = service.call(request).await?; match response_msg { JsonRpcMessage::Response(JsonRpcResponse { id, result, error, .. }) => { // Verify id matches - if id != Some(self.next_id - 1) { + if id != Some(self.next_id.load(Ordering::SeqCst) - 1) { return Err(Error::UnexpectedResponse); } if let Some(err) = error { @@ -130,7 +132,7 @@ where } } JsonRpcMessage::Error(JsonRpcError { id, error, .. }) => { - if id != Some(self.next_id - 1) { + if id != Some(self.next_id.load(Ordering::SeqCst) - 1) { return Err(Error::UnexpectedResponse); } Err(Error::RpcError { @@ -146,8 +148,9 @@ where } /// Send a JSON-RPC notification. - async fn send_notification(&mut self, method: &str, params: Value) -> Result<(), Error> { - self.service.ready().await.map_err(|_| Error::NotReady)?; + async fn send_notification(&self, method: &str, params: Value) -> Result<(), Error> { + let mut service = self.service.lock().await; + service.ready().await.map_err(|_| Error::NotReady)?; let notification = JsonRpcMessage::Notification(JsonRpcNotification { jsonrpc: "2.0".to_string(), @@ -155,7 +158,7 @@ where params: Some(params), }); - self.service.call(notification).await?; + service.call(notification).await?; Ok(()) } } @@ -168,11 +171,12 @@ where Response = JsonRpcMessage, Error = super::service::ServiceError, > + Send - + Sync + 'static, + + Sync + + 'static, S::Future: Send, { async fn initialize( - &mut self, + &self, info: ClientInfo, capabilities: ClientCapabilities, ) -> Result { @@ -191,21 +195,21 @@ where Ok(result) } - async fn list_resources(&mut self) -> Result { + async fn list_resources(&self) -> Result { self.send_request("resources/list", serde_json::json!({})) .await } - async fn read_resource(&mut self, uri: &str) -> Result { + async fn read_resource(&self, uri: &str) -> Result { let params = serde_json::json!({ "uri": uri }); self.send_request("resources/read", params).await } - async fn list_tools(&mut self) -> Result { + async fn list_tools(&self) -> Result { self.send_request("tools/list", serde_json::json!({})).await } - async fn call_tool(&mut self, name: &str, arguments: Value) -> Result { + async fn call_tool(&self, name: &str, arguments: Value) -> Result { let params = serde_json::json!({ "name": name, "arguments": arguments }); self.send_request("tools/call", params).await } From 03be424ae9ed56445790624b0374510123e08b48 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Fri, 20 Dec 2024 11:27:24 -0500 Subject: [PATCH 18/21] add an example that uses our mcp-client and mcp-server --- .../mcp-client/examples/stdio_integration.rs | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 crates/mcp-client/examples/stdio_integration.rs diff --git a/crates/mcp-client/examples/stdio_integration.rs b/crates/mcp-client/examples/stdio_integration.rs new file mode 100644 index 000000000000..58226711382f --- /dev/null +++ b/crates/mcp-client/examples/stdio_integration.rs @@ -0,0 +1,73 @@ +// This example shows how to use the mcp-client crate to interact with a server that has a simple counter tool. +// The server is started by running `cargo run -p mcp-server` in the root of the mcp-server crate. +use anyhow::Result; +use mcp_client::client::{ + ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientImpl, +}; +use mcp_client::{service::TransportService, transport::StdioTransport}; +use tower::ServiceBuilder; +use tracing_subscriber::EnvFilter; + +#[tokio::main] +async fn main() -> Result<(), ClientError> { + // Initialize logging + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::from_default_env() + .add_directive("mcp_client=debug".parse().unwrap()) + .add_directive("eventsource_client=debug".parse().unwrap()), + ) + .init(); + + // Create the transport + let transport = StdioTransport::new( + "cargo", + vec!["run", "-p", "mcp-server"] + .into_iter() + .map(|s| s.to_string()) + .collect(), + ); + + // Build service + // TODO: Add timeout middleware + let service = ServiceBuilder::new().service(TransportService::new(transport)); + + // Create client + let client = McpClientImpl::new(service); + + // Initialize + let server_info = client + .initialize( + ClientInfo { + name: "test-client".into(), + version: "1.0.0".into(), + }, + ClientCapabilities::default(), + ) + .await?; + println!("Connected to server: {server_info:?}\n"); + + // List tools + let tools = client.list_tools().await?; + println!("Available tools: {tools:?}\n"); + + // Call tool 'increment' tool 3 times + for _ in 0..3 { + let increment_result = client.call_tool("increment", serde_json::json!({})).await?; + println!("Tool result for 'increment': {increment_result:?}\n"); + } + + // Call tool 'get_value' + let get_value_result = client.call_tool("get_value", serde_json::json!({})).await?; + println!("Tool result for 'get_value': {get_value_result:?}\n"); + + // Call tool 'decrement' once + let decrement_result = client.call_tool("decrement", serde_json::json!({})).await?; + println!("Tool result for 'decrement': {decrement_result:?}\n"); + + // Call tool 'get_value' + let get_value_result = client.call_tool("get_value", serde_json::json!({})).await?; + println!("Tool result for 'get_value': {get_value_result:?}\n"); + + Ok(()) +} From 64b8ea6d6016f03c337e1a91b55def0926155eeb Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Fri, 20 Dec 2024 11:32:50 -0500 Subject: [PATCH 19/21] add kill_on_drop() for child process spawned in stdio transport --- crates/mcp-client/src/transport/stdio.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs index 749a952cce1e..208bea32a895 100644 --- a/crates/mcp-client/src/transport/stdio.rs +++ b/crates/mcp-client/src/transport/stdio.rs @@ -40,6 +40,7 @@ impl StdioTransport { .stdin(std::process::Stdio::piped()) .stdout(std::process::Stdio::piped()) .stderr(std::process::Stdio::inherit()) + .kill_on_drop(true) .spawn()?; let stdin = child From 1aacd48cabbe9ea35820bcf0e8debdb4123742ac Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Fri, 20 Dec 2024 11:38:25 -0500 Subject: [PATCH 20/21] make JsonRpcRaw struct fields private --- crates/mcp-core/src/protocol.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/crates/mcp-core/src/protocol.rs b/crates/mcp-core/src/protocol.rs index e19cb80da5c6..a98d8d79480f 100644 --- a/crates/mcp-core/src/protocol.rs +++ b/crates/mcp-core/src/protocol.rs @@ -41,7 +41,7 @@ pub struct JsonRpcError { } #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -#[serde(untagged)] +#[serde(untagged, try_from = "JsonRpcRaw")] pub enum JsonRpcMessage { Request(JsonRpcRequest), Response(JsonRpcResponse), @@ -52,17 +52,17 @@ pub enum JsonRpcMessage { #[derive(Debug, Serialize, Deserialize)] struct JsonRpcRaw { - pub jsonrpc: String, + jsonrpc: String, #[serde(skip_serializing_if = "Option::is_none")] - pub id: Option, + id: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub method: Option, + method: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub params: Option, + params: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub result: Option, + result: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, + error: Option, } impl TryFrom for JsonRpcMessage { From 691fc69aae18a3270360a459e629ba27e8d8d899 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Fri, 20 Dec 2024 14:06:09 -0500 Subject: [PATCH 21/21] send errors back through the channel to the user --- crates/mcp-client/src/transport/mod.rs | 12 ++++++ crates/mcp-client/src/transport/sse.rs | 49 ++++++++++++++++++------ crates/mcp-client/src/transport/stdio.rs | 20 +++++++--- 3 files changed, 64 insertions(+), 17 deletions(-) diff --git a/crates/mcp-client/src/transport/mod.rs b/crates/mcp-client/src/transport/mod.rs index 75167d6e5378..aa4030690dd9 100644 --- a/crates/mcp-client/src/transport/mod.rs +++ b/crates/mcp-client/src/transport/mod.rs @@ -24,6 +24,18 @@ pub enum Error { #[error("Channel closed")] ChannelClosed, + #[error("Serialization error: {0}")] + Serialization(#[from] serde_json::Error), + + #[error("HTTP error: {status} - {message}")] + HttpError { status: u16, message: String }, + + #[error("SSE connection error: {0}")] + SseConnection(String), + + #[error("Connection closed by server")] + ConnectionClosed, + #[error("Unexpected transport error: {0}")] Other(String), } diff --git a/crates/mcp-client/src/transport/sse.rs b/crates/mcp-client/src/transport/sse.rs index fac0998ff633..6bfd857a4c88 100644 --- a/crates/mcp-client/src/transport/sse.rs +++ b/crates/mcp-client/src/transport/sse.rs @@ -5,6 +5,7 @@ use reqwest::Client as HttpClient; use std::sync::Arc; use tokio::sync::{mpsc, oneshot, Mutex}; use tokio::task::JoinHandle; +use tracing::warn; use super::{Error, Transport, TransportMessage}; use mcp_core::protocol::JsonRpcMessage; @@ -65,7 +66,11 @@ impl SseTransport { let client = match eventsource_client::ClientBuilder::for_url(&sse_url) { Ok(builder) => builder.build(), Err(e) => { - eprintln!("Failed to create SSE client: {}", e); + // Properly handle initial connection error + let mut pending = pending_requests.lock().await; + for (_, tx) in pending.drain() { + let _ = tx.send(Err(Error::SseConnection(e.to_string()))); + } return; } }; @@ -110,7 +115,20 @@ impl SseTransport { let post_url = match post_endpoint.lock().await.as_ref() { Some(url) => url.clone(), None => { - eprintln!("No POST endpoint available"); + if let Some(response_tx) = transport_msg.response_tx { + let _ = response_tx.send(Err(Error::NotConnected)); + } + continue; + } + }; + + // Serialize message first + let message_str = match serde_json::to_string(&transport_msg.message) { + Ok(s) => s, + Err(e) => { + if let Some(response_tx) = transport_msg.response_tx { + let _ = response_tx.send(Err(Error::Serialization(e))); + } continue; } }; @@ -128,22 +146,29 @@ impl SseTransport { } // Send message via HTTP POST - let message_str = match serde_json::to_string(&transport_msg.message) { - Ok(s) => s, - Err(e) => { - eprintln!("Failed to serialize message: {}", e); - continue; - } - }; - - if let Err(e) = http_client + match http_client .post(&post_url) .header("Content-Type", "application/json") .body(message_str) .send() .await { - eprintln!("Failed to send message: {}", e); + Ok(response) => { + if !response.status().is_success() { + let error = Error::HttpError { + status: response.status().as_u16(), + message: response.status().to_string(), + }; + // We don't handle the error directly as it will come through SSE, + // but we log it for debugging purposes + warn!("HTTP request failed with error: {}", error); + } + } + Err(e) => { + let error = Error::Other(format!("HTTP request failed: {}", e)); + // Transport errors will also be communicated through the SSE channel + warn!("HTTP request failed with error: {}", error); + } } } diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs index 208bea32a895..70db75a6d6db 100644 --- a/crates/mcp-client/src/transport/stdio.rs +++ b/crates/mcp-client/src/transport/stdio.rs @@ -113,7 +113,9 @@ impl StdioTransport { let message_str = match serde_json::to_string(&transport_msg.message) { Ok(s) => s, Err(e) => { - eprintln!("Failed to serialize message: {}", e); + if let Some(tx) = transport_msg.response_tx { + let _ = tx.send(Err(Error::Serialization(e))); + } continue; } }; @@ -131,15 +133,23 @@ impl StdioTransport { } // Write message to stdin - if let Err(e) = stdin + if let Err(_) = stdin .write_all(format!("{}\n", message_str).as_bytes()) .await { - eprintln!("Failed to write to stdin: {}", e); + // Break with a specific error indicating write failure + let mut pending = pending_requests.lock().await; + for (_, tx) in pending.drain() { + let _ = tx.send(Err(Error::SendFailed)); + } break; } - if let Err(e) = stdin.flush().await { - eprintln!("Failed to flush stdin: {}", e); + if let Err(_) = stdin.flush().await { + // Break with a specific error indicating connection issues + let mut pending = pending_requests.lock().await; + for (_, tx) in pending.drain() { + let _ = tx.send(Err(Error::ConnectionClosed)); + } break; } }