diff --git a/Cargo.lock b/Cargo.lock index d51b2d6..9d4b91d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -61,9 +61,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.88" +version = "0.1.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", @@ -118,7 +118,7 @@ dependencies = [ "http 1.3.1", "http-body 1.0.1", "http-body-util", - "hyper 1.6.0", + "hyper 1.7.0", "hyper-util", "itoa", "matchit", @@ -170,7 +170,7 @@ dependencies = [ "fs-err", "http 1.3.1", "http-body 1.0.1", - "hyper 1.6.0", + "hyper 1.7.0", "hyper-util", "pin-project-lite", "rustls", @@ -239,9 +239,9 @@ dependencies = [ [[package]] name = "bitflags" -version = "2.9.1" +version = "2.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" +checksum = "6a65b545ab31d687cff52899d4890855fec459eb6afe0da6417b8a18da87aa29" [[package]] name = "bumpalo" @@ -257,9 +257,9 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] name = "cc" -version = "1.2.32" +version = "1.2.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2352e5597e9c544d5e6d9c95190d5d27738ade584fa8db0a16e130e5c2b5296e" +checksum = "3ee0f8803222ba5a7e2777dd72ca451868909b1ac410621b676adf07280e9b5f" dependencies = [ "jobserver", "libc", @@ -277,9 +277,9 @@ dependencies = [ [[package]] name = "cfg-if" -version = "1.0.1" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" +checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" [[package]] name = "cfg_aliases" @@ -870,13 +870,14 @@ dependencies = [ [[package]] name = "hyper" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" +checksum = "eb3aa54a13a0dfe7fbe3a59e0c76093041720fdc77b110cc0fc260fafb4dc51e" dependencies = [ + "atomic-waker", "bytes", "futures-channel", - "futures-util", + "futures-core", "h2 0.4.12", "http 1.3.1", "http-body 1.0.1", @@ -884,6 +885,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", + "pin-utils", "smallvec", "tokio", "want", @@ -896,7 +898,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" dependencies = [ "http 1.3.1", - "hyper 1.6.0", + "hyper 1.7.0", "hyper-util", "rustls", "rustls-pki-types", @@ -919,7 +921,7 @@ dependencies = [ "futures-util", "http 1.3.1", "http-body 1.0.1", - "hyper 1.6.0", + "hyper 1.7.0", "ipnet", "libc", "percent-encoding", @@ -1385,9 +1387,9 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.36" +version = "0.2.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff24dfcda44452b9816fff4cd4227e1bb73ff5a2f1bc1105aa92fb8565ce44d2" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", "syn", @@ -1395,9 +1397,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.97" +version = "1.0.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d61789d7719defeb74ea5fe81f2fdfdbd28a803847077cecce2ff14e1472f6f1" +checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de" dependencies = [ "unicode-ident", ] @@ -1432,7 +1434,7 @@ dependencies = [ "rustc-hash 2.1.1", "rustls", "socket2 0.5.10", - "thiserror 2.0.14", + "thiserror 2.0.15", "tokio", "tracing", "web-time", @@ -1453,7 +1455,7 @@ dependencies = [ "rustls", "rustls-pki-types", "slab", - "thiserror 2.0.14", + "thiserror 2.0.15", "tinyvec", "tracing", "web-time", @@ -1626,7 +1628,7 @@ dependencies = [ "http 1.3.1", "http-body 1.0.1", "http-body-util", - "hyper 1.6.0", + "hyper 1.7.0", "hyper-rustls", "hyper-util", "js-sys", @@ -1705,14 +1707,14 @@ dependencies = [ "axum", "axum-server", "futures", - "hyper 1.6.0", + "hyper 1.7.0", "reqwest", "rust-mcp-macros", "rust-mcp-schema", "rust-mcp-transport", "serde", "serde_json", - "thiserror 2.0.14", + "thiserror 2.0.15", "tokio", "tokio-stream", "tracing", @@ -1731,7 +1733,7 @@ dependencies = [ "rust-mcp-schema", "serde", "serde_json", - "thiserror 2.0.14", + "thiserror 2.0.15", "tokio", "tokio-stream", "tracing", @@ -1855,9 +1857,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.142" +version = "1.0.143" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "030fedb782600dcbd6f02d479bf0d817ac3bb40d644745b769d6a96bc3afc5a7" +checksum = "d401abef1d108fbd9cbaebc3e46611f4b1021f714a0597a71f41ee463f5f4a5a" dependencies = [ "itoa", "memchr", @@ -1932,7 +1934,7 @@ dependencies = [ "rust-mcp-sdk", "serde", "serde_json", - "thiserror 2.0.14", + "thiserror 2.0.15", "tokio", ] @@ -1946,7 +1948,7 @@ dependencies = [ "rust-mcp-sdk", "serde", "serde_json", - "thiserror 2.0.14", + "thiserror 2.0.15", "tokio", ] @@ -1960,7 +1962,7 @@ dependencies = [ "rust-mcp-sdk", "serde", "serde_json", - "thiserror 2.0.14", + "thiserror 2.0.15", "tokio", "tracing", "tracing-subscriber", @@ -1976,7 +1978,7 @@ dependencies = [ "rust-mcp-sdk", "serde", "serde_json", - "thiserror 2.0.14", + "thiserror 2.0.15", "tokio", "tracing", "tracing-subscriber", @@ -2028,9 +2030,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.104" +version = "2.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" +checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" dependencies = [ "proc-macro2", "quote", @@ -2068,11 +2070,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.14" +version = "2.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b0949c3a6c842cbde3f1686d6eea5a010516deb7085f79db747562d4102f41e" +checksum = "80d76d3f064b981389ecb4b6b7f45a0bf9fdac1d5b9204c7bd6714fecc302850" dependencies = [ - "thiserror-impl 2.0.14", + "thiserror-impl 2.0.15", ] [[package]] @@ -2088,9 +2090,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.14" +version = "2.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc5b44b4ab9c2fdd0e0512e6bece8388e214c0749f5862b114cc5b7a25daf227" +checksum = "44d29feb33e986b6ea906bd9c3559a856983f92371b3eaa5e83782a351623de0" dependencies = [ "proc-macro2", "quote", @@ -2149,9 +2151,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09b3661f17e86524eccd4371ab0429194e0d7c008abb45f7a7495b1719463c71" +checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" dependencies = [ "tinyvec_macros", ] diff --git a/README.md b/README.md index ef5b4ed..1581d1d 100644 --- a/README.md +++ b/README.md @@ -526,6 +526,7 @@ Below is a list of projects that utilize the `rust-mcp-sdk`, showcasing their na | | [text-to-cypher](https://github.com/FalkorDB/text-to-cypher) | A high-performance Rust-based API service that translates natural language text to Cypher queries for graph databases. | [GitHub](https://github.com/FalkorDB/text-to-cypher) | | | [notify-mcp](https://github.com/Tuurlijk/notify-mcp) | A Model Context Protocol (MCP) server that provides desktop notification functionality. | [GitHub](https://github.com/Tuurlijk/notify-mcp) | | | [lst](https://github.com/WismutHansen/lst) | `lst` is a personal lists, notes, and blog posts management application with a focus on plain-text storage, offline-first functionality, and multi-device synchronization. | [GitHub](https://github.com/WismutHansen/lst) | +| | [rust-mcp-server](https://github.com/Vaiz/rust-mcp-server) | `rust-mcp-server` allows the model to perform actions on your behalf, such as building, testing, and analyzing your Rust code. | [GitHub](https://github.com/Vaiz/rust-mcp-server) | diff --git a/crates/rust-mcp-sdk/README.md b/crates/rust-mcp-sdk/README.md index ef5b4ed..1581d1d 100644 --- a/crates/rust-mcp-sdk/README.md +++ b/crates/rust-mcp-sdk/README.md @@ -526,6 +526,7 @@ Below is a list of projects that utilize the `rust-mcp-sdk`, showcasing their na | | [text-to-cypher](https://github.com/FalkorDB/text-to-cypher) | A high-performance Rust-based API service that translates natural language text to Cypher queries for graph databases. | [GitHub](https://github.com/FalkorDB/text-to-cypher) | | | [notify-mcp](https://github.com/Tuurlijk/notify-mcp) | A Model Context Protocol (MCP) server that provides desktop notification functionality. | [GitHub](https://github.com/Tuurlijk/notify-mcp) | | | [lst](https://github.com/WismutHansen/lst) | `lst` is a personal lists, notes, and blog posts management application with a focus on plain-text storage, offline-first functionality, and multi-device synchronization. | [GitHub](https://github.com/WismutHansen/lst) | +| | [rust-mcp-server](https://github.com/Vaiz/rust-mcp-server) | `rust-mcp-server` allows the model to perform actions on your behalf, such as building, testing, and analyzing your Rust code. | [GitHub](https://github.com/Vaiz/rust-mcp-server) | diff --git a/crates/rust-mcp-sdk/src/error.rs b/crates/rust-mcp-sdk/src/error.rs index 2feab67..3de8d98 100644 --- a/crates/rust-mcp-sdk/src/error.rs +++ b/crates/rust-mcp-sdk/src/error.rs @@ -41,6 +41,3 @@ impl McpSdkError { None } } - -#[deprecated(since = "0.2.0", note = "Use `McpSdkError` instead.")] -pub type MCPSdkError = McpSdkError; diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs index f8ee1a0..c6fb208 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs @@ -148,7 +148,7 @@ pub trait ClientHandler: Send + Sync + 'static { //********************// async fn handle_error( &self, - error: RpcError, + error: &RpcError, runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { Ok(()) diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs index 3bbe5c9..a0afdf1 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs @@ -38,7 +38,7 @@ pub trait ClientHandlerCore: Send + Sync + 'static { /// - `error` – The error data received from the MCP server. async fn handle_error( &self, - error: RpcError, + error: &RpcError, runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError>; diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs index bf3fe17..89aebf5 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs @@ -319,7 +319,7 @@ pub trait ServerHandler: Send + Sync + 'static { /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. async fn handle_error( &self, - error: RpcError, + error: &RpcError, runtime: &dyn McpServer, ) -> std::result::Result<(), RpcError> { Ok(()) diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs index fffe2fc..e7b0e6d 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs @@ -45,7 +45,7 @@ pub trait ServerHandlerCore: Send + Sync + 'static { /// - `error` – The error data received from the MCP client. async fn handle_error( &self, - error: RpcError, + error: &RpcError, runtime: &dyn McpServer, ) -> std::result::Result<(), RpcError>; async fn on_server_started(&self, runtime: &dyn McpServer) { diff --git a/crates/rust-mcp-sdk/src/mcp_macros/tool_box.rs b/crates/rust-mcp-sdk/src/mcp_macros/tool_box.rs index 3bd2735..a5b75d5 100644 --- a/crates/rust-mcp-sdk/src/mcp_macros/tool_box.rs +++ b/crates/rust-mcp-sdk/src/mcp_macros/tool_box.rs @@ -57,15 +57,6 @@ macro_rules! tool_box { )* ] } - - #[deprecated(since = "0.2.0", note = "Use `tools()` instead.")] - pub fn get_tools() -> Vec { - vec![ - $( - $tool::tool(), - )* - ] - } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs index 8d113c3..7ee0815 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs @@ -1,20 +1,26 @@ pub mod mcp_client_runtime; pub mod mcp_client_runtime_core; -use crate::schema::{ - schema_utils::{ - self, ClientMessage, ClientMessages, FromMessage, MessageFromClient, ServerMessage, - ServerMessages, +use crate::{ + mcp_traits::{RequestIdGen, RequestIdGenNumeric}, + schema::{ + schema_utils::{ + self, ClientMessage, ClientMessages, FromMessage, MessageFromClient, ServerMessage, + ServerMessages, + }, + InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification, + RequestId, RpcError, ServerResult, }, - InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification, - RpcError, ServerResult, }; use async_trait::async_trait; use futures::future::{join_all, try_join_all}; use futures::StreamExt; use rust_mcp_transport::{IoStream, McpDispatch, MessageDispatcher, Transport}; -use std::sync::{Arc, RwLock}; +use std::{ + sync::{Arc, RwLock}, + time::Duration, +}; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::sync::Mutex; @@ -41,6 +47,7 @@ pub struct ClientRuntime { // Details about the connected server server_details: Arc>>, handlers: Mutex>>>, + request_id_gen: Box, } impl ClientRuntime { @@ -61,6 +68,7 @@ impl ClientRuntime { client_details, server_details: Arc::new(RwLock::new(None)), handlers: Mutex::new(vec![]), + request_id_gen: Box::new(RequestIdGenNumeric::new(None)), } } @@ -123,7 +131,19 @@ impl ClientRuntime { None } ServerMessage::Error(jsonrpc_error) => { - self.handler.handle_error(jsonrpc_error.error, self).await?; + self.handler + .handle_error(&jsonrpc_error.error, self) + .await?; + if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await { + tx_response + .send(ServerMessage::Error(jsonrpc_error)) + .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; + } else { + tracing::warn!( + "Received an error response with no corresponding request: {:?}", + &jsonrpc_error.id + ); + } None } ServerMessage::Response(response) => { @@ -133,7 +153,7 @@ impl ClientRuntime { .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; } else { tracing::warn!( - "Received response or error without a matching request: {:?}", + "Received a response with no corresponding request: {:?}", &response.id ); } @@ -284,6 +304,33 @@ impl McpClient for ClientRuntime { } } + async fn send( + &self, + message: MessageFromClient, + request_id: Option, + timeout: Option, + ) -> SdkResult> { + let sender = self.sender(); + let sender = sender.read().await; + let sender = sender + .as_ref() + .ok_or(schema_utils::SdkError::connection_closed())?; + + let outgoing_request_id = self + .request_id_gen + .request_id_for_message(&message, request_id); + + let mcp_message = ClientMessage::from_message(message, outgoing_request_id)?; + + let response = sender + .send_message(ClientMessages::Single(mcp_message), timeout) + .await? + .map(|res| res.as_single()) + .transpose()?; + + Ok(response) + } + async fn is_shut_down(&self) -> bool { self.transport.is_shut_down().await } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs index 9ccd4d9..7925f07 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs @@ -113,7 +113,7 @@ impl McpClientHandler for ClientInternalHandler> { /// Handles errors received from the server by passing the request to self.handler async fn handle_error( &self, - jsonrpc_error: RpcError, + jsonrpc_error: &RpcError, runtime: &dyn McpClient, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs index 3bdc318..8cb8cff 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs @@ -83,7 +83,7 @@ impl McpClientHandler for ClientCoreInternalHandler> async fn handle_error( &self, - jsonrpc_error: RpcError, + jsonrpc_error: &RpcError, runtime: &dyn McpClient, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index d787a10..49b5c3c 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -3,6 +3,7 @@ pub mod mcp_server_runtime_core; use crate::error::SdkResult; use crate::mcp_traits::mcp_handler::McpServerHandler; use crate::mcp_traits::mcp_server::McpServer; +use crate::mcp_traits::{RequestIdGen, RequestIdGenNumeric}; use crate::schema::{ schema_utils::{ ClientMessage, ClientMessages, FromMessage, MessageFromServer, SdkError, ServerMessage, @@ -45,6 +46,7 @@ pub struct ServerRuntime { #[cfg(feature = "hyper-server")] session_id: Option, transport_map: tokio::sync::RwLock>, + request_id_gen: Box, client_details_tx: watch::Sender>, client_details_rx: watch::Receiver>, } @@ -79,7 +81,7 @@ impl McpServer for ServerRuntime { message: MessageFromServer, request_id: Option, request_timeout: Option, - ) -> SdkResult> { + ) -> SdkResult> { let transport_map = self.transport_map.read().await; let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( RpcError::internal_error() @@ -87,14 +89,18 @@ impl McpServer for ServerRuntime { )?; let outgoing_request_id = self - .request_id_for_message(transport, &message, request_id) - .await; + .request_id_gen + .request_id_for_message(&message, request_id); let mcp_message = ServerMessage::from_message(message, outgoing_request_id)?; - transport + + let response = transport .send_message(ServerMessages::Single(mcp_message), request_timeout) - .map_err(|err| err.into()) - .await + .await? + .map(|res| res.as_single()) + .transpose()?; + + Ok(response) } async fn send_batch( @@ -211,40 +217,6 @@ impl ServerRuntime { Ok(()) } - /// Determines the request ID for an outgoing MCP message. - /// - /// For requests, generates a new ID using the internal counter. For responses or errors, - /// uses the provided `request_id`. Notifications receive no ID. - /// - /// # Arguments - /// * `message` - The MCP message to evaluate. - /// * `request_id` - An optional existing request ID (required for responses/errors). - /// - /// # Returns - /// An `Option`: `Some` for requests or responses/errors, `None` for notifications. - pub(crate) async fn request_id_for_message( - &self, - transport: &Arc< - dyn TransportDispatcher< - ClientMessages, - MessageFromServer, - ClientMessage, - ServerMessages, - ServerMessage, - >, - >, - message: &MessageFromServer, - request_id: Option, - ) -> Option { - let message_sender = transport.message_sender(); - let guard = message_sender.read().await; - if let Some(dispatcher) = guard.as_ref() { - dispatcher.request_id_for_message(message, request_id) - } else { - None - } - } - pub(crate) async fn handle_message( &self, message: ClientMessage, @@ -290,7 +262,19 @@ impl ServerRuntime { None } ClientMessage::Error(jsonrpc_error) => { - self.handler.handle_error(jsonrpc_error.error, self).await?; + self.handler + .handle_error(&jsonrpc_error.error, self) + .await?; + if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await { + tx_response + .send(ClientMessage::Error(jsonrpc_error)) + .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; + } else { + tracing::warn!( + "Received an error response with no corresponding request {:?}", + &jsonrpc_error.id + ); + } None } // The response is the result of a request, it is processed at the transport level. @@ -301,7 +285,7 @@ impl ServerRuntime { .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; } else { tracing::warn!( - "Received response or error without a matching request: {:?}", + "Received a response with no corresponding request: {:?}", &response.id ); } @@ -471,6 +455,7 @@ impl ServerRuntime { transport_map: tokio::sync::RwLock::new(HashMap::new()), client_details_tx, client_details_rx, + request_id_gen: Box::new(RequestIdGenNumeric::new(None)), } } @@ -497,6 +482,7 @@ impl ServerRuntime { transport_map: tokio::sync::RwLock::new(map), client_details_tx, client_details_rx, + request_id_gen: Box::new(RequestIdGenNumeric::new(None)), } } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs index 26f37e1..ea19e19 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs @@ -177,7 +177,7 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { async fn handle_error( &self, - jsonrpc_error: RpcError, + jsonrpc_error: &RpcError, runtime: &dyn McpServer, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs index 154b4bc..e0e7108 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs @@ -87,7 +87,7 @@ impl McpServerHandler for RuntimeCoreInternalHandler> } async fn handle_error( &self, - jsonrpc_error: RpcError, + jsonrpc_error: &RpcError, runtime: &dyn McpServer, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; diff --git a/crates/rust-mcp-sdk/src/mcp_traits.rs b/crates/rust-mcp-sdk/src/mcp_traits.rs index 511731c..2b155fa 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits.rs @@ -3,3 +3,6 @@ pub mod mcp_client; pub mod mcp_handler; #[cfg(feature = "server")] pub mod mcp_server; +mod request_id_gen; + +pub use request_id_gen::*; diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs index 8e72c26..1883581 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs @@ -10,7 +10,7 @@ use crate::schema::{ InitializeRequestParams, InitializeResult, ListPromptsRequest, ListPromptsRequestParams, ListResourceTemplatesRequest, ListResourceTemplatesRequestParams, ListResourcesRequest, ListResourcesRequestParams, ListRootsRequest, ListToolsRequest, ListToolsRequestParams, - LoggingLevel, PingRequest, ReadResourceRequest, ReadResourceRequestParams, + LoggingLevel, PingRequest, ReadResourceRequest, ReadResourceRequestParams, RequestId, RootsListChangedNotification, RootsListChangedNotificationParams, RpcError, ServerCapabilities, SetLevelRequest, SetLevelRequestParams, SubscribeRequest, SubscribeRequestParams, UnsubscribeRequest, UnsubscribeRequestParams, @@ -35,16 +35,6 @@ pub trait McpClient: Sync + Send { fn client_info(&self) -> &InitializeRequestParams; fn server_info(&self) -> Option; - #[deprecated(since = "0.2.0", note = "Use `client_info()` instead.")] - fn get_client_info(&self) -> &InitializeRequestParams { - self.client_info() - } - - #[deprecated(since = "0.2.0", note = "Use `server_info()` instead.")] - fn get_server_info(&self) -> Option { - self.server_info() - } - /// Checks whether the server has been initialized with client fn is_initialized(&self) -> bool { self.server_info().is_some() @@ -57,23 +47,12 @@ pub trait McpClient: Sync + Send { .map(|server_details| server_details.server_info) } - #[deprecated(since = "0.2.0", note = "Use `server_version()` instead.")] - fn get_server_version(&self) -> Option { - self.server_info() - .map(|server_details| server_details.server_info) - } - /// Returns the server's capabilities. /// After initialization has completed, this will be populated with the server's reported capabilities. fn server_capabilities(&self) -> Option { self.server_info().map(|item| item.capabilities) } - #[deprecated(since = "0.2.0", note = "Use `server_capabilities()` instead.")] - fn get_server_capabilities(&self) -> Option { - self.server_info().map(|item| item.capabilities) - } - /// Checks if the server has tools available. /// /// This function retrieves the server information and checks if the @@ -156,10 +135,6 @@ pub trait McpClient: Sync + Send { self.server_info() .map(|server_details| server_details.capabilities.logging.is_some()) } - #[deprecated(since = "0.2.0", note = "Use `instructions()` instead.")] - fn get_instructions(&self) -> Option { - self.server_info()?.instructions - } fn instructions(&self) -> Option { self.server_info()?.instructions @@ -175,27 +150,15 @@ pub trait McpClient: Sync + Send { request: RequestFromClient, timeout: Option, ) -> SdkResult { - let sender = self.sender(); - let sender = sender.read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; - - let request_id = sender.next_request_id(); - - let mcp_message = - ClientMessage::from_message(MessageFromClient::from(request), Some(request_id))?; - let response = sender - .send_message(ClientMessages::Single(mcp_message), timeout) + let response = self + .send(MessageFromClient::RequestFromClient(request), None, timeout) .await?; let server_message = response.ok_or_else(|| { RpcError::internal_error() - .with_message("An empty response was received from the server.".to_string()) + .with_message("An empty response was received from the client.".to_string()) })?; - let server_message = server_message.as_single()?; - if server_message.is_error() { return Err(server_message.as_error()?.error.into()); } @@ -205,27 +168,10 @@ pub trait McpClient: Sync + Send { async fn send( &self, - message: ClientMessage, + message: MessageFromClient, + request_id: Option, timeout: Option, - ) -> SdkResult> { - let sender = self.sender(); - let sender = sender.read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; - - let response = sender - .send_message(ClientMessages::Single(message), timeout) - .await?; - - match response { - Some(res) => { - let server_results = res.as_single()?; - Ok(Some(server_results)) - } - None => Ok(None), - } - } + ) -> SdkResult>; async fn send_batch( &self, diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs index c86a623..2974bfc 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs @@ -24,8 +24,11 @@ pub trait McpServerHandler: Send + Sync { client_jsonrpc_request: RequestFromClient, runtime: &dyn McpServer, ) -> std::result::Result; - async fn handle_error(&self, jsonrpc_error: RpcError, runtime: &dyn McpServer) - -> SdkResult<()>; + async fn handle_error( + &self, + jsonrpc_error: &RpcError, + runtime: &dyn McpServer, + ) -> SdkResult<()>; async fn handle_notification( &self, client_jsonrpc_notification: NotificationFromClient, @@ -41,8 +44,11 @@ pub trait McpClientHandler: Send + Sync { server_jsonrpc_request: RequestFromServer, runtime: &dyn McpClient, ) -> std::result::Result; - async fn handle_error(&self, jsonrpc_error: RpcError, runtime: &dyn McpClient) - -> SdkResult<()>; + async fn handle_error( + &self, + jsonrpc_error: &RpcError, + runtime: &dyn McpClient, + ) -> SdkResult<()>; async fn handle_notification( &self, server_jsonrpc_notification: NotificationFromServer, diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs index a1d501d..0130c33 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs @@ -2,8 +2,8 @@ use std::time::Duration; use crate::schema::{ schema_utils::{ - ClientMessage, ClientMessages, McpMessage, MessageFromServer, NotificationFromServer, - RequestFromServer, ResultFromClient, ServerMessage, + ClientMessage, McpMessage, MessageFromServer, NotificationFromServer, RequestFromServer, + ResultFromClient, ServerMessage, }, CallToolRequest, CreateMessageRequest, CreateMessageRequestParams, CreateMessageResult, GetPromptRequest, Implementation, InitializeRequestParams, InitializeResult, @@ -29,22 +29,12 @@ pub trait McpServer: Sync + Send { async fn wait_for_initialization(&self); - #[deprecated(since = "0.2.0", note = "Use `client_info()` instead.")] - fn get_client_info(&self) -> Option { - self.client_info() - } - - #[deprecated(since = "0.2.0", note = "Use `server_info()` instead.")] - fn get_server_info(&self) -> &InitializeResult { - self.server_info() - } - async fn send( &self, message: MessageFromServer, request_id: Option, request_timeout: Option, - ) -> SdkResult>; + ) -> SdkResult>; async fn send_batch( &self, @@ -84,13 +74,11 @@ pub trait McpServer: Sync + Send { .send(MessageFromServer::RequestFromServer(request), None, timeout) .await?; - let client_messages = response.ok_or_else(|| { + let client_message = response.ok_or_else(|| { RpcError::internal_error() .with_message("An empty response was received from the client.".to_string()) })?; - let client_message = client_messages.as_single()?; - if client_message.is_error() { return Err(client_message.as_error()?.error.into()); } diff --git a/crates/rust-mcp-sdk/src/mcp_traits/request_id_gen.rs b/crates/rust-mcp-sdk/src/mcp_traits/request_id_gen.rs new file mode 100644 index 0000000..2372ae9 --- /dev/null +++ b/crates/rust-mcp-sdk/src/mcp_traits/request_id_gen.rs @@ -0,0 +1,101 @@ +use std::sync::atomic::AtomicI64; + +use crate::schema::{schema_utils::McpMessage, RequestId}; +use async_trait::async_trait; + +/// A trait for generating and managing request IDs in a thread-safe manner. +/// +/// Implementors provide functionality to generate unique request IDs, retrieve the last +/// generated ID, and reset the ID counter. +#[async_trait] +pub trait RequestIdGen: Send + Sync { + fn next_request_id(&self) -> RequestId; + #[allow(unused)] + fn last_request_id(&self) -> Option; + #[allow(unused)] + fn reset_to(&self, id: u64); + + /// Determines the request ID for an outgoing MCP message. + /// + /// For requests, generates a new ID using the internal counter. For responses or errors, + /// uses the provided `request_id`. Notifications receive no ID. + /// + /// # Arguments + /// * `message` - The MCP message to evaluate. + /// * `request_id` - An optional existing request ID (required for responses/errors). + /// + /// # Returns + /// An `Option`: `Some` for requests or responses/errors, `None` for notifications. + fn request_id_for_message( + &self, + message: &dyn McpMessage, + request_id: Option, + ) -> Option { + // we need to produce next request_id for requests + if message.is_request() { + // request_id should be None for requests + assert!(request_id.is_none()); + Some(self.next_request_id()) + } else if !message.is_notification() { + // `request_id` must not be `None` for errors, notifications and responses + assert!(request_id.is_some()); + request_id + } else { + None + } + } +} + +pub struct RequestIdGenNumeric { + message_id_counter: AtomicI64, + last_message_id: AtomicI64, +} + +impl RequestIdGenNumeric { + pub fn new(initial_id: Option) -> Self { + Self { + message_id_counter: AtomicI64::new(initial_id.unwrap_or(0) as i64), + last_message_id: AtomicI64::new(-1), + } + } +} + +impl RequestIdGen for RequestIdGenNumeric { + /// Generates the next unique request ID as an integer. + /// + /// Increments the internal counter atomically and updates the last generated ID. + /// Uses `Relaxed` ordering for performance, as the counter only needs to ensure unique IDs. + fn next_request_id(&self) -> RequestId { + let id = self + .message_id_counter + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + // Store the new ID as the last generated ID + self.last_message_id + .store(id, std::sync::atomic::Ordering::Relaxed); + RequestId::Integer(id) + } + + /// Returns the last generated request ID, if any. + /// + /// Returns `None` if no ID has been generated (indicated by a sentinel value of -1). + /// Uses `Relaxed` ordering since the read operation doesn’t require synchronization + /// with other memory operations beyond atomicity. + fn last_request_id(&self) -> Option { + let last_id = self + .last_message_id + .load(std::sync::atomic::Ordering::Relaxed); + if last_id == -1 { + None + } else { + Some(RequestId::Integer(last_id)) + } + } + + /// Resets the internal counter to the specified ID. + /// + /// The provided `id` (u64) is converted to i64 and stored atomically. + fn reset_to(&self, id: u64) { + self.message_id_counter + .store(id as i64, std::sync::atomic::Ordering::Relaxed); + } +} diff --git a/crates/rust-mcp-sdk/tests/common/common.rs b/crates/rust-mcp-sdk/tests/common/common.rs index 57a3ea8..564db0d 100644 --- a/crates/rust-mcp-sdk/tests/common/common.rs +++ b/crates/rust-mcp-sdk/tests/common/common.rs @@ -128,8 +128,10 @@ use futures::stream::Stream; // stream: &mut impl Stream>, pub async fn read_sse_event_from_stream( stream: &mut (impl Stream> + Unpin), -) -> Option { + event_count: usize, +) -> Option> { let mut buffer = String::new(); + let mut events = vec![]; while let Some(item) = stream.next().await { match item { @@ -158,7 +160,10 @@ pub async fn read_sse_event_from_stream( // Return if data was found if let Some(data) = data { - return Some(data); + events.push(data); + if events.len().eq(&event_count) { + return Some(events); + } } } } @@ -171,9 +176,9 @@ pub async fn read_sse_event_from_stream( None } -pub async fn read_sse_event(response: Response) -> Option { +pub async fn read_sse_event(response: Response, event_count: usize) -> Option> { let mut stream = response.bytes_stream(); - read_sse_event_from_stream(&mut stream).await + read_sse_event_from_stream(&mut stream, event_count).await } pub fn test_client_info() -> InitializeRequestParams { diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http.rs b/crates/rust-mcp-sdk/tests/test_streamable_http.rs index 5eb5e47..23ca27f 100644 --- a/crates/rust-mcp-sdk/tests/test_streamable_http.rs +++ b/crates/rust-mcp-sdk/tests/test_streamable_http.rs @@ -169,8 +169,8 @@ async fn should_handle_post_requests_via_sse_response_correctly() { assert_eq!(response.status(), StatusCode::OK); - let event = read_sse_event(response).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&event).unwrap(); + let events = read_sse_event(response, 1).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -220,8 +220,8 @@ async fn should_call_a_tool_and_return_the_result() { assert_eq!(response.status(), StatusCode::OK); - let event = read_sse_event(response).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&event).unwrap(); + let events = read_sse_event(response, 1).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -345,8 +345,8 @@ async fn should_establish_standalone_stream_and_receive_server_messages() { .await .unwrap(); - let event = read_sse_event(response).await.unwrap(); - let message: ServerJsonrpcNotification = serde_json::from_str(&event).unwrap(); + let events = read_sse_event(response, 1).await.unwrap(); + let message: ServerJsonrpcNotification = serde_json::from_str(&events[0]).unwrap(); let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( notification, @@ -365,7 +365,7 @@ async fn should_establish_standalone_stream_and_receive_server_messages() { server.hyper_runtime.await_server().await.unwrap() } -// should establish standalone SSE stream and receive server-initiated messages +// should establish standalone SSE stream and receive server-initiated requests #[tokio::test] async fn should_establish_standalone_stream_and_receive_server_requests() { let (server, session_id) = initialize_server(None).await.unwrap(); @@ -394,48 +394,59 @@ async fn should_establish_standalone_stream_and_receive_server_requests() { ); let hyper_server = Arc::new(server.hyper_runtime); - let hyper_server_clone = hyper_server.clone(); - let session_id_clone = session_id.to_string(); - - tokio::spawn(async move { - // Send a server-initiated notification that should appear on SSE stream with a valid request_id - hyper_server_clone - .list_roots(&session_id_clone, None) - .await - .unwrap(); - }); - - tokio::time::sleep(Duration::from_millis(2250)).await; - - let json_rpc_message: ClientJsonrpcResponse = ClientJsonrpcResponse::new( - RequestId::Integer(0), - ListRootsResult { - meta: None, - roots: vec![], - } - .into(), - ); - send_post_request( - &server.streamable_url, - &serde_json::to_string(&json_rpc_message).unwrap(), - Some(&session_id), - None, - ) - .await - .expect("Request failed"); + // Send two server-initiated request that should appear on SSE stream with a valid request_id + for _ in 0..2 { + let hyper_server_clone = hyper_server.clone(); + let session_id_clone = session_id.to_string(); + tokio::spawn(async move { + hyper_server_clone + .list_roots(&session_id_clone, None) + .await + .unwrap(); + }); + } + + for i in 0..2 { + // send responses back to the server for two server initiated requests + let json_rpc_message: ClientJsonrpcResponse = ClientJsonrpcResponse::new( + RequestId::Integer(i), + ListRootsResult { + meta: None, + roots: vec![], + } + .into(), + ); + send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message).unwrap(), + Some(&session_id), + None, + ) + .await + .expect("Request failed"); + } - let event = read_sse_event(response).await.unwrap(); + // read two events from the sse stream + let events = read_sse_event(response, 2).await.unwrap(); - let message: ServerJsonrpcRequest = serde_json::from_str(&event).unwrap(); + let message1: ServerJsonrpcRequest = serde_json::from_str(&events[0]).unwrap(); - println!(">>> message {:?} ", message); + let RequestFromServer::ServerRequest(ServerRequest::ListRootsRequest(_)) = message1.request + else { + panic!("invalid message received!"); + }; - let RequestFromServer::ServerRequest(ServerRequest::ListRootsRequest(_)) = message.request + let message2: ServerJsonrpcRequest = serde_json::from_str(&events[1]).unwrap(); + + let RequestFromServer::ServerRequest(ServerRequest::ListRootsRequest(_)) = message1.request else { panic!("invalid message received!"); }; + // ensure request_ids are unique + assert!(message2.id != message1.id); + hyper_server.graceful_shutdown(ONE_MILLISECOND); } @@ -461,7 +472,7 @@ async fn should_not_close_get_sse_stream() { .unwrap(); let mut stream = response.bytes_stream(); - let event = read_sse_event_from_stream(&mut stream).await.unwrap(); + let event = read_sse_event_from_stream(&mut stream, 1).await.unwrap()[0].clone(); let message: ServerJsonrpcNotification = serde_json::from_str(&event).unwrap(); let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( @@ -490,7 +501,7 @@ async fn should_not_close_get_sse_stream() { .await .unwrap(); - let event = read_sse_event_from_stream(&mut stream).await.unwrap(); + let event = read_sse_event_from_stream(&mut stream, 1).await.unwrap()[0].clone(); let message: ServerJsonrpcNotification = serde_json::from_str(&event).unwrap(); let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( @@ -702,8 +713,8 @@ async fn should_send_response_messages_to_the_connection_that_sent_the_request() assert_eq!(response_1.status(), StatusCode::OK); assert_eq!(response_2.status(), StatusCode::OK); - let event = read_sse_event(response_2).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&event).unwrap(); + let events = read_sse_event(response_2, 1).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -718,8 +729,8 @@ async fn should_send_response_messages_to_the_connection_that_sent_the_request() "Hello, Ali!" ); - let event = read_sse_event(response_1).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&event).unwrap(); + let events = read_sse_event(response_1, 1).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -1069,8 +1080,8 @@ async fn should_handle_batch_request_messages_with_sse_stream_for_responses() { "text/event-stream" ); - let event = read_sse_event(response).await.unwrap(); - let message: ServerMessages = serde_json::from_str(&event).unwrap(); + let events = read_sse_event(response, 1).await.unwrap(); + let message: ServerMessages = serde_json::from_str(&events[0]).unwrap(); let ServerMessages::Batch(mut messages) = message else { panic!("Invalid message type"); diff --git a/crates/rust-mcp-transport/src/mcp_stream.rs b/crates/rust-mcp-transport/src/mcp_stream.rs index 2d2a377..08bdc21 100644 --- a/crates/rust-mcp-transport/src/mcp_stream.rs +++ b/crates/rust-mcp-transport/src/mcp_stream.rs @@ -5,12 +5,7 @@ use crate::{ utils::CancellationToken, IoStream, }; -use std::{ - collections::HashMap, - pin::Pin, - sync::{atomic::AtomicI64, Arc}, - time::Duration, -}; +use std::{collections::HashMap, pin::Pin, sync::Arc, time::Duration}; use tokio::task::JoinHandle; use tokio::{ io::{AsyncBufReadExt, BufReader}, @@ -57,12 +52,7 @@ impl MCPStream { // rpc message stream that receives incoming messages - let sender = MessageDispatcher::new( - pending_requests, - writable, - Arc::new(AtomicI64::new(0)), - request_timeout, - ); + let sender = MessageDispatcher::new(pending_requests, writable, request_timeout); (stream, sender, error_io) } diff --git a/crates/rust-mcp-transport/src/message_dispatcher.rs b/crates/rust-mcp-transport/src/message_dispatcher.rs index 22d0b58..ea1eb04 100644 --- a/crates/rust-mcp-transport/src/message_dispatcher.rs +++ b/crates/rust-mcp-transport/src/message_dispatcher.rs @@ -10,7 +10,6 @@ use futures::future::join_all; use std::collections::HashMap; use std::pin::Pin; -use std::sync::atomic::AtomicI64; use std::sync::Arc; use std::time::Duration; use tokio::io::AsyncWriteExt; @@ -31,7 +30,6 @@ use crate::McpDispatch; pub struct MessageDispatcher { pending_requests: Arc>>>, writable_std: Mutex>>, - message_id_counter: Arc, request_timeout: Duration, } @@ -49,53 +47,15 @@ impl MessageDispatcher { pub fn new( pending_requests: Arc>>>, writable_std: Mutex>>, - message_id_counter: Arc, request_timeout: Duration, ) -> Self { Self { pending_requests, writable_std, - message_id_counter, request_timeout, } } - /// Determines the request ID for an outgoing MCP message. - /// - /// For requests, generates a new ID using the internal counter. For responses or errors, - /// uses the provided `request_id`. Notifications receive no ID. - /// - /// # Arguments - /// * `message` - The MCP message to evaluate. - /// * `request_id` - An optional existing request ID (required for responses/errors). - /// - /// # Returns - /// An `Option`: `Some` for requests or responses/errors, `None` for notifications. - pub fn request_id_for_message( - &self, - message: &impl McpMessage, - request_id: Option, - ) -> Option { - // we need to produce next request_id for requests - if message.is_request() { - // request_id should be None for requests - assert!(request_id.is_none()); - Some(self.next_request_id()) - } else if !message.is_notification() { - // `request_id` must not be `None` for errors, notifications and responses - assert!(request_id.is_some()); - request_id - } else { - None - } - } - pub fn next_request_id(&self) -> RequestId { - RequestId::Integer( - self.message_id_counter - .fetch_add(1, std::sync::atomic::Ordering::Relaxed), - ) - } - async fn store_pending_request( &self, request_id: RequestId, diff --git a/examples/hello-world-mcp-server-core/src/handler.rs b/examples/hello-world-mcp-server-core/src/handler.rs index fcde15e..f0bdefe 100644 --- a/examples/hello-world-mcp-server-core/src/handler.rs +++ b/examples/hello-world-mcp-server-core/src/handler.rs @@ -98,7 +98,7 @@ impl ServerHandlerCore for MyServerHandler { // Process incoming client errors async fn handle_error( &self, - error: RpcError, + error: &RpcError, _: &dyn McpServer, ) -> std::result::Result<(), RpcError> { Ok(()) diff --git a/examples/hello-world-server-core-streamable-http/src/handler.rs b/examples/hello-world-server-core-streamable-http/src/handler.rs index 53f884c..1c69e8c 100644 --- a/examples/hello-world-server-core-streamable-http/src/handler.rs +++ b/examples/hello-world-server-core-streamable-http/src/handler.rs @@ -103,7 +103,7 @@ impl ServerHandlerCore for MyServerHandler { // Process incoming client errors async fn handle_error( &self, - error: RpcError, + error: &RpcError, _: &dyn McpServer, ) -> std::result::Result<(), RpcError> { Ok(()) diff --git a/examples/simple-mcp-client-core-sse/src/handler.rs b/examples/simple-mcp-client-core-sse/src/handler.rs index a1a95e4..ab86e9e 100644 --- a/examples/simple-mcp-client-core-sse/src/handler.rs +++ b/examples/simple-mcp-client-core-sse/src/handler.rs @@ -41,16 +41,30 @@ impl ClientHandlerCore for MyClientHandler { async fn handle_notification( &self, - _notification: NotificationFromServer, + notification: NotificationFromServer, _runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { - Err(RpcError::internal_error() - .with_message("handle_notification() Not implemented".to_string())) + if let NotificationFromServer::ServerNotification( + schema::ServerNotification::LoggingMessageNotification(logging_message_notification), + ) = notification + { + println!( + "Notification from server: {}", + logging_message_notification.params.data + ); + } else { + println!( + "A {} notification received from the server", + notification.method() + ); + }; + + Ok(()) } async fn handle_error( &self, - _error: RpcError, + _error: &RpcError, _runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { Err(RpcError::internal_error().with_message("handle_error() Not implemented".to_string())) diff --git a/examples/simple-mcp-client-core/src/handler.rs b/examples/simple-mcp-client-core/src/handler.rs index a1a95e4..bd5e4fe 100644 --- a/examples/simple-mcp-client-core/src/handler.rs +++ b/examples/simple-mcp-client-core/src/handler.rs @@ -50,7 +50,7 @@ impl ClientHandlerCore for MyClientHandler { async fn handle_error( &self, - _error: RpcError, + _error: &RpcError, _runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { Err(RpcError::internal_error().with_message("handle_error() Not implemented".to_string()))