From 51af21a683e475b27c064a40894516430db13ddb Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sat, 13 Sep 2025 12:12:34 -0300 Subject: [PATCH 1/4] Add Streamable HTTP Client and multiple refactoring and improvements --- .release-manifest.json | 20 +- Cargo.lock | 94 +- Cargo.toml | 17 +- README.md | 27 +- crates/rust-mcp-sdk/Cargo.toml | 41 +- crates/rust-mcp-sdk/README.md | 28 +- crates/rust-mcp-sdk/src/error.rs | 37 +- .../src/hyper_servers/app_state.rs | 11 +- .../src/hyper_servers/routes/hyper_utils.rs | 100 ++- .../src/hyper_servers/routes/sse_routes.rs | 22 +- .../routes/streamable_http_routes.rs | 7 +- .../rust-mcp-sdk/src/hyper_servers/server.rs | 13 +- .../src/hyper_servers/session_store.rs | 24 - crates/rust-mcp-sdk/src/id_generator.rs | 5 + .../src/id_generator/fast_id_generator.rs | 53 ++ .../src/id_generator/uuid_generator.rs | 18 + crates/rust-mcp-sdk/src/lib.rs | 9 +- .../src/mcp_handlers/mcp_server_handler.rs | 51 +- .../mcp_handlers/mcp_server_handler_core.rs | 17 +- .../src/mcp_runtimes/client_runtime.rs | 531 ++++++++--- .../client_runtime/mcp_client_runtime.rs | 24 +- .../client_runtime/mcp_client_runtime_core.rs | 33 +- .../src/mcp_runtimes/server_runtime.rs | 283 ++++-- .../server_runtime/mcp_server_runtime.rs | 21 +- .../server_runtime/mcp_server_runtime_core.rs | 15 +- crates/rust-mcp-sdk/src/mcp_traits.rs | 2 + .../src/mcp_traits/id_generator.rs | 12 + .../rust-mcp-sdk/src/mcp_traits/mcp_client.rs | 49 +- .../src/mcp_traits/mcp_handler.rs | 11 +- .../rust-mcp-sdk/src/mcp_traits/mcp_server.rs | 7 +- crates/rust-mcp-sdk/src/utils.rs | 43 +- crates/rust-mcp-sdk/tests/check_imports.rs | 5 +- crates/rust-mcp-sdk/tests/common/common.rs | 57 +- .../rust-mcp-sdk/tests/common/mock_server.rs | 528 +++++++++++ .../rust-mcp-sdk/tests/common/test_client.rs | 163 ++++ .../rust-mcp-sdk/tests/common/test_server.rs | 31 +- .../tests/test_protocol_compatibility.rs | 2 +- .../tests/test_streamable_http_client.rs | 823 ++++++++++++++++++ ...http.rs => test_streamable_http_server.rs} | 7 +- crates/rust-mcp-transport/Cargo.toml | 4 +- crates/rust-mcp-transport/README.md | 4 +- crates/rust-mcp-transport/src/client_sse.rs | 101 ++- .../src/client_streamable_http.rs | 515 +++++++++++ crates/rust-mcp-transport/src/constants.rs | 3 + crates/rust-mcp-transport/src/error.rs | 71 +- crates/rust-mcp-transport/src/lib.rs | 17 +- crates/rust-mcp-transport/src/mcp_stream.rs | 37 + .../src/message_dispatcher.rs | 82 +- crates/rust-mcp-transport/src/sse.rs | 4 +- crates/rust-mcp-transport/src/stdio.rs | 74 +- crates/rust-mcp-transport/src/transport.rs | 35 +- crates/rust-mcp-transport/src/utils.rs | 28 +- .../src/utils/http_utils.rs | 125 ++- .../src/utils/sse_parser.rs | 320 +++++++ .../src/utils/streamable_http_stream.rs | 374 ++++++++ .../rust-mcp-transport/tests/check_imports.rs | 5 +- development.md | 6 +- doc/getting-started-mcp-server.md | 4 +- .../.gitignore | 0 .../Cargo.toml | 5 +- .../README.md | 8 +- .../src/handler.rs | 8 +- .../src/main.rs | 0 .../src/tools.rs | 0 .../Cargo.toml | 7 +- .../README.md | 8 +- .../src/handler.rs | 5 +- .../src/main.rs | 5 +- .../src/tools.rs | 0 .../.gitignore | 0 .../Cargo.toml | 5 +- .../README.md | 4 +- .../src/handler.rs | 8 +- .../src/main.rs | 0 .../src/tools.rs | 0 .../Cargo.toml | 1 + .../README.md | 2 +- .../src/handler.rs | 11 +- .../Cargo.toml | 5 +- .../README.md | 2 +- .../src/handler.rs | 0 .../src/inquiry_utils.rs | 0 .../src/main.rs | 1 + examples/simple-mcp-client-sse/Cargo.toml | 2 + examples/simple-mcp-client-sse/src/main.rs | 13 +- .../Cargo.toml | 5 +- .../README.md | 2 +- .../src/handler.rs | 0 .../src/inquiry_utils.rs | 0 .../src/main.rs | 0 .../Cargo.toml | 5 +- .../README.md | 2 +- .../src/handler.rs | 0 .../src/inquiry_utils.rs | 0 .../src/main.rs | 0 .../Cargo.toml | 29 + .../README.md | 40 + .../src/handler.rs | 72 ++ .../src/inquiry_utils.rs | 222 +++++ .../src/main.rs | 95 ++ .../Cargo.toml | 29 + .../README.md | 40 + .../src/handler.rs | 10 + .../src/inquiry_utils.rs | 222 +++++ .../src/main.rs | 99 +++ 105 files changed, 5330 insertions(+), 692 deletions(-) create mode 100644 crates/rust-mcp-sdk/src/id_generator.rs create mode 100644 crates/rust-mcp-sdk/src/id_generator/fast_id_generator.rs create mode 100644 crates/rust-mcp-sdk/src/id_generator/uuid_generator.rs create mode 100644 crates/rust-mcp-sdk/src/mcp_traits/id_generator.rs create mode 100644 crates/rust-mcp-sdk/tests/common/mock_server.rs create mode 100644 crates/rust-mcp-sdk/tests/common/test_client.rs create mode 100644 crates/rust-mcp-sdk/tests/test_streamable_http_client.rs rename crates/rust-mcp-sdk/tests/{test_streamable_http.rs => test_streamable_http_server.rs} (99%) create mode 100644 crates/rust-mcp-transport/src/client_streamable_http.rs create mode 100644 crates/rust-mcp-transport/src/constants.rs create mode 100644 crates/rust-mcp-transport/src/utils/sse_parser.rs create mode 100644 crates/rust-mcp-transport/src/utils/streamable_http_stream.rs rename examples/{hello-world-mcp-server-core => hello-world-mcp-server-stdio-core}/.gitignore (100%) rename examples/{hello-world-mcp-server-core => hello-world-mcp-server-stdio-core}/Cargo.toml (83%) rename examples/{hello-world-mcp-server-core => hello-world-mcp-server-stdio-core}/README.md (81%) rename examples/{hello-world-mcp-server-core => hello-world-mcp-server-stdio-core}/src/handler.rs (97%) rename examples/{hello-world-mcp-server-core => hello-world-mcp-server-stdio-core}/src/main.rs (100%) rename examples/{hello-world-mcp-server-core => hello-world-mcp-server-stdio-core}/src/tools.rs (100%) rename examples/{hello-world-mcp-server => hello-world-mcp-server-stdio}/Cargo.toml (85%) rename examples/{hello-world-mcp-server => hello-world-mcp-server-stdio}/README.md (84%) rename examples/{hello-world-mcp-server => hello-world-mcp-server-stdio}/src/handler.rs (94%) rename examples/{hello-world-mcp-server => hello-world-mcp-server-stdio}/src/main.rs (92%) rename examples/{hello-world-mcp-server => hello-world-mcp-server-stdio}/src/tools.rs (100%) rename examples/{hello-world-server-core-streamable-http => hello-world-server-streamable-http-core}/.gitignore (100%) rename examples/{hello-world-server-core-streamable-http => hello-world-server-streamable-http-core}/Cargo.toml (84%) rename examples/{hello-world-server-core-streamable-http => hello-world-server-streamable-http-core}/README.md (95%) rename examples/{hello-world-server-core-streamable-http => hello-world-server-streamable-http-core}/src/handler.rs (97%) rename examples/{hello-world-server-core-streamable-http => hello-world-server-streamable-http-core}/src/main.rs (100%) rename examples/{hello-world-server-core-streamable-http => hello-world-server-streamable-http-core}/src/tools.rs (100%) rename examples/{simple-mcp-client-core-sse => simple-mcp-client-sse-core}/Cargo.toml (88%) rename examples/{simple-mcp-client-core-sse => simple-mcp-client-sse-core}/README.md (97%) rename examples/{simple-mcp-client-core-sse => simple-mcp-client-sse-core}/src/handler.rs (100%) rename examples/{simple-mcp-client-core-sse => simple-mcp-client-sse-core}/src/inquiry_utils.rs (100%) rename examples/{simple-mcp-client-core-sse => simple-mcp-client-sse-core}/src/main.rs (99%) rename examples/{simple-mcp-client-core => simple-mcp-client-stdio-core}/Cargo.toml (86%) rename examples/{simple-mcp-client-core => simple-mcp-client-stdio-core}/README.md (97%) rename examples/{simple-mcp-client-core => simple-mcp-client-stdio-core}/src/handler.rs (100%) rename examples/{simple-mcp-client-core => simple-mcp-client-stdio-core}/src/inquiry_utils.rs (100%) rename examples/{simple-mcp-client-core => simple-mcp-client-stdio-core}/src/main.rs (100%) rename examples/{simple-mcp-client => simple-mcp-client-stdio}/Cargo.toml (87%) rename examples/{simple-mcp-client => simple-mcp-client-stdio}/README.md (97%) rename examples/{simple-mcp-client => simple-mcp-client-stdio}/src/handler.rs (100%) rename examples/{simple-mcp-client => simple-mcp-client-stdio}/src/inquiry_utils.rs (100%) rename examples/{simple-mcp-client => simple-mcp-client-stdio}/src/main.rs (100%) create mode 100644 examples/simple-mcp-client-streamable-http-core/Cargo.toml create mode 100644 examples/simple-mcp-client-streamable-http-core/README.md create mode 100644 examples/simple-mcp-client-streamable-http-core/src/handler.rs create mode 100644 examples/simple-mcp-client-streamable-http-core/src/inquiry_utils.rs create mode 100644 examples/simple-mcp-client-streamable-http-core/src/main.rs create mode 100644 examples/simple-mcp-client-streamable-http/Cargo.toml create mode 100644 examples/simple-mcp-client-streamable-http/README.md create mode 100644 examples/simple-mcp-client-streamable-http/src/handler.rs create mode 100644 examples/simple-mcp-client-streamable-http/src/inquiry_utils.rs create mode 100644 examples/simple-mcp-client-streamable-http/src/main.rs diff --git a/.release-manifest.json b/.release-manifest.json index 97a0f63..a645da6 100644 --- a/.release-manifest.json +++ b/.release-manifest.json @@ -1,13 +1,15 @@ { "crates/rust-mcp-sdk": "0.6.3", "crates/rust-mcp-macros": "0.5.1", - "crates/rust-mcp-transport": "0.5.1", - "examples/hello-world-mcp-server": "0.1.31", - "examples/hello-world-mcp-server-core": "0.1.22", - "examples/simple-mcp-client": "0.1.31", - "examples/simple-mcp-client-core": "0.1.31", - "examples/hello-world-server-core-streamable-http": "0.1.22", - "examples/hello-world-server-streamable-http": "0.1.31", - "examples/simple-mcp-client-core-sse": "0.1.22", - "examples/simple-mcp-client-sse": "0.1.22" + "crates/rust-mcp-transport": "0.5.0", + "examples/hello-world-mcp-server-stdio": "0.1.28", + "examples/hello-world-mcp-server-stdio-core": "0.1.19", + "examples/simple-mcp-client-stdio": "0.1.28", + "examples/simple-mcp-client-stdio-core": "0.1.28", + "examples/hello-world-server-streamable-http-core": "0.1.19", + "examples/hello-world-server-streamable-http": "0.1.28", + "examples/simple-mcp-client-sse-core": "0.1.19", + "examples/simple-mcp-client-sse": "0.1.19", + "examples/simple-mcp-client-streamable-http": "0.1.0", + "examples/simple-mcp-client-streamable-http-core": "0.1.0" } diff --git a/Cargo.lock b/Cargo.lock index c10e354..c3c4462 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -257,10 +257,11 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] name = "cc" -version = "1.2.34" +version = "1.2.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42bc4aea80032b7bf409b0bc7ccad88853858911b7713a8062fdc0623867bedc" +checksum = "590f9024a68a8c40351881787f1934dc11afd69090f5edb6831464694d836ea3" dependencies = [ + "find-msvc-tools", "jobserver", "libc", "shlex", @@ -381,9 +382,9 @@ checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" [[package]] name = "deranged" -version = "0.4.0" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e" +checksum = "d630bccd429a5bb5a64b5e94f693bfc48c9f8566418fda4c494cc94f911f87cc" dependencies = [ "powerfmt", ] @@ -451,6 +452,12 @@ dependencies = [ "instant", ] +[[package]] +name = "find-msvc-tools" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e178e4fba8a2726903f6ba98a6d221e76f9c12c650d5dc0e6afdc50677b49650" + [[package]] name = "fnv" version = "1.0.7" @@ -687,8 +694,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" [[package]] -name = "hello-world-mcp-server" -version = "0.1.31" +name = "hello-world-mcp-server-stdio" +version = "0.1.28" dependencies = [ "async-trait", "futures", @@ -701,8 +708,8 @@ dependencies = [ ] [[package]] -name = "hello-world-mcp-server-core" -version = "0.1.22" +name = "hello-world-mcp-server-stdio-core" +version = "0.1.19" dependencies = [ "async-trait", "futures", @@ -713,8 +720,8 @@ dependencies = [ ] [[package]] -name = "hello-world-server-core-streamable-http" -version = "0.1.22" +name = "hello-world-server-streamable-http" +version = "0.1.31" dependencies = [ "async-trait", "futures", @@ -727,8 +734,8 @@ dependencies = [ ] [[package]] -name = "hello-world-server-streamable-http" -version = "0.1.31" +name = "hello-world-server-streamable-http-core" +version = "0.1.19" dependencies = [ "async-trait", "futures", @@ -1684,6 +1691,7 @@ dependencies = [ "async-trait", "axum", "axum-server", + "base64 0.22.1", "futures", "hyper 1.7.0", "reqwest", @@ -1698,6 +1706,7 @@ dependencies = [ "tracing", "tracing-subscriber", "uuid", + "wiremock", ] [[package]] @@ -1903,8 +1912,8 @@ dependencies = [ ] [[package]] -name = "simple-mcp-client" -version = "0.1.31" +name = "simple-mcp-client-sse" +version = "0.1.22" dependencies = [ "async-trait", "colored", @@ -1914,11 +1923,13 @@ dependencies = [ "serde_json", "thiserror 2.0.16", "tokio", + "tracing", + "tracing-subscriber", ] [[package]] -name = "simple-mcp-client-core" -version = "0.1.31" +name = "simple-mcp-client-sse-core" +version = "0.1.19" dependencies = [ "async-trait", "colored", @@ -1928,11 +1939,41 @@ dependencies = [ "serde_json", "thiserror 2.0.16", "tokio", + "tracing", + "tracing-subscriber", ] [[package]] -name = "simple-mcp-client-core-sse" -version = "0.1.22" +name = "simple-mcp-client-stdio" +version = "0.1.28" +dependencies = [ + "async-trait", + "colored", + "futures", + "rust-mcp-sdk", + "serde", + "serde_json", + "thiserror 2.0.16", + "tokio", +] + +[[package]] +name = "simple-mcp-client-stdio-core" +version = "0.1.28" +dependencies = [ + "async-trait", + "colored", + "futures", + "rust-mcp-sdk", + "serde", + "serde_json", + "thiserror 2.0.16", + "tokio", +] + +[[package]] +name = "simple-mcp-client-streamable-http" +version = "0.1.0" dependencies = [ "async-trait", "colored", @@ -1947,8 +1988,8 @@ dependencies = [ ] [[package]] -name = "simple-mcp-client-sse" -version = "0.1.22" +name = "simple-mcp-client-streamable-http-core" +version = "0.1.0" dependencies = [ "async-trait", "colored", @@ -2088,12 +2129,11 @@ dependencies = [ [[package]] name = "time" -version = "0.3.41" +version = "0.3.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40" +checksum = "8ca967379f9d8eb8058d86ed467d81d03e81acd45757e4ca341c24affbe8e8e3" dependencies = [ "deranged", - "itoa", "num-conv", "powerfmt", "serde", @@ -2103,15 +2143,15 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c" +checksum = "a9108bb380861b07264b950ded55a44a14a4adc68b9f5efd85aafc3aa4d40a68" [[package]] name = "time-macros" -version = "0.2.22" +version = "0.2.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3526739392ec93fd8b359c8e98514cb3e8e021beb4e5f597b00a0221f8ed8a49" +checksum = "7182799245a7264ce590b349d90338f1c1affad93d2639aed5f8f69c090b334c" dependencies = [ "num-conv", "time-core", diff --git a/Cargo.toml b/Cargo.toml index b4f7cca..711204d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,14 +4,17 @@ members = [ "crates/rust-mcp-macros", "crates/rust-mcp-sdk", "crates/rust-mcp-transport", - "examples/simple-mcp-client", - "examples/simple-mcp-client-core", - "examples/hello-world-mcp-server", - "examples/hello-world-mcp-server-core", + "examples/simple-mcp-client-stdio", + "examples/simple-mcp-client-stdio-core", + "examples/hello-world-mcp-server-stdio", + "examples/hello-world-mcp-server-stdio-core", "examples/hello-world-server-streamable-http", - "examples/hello-world-server-core-streamable-http", + "examples/hello-world-server-streamable-http-core", "examples/simple-mcp-client-sse", - "examples/simple-mcp-client-core-sse", + "examples/simple-mcp-client-sse-core", + "examples/simple-mcp-client-streamable-http", + "examples/simple-mcp-client-streamable-http-core", + ] [workspace.dependencies] @@ -39,7 +42,7 @@ tracing-subscriber = { version = "0.3", features = [ "std", "fmt", ] } - +base64 = "0.22" axum = "0.8" rustls = "0.23" tokio-rustls = "0.26" diff --git a/README.md b/README.md index 1581d1d..b1af670 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ [build status ](https://github.com/rust-mcp-stack/rust-mcp-sdk/actions/workflows/ci.yml) [Hello World MCP Server -](examples/hello-world-mcp-server) +](examples/hello-world-mcp-server-stdio) A high-performance, asynchronous toolkit for building MCP servers and clients. Focus on your app's logic while **rust-mcp-sdk** takes care of the rest! @@ -42,6 +42,17 @@ This project supports following transports: - ⬜ Resumability - ⬜ Authentication / Oauth + + +**MCP Streamable HTTP Support** +- [x] Streamable HTTP Support for MCP Servers +- [x] DNS Rebinding Protection +- [x] Batch Messages +- [x] Streaming & non-streaming JSON response +- [ ] Streamable HTTP Support for MCP Clients +- [ ] Resumability +- [ ] Authentication / Oauth + **⚠️** Project is currently under development and should be used at your own risk. ## Table of Contents @@ -110,7 +121,7 @@ async fn main() -> SdkResult<()> { } ``` -See hello-world-mcp-server example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : +See hello-world-mcp-server-stdio example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : ![mcp-server in rust](assets/examples/hello-world-mcp-server.gif) @@ -180,7 +191,7 @@ pub struct MyServerHandler; #[async_trait] impl ServerHandler for MyServerHandler { // Handle ListToolsRequest, return list of available tools as ListToolsResult - async fn handle_list_tools_request(&self, request: ListToolsRequest, runtime: &dyn McpServer) -> Result { + async fn handle_list_tools_request(&self, request: ListToolsRequest, runtime: Arc) -> Result { Ok(ListToolsResult { tools: vec![SayHelloTool::tool()], @@ -191,7 +202,7 @@ impl ServerHandler for MyServerHandler { } /// Handles requests to call a specific tool. - async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: &dyn McpServer, ) -> Result { + async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: Arc, ) -> Result { if request.tool_name() == SayHelloTool::tool_name() { Ok( CallToolResult::text_content( vec![TextContent::from("Hello World!".to_string())] )) @@ -205,7 +216,7 @@ impl ServerHandler for MyServerHandler { --- -👉 For a more detailed example of a [Hello World MCP](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server) Server that supports multiple tools and provides more type-safe handling of `CallToolRequest`, check out: **[examples/hello-world-mcp-server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server)** +👉 For a more detailed example of a [Hello World MCP](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio) Server that supports multiple tools and provides more type-safe handling of `CallToolRequest`, check out: **[examples/hello-world-mcp-server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server)** See hello-world-server-streamable-http example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : @@ -477,10 +488,10 @@ Learn when to use the `mcp_*_handler` traits versus the lower-level `mcp_*_hand [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) provides two type of handler traits that you can chose from: - **ServerHandler**: This is the recommended trait for your MCP project, offering a default implementation for all types of MCP messages. It includes predefined implementations within the trait, such as handling initialization or responding to ping requests, so you only need to override and customize the handler functions relevant to your specific needs. - Refer to [examples/hello-world-mcp-server/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server/src/handler.rs) for an example. + Refer to [examples/hello-world-mcp-server-stdio/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio/src/handler.rs) for an example. - **ServerHandlerCore**: If you need more control over MCP messages, consider using `ServerHandlerCore`. It offers three primary methods to manage the three MCP message types: `request`, `notification`, and `error`. While still providing type-safe objects in these methods, it allows you to determine how to handle each message based on its type and parameters. - Refer to [examples/hello-world-mcp-server-core/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-core/src/handler.rs) for an example. + Refer to [examples/hello-world-mcp-server-stdio-core/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio-core/src/handler.rs) for an example. --- @@ -509,7 +520,7 @@ Both functions create an MCP client instance. -Check out the corresponding examples at: [examples/simple-mcp-client](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client) and [examples/simple-mcp-client-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-core). +Check out the corresponding examples at: [examples/simple-mcp-client-stdio](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio) and [examples/simple-mcp-client-stdio-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio-core). ## Projects using Rust MCP SDK diff --git a/crates/rust-mcp-sdk/Cargo.toml b/crates/rust-mcp-sdk/Cargo.toml index 48ea665..3fd9ec2 100644 --- a/crates/rust-mcp-sdk/Cargo.toml +++ b/crates/rust-mcp-sdk/Cargo.toml @@ -24,15 +24,17 @@ futures = { workspace = true } thiserror = { workspace = true } axum = { workspace = true, optional = true } -uuid = { workspace = true, features = ["v4"], optional = true } +uuid = { workspace = true, features = ["v4"] } tokio-stream = { workspace = true, optional = true } axum-server = { version = "0.7", features = [], optional = true } tracing.workspace = true +base64.workspace = true # rustls = { workspace = true, optional = true } hyper = { version = "1.6.0", optional = true } [dev-dependencies] +wiremock = "0.5" reqwest = { workspace = true, default-features = false, features = [ "stream", "rustls-tls", @@ -51,47 +53,54 @@ default = [ "client", "server", "macros", + "stdio", + "sse", + "streamable-http", "hyper-server", "ssl", "2025_06_18", ] # All features enabled by default -server = ["rust-mcp-transport/stdio"] # Server feature -client = ["rust-mcp-transport/stdio", "rust-mcp-transport/sse"] # Client feature -hyper-server = [ - "axum", - "axum-server", - "hyper", - "server", - "uuid", - "tokio-stream", - "rust-mcp-transport/sse", -] + +sse = ["rust-mcp-transport/sse"] +streamable-http = ["rust-mcp-transport/streamable-http"] +stdio = ["rust-mcp-transport/stdio"] + +server = [] # Server feature +client = [] # Client feature +hyper-server = ["axum", "axum-server", "hyper", "server", "tokio-stream"] ssl = ["axum-server/tls-rustls"] macros = ["rust-mcp-macros/sdk"] -# enables mcp protocol version 2025_06_18 -2025_06_18 = [ +# enables mcp protocol version 2025-06-18 +2025-06-18 = [ "rust-mcp-schema/2025_06_18", "rust-mcp-macros/2025_06_18", "rust-mcp-transport/2025_06_18", "rust-mcp-schema/schema_utils", ] +# Alias: allow users to use underscores instead of hyphens +2025_06_18 = ["2025-06-18"] # enables mcp protocol version 2025_03_26 -2025_03_26 = [ +2025-03-26 = [ "rust-mcp-schema/2025_03_26", "rust-mcp-macros/2025_03_26", "rust-mcp-transport/2025_03_26", "rust-mcp-schema/schema_utils", ] +# Alias: allow users to use underscores instead of hyphens +2025_03_26 = ["2025-03-26"] + # enables mcp protocol version 2024_11_05 -2024_11_05 = [ +2024-11-05 = [ "rust-mcp-schema/2024_11_05", "rust-mcp-macros/2024_11_05", "rust-mcp-transport/2024_11_05", "rust-mcp-schema/schema_utils", ] +# Alias: allow users to use underscores instead of hyphens +2024_11_05 = ["2024-11-05"] [lints] workspace = true diff --git a/crates/rust-mcp-sdk/README.md b/crates/rust-mcp-sdk/README.md index 1581d1d..9df027d 100644 --- a/crates/rust-mcp-sdk/README.md +++ b/crates/rust-mcp-sdk/README.md @@ -9,7 +9,7 @@ [build status ](https://github.com/rust-mcp-stack/rust-mcp-sdk/actions/workflows/ci.yml) [Hello World MCP Server -](examples/hello-world-mcp-server) +](examples/hello-world-mcp-server-stdio) A high-performance, asynchronous toolkit for building MCP servers and clients. Focus on your app's logic while **rust-mcp-sdk** takes care of the rest! @@ -32,7 +32,6 @@ This project supports following transports: 🚀 The **rust-mcp-sdk** includes a lightweight [Axum](https://github.com/tokio-rs/axum) based server that handles all core functionality seamlessly. Switching between `stdio` and `Streamable HTTP` is straightforward, requiring minimal code changes. The server is designed to efficiently handle multiple concurrent client connections and offers built-in support for SSL. - **MCP Streamable HTTP Support** - ✅ Streamable HTTP Support for MCP Servers - ✅ DNS Rebinding Protection @@ -42,6 +41,17 @@ This project supports following transports: - ⬜ Resumability - ⬜ Authentication / Oauth + + +**MCP Streamable HTTP Support** +- [x] Streamable HTTP Support for MCP Servers +- [x] DNS Rebinding Protection +- [x] Batch Messages +- [x] Streaming & non-streaming JSON response +- [ ] Streamable HTTP Support for MCP Clients +- [ ] Resumability +- [ ] Authentication / Oauth + **⚠️** Project is currently under development and should be used at your own risk. ## Table of Contents @@ -110,7 +120,7 @@ async fn main() -> SdkResult<()> { } ``` -See hello-world-mcp-server example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : +See hello-world-mcp-server-stdio example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : ![mcp-server in rust](assets/examples/hello-world-mcp-server.gif) @@ -180,7 +190,7 @@ pub struct MyServerHandler; #[async_trait] impl ServerHandler for MyServerHandler { // Handle ListToolsRequest, return list of available tools as ListToolsResult - async fn handle_list_tools_request(&self, request: ListToolsRequest, runtime: &dyn McpServer) -> Result { + async fn handle_list_tools_request(&self, request: ListToolsRequest, runtime: Arc) -> Result { Ok(ListToolsResult { tools: vec![SayHelloTool::tool()], @@ -191,7 +201,7 @@ impl ServerHandler for MyServerHandler { } /// Handles requests to call a specific tool. - async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: &dyn McpServer, ) -> Result { + async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: Arc ) -> Result { if request.tool_name() == SayHelloTool::tool_name() { Ok( CallToolResult::text_content( vec![TextContent::from("Hello World!".to_string())] )) @@ -205,7 +215,7 @@ impl ServerHandler for MyServerHandler { --- -👉 For a more detailed example of a [Hello World MCP](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server) Server that supports multiple tools and provides more type-safe handling of `CallToolRequest`, check out: **[examples/hello-world-mcp-server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server)** +👉 For a more detailed example of a [Hello World MCP](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio) Server that supports multiple tools and provides more type-safe handling of `CallToolRequest`, check out: **[examples/hello-world-mcp-server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server)** See hello-world-server-streamable-http example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : @@ -477,10 +487,10 @@ Learn when to use the `mcp_*_handler` traits versus the lower-level `mcp_*_hand [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) provides two type of handler traits that you can chose from: - **ServerHandler**: This is the recommended trait for your MCP project, offering a default implementation for all types of MCP messages. It includes predefined implementations within the trait, such as handling initialization or responding to ping requests, so you only need to override and customize the handler functions relevant to your specific needs. - Refer to [examples/hello-world-mcp-server/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server/src/handler.rs) for an example. + Refer to [examples/hello-world-mcp-server-stdio/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio/src/handler.rs) for an example. - **ServerHandlerCore**: If you need more control over MCP messages, consider using `ServerHandlerCore`. It offers three primary methods to manage the three MCP message types: `request`, `notification`, and `error`. While still providing type-safe objects in these methods, it allows you to determine how to handle each message based on its type and parameters. - Refer to [examples/hello-world-mcp-server-core/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-core/src/handler.rs) for an example. + Refer to [examples/hello-world-mcp-server-stdio-core/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio-core/src/handler.rs) for an example. --- @@ -509,7 +519,7 @@ Both functions create an MCP client instance. -Check out the corresponding examples at: [examples/simple-mcp-client](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client) and [examples/simple-mcp-client-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-core). +Check out the corresponding examples at: [examples/simple-mcp-client-stdio](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio) and [examples/simple-mcp-client-stdio-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio-core). ## Projects using Rust MCP SDK diff --git a/crates/rust-mcp-sdk/src/error.rs b/crates/rust-mcp-sdk/src/error.rs index 3de8d98..3879526 100644 --- a/crates/rust-mcp-sdk/src/error.rs +++ b/crates/rust-mcp-sdk/src/error.rs @@ -11,25 +11,36 @@ pub type SdkResult = core::result::Result; #[derive(Debug, Error)] pub enum McpSdkError { + #[error("Transport error: {0}")] + Transport(#[from] TransportError), + + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + #[error("{0}")] RpcError(#[from] RpcError), + #[error("{0}")] - IoError(#[from] std::io::Error), - #[error("{0}")] - TransportError(#[from] TransportError), - #[error("{0}")] - JoinError(#[from] JoinError), - #[error("{0}")] - AnyError(Box<(dyn std::error::Error + Send + Sync)>), - #[error("{0}")] - SdkError(#[from] crate::schema::schema_utils::SdkError), + Join(#[from] JoinError), + #[cfg(feature = "hyper-server")] #[error("{0}")] - TransportServerError(#[from] TransportServerError), - #[error("Incompatible mcp protocol version: requested:{0} current:{1}")] - IncompatibleProtocolVersion(String, String), + HyperServer(#[from] TransportServerError), + #[error("{0}")] - ParseProtocolVersionError(#[from] ParseProtocolVersionError), + SdkError(#[from] crate::schema::schema_utils::SdkError), + + #[error("Protocol error: {kind}")] + Protocol { kind: ProtocolErrorKind }, +} + +// Sub-enum for protocol-related errors +#[derive(Debug, Error)] +pub enum ProtocolErrorKind { + #[error("Incompatible protocol version: requested {requested}, current {current}")] + IncompatibleVersion { requested: String, current: String }, + #[error("Failed to parse protocol version: {0}")] + ParseError(#[from] ParseProtocolVersionError), } impl McpSdkError { diff --git a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs b/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs index 0c1dcf3..ff6d5b2 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs @@ -1,11 +1,9 @@ use std::{sync::Arc, time::Duration}; -use crate::schema::InitializeResult; -use rust_mcp_transport::TransportOptions; - +use super::session_store::SessionStore; use crate::mcp_traits::mcp_handler::McpServerHandler; - -use super::{session_store::SessionStore, IdGenerator}; +use crate::{id_generator::FastIdGenerator, mcp_traits::IdGenerator, schema::InitializeResult}; +use rust_mcp_transport::{SessionId, TransportOptions}; /// Application state struct for the Hyper server /// @@ -14,7 +12,8 @@ use super::{session_store::SessionStore, IdGenerator}; #[derive(Clone)] pub struct AppState { pub session_store: Arc, - pub id_generator: Arc, + pub id_generator: Arc>, + pub stream_id_gen: Arc, pub server_details: Arc, pub handler: Arc, pub ping_interval: Duration, diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs index 0a77913..da69c67 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs @@ -6,7 +6,7 @@ use crate::{ }, mcp_runtimes::server_runtime::DEFAULT_STREAM_ID, mcp_server::{server_runtime, ServerRuntime}, - mcp_traits::mcp_handler::McpServerHandler, + mcp_traits::{mcp_handler::McpServerHandler, IdGenerator}, utils::validate_mcp_protocol_version, }; @@ -22,13 +22,12 @@ use axum::{ }; use futures::stream; use hyper::{header, HeaderMap, StatusCode}; -use rust_mcp_transport::{SessionId, SseTransport}; +use rust_mcp_transport::{ + SessionId, SseTransport, StreamId, MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, +}; use std::{sync::Arc, time::Duration}; use tokio::io::{duplex, AsyncBufReadExt, BufReader}; -pub const MCP_SESSION_ID_HEADER: &str = "Mcp-Session-Id"; -pub const MCP_PROTOCOL_VERSION_HEADER: &str = "Mcp-Protocol-Version"; - const DUPLEX_BUFFER_SIZE: usize = 8192; async fn create_sse_stream( @@ -41,11 +40,11 @@ async fn create_sse_stream( let payload_string = payload.map(|p| p.to_string()); // TODO: this logic should be moved out after refactoing the mcp_stream.rs - let result = payload_string + let payload_contains_request = payload_string .as_ref() .map(|json_str| contains_request(json_str)) .unwrap_or(Ok(false)); - let Ok(payload_contains_request) = result else { + let Ok(payload_contains_request) = payload_contains_request else { return Ok((StatusCode::BAD_REQUEST, Json(SdkError::parse_error())).into_response()); }; @@ -54,18 +53,20 @@ async fn create_sse_stream( // writable stream to deliver message to the client let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); - let transport = SseTransport::::new( - read_rx, - write_tx, - read_tx, - Arc::clone(&state.transport_options), - ) - .map_err(|err| TransportServerError::TransportError(err.to_string()))?; + let transport = Arc::new( + SseTransport::::new( + read_rx, + write_tx, + read_tx, + Arc::clone(&state.transport_options), + ) + .map_err(|err| TransportServerError::TransportError(err.to_string()))?, + ); - let stream_id = if standalone { + let stream_id: StreamId = if standalone { DEFAULT_STREAM_ID.to_string() } else { - state.id_generator.generate() + state.stream_id_gen.generate() }; let ping_interval = state.ping_interval; let runtime_clone = Arc::clone(&runtime); @@ -85,6 +86,7 @@ async fn create_sse_stream( // Construct SSE stream let reader = BufReader::new(write_rx); + // outgoing messages from server to the client let message_stream = stream::unfold(reader, |mut reader| async move { let mut line = String::new(); @@ -117,12 +119,12 @@ async fn create_sse_stream( // TODO: this function will be removed after refactoring the readable stream of the transports // so we would deserialize the string syncronousely and have more control over the flow -// this function could potentially add a 20-250 ns overhead which could be avoided +// this function may incur a slight runtime cost which could be avoided after refactoring fn contains_request(json_str: &str) -> Result { let value: serde_json::Value = serde_json::from_str(json_str)?; match value { serde_json::Value::Object(obj) => Ok(obj.contains_key("id") && obj.contains_key("method")), - serde_json::Value::Array(arr) => Ok(arr.iter().all(|item| { + serde_json::Value::Array(arr) => Ok(arr.iter().any(|item| { item.as_object() .map(|obj| obj.contains_key("id") && obj.contains_key("method")) .unwrap_or(false) @@ -131,6 +133,19 @@ fn contains_request(json_str: &str) -> Result { } } +fn is_result(json_str: &str) -> Result { + let value: serde_json::Value = serde_json::from_str(json_str)?; + match value { + serde_json::Value::Object(obj) => Ok(obj.contains_key("result")), + serde_json::Value::Array(arr) => Ok(arr.iter().all(|item| { + item.as_object() + .map(|obj| obj.contains_key("result")) + .unwrap_or(false) + })), + _ => Ok(false), + } +} + pub async fn create_standalone_stream( session_id: SessionId, state: Arc, @@ -166,11 +181,11 @@ pub async fn start_new_session( let h: Arc = state.handler.clone(); // create a new server instance with unique session_id and - let runtime: Arc = Arc::new(server_runtime::create_server_instance( + let runtime: Arc = server_runtime::create_server_instance( Arc::clone(&state.server_details), h, session_id.to_owned(), - )); + ); tracing::info!("a new client joined : {}", &session_id); @@ -224,7 +239,12 @@ async fn single_shot_stream( tokio::spawn(async move { match runtime_clone - .start_stream(transport, &stream_id, ping_interval, payload_string) + .start_stream( + Arc::new(transport), + &stream_id, + ping_interval, + payload_string, + ) .await { Ok(_) => tracing::info!("stream {} exited gracefully.", &stream_id), @@ -233,7 +253,6 @@ async fn single_shot_stream( let _ = runtime.remove_transport(&stream_id).await; }); - // Construct SSE stream let mut reader = BufReader::new(write_rx); let mut line = String::new(); let response = match reader.read_line(&mut line).await { @@ -310,15 +329,34 @@ pub async fn process_incoming_message( match state.session_store.get(&session_id).await { Some(runtime) => { let runtime = runtime.lock().await.to_owned(); - - create_sse_stream( - runtime.clone(), - session_id.clone(), - state.clone(), - Some(payload), - false, - ) - .await + // when receiving a result in a streamable_http server, that means it was sent by the standalone sse transport + // it should be processed by the same transport , therefore no need to call create_sse_stream + let Ok(is_result) = is_result(payload) else { + return Ok((StatusCode::BAD_REQUEST, Json(SdkError::parse_error())).into_response()); + }; + + if is_result { + match runtime + .consume_payload_string(DEFAULT_STREAM_ID, payload) + .await + { + Ok(()) => Ok((StatusCode::ACCEPTED, Json(())).into_response()), + Err(err) => Ok(( + StatusCode::BAD_REQUEST, + Json(SdkError::internal_error().with_message(err.to_string().as_ref())), + ) + .into_response()), + } + } else { + create_sse_stream( + runtime.clone(), + session_id.clone(), + state.clone(), + Some(payload), + false, + ) + .await + } } None => { let error = SdkError::session_not_found(); diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs index e1c00f8..27a16b2 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs @@ -1,3 +1,4 @@ +use crate::mcp_server::error::TransportServerError; use crate::schema::schema_utils::ClientMessage; use crate::{ hyper_servers::{ @@ -90,20 +91,24 @@ pub async fn handle_sse( let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); // create a transport for sending/receiving messages - let transport = SseTransport::new( + let Ok(transport) = SseTransport::new( read_rx, write_tx, read_tx, Arc::clone(&state.transport_options), - ) - .unwrap(); + ) else { + return Err(TransportServerError::TransportError( + "Failed to create SSE transport".to_string(), + )); + }; + let h: Arc = state.handler.clone(); // create a new server instance with unique session_id and - let server: Arc = Arc::new(server_runtime::create_server_instance( + let server: Arc = server_runtime::create_server_instance( Arc::clone(&state.server_details), h, session_id.to_owned(), - )); + ); state .session_store @@ -115,7 +120,12 @@ pub async fn handle_sse( // Start the server tokio::spawn(async move { match server - .start_stream(transport, DEFAULT_STREAM_ID, state.ping_interval, None) + .start_stream( + Arc::new(transport), + DEFAULT_STREAM_ID, + state.ping_interval, + None, + ) .await { Ok(_) => tracing::info!("server {} exited gracefully.", session_id.to_owned()), diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs index 83cc372..00d46c0 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs @@ -1,4 +1,4 @@ -use super::hyper_utils::{start_new_session, MCP_SESSION_ID_HEADER}; +use super::hyper_utils::start_new_session; use crate::schema::schema_utils::SdkError; use crate::{ error::McpSdkError, @@ -14,6 +14,7 @@ use crate::{ }, utils::valid_initialize_method, }; +use axum::routing::get; use axum::{ extract::{Query, State}, middleware, @@ -22,11 +23,9 @@ use axum::{ Json, Router, }; use hyper::{HeaderMap, StatusCode}; -use rust_mcp_transport::SessionId; +use rust_mcp_transport::{SessionId, MCP_SESSION_ID_HEADER}; use std::{collections::HashMap, sync::Arc}; -use axum::routing::get; - pub fn routes(state: Arc, streamable_http_endpoint: &str) -> Router> { Router::new() .route(streamable_http_endpoint, get(handle_streamable_http_get)) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/server.rs b/crates/rust-mcp-sdk/src/hyper_servers/server.rs index f093da3..1c3b3cf 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/server.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/server.rs @@ -1,6 +1,8 @@ use crate::{ - error::SdkResult, mcp_server::hyper_runtime::HyperRuntime, - mcp_traits::mcp_handler::McpServerHandler, + error::SdkResult, + id_generator::{FastIdGenerator, UuidGenerator}, + mcp_server::hyper_runtime::HyperRuntime, + mcp_traits::{mcp_handler::McpServerHandler, IdGenerator}, }; #[cfg(feature = "ssl")] use axum_server::tls_rustls::RustlsConfig; @@ -17,11 +19,11 @@ use super::{ app_state::AppState, error::{TransportServerError, TransportServerResult}, routes::app_routes, - IdGenerator, InMemorySessionStore, UuidGenerator, + InMemorySessionStore, }; use crate::schema::InitializeResult; use axum::Router; -use rust_mcp_transport::TransportOptions; +use rust_mcp_transport::{SessionId, TransportOptions}; // Default client ping interval (12 seconds) const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12); @@ -43,7 +45,7 @@ pub struct HyperServerOptions { pub port: u16, /// Optional thread-safe session id generator to generate unique session IDs. - pub session_id_generator: Option>, + pub session_id_generator: Option>>, /// Optional custom path for the Streamable HTTP endpoint (default: `/mcp`) pub custom_streamable_http_endpoint: Option, @@ -258,6 +260,7 @@ impl HyperServer { .session_id_generator .take() .map_or(Arc::new(UuidGenerator {}), |g| Arc::clone(&g)), + stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))), server_details: Arc::new(server_details), handler, ping_interval: server_options.ping_interval, diff --git a/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs b/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs index 95b2158..4384b1a 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs @@ -5,7 +5,6 @@ use async_trait::async_trait; pub use in_memory::*; use rust_mcp_transport::SessionId; use tokio::sync::Mutex; -use uuid::Uuid; use crate::mcp_server::ServerRuntime; @@ -46,26 +45,3 @@ pub trait SessionStore: Send + Sync { async fn has(&self, session: &SessionId) -> bool; } - -/// Trait for generating session identifiers -/// -/// Implementors must be Send and Sync to support concurrent access. -pub trait IdGenerator: Send + Sync { - fn generate(&self) -> SessionId; -} - -/// Struct implementing the IdGenerator trait using UUID v4 -/// -/// This is a simple wrapper around the uuid crate's Uuid::new_v4 function -/// to generate unique session identifiers. -pub struct UuidGenerator {} - -impl IdGenerator for UuidGenerator { - /// Generates a new UUID v4-based session identifier - /// - /// # Returns - /// * `SessionId` - A new UUID-based session identifier as a String - fn generate(&self) -> SessionId { - Uuid::new_v4().to_string() - } -} diff --git a/crates/rust-mcp-sdk/src/id_generator.rs b/crates/rust-mcp-sdk/src/id_generator.rs new file mode 100644 index 0000000..54f0e72 --- /dev/null +++ b/crates/rust-mcp-sdk/src/id_generator.rs @@ -0,0 +1,5 @@ +mod fast_id_generator; +mod uuid_generator; +pub use crate::mcp_traits::IdGenerator; +pub use fast_id_generator::*; +pub use uuid_generator::*; diff --git a/crates/rust-mcp-sdk/src/id_generator/fast_id_generator.rs b/crates/rust-mcp-sdk/src/id_generator/fast_id_generator.rs new file mode 100644 index 0000000..fc2e976 --- /dev/null +++ b/crates/rust-mcp-sdk/src/id_generator/fast_id_generator.rs @@ -0,0 +1,53 @@ +use crate::mcp_traits::IdGenerator; +use base64::Engine; +use std::sync::atomic::{AtomicU64, Ordering}; + +/// An [`IdGenerator`] implementation optimized for lightweight, locally-scoped identifiers. +/// +/// This generator produces short, incrementing identifiers that are Base64-encoded. +/// This makes it well-suited for cases such as `StreamId` generation, where: +/// - IDs only need to be unique within a single process or session +/// - Predictability is acceptable +/// - Shorter, more human-readable identifiers are desirable +/// +pub struct FastIdGenerator { + counter: AtomicU64, + ///Optional prefix for readability + prefix: &'static str, +} + +impl FastIdGenerator { + /// Creates a new ID generator with an optional prefix. + /// + /// # Arguments + /// * `prefix` - A static string to prepend to IDs (e.g., "sid_"). + pub fn new(prefix: Option<&'static str>) -> Self { + FastIdGenerator { + counter: AtomicU64::new(0), + prefix: prefix.unwrap_or_default(), + } + } +} + +impl IdGenerator for FastIdGenerator +where + T: From, +{ + /// Generates a new session ID as a short Base64-encoded string. + /// + /// Increments an internal counter atomically and encodes it in Base64 URL-safe format. + /// The resulting ID is prefixed (if provided) and typically 8–12 characters long. + /// + /// # Returns + /// * `SessionId` - A short, unique session ID (e.g., "sid_BBBB" or "BBBB"). + fn generate(&self) -> T { + let id = self.counter.fetch_add(1, Ordering::Relaxed); + let bytes = id.to_le_bytes(); + let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes); + if self.prefix.is_empty() { + T::from(encoded) + } else { + T::from(format!("{}{}", self.prefix, encoded)) + } + } +} diff --git a/crates/rust-mcp-sdk/src/id_generator/uuid_generator.rs b/crates/rust-mcp-sdk/src/id_generator/uuid_generator.rs new file mode 100644 index 0000000..2f0dc21 --- /dev/null +++ b/crates/rust-mcp-sdk/src/id_generator/uuid_generator.rs @@ -0,0 +1,18 @@ +use crate::mcp_traits::IdGenerator; +use uuid::Uuid; + +/// An [`IdGenerator`] implementation that uses UUID v4 to create unique identifiers. +/// +/// This generator produces random UUIDs (version 4), which are highly unlikely +/// to collide and difficult to predict. It is therefore well-suited for +/// generating identifiers such as `SessionId` or other values where uniqueness is important. +pub struct UuidGenerator; + +impl IdGenerator for UuidGenerator +where + T: From, +{ + fn generate(&self) -> T { + T::from(Uuid::new_v4().to_string()) + } +} diff --git a/crates/rust-mcp-sdk/src/lib.rs b/crates/rust-mcp-sdk/src/lib.rs index 1ea23df..a33f889 100644 --- a/crates/rust-mcp-sdk/src/lib.rs +++ b/crates/rust-mcp-sdk/src/lib.rs @@ -21,7 +21,7 @@ pub mod mcp_client { //! responding to ping requests, so you only need to override and customize the handler //! functions relevant to your specific needs. //! - //! Refer to [examples/simple-mcp-client](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client) for an example. + //! Refer to [examples/simple-mcp-client-stdio](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio) for an example. //! //! //! - **client_runtime_core**: If you need more control over MCP messages, consider using @@ -30,7 +30,7 @@ pub mod mcp_client { //! While still providing type-safe objects in these methods, it allows you to determine how to //! handle each message based on its type and parameters. //! - //! Refer to [examples/simple-mcp-client-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-core) for an example. + //! Refer to [examples/simple-mcp-client-stdio-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio-core) for an example. pub use super::mcp_handlers::mcp_client_handler::ClientHandler; pub use super::mcp_handlers::mcp_client_handler_core::ClientHandlerCore; pub use super::mcp_runtimes::client_runtime::mcp_client_runtime as client_runtime; @@ -53,7 +53,7 @@ pub mod mcp_server { //! responding to ping requests, so you only need to override and customize the handler //! functions relevant to your specific needs. //! - //! Refer to [examples/hello-world-mcp-server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server) for an example. + //! Refer to [examples/hello-world-mcp-server-stdio](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio) for an example. //! //! //! - **server_runtime_core**: If you need more control over MCP messages, consider using @@ -62,7 +62,7 @@ pub mod mcp_server { //! While still providing type-safe objects in these methods, it allows you to determine how to //! handle each message based on its type and parameters. //! - //! Refer to [examples/hello-world-mcp-server-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-core) for an example. + //! Refer to [examples/hello-world-mcp-server-stdio-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio-core) for an example. pub use super::mcp_handlers::mcp_server_handler::ServerHandler; pub use super::mcp_handlers::mcp_server_handler_core::ServerHandlerCore; @@ -93,4 +93,5 @@ pub mod macros { pub use rust_mcp_macros::*; } +pub mod id_generator; pub mod schema; diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs index 89aebf5..9b9577e 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs @@ -1,6 +1,7 @@ use crate::schema::{schema_utils::CallToolError, *}; use async_trait::async_trait; use serde_json::Value; +use std::sync::Arc; use crate::{mcp_traits::mcp_server::McpServer, utils::enforce_compatible_protocol_version}; @@ -15,7 +16,7 @@ pub trait ServerHandler: Send + Sync + 'static { /// The `runtime` parameter provides access to the server's runtime environment, allowing /// interaction with the server's capabilities. /// The default implementation does nothing. - async fn on_initialized(&self, runtime: &dyn McpServer) {} + async fn on_initialized(&self, runtime: Arc) {} /// Handles the InitializeRequest from a client. /// @@ -29,7 +30,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_initialize_request( &self, initialize_request: InitializeRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { let mut server_info = runtime.server_info().to_owned(); // Provide compatibility for clients using older MCP protocol versions. @@ -65,7 +66,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_ping_request( &self, _: PingRequest, - _: &dyn McpServer, + _: Arc, ) -> std::result::Result { Ok(Result::default()) } @@ -77,7 +78,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_list_resources_request( &self, request: ListResourcesRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -93,7 +94,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_list_resource_templates_request( &self, request: ListResourceTemplatesRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -109,7 +110,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_read_resource_request( &self, request: ReadResourceRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -125,7 +126,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_subscribe_request( &self, request: SubscribeRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -141,7 +142,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_unsubscribe_request( &self, request: UnsubscribeRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -157,7 +158,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_list_prompts_request( &self, request: ListPromptsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -173,7 +174,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_get_prompt_request( &self, request: GetPromptRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -189,7 +190,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_list_tools_request( &self, request: ListToolsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -205,7 +206,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_call_tool_request( &self, request: CallToolRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime .assert_server_request_capabilities(request.method()) @@ -220,7 +221,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_set_level_request( &self, request: SetLevelRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -236,7 +237,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_complete_request( &self, request: CompleteRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -252,7 +253,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_custom_request( &self, request: Value, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { Err(RpcError::method_not_found() .with_message("No handler is implemented for custom requests.".to_string())) @@ -265,7 +266,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_initialized_notification( &self, notification: InitializedNotification, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -275,7 +276,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_cancelled_notification( &self, notification: CancelledNotification, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -285,7 +286,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_progress_notification( &self, notification: ProgressNotification, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -295,7 +296,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_roots_list_changed_notification( &self, notification: RootsListChangedNotification, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -320,18 +321,8 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_error( &self, error: &RpcError, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } - - /// Called when the server has successfully started. - /// - /// Sends a "Server started successfully" message to stderr. - /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. - async fn on_server_started(&self, runtime: &dyn McpServer) { - let _ = runtime - .stderr_message("Server started successfully".into()) - .await; - } } diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs index e7b0e6d..9275da7 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs @@ -1,8 +1,8 @@ +use crate::mcp_traits::mcp_server::McpServer; use crate::schema::schema_utils::*; use crate::schema::*; use async_trait::async_trait; - -use crate::mcp_traits::mcp_server::McpServer; +use std::sync::Arc; /// Defines the `ServerHandlerCore` trait for handling Model Context Protocol (MCP) server operations. /// Unlike `ServerHandler`, this trait offers no default implementations, providing full control over MCP message handling @@ -14,7 +14,7 @@ pub trait ServerHandlerCore: Send + Sync + 'static { /// The `runtime` parameter provides access to the server's runtime environment, allowing /// interaction with the server's capabilities. /// The default implementation does nothing. - async fn on_initialized(&self, _runtime: &dyn McpServer) {} + async fn on_initialized(&self, _runtime: Arc) {} /// Asynchronously handles an incoming request from the client. /// @@ -26,7 +26,7 @@ pub trait ServerHandlerCore: Send + Sync + 'static { async fn handle_request( &self, request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result; /// Asynchronously handles an incoming notification from the client. @@ -36,7 +36,7 @@ pub trait ServerHandlerCore: Send + Sync + 'static { async fn handle_notification( &self, notification: NotificationFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError>; /// Asynchronously handles an error received from the client. @@ -46,11 +46,6 @@ pub trait ServerHandlerCore: Send + Sync + 'static { async fn handle_error( &self, error: &RpcError, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError>; - async fn on_server_started(&self, runtime: &dyn McpServer) { - let _ = runtime - .stderr_message("Server started successfully".into()) - .await; - } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs index 7ee0815..9961b84 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs @@ -1,12 +1,17 @@ pub mod mcp_client_runtime; pub mod mcp_client_runtime_core; - +use crate::error::{McpSdkError, SdkResult}; +use crate::id_generator::FastIdGenerator; +use crate::mcp_traits::mcp_client::McpClient; +use crate::mcp_traits::mcp_handler::McpClientHandler; +use crate::mcp_traits::IdGenerator; +use crate::utils::ensure_server_protocole_compatibility; use crate::{ mcp_traits::{RequestIdGen, RequestIdGenNumeric}, schema::{ schema_utils::{ - self, ClientMessage, ClientMessages, FromMessage, MessageFromClient, ServerMessage, - ServerMessages, + self, ClientMessage, ClientMessages, FromMessage, McpMessage, MessageFromClient, + ServerMessage, ServerMessages, }, InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification, RequestId, RpcError, ServerResult, @@ -16,63 +21,100 @@ use async_trait::async_trait; use futures::future::{join_all, try_join_all}; use futures::StreamExt; -use rust_mcp_transport::{IoStream, McpDispatch, MessageDispatcher, Transport}; -use std::{ - sync::{Arc, RwLock}, - time::Duration, -}; +#[cfg(feature = "streamable-http")] +use rust_mcp_transport::{ClientStreamableTransport, StreamableTransportOptions}; +use rust_mcp_transport::{IoStream, SessionId, StreamId, Transport, TransportDispatcher}; +use std::{collections::HashMap, sync::Arc, time::Duration}; use tokio::io::{AsyncBufReadExt, BufReader}; -use tokio::sync::Mutex; +use tokio::sync::{watch, Mutex}; -use crate::error::{McpSdkError, SdkResult}; -use crate::mcp_traits::mcp_client::McpClient; -use crate::mcp_traits::mcp_handler::McpClientHandler; -use crate::utils::ensure_server_protocole_compatibility; +pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM"; + +// Define a type alias for the TransportDispatcher trait object +type TransportDispatcherType = dyn TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, +>; +type TransportType = Arc; pub struct ClientRuntime { - // The transport interface for handling messages between client and server - transport: Arc< - dyn Transport< - ServerMessages, - MessageFromClient, - ServerMessage, - ClientMessages, - ClientMessage, - >, - >, + // A thread-safe map storing transport types + transport_map: tokio::sync::RwLock>, // The handler for processing MCP messages handler: Box, - // // Information about the server + // Information about the server client_details: InitializeRequestParams, - // Details about the connected server - server_details: Arc>>, handlers: Mutex>>>, + // Generator for unique request IDs request_id_gen: Box, + // Generator for stream IDs + stream_id_gen: FastIdGenerator, + #[cfg(feature = "streamable-http")] + // Optional configuration for streamable transport + transport_options: Option, + // Flag indicating whether the client has been shut down + is_shut_down: Mutex, + // Session ID + session_id: tokio::sync::RwLock>, + // Details about the connected server + server_details_tx: watch::Sender>, + server_details_rx: watch::Receiver>, } impl ClientRuntime { pub(crate) fn new( client_details: InitializeRequestParams, - transport: impl Transport< - ServerMessages, - MessageFromClient, - ServerMessage, - ClientMessages, - ClientMessage, - >, + transport: TransportType, handler: Box, ) -> Self { + let mut map: HashMap = HashMap::new(); + map.insert(DEFAULT_STREAM_ID.to_string(), transport); + let (server_details_tx, server_details_rx) = + watch::channel::>(None); Self { - transport: Arc::new(transport), + transport_map: tokio::sync::RwLock::new(map), handler, client_details, - server_details: Arc::new(RwLock::new(None)), handlers: Mutex::new(vec![]), request_id_gen: Box::new(RequestIdGenNumeric::new(None)), + #[cfg(feature = "streamable-http")] + transport_options: None, + is_shut_down: Mutex::new(false), + session_id: tokio::sync::RwLock::new(None), + stream_id_gen: FastIdGenerator::new(Some("s_")), + server_details_tx, + server_details_rx, } } - async fn initialize_request(&self) -> SdkResult<()> { + #[cfg(feature = "streamable-http")] + pub(crate) fn new_instance( + client_details: InitializeRequestParams, + transport_options: StreamableTransportOptions, + handler: Box, + ) -> Self { + let map: HashMap = HashMap::new(); + let (server_details_tx, server_details_rx) = + watch::channel::>(None); + Self { + transport_map: tokio::sync::RwLock::new(map), + handler, + client_details, + handlers: Mutex::new(vec![]), + transport_options: Some(transport_options), + is_shut_down: Mutex::new(false), + session_id: tokio::sync::RwLock::new(None), + request_id_gen: Box::new(RequestIdGenNumeric::new(None)), + stream_id_gen: FastIdGenerator::new(Some("s_")), + server_details_tx, + server_details_rx, + } + } + + async fn initialize_request(self: Arc) -> SdkResult<()> { let request = InitializeRequest::new(self.client_details.clone()); let result: ServerResult = self.request(request.into(), None).await?.try_into()?; @@ -81,9 +123,15 @@ impl ClientRuntime { &self.client_details.protocol_version, &initialize_result.protocol_version, )?; - // store server details self.set_server_details(initialize_result)?; + + #[cfg(feature = "streamable-http")] + // try to create a sse stream for server initiated messages , if supported by the server + if let Err(error) = self.clone().create_sse_stream().await { + tracing::warn!("{error}"); + } + // send a InitializedNotification to the server self.send_notification(InitializedNotification::new(None).into()) .await?; @@ -92,21 +140,14 @@ impl ClientRuntime { .with_message("Incorrect response to InitializeRequest!".into()) .into()); } + Ok(()) } pub(crate) async fn handle_message( &self, message: ServerMessage, - transport: &Arc< - dyn Transport< - ServerMessages, - MessageFromClient, - ServerMessage, - ClientMessages, - ClientMessage, - >, - >, + transport: &TransportType, ) -> SdkResult> { let response = match message { ServerMessage::Request(jsonrpc_request) => { @@ -162,28 +203,26 @@ impl ClientRuntime { }; Ok(response) } -} -#[async_trait] -impl McpClient for ClientRuntime { - fn sender(&self) -> Arc>>> - where - MessageDispatcher: - McpDispatch, - { - (self.transport.message_sender().clone()) as _ - } + async fn start_standalone(self: Arc) -> SdkResult<()> { + let self_clone = self.clone(); + let transport_map = self_clone.transport_map.read().await; + let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( + RpcError::internal_error() + .with_message("transport stream does not exists or is closed!".to_string()), + )?; - async fn start(self: Arc) -> SdkResult<()> { //TODO: improve the flow - let mut stream = self.transport.start().await?; - let transport = self.transport.clone(); + let mut stream = transport.start().await?; + + let transport_clone = transport.clone(); let mut error_io_stream = transport.error_stream().write().await; let error_io_stream = error_io_stream.take(); let self_clone = Arc::clone(&self); let self_clone_err = Arc::clone(&self); + // task reading from the error stream let err_task = tokio::spawn(async move { let self_ref = &*self_clone_err; @@ -191,7 +230,7 @@ impl McpClient for ClientRuntime { let mut reader = BufReader::new(error_input).lines(); loop { tokio::select! { - should_break = self_ref.transport.is_shut_down() =>{ + should_break = transport_clone.is_shut_down() =>{ if should_break { break; } @@ -221,14 +260,10 @@ impl McpClient for ClientRuntime { Ok::<(), McpSdkError>(()) }); - let transport = self.transport.clone(); + let transport = transport.clone(); + // main task reading from mcp_message stream let main_task = tokio::spawn(async move { - let sender = self_clone.sender(); - let sender = sender.read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; while let Some(mcp_messages) = stream.next().await { let self_ref = &*self_clone; @@ -239,7 +274,7 @@ impl McpClient for ClientRuntime { match result { Ok(result) => { if let Some(result) = result { - sender + transport .send_message(ClientMessages::Single(result), None) .await?; } @@ -260,7 +295,7 @@ impl McpClient for ClientRuntime { let results: Vec<_> = results.into_iter().flatten().collect(); if !results.is_empty() { - sender + transport .send_message(ClientMessages::Batch(results), None) .await?; } @@ -271,71 +306,349 @@ impl McpClient for ClientRuntime { }); // send initialize request to the MCP server - self.initialize_request().await?; + self.clone().initialize_request().await?; let mut lock = self.handlers.lock().await; lock.push(main_task); lock.push(err_task); + Ok(()) + } + pub(crate) async fn store_transport( + &self, + stream_id: &str, + transport: TransportType, + ) -> SdkResult<()> { + let mut transport_map = self.transport_map.write().await; + tracing::trace!("save transport for stream id : {}", stream_id); + transport_map.insert(stream_id.to_string(), transport); Ok(()) } - fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()> { - match self.server_details.write() { - Ok(mut details) => { - *details = Some(server_details); - Ok(()) - } - // Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None. - Err(_) => Err(RpcError::internal_error() - .with_message("Internal Error: Failed to acquire write lock.".to_string()) - .into()), - } + pub(crate) async fn transport_by_stream(&self, stream_id: &str) -> SdkResult { + let transport_map = self.transport_map.read().await; + transport_map.get(stream_id).cloned().ok_or_else(|| { + RpcError::internal_error() + .with_message(format!("Transport for key {stream_id} not found")) + .into() + }) } - fn client_info(&self) -> &InitializeRequestParams { - &self.client_details + + #[cfg(feature = "streamable-http")] + pub(crate) async fn new_transport( + &self, + session_id: Option, + standalone: bool, + ) -> SdkResult< + impl TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + >, + > { + let options = self + .transport_options + .as_ref() + .ok_or(schema_utils::SdkError::connection_closed())?; + let transport = ClientStreamableTransport::new(options, session_id, standalone)?; + + Ok(transport) } - fn server_info(&self) -> Option { - if let Ok(details) = self.server_details.read() { - details.clone() - } else { - // Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None. - None + + #[cfg(feature = "streamable-http")] + pub(crate) async fn create_sse_stream(self: Arc) -> SdkResult<()> { + let stream_id: StreamId = DEFAULT_STREAM_ID.into(); + let session_id = self.session_id.read().await.clone(); + let transport: Arc< + dyn TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + >, + > = Arc::new(self.new_transport(session_id, true).await?); + let mut stream = transport.start().await?; + self.store_transport(&stream_id, transport.clone()).await?; + + let self_clone = Arc::clone(&self); + + let main_task = tokio::spawn(async move { + loop { + if let Some(mcp_messages) = stream.next().await { + match mcp_messages { + ServerMessages::Single(server_message) => { + let result = self.handle_message(server_message, &transport).await?; + + if let Some(result) = result { + transport + .send_message(ClientMessages::Single(result), None) + .await?; + } + } + ServerMessages::Batch(server_messages) => { + let handling_tasks: Vec<_> = server_messages + .into_iter() + .map(|server_message| { + self.handle_message(server_message, &transport) + }) + .collect(); + + let results: Vec<_> = try_join_all(handling_tasks).await?; + + let results: Vec<_> = results.into_iter().flatten().collect(); + + if !results.is_empty() { + transport + .send_message(ClientMessages::Batch(results), None) + .await?; + } + } + } + // close the stream after all messages are sent, unless it is a standalone stream + if !stream_id.eq(DEFAULT_STREAM_ID) { + return Ok::<_, McpSdkError>(()); + } + } else { + // end of stream + return Ok::<_, McpSdkError>(()); + } + } + }); + + let mut lock = self_clone.handlers.lock().await; + lock.push(main_task); + + Ok(()) + } + + #[cfg(feature = "streamable-http")] + pub(crate) async fn start_stream( + &self, + messages: ClientMessages, + timeout: Option, + ) -> SdkResult> { + use futures::stream::{AbortHandle, Abortable}; + let stream_id: StreamId = self.stream_id_gen.generate(); + let session_id = self.session_id.read().await.clone(); + let no_session_id = session_id.is_none(); + + let has_request = match &messages { + ClientMessages::Single(client_message) => client_message.is_request(), + ClientMessages::Batch(client_messages) => { + client_messages.iter().any(|m| m.is_request()) + } + }; + + let transport = Arc::new(self.new_transport(session_id, false).await?); + + let mut stream = transport.start().await?; + + self.store_transport(&stream_id, transport).await?; + + let transport = self.transport_by_stream(&stream_id).await?; //TODO: remove + + let send_task = async { + let result = transport.send_message(messages, timeout).await?; + + if no_session_id { + if let Some(resquest_id) = transport.session_id().await.clone() { + let mut guard = self.session_id.write().await; + *guard = Some(resquest_id) + } + } + + Ok::<_, McpSdkError>(result) + }; + + if !has_request { + return send_task.await; } + + let (abort_recv_handle, abort_recv_reg) = AbortHandle::new_pair(); + + let receive_task = async { + loop { + tokio::select! { + Some(mcp_messages) = stream.next() =>{ + + match mcp_messages { + ServerMessages::Single(server_message) => { + let result = self.handle_message(server_message, &transport).await?; + if let Some(result) = result { + transport.send_message(ClientMessages::Single(result), None).await?; + } + } + ServerMessages::Batch(server_messages) => { + + let handling_tasks: Vec<_> = server_messages + .into_iter() + .map(|server_message| self.handle_message(server_message, &transport)) + .collect(); + + let results: Vec<_> = try_join_all(handling_tasks).await?; + + let results: Vec<_> = results.into_iter().flatten().collect(); + + if !results.is_empty() { + transport.send_message(ClientMessages::Batch(results), None).await?; + } + } + } + // close the stream after all messages are sent, unless it is a standalone stream + if !stream_id.eq(DEFAULT_STREAM_ID){ + return Ok::<_, McpSdkError>(()); + } + } + } + } + }; + + let receive_task = Abortable::new(receive_task, abort_recv_reg); + + // Pin the tasks to ensure they are not moved + tokio::pin!(send_task); + tokio::pin!(receive_task); + + // Run both tasks with cancellation logic + let (send_res, _) = tokio::select! { + res = &mut send_task => { + // cancel the receive_task task, to cover the case where sned_task returns with error + abort_recv_handle.abort(); + (res, receive_task.await) // Wait for receive_task to finish (it should exit due to cancellation) + } + res = &mut receive_task => { + (send_task.await, res) + } + }; + send_res } +} +#[async_trait] +impl McpClient for ClientRuntime { async fn send( &self, message: MessageFromClient, request_id: Option, - timeout: Option, + request_timeout: Option, ) -> SdkResult> { - let sender = self.sender(); - let sender = sender.read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; + #[cfg(feature = "streamable-http")] + { + if self.transport_options.is_some() { + let outgoing_request_id = self + .request_id_gen + .request_id_for_message(&message, request_id); + let mcp_message = ClientMessage::from_message(message, outgoing_request_id)?; + + let response = self + .start_stream(ClientMessages::Single(mcp_message), request_timeout) + .await?; + return response + .map(|r| r.as_single()) + .transpose() + .map_err(|err| err.into()); + } + } + + let transport_map = self.transport_map.read().await; + + let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( + RpcError::internal_error() + .with_message("transport stream does not exists or is closed!".to_string()), + )?; let outgoing_request_id = self .request_id_gen .request_id_for_message(&message, request_id); let mcp_message = ClientMessage::from_message(message, outgoing_request_id)?; + let response = transport + .send_message(ClientMessages::Single(mcp_message), request_timeout) + .await?; + response + .map(|r| r.as_single()) + .transpose() + .map_err(|err| err.into()) + } - let response = sender - .send_message(ClientMessages::Single(mcp_message), timeout) - .await? - .map(|res| res.as_single()) - .transpose()?; + async fn send_batch( + &self, + messages: Vec, + timeout: Option, + ) -> SdkResult>> { + #[cfg(feature = "streamable-http")] + { + if self.transport_options.is_some() { + let result = self + .start_stream(ClientMessages::Batch(messages), timeout) + .await?; + // let response = self.start_stream(&stream_id, request_id, message).await?; + return result + .map(|r| r.as_batch()) + .transpose() + .map_err(|err| err.into()); + } + } - Ok(response) + let transport_map = self.transport_map.read().await; + let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( + RpcError::internal_error() + .with_message("transport stream does not exists or is closed!".to_string()), + )?; + transport + .send_batch(messages, timeout) + .await + .map_err(|err| err.into()) + } + + async fn start(self: Arc) -> SdkResult<()> { + #[cfg(feature = "streamable-http")] + { + if self.transport_options.is_some() { + self.initialize_request().await?; + return Ok(()); + } + } + + self.start_standalone().await + } + + fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()> { + self.server_details_tx + .send(Some(server_details)) + .map_err(|_| { + RpcError::internal_error() + .with_message("Failed to set server details".to_string()) + .into() + }) + } + + fn client_info(&self) -> &InitializeRequestParams { + &self.client_details + } + + fn server_info(&self) -> Option { + self.server_details_rx.borrow().clone() } async fn is_shut_down(&self) -> bool { - self.transport.is_shut_down().await + let result = self.is_shut_down.lock().await; + *result } + async fn shut_down(&self) -> SdkResult<()> { - self.transport.shut_down().await?; + let mut is_shut_down_lock = self.is_shut_down.lock().await; + *is_shut_down_lock = true; + + let mut transport_map = self.transport_map.write().await; + let transports: Vec<_> = transport_map.drain().map(|(_, v)| v).collect(); + drop(transport_map); + for transport in transports { + let _ = transport.shut_down().await; + } // wait for tasks let mut tasks_lock = self.handlers.lock().await; @@ -344,4 +657,18 @@ impl McpClient for ClientRuntime { Ok(()) } + + async fn terminate_session(&self) { + #[cfg(feature = "streamable-http")] + { + if let Some(transport_options) = self.transport_options.as_ref() { + let session_id = self.session_id.read().await.clone(); + transport_options + .terminate_session(session_id.as_ref()) + .await; + let _ = self.shut_down().await; + } + } + let _ = self.shut_down().await; + } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs index 7925f07..43a7079 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs @@ -8,7 +8,10 @@ use crate::schema::{ InitializeRequestParams, RpcError, ServerNotification, ServerRequest, }; use async_trait::async_trait; -use rust_mcp_transport::Transport; + +#[cfg(feature = "streamable-http")] +use rust_mcp_transport::StreamableTransportOptions; +use rust_mcp_transport::TransportDispatcher; use crate::{ error::SdkResult, mcp_client::ClientHandler, mcp_traits::mcp_handler::McpClientHandler, @@ -37,10 +40,10 @@ use super::ClientRuntime; /// # Examples /// You can find a detailed example of how to use this function in the repository: /// -/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client) +/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio) pub fn create_client( client_details: InitializeRequestParams, - transport: impl Transport< + transport: impl TransportDispatcher< ServerMessages, MessageFromClient, ServerMessage, @@ -51,7 +54,20 @@ pub fn create_client( ) -> Arc { Arc::new(ClientRuntime::new( client_details, - transport, + Arc::new(transport), + Box::new(ClientInternalHandler::new(Box::new(handler))), + )) +} + +#[cfg(feature = "streamable-http")] +pub fn with_transport_options( + client_details: InitializeRequestParams, + transport_options: StreamableTransportOptions, + handler: impl ClientHandler, +) -> Arc { + Arc::new(ClientRuntime::new_instance( + client_details, + transport_options, Box::new(ClientInternalHandler::new(Box::new(handler))), )) } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs index 8cb8cff..884de9d 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs @@ -1,5 +1,4 @@ -use std::sync::Arc; - +use super::ClientRuntime; use crate::schema::{ schema_utils::{ ClientMessage, ClientMessages, MessageFromClient, NotificationFromServer, @@ -7,17 +6,16 @@ use crate::schema::{ }, InitializeRequestParams, RpcError, }; -use async_trait::async_trait; - -use rust_mcp_transport::Transport; - use crate::{ error::SdkResult, mcp_handlers::mcp_client_handler_core::ClientHandlerCore, mcp_traits::{mcp_client::McpClient, mcp_handler::McpClientHandler}, }; - -use super::ClientRuntime; +use async_trait::async_trait; +#[cfg(feature = "streamable-http")] +use rust_mcp_transport::StreamableTransportOptions; +use rust_mcp_transport::TransportDispatcher; +use std::sync::Arc; /// Creates a new MCP client runtime with the specified configuration. /// @@ -39,10 +37,10 @@ use super::ClientRuntime; /// # Examples /// You can find a detailed example of how to use this function in the repository: /// -/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-core) +/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio-core) pub fn create_client( client_details: InitializeRequestParams, - transport: impl Transport< + transport: impl TransportDispatcher< ServerMessages, MessageFromClient, ServerMessage, @@ -53,7 +51,20 @@ pub fn create_client( ) -> Arc { Arc::new(ClientRuntime::new( client_details, - transport, + Arc::new(transport), + Box::new(ClientCoreInternalHandler::new(Box::new(handler))), + )) +} + +#[cfg(feature = "streamable-http")] +pub fn with_transport_options( + client_details: InitializeRequestParams, + transport_options: StreamableTransportOptions, + handler: impl ClientHandlerCore, +) -> Arc { + Arc::new(ClientRuntime::new_instance( + client_details, + transport_options, Box::new(ClientCoreInternalHandler::new(Box::new(handler))), )) } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index 44f3e53..1b24b57 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -19,12 +19,15 @@ use futures::{StreamExt, TryFutureExt}; use rust_mcp_transport::SessionId; use rust_mcp_transport::{IoStream, TransportDispatcher}; use std::collections::HashMap; +use std::panic; use std::sync::Arc; use std::time::Duration; use tokio::io::AsyncWriteExt; -use tokio::sync::{oneshot, watch}; + +use tokio::sync::{mpsc, oneshot, watch}; pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM"; +const TASK_CHANNEL_CAPACITY: usize = 500; // Define a type alias for the TransportDispatcher trait object type TransportType = Arc< @@ -45,7 +48,7 @@ pub struct ServerRuntime { server_details: Arc, #[cfg(feature = "hyper-server")] session_id: Option, - transport_map: tokio::sync::RwLock>, + transport_map: tokio::sync::RwLock>, //TODO: remove the transport_map, we do not need a hashmap for it request_id_gen: Box, client_details_tx: watch::Sender>, client_details_rx: watch::Receiver>, @@ -55,8 +58,6 @@ pub struct ServerRuntime { impl McpServer for ServerRuntime { /// Set the client details, storing them in client_details async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()> { - self.handler.on_server_started(self).await; - self.client_details_tx .send(Some(client_details)) .map_err(|_| { @@ -132,8 +133,9 @@ impl McpServer for ServerRuntime { } /// Main runtime loop, processes incoming messages and handles requests - async fn start(&self) -> SdkResult<()> { - let transport_map = self.transport_map.read().await; + async fn start(self: Arc) -> SdkResult<()> { + let self_clone = self.clone(); + let transport_map = self_clone.transport_map.read().await; let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( RpcError::internal_error() @@ -142,43 +144,88 @@ impl McpServer for ServerRuntime { let mut stream = transport.start().await?; + // Create a channel to collect results from spawned tasks + let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY); + // Process incoming messages from the client while let Some(mcp_messages) = stream.next().await { match mcp_messages { ClientMessages::Single(client_message) => { - let result = self.handle_message(client_message, transport).await; - - match result { - Ok(result) => { - if let Some(result) = result { - transport - .send_message(ServerMessages::Single(result), None) - .await?; + let transport = transport.clone(); + let self = self.clone(); + let tx = tx.clone(); + + // Handle incoming messages in a separate task to avoid blocking the stream. + tokio::spawn(async move { + let result = self.handle_message(client_message, &transport).await; + + let send_result: SdkResult<_> = match result { + Ok(result) => { + if let Some(result) = result { + transport + .send_message(ServerMessages::Single(result), None) + .map_err(|e| e.into()) + .await + } else { + Ok(None) + } } + Err(error) => { + tracing::error!("Error handling message : {}", error); + Ok(None) + } + }; + // Send result to the main loop + if let Err(error) = tx.send(send_result).await { + tracing::error!("Failed to send result to channel: {}", error); } - Err(error) => { - tracing::error!("Error handling message : {}", error) - } - } + }); } ClientMessages::Batch(client_messages) => { - let handling_tasks: Vec<_> = client_messages - .into_iter() - .map(|client_message| self.handle_message(client_message, transport)) - .collect(); - - let results: Vec<_> = try_join_all(handling_tasks).await?; - - let results: Vec<_> = results.into_iter().flatten().collect(); + let transport = transport.clone(); + let self = self_clone.clone(); + let tx = tx.clone(); + + tokio::spawn(async move { + let handling_tasks: Vec<_> = client_messages + .into_iter() + .map(|client_message| self.handle_message(client_message, &transport)) + .collect(); + + let send_result = match try_join_all(handling_tasks).await { + Ok(results) => { + let results: Vec<_> = results.into_iter().flatten().collect(); + if !results.is_empty() { + transport + .send_message(ServerMessages::Batch(results), None) + .map_err(|e| e.into()) + .await + } else { + Ok(None) + } + } + Err(error) => Err(error), + }; - if !results.is_empty() { - transport - .send_message(ServerMessages::Batch(results), None) - .await?; - } + if let Err(error) = tx.send(send_result).await { + tracing::error!("Failed to send batch result to channel: {}", error); + } + }); } } + + // Check for results from spawned tasks to propagate errors + while let Ok(result) = rx.try_recv() { + result?; // Propagate errors + } } + + // Drop tx to close the channel and collect remaining results + drop(tx); + while let Some(result) = rx.recv().await { + result?; // Propagate errors + } + return Ok(()); } @@ -223,7 +270,7 @@ impl ServerRuntime { } pub(crate) async fn handle_message( - &self, + self: &Arc, message: ClientMessage, transport: &Arc< dyn TransportDispatcher< @@ -240,7 +287,7 @@ impl ServerRuntime { ClientMessage::Request(client_jsonrpc_request) => { let result = self .handler - .handle_request(client_jsonrpc_request.request, self) + .handle_request(client_jsonrpc_request.request, self.clone()) .await; // create a response to send back to the client let response: MessageFromServer = match result { @@ -262,13 +309,13 @@ impl ServerRuntime { } ClientMessage::Notification(client_jsonrpc_notification) => { self.handler - .handle_notification(client_jsonrpc_notification.notification, self) + .handle_notification(client_jsonrpc_notification.notification, self.clone()) .await?; None } ClientMessage::Error(jsonrpc_error) => { self.handler - .handle_error(&jsonrpc_error.error, self) + .handle_error(&jsonrpc_error.error, self.clone()) .await?; if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await { tx_response @@ -282,7 +329,6 @@ impl ServerRuntime { } None } - // The response is the result of a request, it is processed at the transport level. ClientMessage::Response(response) => { if let Some(tx_response) = transport.pending_request_tx(&response.id).await { tx_response @@ -313,6 +359,9 @@ impl ServerRuntime { >, >, ) -> SdkResult<()> { + if stream_id != DEFAULT_STREAM_ID { + return Ok(()); + } let mut transport_map = self.transport_map.write().await; tracing::trace!("save transport for stream id : {}", stream_id); transport_map.insert(stream_id.to_string(), transport); @@ -320,34 +369,18 @@ impl ServerRuntime { } pub(crate) async fn remove_transport(&self, stream_id: &str) -> SdkResult<()> { + if stream_id != DEFAULT_STREAM_ID { + return Ok(()); + } let mut transport_map = self.transport_map.write().await; tracing::trace!("removing transport for stream id : {}", stream_id); + if let Some(transport) = transport_map.get(stream_id) { + transport.shut_down().await?; + } transport_map.remove(stream_id); Ok(()) } - pub(crate) async fn transport_by_stream( - &self, - stream_id: &str, - ) -> SdkResult< - Arc< - dyn TransportDispatcher< - ClientMessages, - MessageFromServer, - ClientMessage, - ServerMessages, - ServerMessage, - >, - >, - > { - let transport_map = self.transport_map.read().await; - transport_map.get(stream_id).cloned().ok_or_else(|| { - RpcError::internal_error() - .with_message(format!("Transport for key {stream_id} not found")) - .into() - }) - } - pub(crate) async fn shutdown(&self) { let mut transport_map = self.transport_map.write().await; let items: Vec<_> = transport_map.drain().map(|(_, v)| v).collect(); @@ -359,17 +392,24 @@ impl ServerRuntime { pub(crate) async fn stream_id_exists(&self, stream_id: &str) -> bool { let transport_map = self.transport_map.read().await; - transport_map.contains_key(stream_id) + let live_transport = if let Some(t) = transport_map.get(stream_id) { + !t.is_shut_down().await + } else { + false + }; + live_transport } pub(crate) async fn start_stream( self: Arc, - transport: impl TransportDispatcher< - ClientMessages, - MessageFromServer, - ClientMessage, - ServerMessages, - ServerMessage, + transport: Arc< + dyn TransportDispatcher< + ClientMessages, + MessageFromServer, + ClientMessage, + ServerMessages, + ServerMessage, + >, >, stream_id: &str, ping_interval: Duration, @@ -377,9 +417,11 @@ impl ServerRuntime { ) -> SdkResult<()> { let mut stream = transport.start().await?; - self.store_transport(stream_id, Arc::new(transport)).await?; + if stream_id == DEFAULT_STREAM_ID { + self.store_transport(stream_id, transport.clone()).await?; + } - let transport = self.transport_by_stream(stream_id).await?; + let self_clone = self.clone(); let (disconnect_tx, mut disconnect_rx) = oneshot::channel::<()>(); let abort_alive_task = transport @@ -394,43 +436,102 @@ impl ServerRuntime { // in case there is a payload, we consume it by transport to get processed if let Some(payload) = payload { - transport.consume_string_payload(&payload).await?; + if let Err(err) = transport.consume_string_payload(&payload).await { + let _ = self.remove_transport(stream_id).await; + return Err(err.into()); + } } + // Create a channel to collect results from spawned tasks + let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY); + loop { tokio::select! { Some(mcp_messages) = stream.next() =>{ match mcp_messages { ClientMessages::Single(client_message) => { - let result = self.handle_message(client_message, &transport).await?; - if let Some(result) = result { - transport.send_message(ServerMessages::Single(result), None).await?; - } + let transport = transport.clone(); + let self_clone = self.clone(); + let tx = tx.clone(); + tokio::spawn(async move { + + let result = self_clone.handle_message(client_message, &transport).await; + + let send_result: SdkResult<_> = match result { + Ok(result) => { + if let Some(result) = result { + transport + .send_message(ServerMessages::Single(result), None) + .map_err(|e| e.into()) + .await + } else { + Ok(None) + } + } + Err(error) => { + tracing::error!("Error handling message : {}", error); + Ok(None) + } + }; + if let Err(error) = tx.send(send_result).await { + tracing::error!("Failed to send batch result to channel: {}", error); + } + }); } ClientMessages::Batch(client_messages) => { - let handling_tasks: Vec<_> = client_messages - .into_iter() - .map(|client_message| self.handle_message(client_message, &transport)) - .collect(); - - let results: Vec<_> = try_join_all(handling_tasks).await?; - - let results: Vec<_> = results.into_iter().flatten().collect(); - - - if !results.is_empty() { - transport.send_message(ServerMessages::Batch(results), None).await?; - } + let transport = transport.clone(); + let self_clone = self_clone.clone(); + let tx = tx.clone(); + + tokio::spawn(async move { + let handling_tasks: Vec<_> = client_messages + .into_iter() + .map(|client_message| self_clone.handle_message(client_message, &transport)) + .collect(); + + let send_result = match try_join_all(handling_tasks).await { + Ok(results) => { + let results: Vec<_> = results.into_iter().flatten().collect(); + if !results.is_empty() { + transport.send_message(ServerMessages::Batch(results), None) + .map_err(|e| e.into()) + .await + }else { + Ok(None) + } + }, + Err(error) => Err(error), + }; + if let Err(error) = tx.send(send_result).await { + tracing::error!("Failed to send batch result to channel: {}", error); + } + }); } } + + // Check for results from spawned tasks to propagate errors + while let Ok(result) = rx.try_recv() { + result?; // Propagate errors + } + // close the stream after all messages are sent, unless it is a standalone stream if !stream_id.eq(DEFAULT_STREAM_ID){ + // Drop tx to close the channel and collect remaining results + drop(tx); + while let Some(result) = rx.recv().await { + result?; // Propagate errors + } return Ok(()); } } _ = &mut disconnect_rx => { + // Drop tx to close the channel and collect remaining results + drop(tx); + while let Some(result) = rx.recv().await { + result?; // Propagate errors + } self.remove_transport(stream_id).await?; // Disconnection detected by keep-alive task return Err(SdkError::connection_closed().into()); @@ -445,10 +546,10 @@ impl ServerRuntime { server_details: Arc, handler: Arc, session_id: SessionId, - ) -> Self { + ) -> Arc { let (client_details_tx, client_details_rx) = watch::channel::>(None); - Self { + Arc::new(Self { server_details, handler, session_id: Some(session_id), @@ -456,7 +557,7 @@ impl ServerRuntime { client_details_tx, client_details_rx, request_id_gen: Box::new(RequestIdGenNumeric::new(None)), - } + }) } pub(crate) fn new( @@ -469,12 +570,12 @@ impl ServerRuntime { ServerMessage, >, handler: Arc, - ) -> Self { + ) -> Arc { let mut map: HashMap = HashMap::new(); map.insert(DEFAULT_STREAM_ID.to_string(), Arc::new(transport)); let (client_details_tx, client_details_rx) = watch::channel::>(None); - Self { + Arc::new(Self { server_details: Arc::new(server_details), handler, #[cfg(feature = "hyper-server")] @@ -483,6 +584,6 @@ impl ServerRuntime { client_details_tx, client_details_rx, request_id_gen: Box::new(RequestIdGenNumeric::new(None)), - } + }) } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs index ea19e19..62fd31f 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs @@ -38,7 +38,7 @@ use crate::{ /// # Examples /// You can find a detailed example of how to use this function in the repository: /// -/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server) +/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio) pub fn create_server( server_details: InitializeResult, transport: impl TransportDispatcher< @@ -49,7 +49,7 @@ pub fn create_server( ServerMessage, >, handler: impl ServerHandler, -) -> ServerRuntime { +) -> Arc { ServerRuntime::new( server_details, transport, @@ -62,7 +62,7 @@ pub(crate) fn create_server_instance( server_details: Arc, handler: Arc, session_id: SessionId, -) -> ServerRuntime { +) -> Arc { ServerRuntime::new_instance(server_details, handler, session_id) } @@ -80,7 +80,7 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { async fn handle_request( &self, client_jsonrpc_request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { match client_jsonrpc_request { schema_utils::RequestFromClient::ClientRequest(client_request) => { @@ -178,7 +178,7 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { async fn handle_error( &self, jsonrpc_error: &RpcError, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; Ok(()) @@ -187,7 +187,7 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { async fn handle_notification( &self, client_jsonrpc_notification: NotificationFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()> { match client_jsonrpc_notification { schema_utils::NotificationFromClient::ClientNotification(client_notification) => { @@ -199,7 +199,10 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { } ClientNotification::InitializedNotification(initialized_notification) => { self.handler - .handle_initialized_notification(initialized_notification, runtime) + .handle_initialized_notification( + initialized_notification, + runtime.clone(), + ) .await?; self.handler.on_initialized(runtime).await; } @@ -226,8 +229,4 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { } Ok(()) } - - async fn on_server_started(&self, runtime: &dyn McpServer) { - self.handler.on_server_started(runtime).await; - } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs index e0e7108..110b20b 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs @@ -32,7 +32,7 @@ use std::sync::Arc; /// # Examples /// You can find a detailed example of how to use this function in the repository: /// -/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-core) +/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio-core) pub fn create_server( server_details: InitializeResult, transport: impl TransportDispatcher< @@ -43,7 +43,7 @@ pub fn create_server( ServerMessage, >, handler: impl ServerHandlerCore, -) -> ServerRuntime { +) -> Arc { ServerRuntime::new( server_details, transport, @@ -66,7 +66,7 @@ impl McpServerHandler for RuntimeCoreInternalHandler> async fn handle_request( &self, client_jsonrpc_request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { // store the client details if the request is a client initialization request if let schema_utils::RequestFromClient::ClientRequest(ClientRequest::InitializeRequest( @@ -88,7 +88,7 @@ impl McpServerHandler for RuntimeCoreInternalHandler> async fn handle_error( &self, jsonrpc_error: &RpcError, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; Ok(()) @@ -96,11 +96,11 @@ impl McpServerHandler for RuntimeCoreInternalHandler> async fn handle_notification( &self, client_jsonrpc_notification: NotificationFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()> { // Trigger the `on_initialized()` callback if an `initialized_notification` is received from the client. if client_jsonrpc_notification.is_initialized_notification() { - self.handler.on_initialized(runtime).await; + self.handler.on_initialized(runtime.clone()).await; } // handle notification @@ -109,7 +109,4 @@ impl McpServerHandler for RuntimeCoreInternalHandler> .await?; Ok(()) } - async fn on_server_started(&self, runtime: &dyn McpServer) { - self.handler.on_server_started(runtime).await; - } } diff --git a/crates/rust-mcp-sdk/src/mcp_traits.rs b/crates/rust-mcp-sdk/src/mcp_traits.rs index 2b155fa..b66ba93 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits.rs @@ -1,3 +1,4 @@ +pub(super) mod id_generator; #[cfg(feature = "client")] pub mod mcp_client; pub mod mcp_handler; @@ -5,4 +6,5 @@ pub mod mcp_handler; pub mod mcp_server; mod request_id_gen; +pub use id_generator::*; pub use request_id_gen::*; diff --git a/crates/rust-mcp-sdk/src/mcp_traits/id_generator.rs b/crates/rust-mcp-sdk/src/mcp_traits/id_generator.rs new file mode 100644 index 0000000..e7cb8d3 --- /dev/null +++ b/crates/rust-mcp-sdk/src/mcp_traits/id_generator.rs @@ -0,0 +1,12 @@ +/// Trait for generating unique identifiers. +/// +/// This trait is generic over the target ID type, allowing it to be used for +/// generating different kinds of identifiers such as `SessionId` or +/// transport-scoped `StreamId`. +/// +pub trait IdGenerator: Send + Sync +where + T: From, +{ + fn generate(&self) -> T; +} diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs index 1883581..5fe3fba 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs @@ -1,9 +1,7 @@ -use std::{sync::Arc, time::Duration}; - use crate::schema::{ schema_utils::{ - self, ClientMessage, ClientMessages, FromMessage, McpMessage, MessageFromClient, - NotificationFromClient, RequestFromClient, ResultFromServer, ServerMessage, ServerMessages, + ClientMessage, McpMessage, MessageFromClient, NotificationFromClient, RequestFromClient, + ResultFromServer, ServerMessage, }, CallToolRequest, CallToolRequestParams, CallToolResult, CompleteRequest, CompleteRequestParams, CreateMessageRequest, GetPromptRequest, GetPromptRequestParams, Implementation, @@ -17,21 +15,18 @@ use crate::schema::{ }; use crate::{error::SdkResult, utils::format_assertion_message}; use async_trait::async_trait; -use rust_mcp_transport::{McpDispatch, MessageDispatcher}; +use std::{sync::Arc, time::Duration}; #[async_trait] pub trait McpClient: Sync + Send { async fn start(self: Arc) -> SdkResult<()>; fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()>; + async fn terminate_session(&self); + async fn shut_down(&self) -> SdkResult<()>; async fn is_shut_down(&self) -> bool; - fn sender(&self) -> Arc>>> - where - MessageDispatcher: - McpDispatch; - fn client_info(&self) -> &InitializeRequestParams; fn server_info(&self) -> Option; @@ -170,48 +165,20 @@ pub trait McpClient: Sync + Send { &self, message: MessageFromClient, request_id: Option, - timeout: Option, + request_timeout: Option, ) -> SdkResult>; async fn send_batch( &self, messages: Vec, timeout: Option, - ) -> SdkResult>> { - let sender = self.sender(); - let sender = sender.read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; - - let response = sender - .send_message(ClientMessages::Batch(messages), timeout) - .await?; - - match response { - Some(res) => { - let server_results = res.as_batch()?; - Ok(Some(server_results)) - } - None => Ok(None), - } - } + ) -> SdkResult>>; /// Sends a notification. This is a one-way message that is not expected /// to return any response. The method asynchronously sends the notification using /// the transport layer and does not wait for any acknowledgement or result. async fn send_notification(&self, notification: NotificationFromClient) -> SdkResult<()> { - let sender = self.sender(); - let sender = sender.read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; - - let mcp_message = ClientMessage::from_message(MessageFromClient::from(notification), None)?; - - sender - .send_message(ClientMessages::Single(mcp_message), None) - .await?; + self.send(notification.into(), None, None).await?; Ok(()) } diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs index 2974bfc..cb37f2a 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs @@ -6,9 +6,9 @@ use crate::schema::schema_utils::{NotificationFromClient, RequestFromClient, Res #[cfg(feature = "client")] use crate::schema::schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient}; -use crate::schema::RpcError; - use crate::error::SdkResult; +use crate::schema::RpcError; +use std::sync::Arc; #[cfg(feature = "client")] use super::mcp_client::McpClient; @@ -18,21 +18,20 @@ use super::mcp_server::McpServer; #[cfg(feature = "server")] #[async_trait] pub trait McpServerHandler: Send + Sync { - async fn on_server_started(&self, runtime: &dyn McpServer); async fn handle_request( &self, client_jsonrpc_request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result; async fn handle_error( &self, jsonrpc_error: &RpcError, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()>; async fn handle_notification( &self, client_jsonrpc_notification: NotificationFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()>; } diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs index 2eab9db..dc860b6 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs @@ -13,16 +13,15 @@ use crate::schema::{ ResourceUpdatedNotification, ResourceUpdatedNotificationParams, RpcError, ServerCapabilities, SetLevelRequest, ToolListChangedNotification, ToolListChangedNotificationParams, }; +use crate::{error::SdkResult, utils::format_assertion_message}; use async_trait::async_trait; use rust_mcp_transport::SessionId; -use std::time::Duration; - -use crate::{error::SdkResult, utils::format_assertion_message}; +use std::{sync::Arc, time::Duration}; //TODO: support options , such as enforceStrictCapabilities #[async_trait] pub trait McpServer: Sync + Send { - async fn start(&self) -> SdkResult<()>; + async fn start(self: Arc) -> SdkResult<()>; async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()>; fn server_info(&self) -> &InitializeResult; fn client_info(&self) -> Option; diff --git a/crates/rust-mcp-sdk/src/utils.rs b/crates/rust-mcp-sdk/src/utils.rs index e98a1ed..16fe7c7 100644 --- a/crates/rust-mcp-sdk/src/utils.rs +++ b/crates/rust-mcp-sdk/src/utils.rs @@ -1,6 +1,6 @@ use crate::schema::schema_utils::{ClientMessages, SdkError}; -use crate::error::{McpSdkError, SdkResult}; +use crate::error::{McpSdkError, ProtocolErrorKind, SdkResult}; use crate::schema::ProtocolVersion; use std::cmp::Ordering; @@ -71,20 +71,20 @@ pub fn format_assertion_message(entity: &str, capability: &str, method_name: &st /// let result = ensure_server_protocole_compatibility("2024_11_05", "2024_11_05"); /// assert!(result.is_ok()); /// -/// // Incompatible versions (client < server) +/// // Incompatible versions (requested < current) /// let result = ensure_server_protocole_compatibility("2024_11_05", "2025_03_26"); /// assert!(matches!( /// result, -/// Err(McpSdkError::IncompatibleProtocolVersion(client, server)) -/// if client == "2024_11_05" && server == "2025_03_26" +/// Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}}) +/// if requested == "2024_11_05" && current == "2025_03_26" /// )); /// -/// // Incompatible versions (client > server) +/// // Incompatible versions (requested > current) /// let result = ensure_server_protocole_compatibility("2025_03_26", "2024_11_05"); /// assert!(matches!( /// result, -/// Err(McpSdkError::IncompatibleProtocolVersion(client, server)) -/// if client == "2025_03_26" && server == "2024_11_05" +/// Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}}) +/// if requested == "2025_03_26" && current == "2024_11_05" /// )); /// ``` #[allow(unused)] @@ -93,10 +93,12 @@ pub fn ensure_server_protocole_compatibility( server_protocol_version: &str, ) -> SdkResult<()> { match client_protocol_version.cmp(server_protocol_version) { - Ordering::Less | Ordering::Greater => Err(McpSdkError::IncompatibleProtocolVersion( - client_protocol_version.to_string(), - server_protocol_version.to_string(), - )), + Ordering::Less | Ordering::Greater => Err(McpSdkError::Protocol { + kind: ProtocolErrorKind::IncompatibleVersion { + requested: client_protocol_version.to_string(), + current: server_protocol_version.to_string(), + }, + }), Ordering::Equal => Ok(()), } } @@ -140,8 +142,8 @@ pub fn ensure_server_protocole_compatibility( /// let result = enforce_compatible_protocol_version("2025_03_26", "2024_11_05"); /// assert!(matches!( /// result, -/// Err(McpSdkError::IncompatibleProtocolVersion(client, server)) -/// if client == "2025_03_26" && server == "2024_11_05" +/// Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}}) +/// if requested == "2025_03_26" && current == "2024_11_05" /// )); /// ``` #[allow(unused)] @@ -151,10 +153,12 @@ pub fn enforce_compatible_protocol_version( ) -> SdkResult> { match client_protocol_version.cmp(server_protocol_version) { // if client protocol version is higher - Ordering::Greater => Err(McpSdkError::IncompatibleProtocolVersion( - client_protocol_version.to_string(), - server_protocol_version.to_string(), - )), + Ordering::Greater => Err(McpSdkError::Protocol { + kind: ProtocolErrorKind::IncompatibleVersion { + requested: client_protocol_version.to_string(), + current: server_protocol_version.to_string(), + }, + }), Ordering::Equal => Ok(None), Ordering::Less => { // return the same version that was received from the client @@ -164,7 +168,10 @@ pub fn enforce_compatible_protocol_version( } pub fn validate_mcp_protocol_version(mcp_protocol_version: &str) -> SdkResult<()> { - let _mcp_protocol_version = ProtocolVersion::try_from(mcp_protocol_version)?; + let _mcp_protocol_version = + ProtocolVersion::try_from(mcp_protocol_version).map_err(|err| McpSdkError::Protocol { + kind: ProtocolErrorKind::ParseError(err), + })?; Ok(()) } diff --git a/crates/rust-mcp-sdk/tests/check_imports.rs b/crates/rust-mcp-sdk/tests/check_imports.rs index cda7d0c..207644e 100644 --- a/crates/rust-mcp-sdk/tests/check_imports.rs +++ b/crates/rust-mcp-sdk/tests/check_imports.rs @@ -37,13 +37,12 @@ mod tests { // Check for `use rust_mcp_schema` if content.contains("use rust_mcp_schema") { errors.push(format!( - "File {} contains `use rust_mcp_schema`. Use `use crate::schema` instead.", - abs_path + "File {abs_path} contains `use rust_mcp_schema`. Use `use crate::schema` instead." )); } } Err(e) => { - errors.push(format!("Failed to read file `{}`: {}", path_str, e)); + errors.push(format!("Failed to read file `{path_str}`: {e}")); } } } diff --git a/crates/rust-mcp-sdk/tests/common/common.rs b/crates/rust-mcp-sdk/tests/common/common.rs index 564db0d..f330dda 100644 --- a/crates/rust-mcp-sdk/tests/common/common.rs +++ b/crates/rust-mcp-sdk/tests/common/common.rs @@ -1,5 +1,8 @@ +mod mock_server; +mod test_client; mod test_server; use async_trait::async_trait; +pub use mock_server::*; use reqwest::{Client, Response, Url}; use rust_mcp_macros::{mcp_tool, JsonSchema}; use rust_mcp_schema::ProtocolVersion; @@ -8,9 +11,12 @@ use rust_mcp_sdk::mcp_client::ClientHandler; use rust_mcp_sdk::schema::{ClientCapabilities, Implementation, InitializeRequestParams}; use std::collections::HashMap; use std::process; -use std::time::{SystemTime, UNIX_EPOCH}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use tokio::time::timeout; use tokio_stream::StreamExt; +use wiremock::{MockServer, Request, ResponseTemplate}; +pub use test_client::*; pub use test_server::*; pub const NPX_SERVER_EVERYTHING: &str = "@modelcontextprotocol/server-everything"; @@ -337,3 +343,52 @@ pub mod sample_tools { } } } + +pub async fn wiremock_request(mock_server: &MockServer, index: usize) -> Request { + let requests = mock_server.received_requests().await.unwrap(); + requests[index].clone() +} + +pub async fn debug_wiremock(mock_server: &MockServer) { + let requests = mock_server.received_requests().await.unwrap(); + let len = requests.len(); + println!(">>> {len} request(s) received <<<"); + + for (index, request) in requests.iter().enumerate() { + println!("\n--- #{index} of {len} ---"); + println!("Method: {}", request.method); + println!("Path: {}", request.url.path()); + // println!("Headers: {:#?}", request.headers); + println!("---- headers ----"); + for (key, values) in &request.headers { + println!("{key}: {values:?}"); + } + + let body_str = String::from_utf8_lossy(&request.body); + println!("Body: {body_str}\n"); + } +} + +pub fn create_sse_response(payload: &str) -> ResponseTemplate { + let sse_body = format!(r#"data: {}{}"#, payload, "\n\n"); + ResponseTemplate::new(200).set_body_raw(sse_body.into_bytes(), "text/event-stream") +} + +pub async fn wait_for_n_requests( + mock_server: &MockServer, + num_requests: usize, + duration: Option, +) { + let duration = duration.unwrap_or(Duration::from_secs(1)); + timeout(duration, async { + loop { + let requests = mock_server.received_requests().await.unwrap(); + if requests.len() >= num_requests { + break; + } + tokio::time::sleep(Duration::from_millis(100)).await; + } + }) + .await + .unwrap(); +} diff --git a/crates/rust-mcp-sdk/tests/common/mock_server.rs b/crates/rust-mcp-sdk/tests/common/mock_server.rs new file mode 100644 index 0000000..f5b533a --- /dev/null +++ b/crates/rust-mcp-sdk/tests/common/mock_server.rs @@ -0,0 +1,528 @@ +use axum::{ + body::Body, + extract::Request, + http::{header::CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue, Method, StatusCode}, + response::{ + sse::{Event, KeepAlive}, + IntoResponse, Response, Sse, + }, + routing::any, + Router, +}; +use core::fmt; +use futures::stream; +use std::collections::VecDeque; +use std::{future::Future, net::SocketAddr, pin::Pin}; +use std::{ + sync::{Arc, Mutex}, + time::Duration, +}; +use tokio::net::TcpListener; + +pub struct SseEvent { + /// The optional event type (e.g., "message"). + pub event: Option, + /// The optional data payload of the event, stored as bytes. + pub data: Option, + /// The optional event ID for reconnection or tracking purposes. + pub id: Option, +} + +impl ToString for SseEvent { + fn to_string(&self) -> String { + let mut s = String::new(); + + if let Some(id) = &self.id { + s.push_str("id: "); + s.push_str(id); + s.push('\n'); + } + + if let Some(event) = &self.event { + s.push_str("event: "); + s.push_str(event); + s.push('\n'); + } + + if let Some(data) = &self.data { + // Convert bytes to string safely, fallback if invalid UTF-8 + for line in data.lines() { + s.push_str("data: "); + s.push_str(line); + s.push('\n'); + } + } + + s.push('\n'); // End of event + s + } +} + +impl fmt::Debug for SseEvent { + /// Formats the `SseEvent` for debugging, converting the `data` field to a UTF-8 string + /// (with lossy conversion if invalid UTF-8 is encountered). + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let data_str = self.data.as_ref(); + + f.debug_struct("SseEvent") + .field("event", &self.event) + .field("data", &data_str) + .field("id", &self.id) + .finish() + } +} + +// RequestRecord stores the history of incoming requests +#[derive(Clone, Debug)] +pub struct RequestRecord { + pub method: Method, + pub path: String, + pub headers: HeaderMap, + pub body: String, +} + +#[derive(Clone, Debug)] +pub struct ResponseRecord { + pub status: StatusCode, + pub headers: HeaderMap, + pub body: String, +} + +// pub type BoxedStream = +// Pin> + Send>>; +// pub type BoxedSseResponse = Sse; + +// pub type AsyncResponseFn = +// Box Pin + Send>> + Send + Sync>; + +type AsyncResponseFn = + Box Pin + Send>> + Send + Sync>; + +// Mock defines a single mock response configuration +// #[derive(Clone)] +pub struct Mock { + method: Method, + path: String, + response: String, + response_func: Option, + header_map: HeaderMap, + matcher: Option bool + Send + Sync>>, + remaining_calls: Option>>, + status: StatusCode, +} + +// MockBuilder is a factory for creating Mock instances +pub struct MockBuilder { + method: Method, + path: String, + response: String, + header_map: HeaderMap, + response_func: Option, + matcher: Option bool + Send + Sync>>, + remaining_calls: Option>>, + status: StatusCode, +} + +impl MockBuilder { + fn new(method: Method, path: String, response: String, header_map: HeaderMap) -> Self { + Self { + method, + path, + response, + response_func: None, + header_map, + matcher: None, + status: StatusCode::OK, + remaining_calls: None, // Default to unlimited calls + } + } + + fn new_with_func( + method: Method, + path: String, + response_func: AsyncResponseFn, + header_map: HeaderMap, + ) -> Self { + Self { + method, + path, + response: String::new(), + response_func: Some(response_func), + header_map, + matcher: None, + status: StatusCode::OK, + remaining_calls: None, // Default to unlimited calls + } + } + + pub fn new_breakable_sse( + method: Method, + path: String, + repeating_message: SseEvent, + interval: Duration, + repeat: usize, + ) -> Self { + let message = Arc::new(repeating_message); + let interval = interval; + let max_repeats = repeat; + + let response_fn: AsyncResponseFn = Box::new({ + let message = Arc::clone(&message); + move || { + let message = Arc::clone(&message); + + Box::pin(async move { + // Construct SSE stream with 10 static messages using unfold + let message_stream = stream::unfold(0, move |count| { + let message = Arc::clone(&message); + + async move { + if count >= max_repeats { + return Some(( + Err(std::io::Error::other("Message limit reached")), + count, + )); + } + tokio::time::sleep(interval).await; + + Some(( + Ok(Event::default() + .data(message.data.clone().unwrap_or("".into())) + .id(message.id.clone().unwrap_or(format!("msg-id_{count}"))) + .event(message.event.clone().unwrap_or("message".into()))), + count + 1, + )) + } + }); + + let sse_stream = Sse::new(message_stream) + .keep_alive(KeepAlive::new().interval(Duration::from_secs(10))); + + sse_stream.into_response() + }) + } + }); + + let mut header_map = HeaderMap::new(); + header_map.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); + Self::new_with_func(method, path, response_fn, header_map) + } + + pub fn with_matcher(mut self, matcher: F) -> Self + where + F: Fn(&str, &HeaderMap) -> bool + Send + Sync + 'static, + { + self.matcher = Some(Arc::new(matcher)); + self + } + + pub fn add_header(mut self, key: HeaderName, val: HeaderValue) -> Self { + self.header_map.insert(key, val); + self + } + + pub fn without_matcher(mut self) -> Self { + self.matcher = None; + self + } + + pub fn expect(mut self, num_calls: usize) -> Self { + self.remaining_calls = Some(Arc::new(Mutex::new(num_calls))); + self + } + + pub fn unlimited_calls(mut self) -> Self { + self.remaining_calls = None; + self + } + + pub fn with_status(mut self, status: StatusCode) -> Self { + self.status = status; + self + } + + pub fn build(self) -> Mock { + Mock { + method: self.method, + path: self.path, + response: self.response, + header_map: self.header_map, + matcher: self.matcher, + remaining_calls: self.remaining_calls, + status: self.status, + response_func: self.response_func, + } + } + + // add_string with text/plain + pub fn new_text(method: Method, path: String, response: impl Into) -> Self { + let mut header_map = HeaderMap::new(); + header_map.insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); + + Self::new(method, path, response.into(), header_map) + } + + /** + MockBuilder::new_response( + Method::GET, + "/mcp".to_string(), + Box::new(|| { + // tokio::time::sleep(Duration::from_secs(1)).await; + let json_response = Json(json!({ + "status": "ok", + "data": [1, 2, 3] + })) + .into_response(); + Box::pin(async move { json_response }) + }), + ) + .build(), + */ + pub fn new_response(method: Method, path: String, response_func: AsyncResponseFn) -> Self { + Self::new_with_func(method, path, response_func, HeaderMap::new()) + } + + // new_json with application/json + pub fn new_json(method: Method, path: String, response: impl Into) -> Self { + let mut header_map = HeaderMap::new(); + header_map.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + Self::new(method, path, response.into(), header_map) + } + + // new_sse with text/event-stream + pub fn new_sse(method: Method, path: String, response: impl Into) -> Self { + let response = format!(r#"data: {}{}"#, response.into(), '\n'); + + let mut header_map = HeaderMap::new(); + header_map.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); + // ensure message ends with a \n\n , if needed + let cr = if response.ends_with("\n\n") { + "" + } else { + "\n\n" + }; + Self::new(method, path, format!("{response}{cr}"), header_map) + } + + // new_raw with application/octet-stream + pub fn new_raw(method: Method, path: String, response: impl Into) -> Self { + let mut header_map = HeaderMap::new(); + header_map.insert( + CONTENT_TYPE, + HeaderValue::from_static("application/octet-stream"), + ); + Self::new(method, path, response.into(), header_map) + } +} + +// MockServerHandle provides access to the request history after the server starts +pub struct MockServerHandle { + history: Arc>>, +} + +impl MockServerHandle { + pub async fn get_history(&self) -> Vec<(RequestRecord, ResponseRecord)> { + let history = self.history.lock().unwrap(); + history.iter().cloned().collect() + } + + pub async fn print(&self) { + let requests = self.get_history().await; + + let len = requests.len(); + println!("\n>>> {len} request(s) received <<<"); + + for (index, (request, response)) in requests.iter().enumerate() { + println!( + "\n--- Request {} of {len} ------------------------------------", + index + 1 + ); + println!("Method: {}", request.method); + println!("Path: {}", request.path); + // println!("Headers: {:#?}", request.headers); + println!("> headers "); + for (key, values) in &request.headers { + println!("{key}: {values:?}"); + } + + println!("\n> Body"); + println!("{}\n", &request.body); + + println!(">>>>> Response <<<<<"); + println!("> status: {}", response.status); + println!("> headers"); + for (key, values) in &response.headers { + println!("{key}: {values:?}"); + } + println!("> Body"); + println!("{}", &response.body); + } + } +} + +// MockServer is the main struct for configuring and starting the mock server +pub struct SimpleMockServer { + mocks: Vec, + history: Arc>>, +} + +impl Default for SimpleMockServer { + fn default() -> Self { + Self::new() + } +} + +impl SimpleMockServer { + pub fn new() -> Self { + Self { + mocks: Vec::new(), + history: Arc::new(Mutex::new(VecDeque::new())), + } + } + + pub async fn start_with_mocks(mocks: Vec) -> (String, MockServerHandle) { + let mut server = SimpleMockServer::new(); + server.add_mocks(mocks); + server.start().await + } + + // Generic add function + pub fn add_mock_builder(&mut self, builder: MockBuilder) -> &mut Self { + self.mocks.push(builder.build()); + self + } + + pub fn add_mock(&mut self, mock: Mock) -> &mut Self { + self.mocks.push(mock); + self + } + + pub fn add_mocks(&mut self, mock: Vec) -> &mut Self { + mock.into_iter().for_each(|m| self.mocks.push(m)); + self + } + + pub async fn start(self) -> (String, MockServerHandle) { + let mocks = Arc::new(self.mocks); + let history = Arc::clone(&self.history); + + async fn handler( + mocks: Arc>, + history: Arc>>, + mut req: Request, + ) -> impl IntoResponse { + // Take ownership of the body using std::mem::take + let body = std::mem::take(req.body_mut()); + let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap(); + let body_str = String::from_utf8_lossy(&body_bytes).to_string(); + + let request_record = RequestRecord { + method: req.method().clone(), + path: req.uri().path().to_string(), + headers: req.headers().clone(), + body: body_str.clone(), + }; + + for m in mocks.iter() { + if m.method != *req.method() || m.path != req.uri().path() { + continue; + } + + if let Some(matcher) = &m.matcher { + if !(matcher)(&body_str, req.headers()) { + continue; + } + } + + if let Some(remaining) = &m.remaining_calls { + let mut rem = remaining.lock().unwrap(); + if *rem == 0 { + continue; + } + *rem -= 1; + } + + let mut resp = match m.response_func.as_ref() { + Some(get_response) => get_response().await.into_response(), + None => Response::new(Body::from(m.response.clone())), + }; + + // if let Some(resp_box) = &mut m.response_func.take() { + // let response = resp_box.into_response(); + // // *response.status_mut() = m.status; + // // m.response_func = Some(Box::new(response)); + // } + + // let mut resp = m.response_func.as_ref().unwrap().clone().to_owned(); + // let resp = *resp; + // *resp.into_response().status_mut() = m.status; + + // let mut response = m.response_func.as_ref().unwrap().clone(); + // let mut response = m.response_func.as_ref().unwrap().clone().to_owned(); + // let mut m = *response; + // *response.status_mut() = m.status; + // let resp = &*m.response_func.as_ref().unwrap().to_owned().clone().deref(); + + // let response = boxed_response.into_response(); + + // let mut resp = Response::new(Body::from(m.response.clone())); + *resp.status_mut() = m.status; + m.header_map.iter().for_each(|(k, v)| { + resp.headers_mut().insert(k, v.clone()); + }); + + let response_record = ResponseRecord { + status: resp.status(), + headers: resp.headers().clone(), + body: m.response.clone(), + }; + + { + let mut hist = history.lock().unwrap(); + hist.push_back((request_record, response_record)); + } + + return resp; + } + + let resp = Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Body::empty()) + .unwrap(); + + let response_record = ResponseRecord { + status: resp.status(), + headers: resp.headers().clone(), + body: "".into(), + }; + + { + let mut hist = history.lock().unwrap(); + hist.push_back((request_record, response_record)); + } + + resp + } + + let app = Router::new().route( + "/{*path}", + any(move |req: Request| handler(Arc::clone(&mocks), Arc::clone(&history), req)), + ); + + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await.unwrap(); + let local_addr = listener.local_addr().unwrap(); + let url = format!("http://{local_addr}"); + + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + ( + url, + MockServerHandle { + history: self.history, + }, + ) + } +} diff --git a/crates/rust-mcp-sdk/tests/common/test_client.rs b/crates/rust-mcp-sdk/tests/common/test_client.rs new file mode 100644 index 0000000..21678c7 --- /dev/null +++ b/crates/rust-mcp-sdk/tests/common/test_client.rs @@ -0,0 +1,163 @@ +use async_trait::async_trait; +use rust_mcp_schema::{schema_utils::MessageFromServer, PingRequest, RpcError}; +use rust_mcp_sdk::{mcp_client::ClientHandler, McpClient}; +use serde_json::json; +use std::sync::Arc; +use tokio::sync::RwLock; + +#[cfg(feature = "hyper-server")] +pub mod test_client_common { + use rust_mcp_schema::{ + schema_utils::MessageFromServer, ClientCapabilities, Implementation, + InitializeRequestParams, LATEST_PROTOCOL_VERSION, + }; + use rust_mcp_sdk::{ + mcp_client::{client_runtime, ClientRuntime}, + McpClient, RequestOptions, SessionId, StreamableTransportOptions, + }; + use std::{collections::HashMap, sync::Arc, time::Duration}; + use tokio::sync::RwLock; + use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + use wiremock::{ + matchers::{body_json_string, method, path}, + Mock, MockServer, ResponseTemplate, + }; + + use crate::common::{ + create_sse_response, test_server_common::INITIALIZE_RESPONSE, wait_for_n_requests, + }; + + pub struct InitializedClient { + pub client: Arc, + pub mcp_url: String, + pub mock_server: MockServer, + } + + pub const TEST_SESSION_ID: &str = "test-session-id"; + pub const INITIALIZE_REQUEST: &str = r#"{"id":0,"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{},"clientInfo":{"name":"simple-rust-mcp-client-sse","title":"Simple Rust MCP Client (SSE)","version":"0.1.0"},"protocolVersion":"2025-06-18"}}"#; + + pub fn test_client_details() -> InitializeRequestParams { + InitializeRequestParams { + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "simple-rust-mcp-client-sse".to_string(), + version: "0.1.0".to_string(), + title: Some("Simple Rust MCP Client (SSE)".to_string()), + }, + protocol_version: LATEST_PROTOCOL_VERSION.into(), + } + } + + pub async fn create_client( + mcp_url: &str, + custom_headers: Option>, + ) -> (Arc, Arc>>) { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "info".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let client_details: InitializeRequestParams = test_client_details(); + + let transport_options = StreamableTransportOptions { + mcp_url: mcp_url.to_string(), + request_options: RequestOptions { + request_timeout: Duration::from_secs(2), + custom_headers, + ..RequestOptions::default() + }, + }; + + let message_history = Arc::new(RwLock::new(vec![])); + let handler = super::TestClientHandler { + message_history: message_history.clone(), + }; + + let client = + client_runtime::with_transport_options(client_details, transport_options, handler); + + // client.clone().start().await.unwrap(); + (client, message_history) + } + + pub async fn initialize_client( + session_id: Option, + custom_headers: Option>, + ) -> InitializedClient { + let mock_server = MockServer::start().await; + + // intialize response + let mut response = create_sse_response(INITIALIZE_RESPONSE); + + if let Some(session_id) = session_id { + response = response.append_header("mcp-session-id", session_id.as_str()); + } + + // initialize request and response + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string(INITIALIZE_REQUEST)) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // receive initialized notification + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string( + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + )) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + let mcp_url = format!("{}/mcp", mock_server.uri()); + let (client, _) = create_client(&mcp_url, custom_headers).await; + + client.clone().start().await.unwrap(); + + wait_for_n_requests(&mock_server, 2, None).await; + + InitializedClient { + client, + mcp_url, + mock_server, + } + } +} + +// Custom responder for SSE with 10 ping messages +struct SsePingResponder; + +// Test handler +pub struct TestClientHandler { + message_history: Arc>>, +} + +impl TestClientHandler { + async fn register_message(&self, message: &MessageFromServer) { + let mut lock = self.message_history.write().await; + lock.push(message.clone()); + } +} + +#[async_trait] +impl ClientHandler for TestClientHandler { + async fn handle_ping_request( + &self, + request: PingRequest, + runtime: &dyn McpClient, + ) -> std::result::Result { + self.register_message(&request.into()).await; + + Ok(rust_mcp_schema::Result { + meta: Some(json!({"meta_number":1515}).as_object().unwrap().to_owned()), + extra: None, + }) + } +} diff --git a/crates/rust-mcp-sdk/tests/common/test_server.rs b/crates/rust-mcp-sdk/tests/common/test_server.rs index aa8e2fb..769f8c6 100644 --- a/crates/rust-mcp-sdk/tests/common/test_server.rs +++ b/crates/rust-mcp-sdk/tests/common/test_server.rs @@ -1,30 +1,30 @@ #[cfg(feature = "hyper-server")] pub mod test_server_common { + use crate::common::sample_tools::SayHelloTool; use async_trait::async_trait; use rust_mcp_schema::schema_utils::CallToolError; use rust_mcp_schema::{ CallToolRequest, CallToolResult, ListToolsRequest, ListToolsResult, ProtocolVersion, RpcError, }; + use rust_mcp_sdk::id_generator::IdGenerator; use rust_mcp_sdk::mcp_server::hyper_runtime::HyperRuntime; - use tokio_stream::StreamExt; - use rust_mcp_sdk::schema::{ ClientCapabilities, Implementation, InitializeRequest, InitializeRequestParams, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, }; use rust_mcp_sdk::{ - mcp_server::{hyper_server, HyperServer, HyperServerOptions, IdGenerator, ServerHandler}, + mcp_server::{hyper_server, HyperServer, HyperServerOptions, ServerHandler}, McpServer, SessionId, }; - use std::sync::RwLock; + use std::sync::{Arc, RwLock}; use std::time::Duration; use tokio::time::timeout; - - use crate::common::sample_tools::SayHelloTool; + use tokio_stream::StreamExt; pub const INITIALIZE_REQUEST: &str = r#"{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{"sampling":{},"roots":{"listChanged":true}},"clientInfo":{"name":"reqwest-test","version":"0.1.0"}}}"#; pub const PING_REQUEST: &str = r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#; + pub const INITIALIZE_RESPONSE: &str = r#"{"result":{"protocolVersion":"2025-06-18","capabilities":{"prompts":{},"resources":{"subscribe":true},"tools":{},"logging":{}},"serverInfo":{"name":"example-servers/everything","version":"1.0.0"}},"jsonrpc":"2.0","id":0}"#; pub struct LaunchedServer { pub hyper_runtime: HyperRuntime, @@ -71,16 +71,10 @@ pub mod test_server_common { #[async_trait] impl ServerHandler for TestServerHandler { - async fn on_server_started(&self, runtime: &dyn McpServer) { - let _ = runtime - .stderr_message("Server started successfully".into()) - .await; - } - async fn handle_list_tools_request( &self, request: ListToolsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; @@ -94,7 +88,7 @@ pub mod test_server_common { async fn handle_call_tool_request( &self, request: CallToolRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime .assert_server_request_capabilities(request.method()) @@ -156,14 +150,17 @@ pub mod test_server_common { } } - impl IdGenerator for TestIdGenerator { - fn generate(&self) -> SessionId { + impl IdGenerator for TestIdGenerator + where + T: From, + { + fn generate(&self) -> T { let mut lock = self.generated.write().unwrap(); *lock += 1; if *lock > self.constant_ids.len() { *lock = 1; } - self.constant_ids[*lock - 1].to_owned() + T::from(self.constant_ids[*lock - 1].to_owned()) } } diff --git a/crates/rust-mcp-sdk/tests/test_protocol_compatibility.rs b/crates/rust-mcp-sdk/tests/test_protocol_compatibility.rs index 5c184cf..9f2fd95 100644 --- a/crates/rust-mcp-sdk/tests/test_protocol_compatibility.rs +++ b/crates/rust-mcp-sdk/tests/test_protocol_compatibility.rs @@ -30,7 +30,7 @@ mod protocol_compatibility_on_server { ); handler - .handle_initialize_request(InitializeRequest::new(initialize_request), &runtime) + .handle_initialize_request(InitializeRequest::new(initialize_request), runtime) .await } diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs b/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs new file mode 100644 index 0000000..a0a2804 --- /dev/null +++ b/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs @@ -0,0 +1,823 @@ +#[path = "common/common.rs"] +pub mod common; + +use common::test_client_common::create_client; +use hyper::{Method, StatusCode}; +use rust_mcp_schema::{ + schema_utils::{ + ClientJsonrpcRequest, ClientMessage, MessageFromServer, RequestFromClient, + RequestFromServer, ResultFromServer, RpcMessage, ServerMessage, + }, + RequestId, ServerRequest, ServerResult, +}; +use rust_mcp_sdk::{ + error::McpSdkError, mcp_server::HyperServerOptions, McpClient, TransportError, + MCP_LAST_EVENT_ID_HEADER, +}; +use serde_json::{json, Value}; +use std::{collections::HashMap, str::FromStr, sync::Arc, time::Duration}; +use wiremock::{ + http::{HeaderName, HeaderValue}, + matchers::{body_json_string, header, method, path}, + Mock, MockServer, ResponseTemplate, +}; + +use crate::common::{ + create_sse_response, debug_wiremock, random_port, + test_client_common::{ + initialize_client, InitializedClient, INITIALIZE_REQUEST, TEST_SESSION_ID, + }, + test_server_common::{ + create_start_server, LaunchedServer, TestIdGenerator, INITIALIZE_RESPONSE, + }, + wait_for_n_requests, wiremock_request, MockBuilder, SimpleMockServer, SseEvent, +}; + +// should send JSON-RPC messages via POST +#[tokio::test] +async fn should_send_json_rpc_messages_via_post() { + // Start a mock server + let mock_server = MockServer::start().await; + + // intialize response + let response = create_sse_response(INITIALIZE_RESPONSE); + + // initialize request and response + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string(INITIALIZE_REQUEST)) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // receive initialized notification + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string( + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + )) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + let mcp_url = format!("{}/mcp", mock_server.uri()); + let (client, _) = create_client(&mcp_url, None).await; + + client.clone().start().await.unwrap(); + + let received_request = wiremock_request(&mock_server, 0).await; + let header_values = received_request + .headers + .get(&HeaderName::from_str("accept").unwrap()) + .unwrap(); + + assert!(header_values.contains(&HeaderValue::from_str("application/json").unwrap())); + assert!(header_values.contains(&HeaderValue::from_str("text/event-stream").unwrap())); + + wait_for_n_requests(&mock_server, 2, None).await; +} + +// should send batch messages +#[tokio::test] +async fn should_send_batch_messages() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(None, None).await; + + let response = create_sse_response( + r#"[{"id":"id1","jsonrpc":"2.0", "result":{}},{"id":"id2","jsonrpc":"2.0", "result":{}}]"#, + ); + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(response) + // .expect(1) + .mount(&mock_server) + .await; + + let message_1: ClientMessage = ClientJsonrpcRequest::new( + RequestId::String("id1".to_string()), + RequestFromClient::CustomRequest(json!({"method": "test1", "params": {}})), + ) + .into(); + let message_2: ClientMessage = ClientJsonrpcRequest::new( + RequestId::String("id2".to_string()), + RequestFromClient::CustomRequest(json!({"method": "test2", "params": {}})), + ) + .into(); + + let result = client + .send_batch(vec![message_1, message_2], None) + .await + .unwrap() + .unwrap(); + + // two results for two requests + assert_eq!(result.len(), 2); + assert!(result.iter().all(|r| { + let id = r.request_id().unwrap(); + id == RequestId::String("id1".to_string()) || id == RequestId::String("id2".to_string()) + })); + + // not an Error + assert!(result + .iter() + .all(|r| matches!(r, ServerMessage::Response(_)))); + + // debug_wiremock(&mock_server).await; +} + +// should store session ID received during initialization +#[tokio::test] +async fn should_store_session_id_received_during_initialization() { + // Start a mock server + let mock_server = MockServer::start().await; + + // intialize response + let response = + create_sse_response(INITIALIZE_RESPONSE).append_header("mcp-session-id", "test-session-id"); + + // initialize request and response + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string(INITIALIZE_REQUEST)) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // receive initialized notification + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string( + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + )) + .and(header("mcp-session-id", "test-session-id")) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + let mcp_url = format!("{}/mcp", mock_server.uri()); + let (client, _) = create_client(&mcp_url, None).await; + + client.clone().start().await.unwrap(); + + let received_request = wiremock_request(&mock_server, 0).await; + let header_values = received_request + .headers + .get(&HeaderName::from_str("accept").unwrap()) + .unwrap(); + + assert!(header_values.contains(&HeaderValue::from_str("application/json").unwrap())); + assert!(header_values.contains(&HeaderValue::from_str("text/event-stream").unwrap())); + + wait_for_n_requests(&mock_server, 2, None).await; +} + +// should terminate session with DELETE request +#[tokio::test] +async fn should_terminate_session_with_delete_request() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(Some(TEST_SESSION_ID.to_string()), None).await; + + Mock::given(method("DELETE")) + .and(path("/mcp")) + .and(header("mcp-session-id", "test-session-id")) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + client.terminate_session().await; +} + +// should handle 405 response when server doesn't support session termination +#[tokio::test] +async fn should_handle_405_unsupported_session_termination() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(Some(TEST_SESSION_ID.to_string()), None).await; + + Mock::given(method("DELETE")) + .and(path("/mcp")) + .and(header("mcp-session-id", "test-session-id")) + .respond_with(ResponseTemplate::new(405)) + .expect(1) + .mount(&mock_server) + .await; + + client.terminate_session().await; +} + +// should handle 404 response when session expires +#[tokio::test] +async fn should_handle_404_response_when_session_expires() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(Some(TEST_SESSION_ID.to_string()), None).await; + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(ResponseTemplate::new(404)) + .expect(1) + .mount(&mock_server) + .await; + + let result = client.ping(None).await; + + matches!( + result, + Err(McpSdkError::Transport(TransportError::SessionExpired)) + ); +} + +// should handle non-streaming JSON response +#[tokio::test] +async fn should_handle_non_streaming_json_response() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(Some(TEST_SESSION_ID.to_string()), None).await; + + let response = ResponseTemplate::new(200) + .set_body_json(json!({ + "id":1,"jsonrpc":"2.0", "result":{"something":"good"} + })) + .insert_header("Content-Type", "application/json"); + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + let request = RequestFromClient::CustomRequest(json!({"method": "test1", "params": {}})); + + let result = client.request(request, None).await.unwrap(); + + let ResultFromServer::ServerResult(ServerResult::Result(result)) = result else { + panic!("Wrong result variant!") + }; + + let extra = result.extra.unwrap(); + assert_eq!(extra.get("something").unwrap(), "good"); +} + +// should handle successful initial GET connection for SSE +#[tokio::test] +async fn should_handle_successful_initial_get_connection_for_sse() { + // Start a mock server + let mock_server = MockServer::start().await; + + // intialize response + let response = create_sse_response(INITIALIZE_RESPONSE); + + // initialize request and response + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string(INITIALIZE_REQUEST)) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // receive initialized notification + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string( + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + )) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + // let payload = r#"{"jsonrpc": "2.0", "method": "serverNotification", "params": {}}"#; + // + let mut body = String::new(); + body.push_str(&"data: Connection established\n\n".to_string()); + + let response = ResponseTemplate::new(200) + .set_body_raw(body.into_bytes(), "text/event-stream") + .append_header("Connection", "keep-alive"); + + // Mount the mock for a GET request + Mock::given(method("GET")) + .and(path("/mcp")) + .respond_with(response) + .mount(&mock_server) + .await; + + let mcp_url = format!("{}/mcp", mock_server.uri()); + let (client, _) = create_client(&mcp_url, None).await; + + client.clone().start().await.unwrap(); + + let requests = mock_server.received_requests().await.unwrap(); + let get_request = requests + .iter() + .find(|r| r.method == wiremock::http::Method::Get); + + assert!(get_request.is_some()) +} + +#[tokio::test] +async fn should_receive_server_initiated_messaged() { + let server_options = HyperServerOptions { + port: random_port(), + session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ + "AAA-BBB-CCC".to_string() + ]))), + enable_json_response: Some(false), + ..Default::default() + }; + let LaunchedServer { + hyper_runtime, + streamable_url, + sse_url, + sse_message_url, + } = create_start_server(server_options).await; + + let (client, message_history) = create_client(&streamable_url, None).await; + + client.clone().start().await.unwrap(); + + tokio::time::sleep(Duration::from_secs(1)).await; + + let result = hyper_runtime + .ping(&"AAA-BBB-CCC".to_string(), None) + .await + .unwrap(); + + let lock = message_history.read().await; + let ping_request = lock + .iter() + .find(|m| { + matches!( + m, + MessageFromServer::RequestFromServer(RequestFromServer::ServerRequest( + ServerRequest::PingRequest(_) + )) + ) + }) + .unwrap(); + let MessageFromServer::RequestFromServer(RequestFromServer::ServerRequest( + ServerRequest::PingRequest(_), + )) = ping_request + else { + panic!("Request is not a match!") + }; + assert!(result.meta.is_some()); + + let v = result.meta.unwrap().get("meta_number").unwrap().clone(); + + assert!(matches!(v, Value::Number(value) if value.as_i64().unwrap()==1515)) //1515 is passed from TestClientHandler +} + +// should attempt initial GET connection and handle 405 gracefully +#[tokio::test] +async fn should_attempt_initial_get_connection_and_handle_405_gracefully() { + // Start a mock server + let mock_server = MockServer::start().await; + + // intialize response + let response = create_sse_response(INITIALIZE_RESPONSE); + + // initialize request and response + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string(INITIALIZE_REQUEST)) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // Mount the mock for a GET request + Mock::given(method("GET")) + .and(path("/mcp")) + .respond_with(ResponseTemplate::new(405)) + .mount(&mock_server) + .await; + + // receive initialized notification + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string( + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + )) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + // let payload = r#"{"jsonrpc": "2.0", "method": "serverNotification", "params": {}}"#; + // + let mut body = String::new(); + body.push_str(&"data: Connection established\n\n".to_string()); + + let response = ResponseTemplate::new(405) + .set_body_raw(body.into_bytes(), "text/event-stream") + .append_header("Connection", "keep-alive"); + + let mcp_url = format!("{}/mcp", mock_server.uri()); + let (client, _) = create_client(&mcp_url, None).await; + + client.clone().start().await.unwrap(); + + let requests = mock_server.received_requests().await.unwrap(); + let get_request = requests + .iter() + .find(|r| r.method == wiremock::http::Method::Get); + + assert!(get_request.is_some()); + + // send a batch message, runtime should work as expected with no isse + + let response = create_sse_response( + r#"[{"id":"id1","jsonrpc":"2.0", "result":{}},{"id":"id2","jsonrpc":"2.0", "result":{}}]"#, + ); + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(response) + // .expect(1) + .mount(&mock_server) + .await; + + let message_1: ClientMessage = ClientJsonrpcRequest::new( + RequestId::String("id1".to_string()), + RequestFromClient::CustomRequest(json!({"method": "test1", "params": {}})), + ) + .into(); + let message_2: ClientMessage = ClientJsonrpcRequest::new( + RequestId::String("id2".to_string()), + RequestFromClient::CustomRequest(json!({"method": "test2", "params": {}})), + ) + .into(); + + let result = client + .send_batch(vec![message_1, message_2], None) + .await + .unwrap() + .unwrap(); + + // two results for two requests + assert_eq!(result.len(), 2); + assert!(result.iter().all(|r| { + let id = r.request_id().unwrap(); + id == RequestId::String("id1".to_string()) || id == RequestId::String("id2".to_string()) + })); +} + +// should handle multiple concurrent SSE streams +#[tokio::test] +async fn should_handle_multiple_concurrent_sse_streams() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(None, None).await; + + let message_1: ClientMessage = ClientJsonrpcRequest::new( + RequestId::String("id1".to_string()), + RequestFromClient::CustomRequest(json!({"method": "test1", "params": {}})), + ) + .into(); + let message_2: ClientMessage = ClientJsonrpcRequest::new( + RequestId::String("id2".to_string()), + RequestFromClient::CustomRequest(json!({"method": "test2", "params": {}})), + ) + .into(); + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(|req: &wiremock::Request| { + let body_string = String::from_utf8(req.body.clone()).unwrap(); + if body_string.contains("test3") { + create_sse_response(r#"{"id":1,"jsonrpc":"2.0", "result":{}}"#) + } else { + create_sse_response( + r#"[{"id":"id1","jsonrpc":"2.0", "result":{}},{"id":"id2","jsonrpc":"2.0", "result":{}}]"#, + ) + } + }) + .expect(2) + .mount(&mock_server) + .await; + + let message_3 = RequestFromClient::CustomRequest(json!({"method": "test3", "params": {}})); + let request1 = client.send_batch(vec![message_1, message_2], None); + let request2 = client.send(message_3.into(), None, None); + + // Run them concurrently and wait for both + let (res_batch, res_single) = tokio::join!(request1, request2); + + let res_batch = res_batch.unwrap().unwrap(); + // two results for two requests in the batch + assert_eq!(res_batch.len(), 2); + assert!(res_batch.iter().all(|r| { + let id = r.request_id().unwrap(); + id == RequestId::String("id1".to_string()) || id == RequestId::String("id2".to_string()) + })); + + // not an Error + assert!(res_batch + .iter() + .all(|r| matches!(r, ServerMessage::Response(_)))); + + let res_single = res_single.unwrap().unwrap(); + let ServerMessage::Response(res_single) = res_single else { + panic!("invalid respinse type, expected Result!") + }; + + assert!(matches!(res_single.id, RequestId::Integer(id) if id==1)); +} + +// should throw error when invalid content-type is received +#[tokio::test] +async fn should_throw_error_when_invalid_content_type_is_received() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(None, None).await; + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(ResponseTemplate::new(200).set_body_raw( + r#"{"id":0,"jsonrpc":"2.0", "result":{}}"#.to_string().into_bytes(), + "text/plain", + )) + .expect(1) + .mount(&mock_server) + .await; + + let result = client.ping(None).await; + + let Err(McpSdkError::Transport(TransportError::UnexpectedContentType(content_type))) = result + else { + panic!("Expected a TransportError::UnexpectedContentType error!"); + }; + + assert_eq!(content_type, "text/plain"); +} + +// should always send specified custom headers +#[tokio::test] +async fn should_always_send_specified_custom_headers() { + let mut headers = HashMap::new(); + headers.insert("X-Custom-Header".to_string(), "CustomValue".to_string()); + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(None, Some(headers)).await; + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(ResponseTemplate::new(200).set_body_raw( + r#"{"id":1,"jsonrpc":"2.0", "result":{}}"#.to_string().into_bytes(), + "application/json", + )) + .expect(1) + .mount(&mock_server) + .await; + + let _result = client.ping(None).await; + + let requests = mock_server.received_requests().await.unwrap(); + + assert_eq!(requests.len(), 4); + assert!(requests + .iter() + .all(|r| r.headers.get(&"X-Custom-Header".into()).unwrap().as_str() == "CustomValue")); + + debug_wiremock(&mock_server).await +} + +// should reconnect a GET-initiated notification stream that fails + +#[tokio::test] +async fn should_reconnect_a_get_initiated_notification_stream_that_fails() { + // Start a mock server + let mock_server = MockServer::start().await; + + // intialize response + let response = create_sse_response(INITIALIZE_RESPONSE); + + // initialize request and response + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string(INITIALIZE_REQUEST)) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // two GET Mock, each expects one call , first time it fails, second retry it succeeds + let response = ResponseTemplate::new(502) + .set_body_raw("".to_string().into_bytes(), "text/event-stream") + .append_header("Connection", "keep-alive"); + + // Mount the mock for a GET request + Mock::given(method("GET")) + .and(path("/mcp")) + .respond_with(response) + .expect(1) + .up_to_n_times(1) + .mount(&mock_server) + .await; + + let response = ResponseTemplate::new(200) + .set_body_raw( + "data: Connection established\n\n".to_string().into_bytes(), + "text/event-stream", + ) + .append_header("Connection", "keep-alive"); + Mock::given(method("GET")) + .and(path("/mcp")) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // receive initialized notification + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string( + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + )) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + let mcp_url = format!("{}/mcp", mock_server.uri()); + let (client, _) = create_client(&mcp_url, None).await; + + client.clone().start().await.unwrap(); +} + +//****************** Resumability ****************** +// should pass lastEventId when reconnecting +#[tokio::test] +async fn should_pass_last_event_id_when_reconnecting() { + let msg = r#"{"jsonrpc":"2.0","method":"notifications/message","params":{"data":{},"level":"debug"}}"#; + + let mocks = vec![ + MockBuilder::new_sse(Method::POST, "/mcp".to_string(), INITIALIZE_RESPONSE).build(), + MockBuilder::new_breakable_sse( + Method::GET, + "/mcp".to_string(), + SseEvent { + data: Some(msg.into()), + event: Some("message".to_string()), + id: None, + }, + Duration::from_millis(100), + 5, + ) + .expect(2) + .build(), + MockBuilder::new_sse( + Method::POST, + "/mcp".to_string(), + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + ) + .build(), + ]; + + let (url, handle) = SimpleMockServer::start_with_mocks(mocks).await; + let mcp_url = format!("{url}/mcp"); + + let mut headers = HashMap::new(); + headers.insert("X-Custom-Header".to_string(), "CustomValue".to_string()); + let (client, _) = create_client(&mcp_url, Some(headers)).await; + + client.clone().start().await.unwrap(); + + assert!(client.is_initialized()); + + // give it time for re-connection + tokio::time::sleep(Duration::from_secs(2)).await; + + let request_history = handle.get_history().await; + + let get_requests: Vec<_> = request_history + .iter() + .filter(|r| r.0.method == Method::GET) + .collect(); + + // there should be more than one GET reueat, indicating reconnection + assert!(get_requests.len() > 1); + + let Some(last_get_request) = get_requests.last() else { + panic!("Unable to find last GET reuest!"); + }; + + let last_event_id = last_get_request + .0 + .headers + .get(axum::http::HeaderName::from_static( + MCP_LAST_EVENT_ID_HEADER, + )); + + // last-event-id should be sent + assert!( + matches!(last_event_id, Some(last_event_id) if last_event_id.to_str().unwrap().starts_with("msg-id")) + ); + + // custom headers should be passed for all GET requests + assert!(get_requests.iter().all(|r| r + .0 + .headers + .get(axum::http::HeaderName::from_str("X-Custom-Header").unwrap()) + .unwrap() + .to_str() + .unwrap() + == "CustomValue")); + + println!("last_event_id {:?} ", last_event_id.unwrap()); +} + +// should NOT reconnect a POST-initiated stream that fails +#[tokio::test] +async fn should_not_reconnect_a_post_initiated_stream_that_fails() { + let mocks = vec![ + MockBuilder::new_sse(Method::POST, "/mcp".to_string(), INITIALIZE_RESPONSE) + .expect(1) + .build(), + MockBuilder::new_sse(Method::GET, "/mcp".to_string(), "".to_string()) + .with_status(StatusCode::METHOD_NOT_ALLOWED) + .build(), + MockBuilder::new_sse( + Method::POST, + "/mcp".to_string(), + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + ) + .expect(1) + .build(), + MockBuilder::new_breakable_sse( + Method::POST, + "/mcp".to_string(), + SseEvent { + data: Some("msg".to_string()), + event: None, + id: None, + }, + Duration::ZERO, + 0, + ) + .build(), + ]; + + let (url, handle) = SimpleMockServer::start_with_mocks(mocks).await; + let mcp_url = format!("{url}/mcp"); + + let mut headers = HashMap::new(); + headers.insert("X-Custom-Header".to_string(), "CustomValue".to_string()); + let (client, _) = create_client(&mcp_url, Some(headers)).await; + + client.clone().start().await.unwrap(); + + assert!(client.is_initialized()); + + let result = client.send_roots_list_changed(None).await; + + assert!(result.is_err()); + + tokio::time::sleep(Duration::from_secs(2)).await; + + let request_history = handle.get_history().await; + let post_requests: Vec<_> = request_history + .iter() + .filter(|r| r.0.method == Method::POST) + .collect(); + assert_eq!(post_requests.len(), 3); // initialize, initialized, root_list_changed +} + +//****************** Auth ****************** +// attempts auth flow on 401 during POST request +// invalidates all credentials on InvalidClientError during auth +// invalidates all credentials on UnauthorizedClientError during auth +//invalidates tokens on InvalidGrantError during auth + +//****************** Others ****************** +// custom fetch in auth code paths +// should support custom reconnection options +// uses custom fetch implementation if provided +// should have exponential backoff with configurable maxRetries diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http.rs b/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs similarity index 99% rename from crates/rust-mcp-sdk/tests/test_streamable_http.rs rename to crates/rust-mcp-sdk/tests/test_streamable_http_server.rs index 23ca27f..4809d6d 100644 --- a/crates/rust-mcp-sdk/tests/test_streamable_http.rs +++ b/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs @@ -8,13 +8,12 @@ use rust_mcp_schema::{ SdkErrorCodes, ServerJsonrpcNotification, ServerJsonrpcRequest, ServerJsonrpcResponse, ServerMessages, }, - CallToolRequest, CallToolRequestParams, ListPromptsRequestParams, ListRootsRequestParams, - ListRootsResult, ListToolsRequest, LoggingLevel, LoggingMessageNotificationParams, RequestId, - RootsListChangedNotification, ServerNotification, ServerRequest, ServerResult, + CallToolRequest, CallToolRequestParams, ListRootsResult, ListToolsRequest, LoggingLevel, + LoggingMessageNotificationParams, RequestId, RootsListChangedNotification, ServerNotification, + ServerRequest, ServerResult, }; use rust_mcp_sdk::mcp_server::HyperServerOptions; use serde_json::{json, Map, Value}; -use tokio_stream::StreamExt; use crate::common::{ random_port, read_sse_event, read_sse_event_from_stream, send_delete_request, send_get_request, diff --git a/crates/rust-mcp-transport/Cargo.toml b/crates/rust-mcp-transport/Cargo.toml index ec061bb..2f03580 100644 --- a/crates/rust-mcp-transport/Cargo.toml +++ b/crates/rust-mcp-transport/Cargo.toml @@ -42,10 +42,12 @@ workspace = true ### FEATURES ################################################################# [features] -default = ["stdio", "sse", "2025_06_18"] # Default features +default = ["stdio", "sse", "streamable-http", "2025_06_18"] # Default features stdio = [] sse = ["reqwest"] +streamable-http = ["reqwest"] + # enabled mcp protocol version 2025_06_18 2025_06_18 = ["rust-mcp-schema/2025_06_18", "rust-mcp-schema/schema_utils"] diff --git a/crates/rust-mcp-transport/README.md b/crates/rust-mcp-transport/README.md index 23b78bf..30cad83 100644 --- a/crates/rust-mcp-transport/README.md +++ b/crates/rust-mcp-transport/README.md @@ -14,7 +14,7 @@ let transport = StdioTransport::new(TransportOptions { timeout: 60_000 })?; ``` -Refer to the [Hello World MCP Server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server) example for a complete demonstration. +Refer to the [Hello World MCP Server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio) example for a complete demonstration. ### For MCP Client @@ -51,7 +51,7 @@ let transport = StdioTransport::create_with_server_launch( )?; ``` -Refer to the [Simple MCP Client](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client) example for a complete demonstration. +Refer to the [Simple MCP Client](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio) example for a complete demonstration. --- diff --git a/crates/rust-mcp-transport/src/client_sse.rs b/crates/rust-mcp-transport/src/client_sse.rs index f201aa0..8d55bd0 100644 --- a/crates/rust-mcp-transport/src/client_sse.rs +++ b/crates/rust-mcp-transport/src/client_sse.rs @@ -5,7 +5,7 @@ use crate::transport::Transport; use crate::utils::{ extract_origin, http_post, CancellationTokenSource, ReadableChannel, SseStream, WritableChannel, }; -use crate::{IoStream, McpDispatch, TransportOptions}; +use crate::{IoStream, McpDispatch, TransportDispatcher, TransportOptions}; use async_trait::async_trait; use bytes::Bytes; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; @@ -13,8 +13,13 @@ use reqwest::Client; use tokio::sync::oneshot::Sender; use tokio::task::JoinHandle; -use crate::schema::schema_utils::McpMessage; -use crate::schema::RequestId; +use crate::schema::{ + schema_utils::{ + ClientMessage, ClientMessages, McpMessage, MessageFromClient, SdkError, ServerMessage, + ServerMessages, + }, + RequestId, +}; use std::cmp::Ordering; use std::collections::HashMap; use std::pin::Pin; @@ -25,7 +30,7 @@ use tokio::sync::{mpsc, oneshot, Mutex}; const DEFAULT_CHANNEL_CAPACITY: usize = 64; const DEFAULT_MAX_RETRY: usize = 5; -const DEFAULT_RETRY_TIME_SECONDS: u64 = 3; +const DEFAULT_RETRY_TIME_SECONDS: u64 = 1; const SHUTDOWN_TIMEOUT_SECONDS: u64 = 5; /// Configuration options for the Client SSE Transport @@ -102,10 +107,9 @@ where let base_url = match extract_origin(server_url) { Some(url) => url, None => { - let error_message = - format!("Failed to extract origin from server URL: {server_url}"); - tracing::error!(error_message); - return Err(TransportError::InvalidOptions(error_message)); + let message = format!("Failed to extract origin from server URL: {server_url}"); + tracing::error!(message); + return Err(TransportError::Configuration { message }); } }; @@ -145,12 +149,15 @@ where let mut header_map = HeaderMap::new(); for (key, value) in headers { - let header_name = key - .parse::() - .map_err(|e| TransportError::InvalidOptions(format!("Invalid header name: {e}")))?; - let header_value = HeaderValue::from_str(value).map_err(|e| { - TransportError::InvalidOptions(format!("Invalid header value: {e}")) - })?; + let header_name = + key.parse::() + .map_err(|e| TransportError::Configuration { + message: format!("Invalid header name: {e}"), + })?; + let header_value = + HeaderValue::from_str(value).map_err(|e| TransportError::Configuration { + message: format!("Invalid header value: {e}"), + })?; header_map.insert(header_name, header_value); } @@ -172,10 +179,12 @@ where } if let Some(endpoint_origin) = extract_origin(&endpoint) { if endpoint_origin.cmp(&self.base_url) != Ordering::Equal { - return Err(TransportError::InvalidOptions(format!( + return Err(TransportError::Configuration { + message: format!( "Endpoint origin does not match connection origin. expected: {} , received: {}", self.base_url, endpoint_origin - ))); + ), + }); } return Ok(endpoint); } @@ -284,8 +293,8 @@ where Some(data) => { // trim the trailing \n before making a request let body = String::from_utf8_lossy(&data).trim().to_string(); - if let Err(e) = http_post(&client_clone, &post_url, body, &custom_headers).await { - tracing::error!("Failed to POST message: {e:?}"); + if let Err(e) = http_post(&client_clone, &post_url, body,None, custom_headers.as_ref()).await { + tracing::error!("Failed to POST message: {e}"); } }, None => break, // Exit if channel is closed @@ -335,7 +344,7 @@ where } async fn consume_string_payload(&self, _payload: &str) -> TransportResult<()> { - Err(TransportError::FromString( + Err(TransportError::Internal( "Invalid invocation of consume_string_payload() function for ClientSseTransport" .to_string(), )) @@ -346,7 +355,7 @@ where _: Duration, _: oneshot::Sender<()>, ) -> TransportResult> { - Err(TransportError::FromString( + Err(TransportError::Internal( "Invalid invocation of keep_alive() function for ClientSseTransport".to_string(), )) } @@ -413,3 +422,55 @@ where pending_requests.remove(request_id) } } + +#[async_trait] +impl McpDispatch + for ClientSseTransport +{ + async fn send_message( + &self, + message: ClientMessages, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send_message(message, request_timeout).await + } + + async fn send( + &self, + message: ClientMessage, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send(message, request_timeout).await + } + + async fn send_batch( + &self, + message: Vec, + request_timeout: Option, + ) -> TransportResult>> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send_batch(message, request_timeout).await + } + + async fn write_str(&self, payload: &str) -> TransportResult<()> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.write_str(payload).await + } +} + +impl + TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + > for ClientSseTransport +{ +} diff --git a/crates/rust-mcp-transport/src/client_streamable_http.rs b/crates/rust-mcp-transport/src/client_streamable_http.rs new file mode 100644 index 0000000..c318649 --- /dev/null +++ b/crates/rust-mcp-transport/src/client_streamable_http.rs @@ -0,0 +1,515 @@ +use crate::error::TransportError; +use crate::mcp_stream::MCPStream; + +use crate::schema::{ + schema_utils::{ + ClientMessage, ClientMessages, McpMessage, MessageFromClient, SdkError, ServerMessage, + ServerMessages, + }, + RequestId, +}; +use crate::utils::{ + http_delete, http_post, CancellationTokenSource, ReadableChannel, StreamableHttpStream, + WritableChannel, +}; +use crate::{error::TransportResult, IoStream, McpDispatch, MessageDispatcher, Transport}; +use crate::{SessionId, TransportDispatcher, TransportOptions}; +use async_trait::async_trait; +use bytes::Bytes; +use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; +use reqwest::Client; +use std::collections::HashMap; +use std::pin::Pin; +use std::{sync::Arc, time::Duration}; +use tokio::io::{BufReader, BufWriter}; +use tokio::sync::oneshot::Sender; +use tokio::sync::{mpsc, oneshot, Mutex}; +use tokio::task::JoinHandle; + +const DEFAULT_CHANNEL_CAPACITY: usize = 64; +const DEFAULT_MAX_RETRY: usize = 5; +const DEFAULT_RETRY_TIME_SECONDS: u64 = 1; +const SHUTDOWN_TIMEOUT_SECONDS: u64 = 5; + +pub struct StreamableTransportOptions { + pub mcp_url: String, + pub request_options: RequestOptions, +} + +impl StreamableTransportOptions { + pub async fn terminate_session(&self, session_id: Option<&SessionId>) { + let client = Client::new(); + match http_delete(&client, &self.mcp_url, session_id, None).await { + Ok(_) => {} + Err(TransportError::Http(status_code)) => { + tracing::info!("Session termination failed with status code {status_code}",); + } + Err(error) => { + tracing::info!("Session termination failed with error :{error}"); + } + }; + } +} + +pub struct RequestOptions { + pub request_timeout: Duration, + pub retry_delay: Option, + pub max_retries: Option, + pub custom_headers: Option>, +} + +impl Default for RequestOptions { + fn default() -> Self { + Self { + request_timeout: TransportOptions::default().timeout, + retry_delay: None, + max_retries: None, + custom_headers: None, + } + } +} + +pub struct ClientStreamableTransport +where + R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, +{ + /// Optional cancellation token source for shutting down the transport + shutdown_source: tokio::sync::RwLock>, + /// Flag indicating if the transport is shut down + is_shut_down: Mutex, + /// Timeout duration for MCP messages + request_timeout: Duration, + /// HTTP client for making requests + client: Client, + /// URL for the SSE endpoint + mcp_server_url: String, + /// Delay between retry attempts + retry_delay: Duration, + /// Maximum number of retry attempts + max_retries: usize, + /// Optional custom HTTP headers + custom_headers: Option, + sse_task: tokio::sync::RwLock>>, + post_task: tokio::sync::RwLock>>, + message_sender: Arc>>>, + error_stream: tokio::sync::RwLock>, + pending_requests: Arc>>>, + session_id: Arc>>, + standalone: bool, +} + +impl ClientStreamableTransport +where + R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, +{ + pub fn new( + options: &StreamableTransportOptions, + session_id: Option, + standalone: bool, + ) -> TransportResult { + let client = Client::new(); + + let headers = match &options.request_options.custom_headers { + Some(h) => Some(Self::validate_headers(h)?), + None => None, + }; + + let mcp_server_url = options.mcp_url.to_owned(); + Ok(Self { + shutdown_source: tokio::sync::RwLock::new(None), + is_shut_down: Mutex::new(false), + request_timeout: options.request_options.request_timeout, + client, + mcp_server_url, + retry_delay: options + .request_options + .retry_delay + .unwrap_or(Duration::from_secs(DEFAULT_RETRY_TIME_SECONDS)), + max_retries: options + .request_options + .max_retries + .unwrap_or(DEFAULT_MAX_RETRY), + sse_task: tokio::sync::RwLock::new(None), + post_task: tokio::sync::RwLock::new(None), + custom_headers: headers, + message_sender: Arc::new(tokio::sync::RwLock::new(None)), + error_stream: tokio::sync::RwLock::new(None), + pending_requests: Arc::new(Mutex::new(HashMap::new())), + session_id: Arc::new(tokio::sync::RwLock::new(session_id)), + standalone, + }) + } + + fn validate_headers(headers: &HashMap) -> TransportResult { + let mut header_map = HeaderMap::new(); + for (key, value) in headers { + let header_name = + key.parse::() + .map_err(|e| TransportError::Configuration { + message: format!("Invalid header name: {e}"), + })?; + let header_value = + HeaderValue::from_str(value).map_err(|e| TransportError::Configuration { + message: format!("Invalid header value: {e}"), + })?; + header_map.insert(header_name, header_value); + } + Ok(header_map) + } + + pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher) { + let mut lock = self.message_sender.write().await; + *lock = Some(sender); + } + + pub(crate) async fn set_error_stream( + &self, + error_stream: Pin>, + ) { + let mut lock = self.error_stream.write().await; + *lock = Some(IoStream::Readable(error_stream)); + } +} + +#[async_trait] +impl Transport for ClientStreamableTransport +where + R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static, + M: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + OR: Clone + Send + Sync + serde::Serialize + 'static, + OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, +{ + async fn start(&self) -> TransportResult> + where + MessageDispatcher: McpDispatch, + { + if self.standalone { + // Create CancellationTokenSource and token + let (cancellation_source, cancellation_token) = CancellationTokenSource::new(); + let mut lock = self.shutdown_source.write().await; + *lock = Some(cancellation_source); + + let (write_tx, mut write_rx) = mpsc::channel::(DEFAULT_CHANNEL_CAPACITY); + let (read_tx, read_rx) = mpsc::channel::(DEFAULT_CHANNEL_CAPACITY); + + let max_retries = self.max_retries; + let retry_delay = self.retry_delay; + + let post_url = self.mcp_server_url.clone(); + let custom_headers = self.custom_headers.clone(); + let cancellation_token_post = cancellation_token.clone(); + let cancellation_token_sse = cancellation_token.clone(); + + let session_id_clone = self.session_id.clone(); + + let mut streamable_http = StreamableHttpStream { + client: self.client.clone(), + mcp_url: post_url, + max_retries, + retry_delay, + read_tx, + session_id: session_id_clone, //Arc>> + }; + + let session_id = self.session_id.read().await.to_owned(); + + let sse_response = streamable_http + .make_standalone_stream_connection(&cancellation_token_sse, &custom_headers, None) + .await?; + + let sse_task_handle = tokio::spawn(async move { + if let Err(error) = streamable_http + .run_standalone(&cancellation_token_sse, &custom_headers, sse_response) + .await + { + if !matches!(error, TransportError::Cancelled(_)) { + tracing::warn!("{error}"); + } + } + }); + + let mut sse_task_lock = self.sse_task.write().await; + *sse_task_lock = Some(sse_task_handle); + + let post_url = self.mcp_server_url.clone(); + let client = self.client.clone(); + let custom_headers = self.custom_headers.clone(); + + // Initiate a task to process POST requests from messages received via the writable stream. + let post_task_handle = tokio::spawn(async move { + loop { + tokio::select! { + _ = cancellation_token_post.cancelled() => + { + break; + }, + data = write_rx.recv() => { + match data{ + Some(data) => { + // trim the trailing \n before making a request + let payload = String::from_utf8_lossy(&data).trim().to_string(); + + if let Err(e) = http_post( + &client, + &post_url, + payload.to_string(), + session_id.as_ref(), + custom_headers.as_ref(), + ) + .await{ + tracing::error!("Failed to POST message: {e}") + } + }, + None => break, // Exit if channel is closed + } + } + } + } + }); + let mut post_task_lock = self.post_task.write().await; + *post_task_lock = Some(post_task_handle); + + // Create writable stream + let writable: Mutex>> = + Mutex::new(Box::pin(BufWriter::new(WritableChannel { write_tx }))); + + // Create readable stream + let readable: Pin> = + Box::pin(BufReader::new(ReadableChannel { + read_rx, + buffer: Bytes::new(), + })); + + let (stream, sender, error_stream) = MCPStream::create( + readable, + writable, + IoStream::Writable(Box::pin(tokio::io::stderr())), + self.pending_requests.clone(), + self.request_timeout, + cancellation_token, + ); + + self.set_message_sender(sender).await; + + if let IoStream::Readable(error_stream) = error_stream { + self.set_error_stream(error_stream).await; + } + Ok(stream) + } else { + // Create CancellationTokenSource and token + let (cancellation_source, cancellation_token) = CancellationTokenSource::new(); + let mut lock = self.shutdown_source.write().await; + *lock = Some(cancellation_source); + + // let (write_tx, mut write_rx) = mpsc::channel::(DEFAULT_CHANNEL_CAPACITY); + let (write_tx, mut write_rx): ( + tokio::sync::mpsc::Sender<( + String, + tokio::sync::oneshot::Sender>, + )>, + tokio::sync::mpsc::Receiver<( + String, + tokio::sync::oneshot::Sender>, + )>, + ) = tokio::sync::mpsc::channel(DEFAULT_CHANNEL_CAPACITY); // Buffer size as needed + let (read_tx, read_rx) = mpsc::channel::(DEFAULT_CHANNEL_CAPACITY); + + let max_retries = self.max_retries; + let retry_delay = self.retry_delay; + + let post_url = self.mcp_server_url.clone(); + let custom_headers = self.custom_headers.clone(); + let cancellation_token_post = cancellation_token.clone(); + let cancellation_token_sse = cancellation_token.clone(); + + let session_id_clone = self.session_id.clone(); + + let mut streamable_http = StreamableHttpStream { + client: self.client.clone(), + mcp_url: post_url, + max_retries, + retry_delay, + read_tx, + session_id: session_id_clone, //Arc>> + }; + + // Initiate a task to process POST requests from messages received via the writable stream. + let post_task_handle = tokio::spawn(async move { + loop { + tokio::select! { + _ = cancellation_token_post.cancelled() => + { + break; + }, + data = write_rx.recv() => { + match data{ + Some((data, ack_tx)) => { + // trim the trailing \n before making a request + let payload = data.trim().to_string(); + let result = streamable_http.run(payload, &cancellation_token_sse, &custom_headers).await; + let _ = ack_tx.send(result);// Ignore error if receiver dropped + }, + None => break, // Exit if channel is closed + } + } + } + } + }); + let mut post_task_lock = self.post_task.write().await; + *post_task_lock = Some(post_task_handle); + + // Create readable stream + let readable: Pin> = + Box::pin(BufReader::new(ReadableChannel { + read_rx, + buffer: Bytes::new(), + })); + + let (stream, sender, error_stream) = MCPStream::create_with_ack( + readable, + write_tx, + IoStream::Writable(Box::pin(tokio::io::stderr())), + self.pending_requests.clone(), + self.request_timeout, + cancellation_token, + ); + + self.set_message_sender(sender).await; + + if let IoStream::Readable(error_stream) = error_stream { + self.set_error_stream(error_stream).await; + } + + Ok(stream) + } + } + + fn message_sender(&self) -> Arc>>> { + self.message_sender.clone() as _ + } + + fn error_stream(&self) -> &tokio::sync::RwLock> { + &self.error_stream as _ + } + async fn shut_down(&self) -> TransportResult<()> { + // Trigger cancellation + let mut cancellation_lock = self.shutdown_source.write().await; + if let Some(source) = cancellation_lock.as_ref() { + source.cancel()?; + } + *cancellation_lock = None; // Clear cancellation_source + + // Mark as shut down + let mut is_shut_down_lock = self.is_shut_down.lock().await; + *is_shut_down_lock = true; + + // Get task handle + let post_task = self.post_task.write().await.take(); + + // // Wait for tasks to complete with a timeout + let timeout = Duration::from_secs(SHUTDOWN_TIMEOUT_SECONDS); + let shutdown_future = async { + if let Some(post_handle) = post_task { + let _ = post_handle.await; + } + Ok::<(), TransportError>(()) + }; + + tokio::select! { + result = shutdown_future => { + result // result of task completion + } + _ = tokio::time::sleep(timeout) => { + tracing::warn!("Shutdown timed out after {:?}", timeout); + Err(TransportError::ShutdownTimeout) + } + } + } + async fn is_shut_down(&self) -> bool { + let result = self.is_shut_down.lock().await; + *result + } + async fn consume_string_payload(&self, _: &str) -> TransportResult<()> { + Err(TransportError::Internal( + "Invalid invocation of consume_string_payload() function for ClientStreamableTransport" + .to_string(), + )) + } + + async fn pending_request_tx(&self, request_id: &RequestId) -> Option> { + let mut pending_requests = self.pending_requests.lock().await; + pending_requests.remove(request_id) + } + + async fn keep_alive( + &self, + _: Duration, + _: oneshot::Sender<()>, + ) -> TransportResult> { + Err(TransportError::Internal( + "Invalid invocation of keep_alive() function for ClientStreamableTransport".to_string(), + )) + } + + async fn session_id(&self) -> Option { + let guard = self.session_id.read().await; + guard.clone() + } +} + +#[async_trait] +impl McpDispatch + for ClientStreamableTransport +{ + async fn send_message( + &self, + message: ClientMessages, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + + sender.send_message(message, request_timeout).await + } + + async fn send( + &self, + message: ClientMessage, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + + sender.send(message, request_timeout).await + } + + async fn send_batch( + &self, + message: Vec, + request_timeout: Option, + ) -> TransportResult>> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send_batch(message, request_timeout).await + } + + async fn write_str(&self, payload: &str) -> TransportResult<()> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.write_str(payload).await + } +} + +impl + TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + > for ClientStreamableTransport +{ +} diff --git a/crates/rust-mcp-transport/src/constants.rs b/crates/rust-mcp-transport/src/constants.rs new file mode 100644 index 0000000..6ae0342 --- /dev/null +++ b/crates/rust-mcp-transport/src/constants.rs @@ -0,0 +1,3 @@ +pub const MCP_SESSION_ID_HEADER: &str = "Mcp-Session-Id"; +pub const MCP_PROTOCOL_VERSION_HEADER: &str = "Mcp-Protocol-Version"; +pub const MCP_LAST_EVENT_ID_HEADER: &str = "last-event-id"; diff --git a/crates/rust-mcp-transport/src/error.rs b/crates/rust-mcp-transport/src/error.rs index 8f8b62f..a244456 100644 --- a/crates/rust-mcp-transport/src/error.rs +++ b/crates/rust-mcp-transport/src/error.rs @@ -1,11 +1,14 @@ use crate::schema::{schema_utils::SdkError, RpcError}; -use thiserror::Error; - use crate::utils::CancellationError; use core::fmt; +#[cfg(any(feature = "sse", feature = "streamable-http"))] +use reqwest::Error as ReqwestError; +#[cfg(any(feature = "sse", feature = "streamable-http"))] +use reqwest::StatusCode; use std::any::Any; +use std::io::Error as IoError; +use thiserror::Error; use tokio::sync::{broadcast, mpsc}; - /// A wrapper around a broadcast send error. This structure allows for generic error handling /// by boxing the underlying error into a type-erased form. #[derive(Debug)] @@ -80,31 +83,53 @@ pub type TransportResult = core::result::Result; #[derive(Debug, Error)] pub enum TransportError { - #[error("{0}")] - InvalidOptions(String), + #[error("Session expired or not found")] + SessionExpired, + + #[error("Failed to open SSE stream: {0}")] + FailedToOpenSSEStream(String), + + #[error("Unexpected content type: '{0}'")] + UnexpectedContentType(String), + + #[error("Failed to send message: {0}")] + SendFailure(String), + + #[error("I/O error: {0}")] + Io(#[from] IoError), + + #[cfg(any(feature = "sse", feature = "streamable-http"))] + #[error("HTTP connection error: {0}")] + HttpConnection(#[from] ReqwestError), + + #[cfg(any(feature = "sse", feature = "streamable-http"))] + #[error("HTTP error: {0}")] + Http(StatusCode), + + #[error("SDK error: {0}")] + Sdk(#[from] SdkError), + + #[error("Operation cancelled: {0}")] + Cancelled(#[from] CancellationError), + + #[error("Channel closed: {0}")] + ChannelClosed(#[from] tokio::sync::oneshot::error::RecvError), + + #[error("Configuration error: {message}")] + Configuration { message: String }, + #[error("{0}")] SendError(#[from] GenericSendError), - #[error("{0}")] - WatchSendError(#[from] GenericWatchSendError), - #[error("Send Error: {0}")] - StdioError(#[from] std::io::Error), + #[error("{0}")] JsonrpcError(#[from] RpcError), - #[error("{0}")] - SdkError(#[from] SdkError), - #[error("Process error{0}")] + + #[error("Process error: {0}")] ProcessError(String), - #[error("{0}")] - FromString(String), - #[error("{0}")] - OneshotRecvError(#[from] tokio::sync::oneshot::error::RecvError), - #[cfg(feature = "sse")] - #[error("{0}")] - SendMessageError(#[from] reqwest::Error), - #[error("Http Error: {0}")] - HttpError(u16), + + #[error("Internal error: {0}")] + Internal(String), + #[error("Shutdown timed out")] ShutdownTimeout, - #[error("Cancellation error : {0}")] - CancellationError(#[from] CancellationError), } diff --git a/crates/rust-mcp-transport/src/lib.rs b/crates/rust-mcp-transport/src/lib.rs index 1634922..4a918db 100644 --- a/crates/rust-mcp-transport/src/lib.rs +++ b/crates/rust-mcp-transport/src/lib.rs @@ -1,25 +1,38 @@ // Copyright (c) 2025 mcp-rust-stack // Licensed under the MIT License. See LICENSE file for details. // Modifications to this file must be documented with a description of the changes made. + #[cfg(feature = "sse")] mod client_sse; +#[cfg(feature = "streamable-http")] +mod client_streamable_http; +mod constants; pub mod error; mod mcp_stream; mod message_dispatcher; mod schema; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] mod sse; +#[cfg(feature = "stdio")] mod stdio; mod transport; mod utils; #[cfg(feature = "sse")] pub use client_sse::*; +#[cfg(feature = "streamable-http")] +pub use client_streamable_http::*; +pub use constants::*; pub use message_dispatcher::*; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] pub use sse::*; +#[cfg(feature = "stdio")] pub use stdio::*; pub use transport::*; // Type alias for session identifier, represented as a String pub type SessionId = String; +// Type alias for stream identifier (that will be used at the transport scope), represented as a String +pub type StreamId = String; +// Type alias for event (MCP message) identifier, represented as a String +pub type EventId = String; diff --git a/crates/rust-mcp-transport/src/mcp_stream.rs b/crates/rust-mcp-transport/src/mcp_stream.rs index 08bdc21..0b10918 100644 --- a/crates/rust-mcp-transport/src/mcp_stream.rs +++ b/crates/rust-mcp-transport/src/mcp_stream.rs @@ -57,6 +57,43 @@ impl MCPStream { (stream, sender, error_io) } + pub fn create_with_ack( + readable: Pin>, + writable: tokio::sync::mpsc::Sender<( + String, + tokio::sync::oneshot::Sender>, + )>, + error_io: IoStream, + pending_requests: Arc>>>, + request_timeout: Duration, + cancellation_token: CancellationToken, + ) -> ( + tokio_stream::wrappers::ReceiverStream, + MessageDispatcher, + IoStream, + ) + where + R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + X: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + { + let (tx, rx) = tokio::sync::mpsc::channel::(CHANNEL_CAPACITY); + let stream = tokio_stream::wrappers::ReceiverStream::new(rx); + + // Clone cancellation_token for reader + let reader_token = cancellation_token.clone(); + + #[allow(clippy::let_underscore_future)] + let _ = Self::spawn_reader(readable, tx, reader_token); + + let sender = MessageDispatcher::new_with_acknowledgement( + pending_requests, + writable, + request_timeout, + ); + + (stream, sender, error_io) + } + /// Creates a new task that continuously reads from the readable stream. /// The received data is deserialized into a JsonrpcMessage. If the deserialization is successful, /// the object is transmitted. If the object is a response or error corresponding to a pending request, diff --git a/crates/rust-mcp-transport/src/message_dispatcher.rs b/crates/rust-mcp-transport/src/message_dispatcher.rs index ea1eb04..7c7c93e 100644 --- a/crates/rust-mcp-transport/src/message_dispatcher.rs +++ b/crates/rust-mcp-transport/src/message_dispatcher.rs @@ -29,7 +29,13 @@ use crate::McpDispatch; /// a configurable timeout mechanism for asynchronous responses. pub struct MessageDispatcher { pending_requests: Arc>>>, - writable_std: Mutex>>, + writable_std: Option>>>, + writable_tx: Option< + tokio::sync::mpsc::Sender<( + String, + tokio::sync::oneshot::Sender>, + )>, + >, request_timeout: Duration, } @@ -51,7 +57,24 @@ impl MessageDispatcher { ) -> Self { Self { pending_requests, - writable_std, + writable_std: Some(writable_std), + writable_tx: None, + request_timeout, + } + } + + pub fn new_with_acknowledgement( + pending_requests: Arc>>>, + writable_tx: tokio::sync::mpsc::Sender<( + String, + tokio::sync::oneshot::Sender>, + )>, + request_timeout: Duration, + ) -> Self { + Self { + pending_requests, + writable_tx: Some(writable_tx), + writable_std: None, request_timeout, } } @@ -125,7 +148,7 @@ impl McpDispatch match await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)).await { Ok(response) => Ok(Some(ServerMessages::Single(response))), Err(error) => match error { - TransportError::OneshotRecvError(_) => { + TransportError::ChannelClosed(_) => { Err(schema_utils::SdkError::connection_closed().into()) } _ => Err(error), @@ -147,6 +170,9 @@ impl McpDispatch }) .unzip(); + // Ensure all request IDs are stored before sending the request + let tasks = join_all(pending_tasks).await; + // send the batch messages to the server let message_payload = serde_json::to_string(&client_messages).map_err(|_| { crate::error::TransportError::JsonrpcError(RpcError::parse_error()) @@ -154,12 +180,10 @@ impl McpDispatch self.write_str(message_payload.as_str()).await?; // no request in the batch, no need to wait for the result - if pending_tasks.is_empty() { + if request_ids.is_empty() { return Ok(None); } - let tasks = join_all(pending_tasks).await; - let timeout_wrapped_futures = tasks.into_iter().filter_map(|rx| { rx.map(|rx| await_timeout(rx, request_timeout.unwrap_or(self.request_timeout))) }); @@ -210,11 +234,24 @@ impl McpDispatch /// appending a newline character and flushing the stream afterward. /// async fn write_str(&self, payload: &str) -> TransportResult<()> { - let mut writable_std = self.writable_std.lock().await; - writable_std.write_all(payload.as_bytes()).await?; - writable_std.write_all(b"\n").await?; // new line - writable_std.flush().await?; - Ok(()) + if let Some(writable_std) = self.writable_std.as_ref() { + let mut writable_std = writable_std.lock().await; + writable_std.write_all(payload.as_bytes()).await?; + writable_std.write_all(b"\n").await?; // new line + writable_std.flush().await?; + return Ok(()); + }; + + if let Some(writable_tx) = self.writable_tx.as_ref() { + let (resp_tx, resp_rx) = oneshot::channel(); + writable_tx + .send((payload.to_string(), resp_tx)) + .await + .map_err(|err| TransportError::Internal(format!("{err}")))?; // Send fails if channel closed + return resp_rx.await?; // Await the POST result; propagates the error if POST failed + } + + Err(TransportError::Internal("Invalid dispatcher!".to_string())) } } @@ -339,10 +376,23 @@ impl McpDispatch /// appending a newline character and flushing the stream afterward. /// async fn write_str(&self, payload: &str) -> TransportResult<()> { - let mut writable_std = self.writable_std.lock().await; - writable_std.write_all(payload.as_bytes()).await?; - writable_std.write_all(b"\n").await?; // new line - writable_std.flush().await?; - Ok(()) + if let Some(writable_std) = self.writable_std.as_ref() { + let mut writable_std = writable_std.lock().await; + writable_std.write_all(payload.as_bytes()).await?; + writable_std.write_all(b"\n").await?; // new line + writable_std.flush().await?; + return Ok(()); + }; + + if let Some(writable_tx) = self.writable_tx.as_ref() { + let (resp_tx, resp_rx) = oneshot::channel(); + writable_tx + .send((payload.to_string(), resp_tx)) + .await + .map_err(|err| TransportError::Internal(err.to_string()))?; // Send fails if channel closed + return resp_rx.await?; // Await the POST result; propagates the error if POST failed + } + + Err(TransportError::Internal("Invalid dispatcher!".to_string())) } } diff --git a/crates/rust-mcp-transport/src/sse.rs b/crates/rust-mcp-transport/src/sse.rs index 50dbb32..09809e4 100644 --- a/crates/rust-mcp-transport/src/sse.rs +++ b/crates/rust-mcp-transport/src/sse.rs @@ -156,7 +156,7 @@ impl Transport {} - Err(TransportError::StdioError(error)) => { + Err(TransportError::Io(error)) => { if error.kind() == std::io::ErrorKind::BrokenPipe { let _ = disconnect_tx.send(()); break; diff --git a/crates/rust-mcp-transport/src/stdio.rs b/crates/rust-mcp-transport/src/stdio.rs index 582af5d..11bd0a6 100644 --- a/crates/rust-mcp-transport/src/stdio.rs +++ b/crates/rust-mcp-transport/src/stdio.rs @@ -1,5 +1,6 @@ use crate::schema::schema_utils::{ - ClientMessage, ClientMessages, MessageFromServer, SdkError, ServerMessage, ServerMessages, + ClientMessage, ClientMessages, MessageFromClient, MessageFromServer, SdkError, ServerMessage, + ServerMessages, }; use crate::schema::RequestId; use async_trait::async_trait; @@ -193,30 +194,29 @@ where #[cfg(unix)] command.process_group(0); - let mut process = command.spawn().map_err(TransportError::StdioError)?; + let mut process = command.spawn().map_err(TransportError::Io)?; let stdin = process .stdin .take() - .ok_or_else(|| TransportError::FromString("Unable to retrieve stdin.".into()))?; + .ok_or_else(|| TransportError::Internal("Unable to retrieve stdin.".into()))?; let stdout = process .stdout .take() - .ok_or_else(|| TransportError::FromString("Unable to retrieve stdout.".into()))?; + .ok_or_else(|| TransportError::Internal("Unable to retrieve stdout.".into()))?; let stderr = process .stderr .take() - .ok_or_else(|| TransportError::FromString("Unable to retrieve stderr.".into()))?; + .ok_or_else(|| TransportError::Internal("Unable to retrieve stderr.".into()))?; - let pending_requests_clone1 = self.pending_requests.clone(); - let pending_requests_clone2 = self.pending_requests.clone(); + let pending_requests_clone = self.pending_requests.clone(); tokio::spawn(async move { let _ = process.wait().await; // clean up pending requests to cancel waiting tasks - let mut pending_requests = pending_requests_clone1.lock().await; + let mut pending_requests = pending_requests_clone.lock().await; pending_requests.clear(); }); @@ -224,7 +224,7 @@ where Box::pin(stdout), Mutex::new(Box::pin(stdin)), IoStream::Readable(Box::pin(stderr)), - pending_requests_clone2, + self.pending_requests.clone(), self.options.timeout, cancellation_token, ); @@ -275,7 +275,7 @@ where } async fn consume_string_payload(&self, _payload: &str) -> TransportResult<()> { - Err(TransportError::FromString( + Err(TransportError::Internal( "Invalid invocation of consume_string_payload() function in StdioTransport".to_string(), )) } @@ -285,7 +285,7 @@ where _interval: Duration, _disconnect_tx: oneshot::Sender<()>, ) -> TransportResult> { - Err(TransportError::FromString( + Err(TransportError::Internal( "Invalid invocation of keep_alive() function for StdioTransport".to_string(), )) } @@ -365,3 +365,55 @@ impl > for StdioTransport { } + +#[async_trait] +impl McpDispatch + for StdioTransport +{ + async fn send_message( + &self, + message: ClientMessages, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send_message(message, request_timeout).await + } + + async fn send( + &self, + message: ClientMessage, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send(message, request_timeout).await + } + + async fn send_batch( + &self, + message: Vec, + request_timeout: Option, + ) -> TransportResult>> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send_batch(message, request_timeout).await + } + + async fn write_str(&self, payload: &str) -> TransportResult<()> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.write_str(payload).await + } +} + +impl + TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + > for StdioTransport +{ +} diff --git a/crates/rust-mcp-transport/src/transport.rs b/crates/rust-mcp-transport/src/transport.rs index 3d17ebd..b8e3ddc 100644 --- a/crates/rust-mcp-transport/src/transport.rs +++ b/crates/rust-mcp-transport/src/transport.rs @@ -1,15 +1,12 @@ -use std::{pin::Pin, sync::Arc, time::Duration}; - -use crate::schema::RequestId; +use crate::{error::TransportResult, message_dispatcher::MessageDispatcher}; +use crate::{schema::RequestId, SessionId}; use async_trait::async_trait; - +use std::{pin::Pin, sync::Arc, time::Duration}; use tokio::{ sync::oneshot::{self, Sender}, task::JoinHandle, }; -use crate::{error::TransportResult, message_dispatcher::MessageDispatcher}; - /// Default Timeout in milliseconds const DEFAULT_TIMEOUT_MSEC: u64 = 60_000; @@ -125,6 +122,9 @@ where interval: Duration, disconnect_tx: oneshot::Sender<()>, ) -> TransportResult>; + async fn session_id(&self) -> Option { + None + } } /// A composite trait that combines both transport and dispatch capabilities for the MCP protocol. @@ -160,3 +160,26 @@ where OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, { } + +// pub trait IntoClientTransport { +// type TransportType: Transport< +// ServerMessages, +// MessageFromClient, +// ServerMessage, +// ClientMessages, +// ClientMessage, +// >; + +// fn into_transport(self, session_id: Option) -> TransportResult; +// } + +// impl IntoClientTransport for T +// where +// T: Transport, +// { +// type TransportType = T; + +// fn into_transport(self, _: Option) -> TransportResult { +// Ok(self) +// } +// } diff --git a/crates/rust-mcp-transport/src/utils.rs b/crates/rust-mcp-transport/src/utils.rs index 218d517..82d7326 100644 --- a/crates/rust-mcp-transport/src/utils.rs +++ b/crates/rust-mcp-transport/src/utils.rs @@ -1,21 +1,29 @@ mod cancellation_token; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] mod http_utils; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] mod readable_channel; +#[cfg(any(feature = "sse", feature = "streamable-http"))] +mod sse_parser; #[cfg(feature = "sse")] mod sse_stream; -#[cfg(feature = "sse")] +#[cfg(feature = "streamable-http")] +mod streamable_http_stream; +#[cfg(any(feature = "sse", feature = "streamable-http"))] mod writable_channel; pub(crate) use cancellation_token::*; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use http_utils::*; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use readable_channel::*; +#[cfg(any(feature = "sse", feature = "streamable-http"))] +pub(crate) use sse_parser::*; #[cfg(feature = "sse")] pub(crate) use sse_stream::*; -#[cfg(feature = "sse")] +#[cfg(feature = "streamable-http")] +pub(crate) use streamable_http_stream::*; +#[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use writable_channel::*; use crate::schema::schema_utils::SdkError; @@ -23,16 +31,16 @@ use tokio::time::{timeout, Duration}; use crate::error::{TransportError, TransportResult}; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] use crate::SessionId; pub async fn await_timeout(operation: F, timeout_duration: Duration) -> TransportResult where F: std::future::Future>, // The operation returns a Result - E: Into, // The error type must be convertible to TransportError + E: Into, { match timeout(timeout_duration, operation).await { - Ok(result) => result.map_err(|err| err.into()), // Convert the error type into TransportError + Ok(result) => result.map_err(|err| err.into()), Err(_) => Err(SdkError::request_timeout(timeout_duration.as_millis()).into()), // Timeout error } } @@ -46,7 +54,7 @@ where /// # Returns /// A String containing the endpoint with the session ID added as a query parameter /// -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) fn endpoint_with_session_id(endpoint: &str, session_id: &SessionId) -> String { // Handle empty endpoint let base = if endpoint.is_empty() { "/" } else { endpoint }; diff --git a/crates/rust-mcp-transport/src/utils/http_utils.rs b/crates/rust-mcp-transport/src/utils/http_utils.rs index 701dcb0..84b62dd 100644 --- a/crates/rust-mcp-transport/src/utils/http_utils.rs +++ b/crates/rust-mcp-transport/src/utils/http_utils.rs @@ -1,7 +1,35 @@ use crate::error::{TransportError, TransportResult}; +use crate::{SessionId, MCP_SESSION_ID_HEADER}; -use reqwest::header::{HeaderMap, CONTENT_TYPE}; -use reqwest::Client; +use reqwest::header::{HeaderMap, HeaderName, HeaderValue, ACCEPT, CONTENT_TYPE}; +use reqwest::{Client, Response}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ResponseType { + EventStream, + Json, +} + +/// Determines the response type based on the `Content-Type` header. +pub async fn validate_response_type(response: &Response) -> TransportResult { + match response.headers().get(reqwest::header::CONTENT_TYPE) { + Some(content_type) => { + let content_type_str = content_type.to_str().map_err(|_| { + TransportError::UnexpectedContentType("".to_string()) + })?; + + // Normalize to lowercase for case-insensitive comparison + let content_type_normalized = content_type_str.to_ascii_lowercase(); + + match content_type_normalized.as_str() { + "text/event-stream" => Ok(ResponseType::EventStream), + "application/json" => Ok(ResponseType::Json), + other => Err(TransportError::UnexpectedContentType(other.to_string())), + } + } + None => Err(TransportError::UnexpectedContentType("".to_string())), + } +} /// Sends an HTTP POST request with the given body and headers /// @@ -17,21 +45,96 @@ pub async fn http_post( client: &Client, post_url: &str, body: String, - headers: &Option, -) -> TransportResult<()> { + session_id: Option<&SessionId>, + headers: Option<&HeaderMap>, +) -> TransportResult { let mut request = client .post(post_url) .header(CONTENT_TYPE, "application/json") + .header(ACCEPT, "application/json, text/event-stream") .body(body); if let Some(map) = headers { request = request.headers(map.clone()); } + + if let Some(session_id) = session_id { + request = request.header( + MCP_SESSION_ID_HEADER, + HeaderValue::from_str(session_id).unwrap(), + ); + } + let response = request.send().await?; if !response.status().is_success() { - return Err(TransportError::HttpError(response.status().as_u16())); + return Err(TransportError::Http(response.status())); } - Ok(()) + Ok(response) +} + +pub async fn http_get( + client: &Client, + url: &str, + session_id: Option<&SessionId>, + headers: Option<&HeaderMap>, +) -> TransportResult { + let mut request = client + .get(url) + .header(CONTENT_TYPE, "application/json") + .header(ACCEPT, "application/json, text/event-stream"); + + if let Some(map) = headers { + request = request.headers(map.clone()); + } + + if let Some(session_id) = session_id { + request = request.header( + MCP_SESSION_ID_HEADER, + HeaderValue::from_str(session_id).unwrap(), + ); + } + + let response = request.send().await?; + if !response.status().is_success() { + return Err(TransportError::Http(response.status())); + } + Ok(response) +} + +pub async fn http_delete( + client: &Client, + post_url: &str, + session_id: Option<&SessionId>, + headers: Option<&HeaderMap>, +) -> TransportResult { + let mut request = client + .delete(post_url) + .header(CONTENT_TYPE, "application/json") + .header(ACCEPT, "application/json, text/event-stream"); + + if let Some(map) = headers { + request = request.headers(map.clone()); + } + + if let Some(session_id) = session_id { + request = request.header( + MCP_SESSION_ID_HEADER, + HeaderValue::from_str(session_id).unwrap(), + ); + } + + let response = request.send().await?; + if !response.status().is_success() { + let status_code = response.status(); + return Err(TransportError::Http(status_code)); + } + Ok(response) +} + +#[allow(unused)] +pub fn get_header_value(response: &Response, header_name: HeaderName) -> Option { + let content_type = response.headers().get(header_name)?.to_str().ok()?; + Some(content_type.to_string()) } pub fn extract_origin(url: &str) -> Option { @@ -88,7 +191,7 @@ mod tests { let headers = None; // Perform the POST request - let result = http_post(&client, &url, body, &headers).await; + let result = http_post(&client, &url, body, None, headers.as_ref()).await; // Assert the result is Ok assert!(result.is_ok()); @@ -113,11 +216,11 @@ mod tests { let headers = None; // Perform the POST request - let result = http_post(&client, &url, body, &headers).await; + let result = http_post(&client, &url, body, None, headers.as_ref()).await; // Assert the result is an HttpError with status 400 match result { - Err(TransportError::HttpError(status)) => assert_eq!(status, 400), + Err(TransportError::Http(status)) => assert_eq!(status, 400), _ => panic!("Expected HttpError with status 400"), } } @@ -142,7 +245,7 @@ mod tests { let headers = Some(create_test_headers()); // Perform the POST request - let result = http_post(&client, &url, body, &headers).await; + let result = http_post(&client, &url, body, None, headers.as_ref()).await; // Assert the result is Ok assert!(result.is_ok()); @@ -157,7 +260,7 @@ mod tests { let headers = None; // Perform the POST request - let result = http_post(&client, url, body, &headers).await; + let result = http_post(&client, url, body, None, headers.as_ref()).await; // Assert the result is an error (likely a connection error) assert!(result.is_err()); diff --git a/crates/rust-mcp-transport/src/utils/sse_parser.rs b/crates/rust-mcp-transport/src/utils/sse_parser.rs new file mode 100644 index 0000000..064d3c3 --- /dev/null +++ b/crates/rust-mcp-transport/src/utils/sse_parser.rs @@ -0,0 +1,320 @@ +use core::fmt; +use std::collections::HashMap; + +use bytes::{Bytes, BytesMut}; +const BUFFER_CAPACITY: usize = 1024; + +/// Represents a single Server-Sent Event (SSE) as defined in the SSE protocol. +/// +/// Contains the event type, data payload, and optional event ID. +pub struct SseEvent { + /// The optional event type (e.g., "message"). + pub event: Option, + /// The optional data payload of the event, stored as bytes. + pub data: Option, + /// The optional event ID for reconnection or tracking purposes. + pub id: Option, +} + +impl std::fmt::Display for SseEvent { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(id) = &self.id { + writeln!(f, "id: {id}")?; + } + + if let Some(event) = &self.event { + writeln!(f, "event: {event}")?; + } + + if let Some(data) = &self.data { + match std::str::from_utf8(data) { + Ok(text) => { + for line in text.lines() { + writeln!(f, "data: {line}")?; + } + } + Err(_) => { + writeln!(f, "data: [binary data]")?; + } + } + } + + writeln!(f)?; // Trailing newline for SSE message end + Ok(()) + } +} + +impl fmt::Debug for SseEvent { + /// Formats the `SseEvent` for debugging, converting the `data` field to a UTF-8 string + /// (with lossy conversion if invalid UTF-8 is encountered). + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let data_str = self + .data + .as_ref() + .map(|b| String::from_utf8_lossy(b).to_string()); + + f.debug_struct("SseEvent") + .field("event", &self.event) + .field("data", &data_str) + .field("id", &self.id) + .finish() + } +} + +/// A parser for Server-Sent Events (SSE) that processes incoming byte chunks into `SseEvent`s. +/// This Parser is specificly designed for MCP messages and with no multi-line data support +/// +/// This struct maintains a buffer to accumulate incoming data and parses it into SSE events +/// based on the SSE protocol. It handles fields like `event`, `data`, and `id` as defined +/// in the SSE specification. +#[derive(Debug)] +pub struct SseParser { + pub buffer: BytesMut, +} + +impl SseParser { + /// Creates a new `SseParser` with an empty buffer pre-allocated to a default capacity. + /// + /// The buffer is initialized with a capacity of `BUFFER_CAPACITY` to + /// optimize for typical SSE message sizes. + /// + /// # Returns + /// A new `SseParser` instance with an empty buffer. + pub fn new() -> Self { + Self { + buffer: BytesMut::with_capacity(BUFFER_CAPACITY), + } + } + + /// Processes a new chunk of bytes and parses it into a vector of `SseEvent`s. + /// + /// This method appends the incoming `bytes` to the internal buffer, splits it into + /// complete lines (delimited by `\n`), and parses each line according to the SSE + /// protocol. It supports `event`, `id`, and `data` fields, as well as comments + /// (lines starting with `:`). Empty lines are skipped, and incomplete lines remain + /// in the buffer for future processing. + /// + /// # Parameters + /// - `bytes`: The incoming chunk of bytes to parse. + /// + /// # Returns + /// A vector of `SseEvent`s parsed from the complete lines in the buffer. If no + /// complete events are found, an empty vector is returned. + pub fn process_new_chunk(&mut self, bytes: Bytes) -> Vec { + self.buffer.extend_from_slice(&bytes); + + // Collect complete lines (ending in \n)—keep ALL lines, including empty ones for \n\n detection + let mut lines = Vec::new(); + while let Some(pos) = self.buffer.iter().position(|&b| b == b'\n') { + let line = self.buffer.split_to(pos + 1).freeze(); + lines.push(line); + } + + let mut events = Vec::new(); + let mut current_message_lines: Vec = Vec::new(); + + for line in lines { + current_message_lines.push(line); + + // Check if we've hit a double newline (end of message) + if current_message_lines.len() >= 2 + && current_message_lines + .last() + .is_some_and(|b| b.as_ref() == b"\n") + { + // Process the complete message (exclude the last empty lines for parsing) + let message_lines: Vec<_> = current_message_lines + .drain(..current_message_lines.len() - 1) + .filter(|l| l.as_ref() != b"\n") // Filter internal empties + .collect(); + + if let Some(event) = self.parse_sse_message(&message_lines) { + events.push(event); + } + } + } + + // Put back any incomplete message + if !current_message_lines.is_empty() { + self.buffer.clear(); + for line in current_message_lines { + self.buffer.extend_from_slice(&line); + } + } + + events + } + + fn parse_sse_message(&self, lines: &[Bytes]) -> Option { + let mut fields: HashMap = HashMap::new(); + let mut data_parts: Vec = Vec::new(); + + for line_bytes in lines { + let line_str = String::from_utf8_lossy(line_bytes); + + // Skip comments and empty lines + if line_str.is_empty() || line_str.starts_with(':') { + continue; + } + + let (key, value) = if let Some(value) = line_str.strip_prefix("data: ") { + ("data", value.trim_start().to_string()) + } else if let Some(value) = line_str.strip_prefix("event: ") { + ("event", value.trim().to_string()) + } else if let Some(value) = line_str.strip_prefix("id: ") { + ("id", value.trim().to_string()) + } else if let Some(value) = line_str.strip_prefix("retry: ") { + ("retry", value.trim().to_string()) + } else { + // Invalid line; skip + continue; + }; + + if key == "data" { + if !value.is_empty() { + data_parts.push(value); + } + } else { + fields.insert(key.to_string(), value); + } + } + + // Build data (concat multi-line data with \n) , should not occur in MCP tho + let data = if data_parts.is_empty() { + None + } else { + let full_data = data_parts.join("\n"); + Some(Bytes::copy_from_slice(full_data.as_bytes())) // Use copy_from_slice for efficiency + }; + + // Skip invalid message with no data + let data = data?; + + // Get event (default to None) + let event = fields.get("event").cloned(); + let id = fields.get("id").cloned(); + + Some(SseEvent { + event, + data: Some(data), + id, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + + #[test] + fn test_single_data_event() { + let mut parser = SseParser::new(); + let input = Bytes::from("data: hello\n\n"); + let events = parser.process_new_chunk(input); + + assert_eq!(events.len(), 1); + assert_eq!( + events[0].data.as_deref(), + Some(Bytes::from("hello\n").as_ref()) + ); + assert!(events[0].event.is_none()); + assert!(events[0].id.is_none()); + } + + #[test] + fn test_event_with_id_and_data() { + let mut parser = SseParser::new(); + let input = Bytes::from("event: message\nid: 123\ndata: hello\n\n"); + let events = parser.process_new_chunk(input); + + assert_eq!(events.len(), 1); + assert_eq!(events[0].event.as_deref(), Some("message")); + assert_eq!(events[0].id.as_deref(), Some("123")); + assert_eq!( + events[0].data.as_deref(), + Some(Bytes::from("hello\n").as_ref()) + ); + } + + #[test] + fn test_event_chunks_in_different_orders() { + let mut parser = SseParser::new(); + let input = Bytes::from("data: hello\nevent: message\nid: 123\n\n"); + let events = parser.process_new_chunk(input); + + assert_eq!(events.len(), 1); + assert_eq!(events[0].event.as_deref(), Some("message")); + assert_eq!(events[0].id.as_deref(), Some("123")); + assert_eq!( + events[0].data.as_deref(), + Some(Bytes::from("hello\n").as_ref()) + ); + } + + #[test] + fn test_comment_line_ignored() { + let mut parser = SseParser::new(); + let input = Bytes::from(": this is a comment\n\n"); + let events = parser.process_new_chunk(input); + assert_eq!(events.len(), 0); + } + + #[test] + fn test_event_with_empty_data() { + let mut parser = SseParser::new(); + let input = Bytes::from("data:\n\n"); + let events = parser.process_new_chunk(input); + // Your parser skips data lines with empty content + assert_eq!(events.len(), 0); + } + + #[test] + fn test_partial_chunks() { + let mut parser = SseParser::new(); + + let part1 = Bytes::from("data: hello"); + let part2 = Bytes::from(" world\n\n"); + + let events1 = parser.process_new_chunk(part1); + assert_eq!(events1.len(), 0); // incomplete + + let events2 = parser.process_new_chunk(part2); + assert_eq!(events2.len(), 1); + assert_eq!( + events2[0].data.as_deref(), + Some(Bytes::from("hello world\n").as_ref()) + ); + } + + #[test] + fn test_malformed_lines() { + let mut parser = SseParser::new(); + let input = Bytes::from("something invalid\ndata: ok\n\n"); + + let events = parser.process_new_chunk(input); + + assert_eq!(events.len(), 1); + assert_eq!( + events[0].data.as_deref(), + Some(Bytes::from("ok\n").as_ref()) + ); + } + + #[test] + fn test_multiple_events_in_one_chunk() { + let mut parser = SseParser::new(); + let input = Bytes::from("data: first\n\ndata: second\n\n"); + let events = parser.process_new_chunk(input); + + assert_eq!(events.len(), 2); + assert_eq!( + events[0].data.as_deref(), + Some(Bytes::from("first\n").as_ref()) + ); + assert_eq!( + events[1].data.as_deref(), + Some(Bytes::from("second\n").as_ref()) + ); + } +} diff --git a/crates/rust-mcp-transport/src/utils/streamable_http_stream.rs b/crates/rust-mcp-transport/src/utils/streamable_http_stream.rs new file mode 100644 index 0000000..ae9c69c --- /dev/null +++ b/crates/rust-mcp-transport/src/utils/streamable_http_stream.rs @@ -0,0 +1,374 @@ +use super::CancellationToken; +use crate::error::{TransportError, TransportResult}; +use crate::utils::SseParser; +use crate::utils::{http_get, validate_response_type, ResponseType}; +use crate::{utils::http_post, MCP_SESSION_ID_HEADER}; +use crate::{EventId, MCP_LAST_EVENT_ID_HEADER}; +use bytes::Bytes; +use reqwest::header::{HeaderMap, HeaderValue}; +use reqwest::{Client, Response, StatusCode}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::{mpsc, RwLock}; +use tokio::time; +use tokio_stream::StreamExt; + +//-----------------------------------------------------------------------------------// +pub(crate) struct StreamableHttpStream { + /// HTTP client for making SSE requests + pub client: Client, + /// URL of the SSE endpoint + pub mcp_url: String, + /// Maximum number of retry attempts for failed connections + pub max_retries: usize, + /// Delay between retry attempts + pub retry_delay: Duration, + /// Sender for transmitting received data to the readable channel + pub read_tx: mpsc::Sender, + /// Session id will be received from the server in the http + pub session_id: Arc>>, +} + +impl StreamableHttpStream { + pub(crate) async fn run( + &mut self, + payload: String, + cancellation_token: &CancellationToken, + custom_headers: &Option, + ) -> TransportResult<()> { + let mut stream_parser = SseParser::new(); + let mut _last_event_id: Option = None; + + let session_id = self.session_id.read().await.clone(); + + // Check for cancellation before attempting connection + if cancellation_token.is_cancelled() { + tracing::info!( + "StreamableHttp cancelled before connection attempt {}", + payload + ); + return Err(TransportError::Cancelled( + crate::utils::CancellationError::ChannelClosed, + )); + } + + //TODO: simplify + let response = match http_post( + &self.client, + &self.mcp_url, + payload.to_string(), + session_id.as_ref(), + custom_headers.as_ref(), + ) + .await + { + Ok(response) => { + // if session_id_clone.read().await.is_none() { + let session_id = response + .headers() + .get(MCP_SESSION_ID_HEADER) + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_string()); + + let mut guard = self.session_id.write().await; + *guard = session_id; + response + } + + Err(error) => { + tracing::error!("Failed to connect to MCP endpoint: {error}"); + return Err(error); + } + }; + + // return if status code != 200 and no result is expected + if response.status() != StatusCode::OK { + return Ok(()); + } + + let response_type = validate_response_type(&response).await?; + + // Handle non-streaming JSON response + if response_type == ResponseType::Json { + return match response.bytes().await { + Ok(bytes) => { + // Send the message + self.read_tx.send(bytes).await.map_err(|_| { + tracing::error!("Readable stream closed, shutting down MCP task"); + TransportError::SendFailure( + "Failed to send message: channel closed or full".to_string(), + ) + })?; + + // Send the newline + self.read_tx + .send(Bytes::from_static(b"\n")) + .await + .map_err(|_| { + tracing::error!( + "Failed to send newline, channel may be closed or full" + ); + TransportError::SendFailure( + "Failed to send newline: channel closed or full".to_string(), + ) + })?; + + Ok(()) + } + Err(error) => Err(error.into()), + }; + } + + // Create a stream from the response bytes + let mut stream = response.bytes_stream(); + + // Inner loop for processing stream chunks + loop { + let next_chunk = tokio::select! { + // Wait for the next stream chunk + chunk = stream.next() => { + match chunk { + Some(chunk) => chunk, + None => { + // stream ended, unline SSE, so no retry attempt here needed to reconnect + return Err(TransportError::Internal("Stream has ended.".to_string())); + } + } + } + // Wait for cancellation + _ = cancellation_token.cancelled() => { + return Err(TransportError::Cancelled( + crate::utils::CancellationError::ChannelClosed, + )); + } + }; + + match next_chunk { + Ok(bytes) => { + let events = stream_parser.process_new_chunk(bytes); + + if !events.is_empty() { + for event in events { + if let Some(bytes) = event.data { + if event.id.is_some() { + _last_event_id = event.id.clone(); + } + + if self.read_tx.send(bytes).await.is_err() { + tracing::error!( + "Readable stream closed, shutting down MCP task" + ); + return Err(TransportError::SendFailure( + "Failed to send message: stream closed".to_string(), + )); + } + } + } + // break after receiving the message(s) + return Ok(()); + } + } + Err(error) => { + tracing::error!("Error reading stream: {error}"); + return Err(error.into()); + } + } + } + } + + pub(crate) async fn make_standalone_stream_connection( + &self, + cancellation_token: &CancellationToken, + custom_headers: &Option, + last_event_id: Option, + ) -> TransportResult { + let mut retry_count = 0; + let session_id = self.session_id.read().await.clone(); + + let headers = if let Some(event_id) = last_event_id.as_ref() { + let mut headers = HeaderMap::new(); + if let Some(custom) = custom_headers { + headers.extend(custom.iter().map(|(k, v)| (k.clone(), v.clone()))); + } + if let Ok(event_id_value) = HeaderValue::from_str(event_id) { + headers.insert(MCP_LAST_EVENT_ID_HEADER, event_id_value); + } + &Some(headers) + } else { + custom_headers + }; + + loop { + // Check for cancellation before attempting connection + if cancellation_token.is_cancelled() { + tracing::info!("Standalone StreamableHttp cancelled."); + return Err(TransportError::Cancelled( + crate::utils::CancellationError::ChannelClosed, + )); + } + + match http_get( + &self.client, + &self.mcp_url, + session_id.as_ref(), + headers.as_ref(), + ) + .await + { + Ok(response) => { + let is_event_stream = validate_response_type(&response) + .await + .is_ok_and(|response_type| response_type == ResponseType::EventStream); + + if !is_event_stream { + let message = + "SSE stream response returned an unexpected Content-Type.".to_string(); + tracing::warn!("{message}"); + return Err(TransportError::FailedToOpenSSEStream(message)); + } + + return Ok(response); + } + + Err(error) => { + match error { + crate::error::TransportError::HttpConnection(_) => { + // A reqwest::Error happened, we do not return ans instead retry the operation + } + crate::error::TransportError::Http(status_code) => match status_code { + StatusCode::NOT_FOUND | StatusCode::METHOD_NOT_ALLOWED => { + return Err(crate::error::TransportError::FailedToOpenSSEStream( + format!("Not supported (code: {status_code})"), + )); + } + other => { + tracing::warn!( + "Failed to open SSE stream: {error} (code: {other})" + ); + } + }, + error => { + return Err(error); // return the error where the retry wont help + } + } + + if retry_count >= self.max_retries { + tracing::warn!("Max retries ({}) reached, giving up", self.max_retries); + return Err(error); + } + retry_count += 1; + time::sleep(self.retry_delay).await; + continue; + } + }; + } + } + + pub(crate) async fn run_standalone( + &mut self, + cancellation_token: &CancellationToken, + custom_headers: &Option, + response: Response, + ) -> TransportResult<()> { + let mut retry_count = 0; + let mut stream_parser = SseParser::new(); + let mut _last_event_id: Option = None; + + let mut response = Some(response); + + // Main loop for reconnection attempts + loop { + // Check for cancellation before attempting connection + if cancellation_token.is_cancelled() { + tracing::debug!("Standalone StreamableHttp cancelled."); + return Err(TransportError::Cancelled( + crate::utils::CancellationError::ChannelClosed, + )); + } + + // use initially passed response, otherwise try to make a new sse connection + let response = match response.take() { + Some(response) => response, + None => { + tracing::debug!( + "Reconnecting to SSE stream... (try {} of {})", + retry_count, + self.max_retries + ); + self.make_standalone_stream_connection( + cancellation_token, + custom_headers, + _last_event_id.clone(), + ) + .await? + } + }; + + // Create a stream from the response bytes + let mut stream = response.bytes_stream(); + + // Inner loop for processing stream chunks + loop { + let next_chunk = tokio::select! { + // Wait for the next stream chunk + chunk = stream.next() => { + match chunk { + Some(chunk) => chunk, + None => { + // stream ended, unline SSE, so no retry attempt here needed to reconnect + return Err(TransportError::Internal("Stream has ended.".to_string())); + } + } + } + // Wait for cancellation + _ = cancellation_token.cancelled() => { + return Err(TransportError::Cancelled( + crate::utils::CancellationError::ChannelClosed, + )); + } + }; + + match next_chunk { + Ok(bytes) => { + let events = stream_parser.process_new_chunk(bytes); + + if !events.is_empty() { + for event in events { + if let Some(bytes) = event.data { + if event.id.is_some() { + _last_event_id = event.id.clone(); + } + + if self.read_tx.send(bytes).await.is_err() { + tracing::error!( + "Readable stream closed, shutting down MCP task" + ); + return Err(TransportError::SendFailure( + "Failed to send message: stream closed".to_string(), + )); + } + } + } + } + retry_count = 0; // Reset retry count on successful chunk + } + Err(error) => { + if retry_count >= self.max_retries { + tracing::error!("Error reading stream: {error}"); + tracing::warn!("Max retries ({}) reached, giving up", self.max_retries); + return Err(error.into()); + } + + tracing::debug!( + "The standalone SSE stream encountered an error: '{}'", + error + ); + retry_count += 1; + time::sleep(self.retry_delay).await; + break; // Break inner loop to reconnect + } + } + } + } + } +} diff --git a/crates/rust-mcp-transport/tests/check_imports.rs b/crates/rust-mcp-transport/tests/check_imports.rs index cda7d0c..207644e 100644 --- a/crates/rust-mcp-transport/tests/check_imports.rs +++ b/crates/rust-mcp-transport/tests/check_imports.rs @@ -37,13 +37,12 @@ mod tests { // Check for `use rust_mcp_schema` if content.contains("use rust_mcp_schema") { errors.push(format!( - "File {} contains `use rust_mcp_schema`. Use `use crate::schema` instead.", - abs_path + "File {abs_path} contains `use rust_mcp_schema`. Use `use crate::schema` instead." )); } } Err(e) => { - errors.push(format!("Failed to read file `{}`: {}", path_str, e)); + errors.push(format!("Failed to read file `{path_str}`: {e}")); } } } diff --git a/development.md b/development.md index e3673cc..e17dd17 100644 --- a/development.md +++ b/development.md @@ -33,14 +33,14 @@ Build and run instructions are available in their respective README.md files. You can run examples by passing the example project name to Cargo using the `-p` argument, like this: ```sh -cargo run -p simple-mcp-client +cargo run -p simple-mcp-client-stdio ``` -You can build the examples in a similar way. The following command builds the project and generates the binary at `target/release/hello-world-mcp-server`: +You can build the examples in a similar way. The following command builds the project and generates the binary at `target/release/hello-world-mcp-server-stdio`: ```sh -cargo build -p hello-world-mcp-server --release +cargo build -p hello-world-mcp-server-stdio --release ``` ## Code Formatting diff --git a/doc/getting-started-mcp-server.md b/doc/getting-started-mcp-server.md index 358b1b4..6fac258 100644 --- a/doc/getting-started-mcp-server.md +++ b/doc/getting-started-mcp-server.md @@ -160,7 +160,7 @@ impl ServerHandler for MyServerHandler { async fn handle_list_tools_request( &self, _request: ListToolsRequest, - _runtime: &dyn McpServer, + _runtime: Arc, ) -> std::result::Result { Ok(ListToolsResult { meta: None, @@ -173,7 +173,7 @@ impl ServerHandler for MyServerHandler { async fn handle_call_tool_request( &self, request: CallToolRequest, - _runtime: &dyn McpServer, + _runtime: Arc, ) -> std::result::Result { // Attempt to convert request parameters into GreetingTools enum let tool_params: GreetingTools = diff --git a/examples/hello-world-mcp-server-core/.gitignore b/examples/hello-world-mcp-server-stdio-core/.gitignore similarity index 100% rename from examples/hello-world-mcp-server-core/.gitignore rename to examples/hello-world-mcp-server-stdio-core/.gitignore diff --git a/examples/hello-world-mcp-server-core/Cargo.toml b/examples/hello-world-mcp-server-stdio-core/Cargo.toml similarity index 83% rename from examples/hello-world-mcp-server-core/Cargo.toml rename to examples/hello-world-mcp-server-stdio-core/Cargo.toml index bbab301..14eb904 100644 --- a/examples/hello-world-mcp-server-core/Cargo.toml +++ b/examples/hello-world-mcp-server-stdio-core/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "hello-world-mcp-server-core" -version = "0.1.22" +name = "hello-world-mcp-server-stdio-core" +version = "0.1.19" edition = "2021" publish = false license = "MIT" @@ -10,6 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "server", "macros", + "stdio", "2025_06_18", ] } diff --git a/examples/hello-world-mcp-server-core/README.md b/examples/hello-world-mcp-server-stdio-core/README.md similarity index 81% rename from examples/hello-world-mcp-server-core/README.md rename to examples/hello-world-mcp-server-stdio-core/README.md index af9d703..cf57884 100644 --- a/examples/hello-world-mcp-server-core/README.md +++ b/examples/hello-world-mcp-server-stdio-core/README.md @@ -23,14 +23,14 @@ cd rust-mcp-sdk 2. Build the project: ```bash -cargo build -p hello-world-mcp-server-core --release +cargo build -p hello-world-mcp-server-stdio-core --release ``` -3. After building the project, the binary will be located at `target/release/hello-world-mcp-server-core` +3. After building the project, the binary will be located at `target/release/hello-world-mcp-server-stdio-core` You can test it with [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector), or alternatively, use it with any MCP client you prefer. ```bash -npx -y @modelcontextprotocol/inspector ./target/release/hello-world-mcp-server-core +npx -y @modelcontextprotocol/inspector ./target/release/hello-world-mcp-server-stdio-core ``` ``` @@ -41,4 +41,4 @@ Starting MCP inspector... Here you can see it in action : -![hello-world-mcp-server-core]![hello-world-mcp-server](../../assets/examples/hello-world-mcp-server.gif) +![hello-world-mcp-server-stdio-core]![hello-world-mcp-server](../../assets/examples/hello-world-mcp-server.gif) diff --git a/examples/hello-world-mcp-server-core/src/handler.rs b/examples/hello-world-mcp-server-stdio-core/src/handler.rs similarity index 97% rename from examples/hello-world-mcp-server-core/src/handler.rs rename to examples/hello-world-mcp-server-stdio-core/src/handler.rs index f0bdefe..acf55ea 100644 --- a/examples/hello-world-mcp-server-core/src/handler.rs +++ b/examples/hello-world-mcp-server-stdio-core/src/handler.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use async_trait::async_trait; use rust_mcp_sdk::schema::{ @@ -22,7 +24,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_request( &self, request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { let method_name = &request.method().to_owned(); match request { @@ -90,7 +92,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_notification( &self, notification: NotificationFromClient, - _: &dyn McpServer, + _: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -99,7 +101,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_error( &self, error: &RpcError, - _: &dyn McpServer, + _: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } diff --git a/examples/hello-world-mcp-server-core/src/main.rs b/examples/hello-world-mcp-server-stdio-core/src/main.rs similarity index 100% rename from examples/hello-world-mcp-server-core/src/main.rs rename to examples/hello-world-mcp-server-stdio-core/src/main.rs diff --git a/examples/hello-world-mcp-server-core/src/tools.rs b/examples/hello-world-mcp-server-stdio-core/src/tools.rs similarity index 100% rename from examples/hello-world-mcp-server-core/src/tools.rs rename to examples/hello-world-mcp-server-stdio-core/src/tools.rs diff --git a/examples/hello-world-mcp-server/Cargo.toml b/examples/hello-world-mcp-server-stdio/Cargo.toml similarity index 85% rename from examples/hello-world-mcp-server/Cargo.toml rename to examples/hello-world-mcp-server-stdio/Cargo.toml index 63a54af..9d15be3 100644 --- a/examples/hello-world-mcp-server/Cargo.toml +++ b/examples/hello-world-mcp-server-stdio/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "hello-world-mcp-server" -version = "0.1.31" +name = "hello-world-mcp-server-stdio" +version = "0.1.28" edition = "2021" publish = false license = "MIT" @@ -10,8 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "server", "macros", - "hyper-server", - "ssl", + "stdio", "2025_06_18", ] } diff --git a/examples/hello-world-mcp-server/README.md b/examples/hello-world-mcp-server-stdio/README.md similarity index 84% rename from examples/hello-world-mcp-server/README.md rename to examples/hello-world-mcp-server-stdio/README.md index 33a62af..9e0bdda 100644 --- a/examples/hello-world-mcp-server/README.md +++ b/examples/hello-world-mcp-server-stdio/README.md @@ -22,14 +22,14 @@ cd rust-mcp-sdk 2. Build the project: ```bash -cargo build -p hello-world-mcp-server --release +cargo build -p hello-world-mcp-server-stdio --release ``` -3. After building the project, the binary will be located at `target/release/hello-world-mcp-server` +3. After building the project, the binary will be located at `target/release/hello-world-mcp-server-stdio` You can test it with [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector), or alternatively, use it with any MCP client you prefer. ```bash -npx -y @modelcontextprotocol/inspector ./target/release/hello-world-mcp-server +npx -y @modelcontextprotocol/inspector ./target/release/hello-world-mcp-server-stdio ``` ``` @@ -40,4 +40,4 @@ Starting MCP inspector... Here you can see it in action : -![hello-world-mcp-server](../../assets/examples/hello-world-mcp-server.gif) +![hello-world-mcp-server-stdio](../../assets/examples/hello-world-mcp-server.gif) diff --git a/examples/hello-world-mcp-server/src/handler.rs b/examples/hello-world-mcp-server-stdio/src/handler.rs similarity index 94% rename from examples/hello-world-mcp-server/src/handler.rs rename to examples/hello-world-mcp-server-stdio/src/handler.rs index d9741a0..47925a0 100644 --- a/examples/hello-world-mcp-server/src/handler.rs +++ b/examples/hello-world-mcp-server-stdio/src/handler.rs @@ -4,6 +4,7 @@ use rust_mcp_sdk::schema::{ ListToolsResult, RpcError, }; use rust_mcp_sdk::{mcp_server::ServerHandler, McpServer}; +use std::sync::Arc; use crate::tools::GreetingTools; @@ -20,7 +21,7 @@ impl ServerHandler for MyServerHandler { async fn handle_list_tools_request( &self, request: ListToolsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { Ok(ListToolsResult { meta: None, @@ -33,7 +34,7 @@ impl ServerHandler for MyServerHandler { async fn handle_call_tool_request( &self, request: CallToolRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { // Attempt to convert request parameters into GreetingTools enum let tool_params: GreetingTools = diff --git a/examples/hello-world-mcp-server/src/main.rs b/examples/hello-world-mcp-server-stdio/src/main.rs similarity index 92% rename from examples/hello-world-mcp-server/src/main.rs rename to examples/hello-world-mcp-server-stdio/src/main.rs index 00ca6a7..98ff6f0 100644 --- a/examples/hello-world-mcp-server/src/main.rs +++ b/examples/hello-world-mcp-server-stdio/src/main.rs @@ -1,6 +1,8 @@ mod handler; mod tools; +use std::sync::Arc; + use handler::MyServerHandler; use rust_mcp_sdk::schema::{ Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, @@ -40,7 +42,8 @@ async fn main() -> SdkResult<()> { let handler = MyServerHandler {}; // STEP 4: create a MCP server - let server: ServerRuntime = server_runtime::create_server(server_details, transport, handler); + let server: Arc = + server_runtime::create_server(server_details, transport, handler); // STEP 5: Start the server if let Err(start_error) = server.start().await { diff --git a/examples/hello-world-mcp-server/src/tools.rs b/examples/hello-world-mcp-server-stdio/src/tools.rs similarity index 100% rename from examples/hello-world-mcp-server/src/tools.rs rename to examples/hello-world-mcp-server-stdio/src/tools.rs diff --git a/examples/hello-world-server-core-streamable-http/.gitignore b/examples/hello-world-server-streamable-http-core/.gitignore similarity index 100% rename from examples/hello-world-server-core-streamable-http/.gitignore rename to examples/hello-world-server-streamable-http-core/.gitignore diff --git a/examples/hello-world-server-core-streamable-http/Cargo.toml b/examples/hello-world-server-streamable-http-core/Cargo.toml similarity index 84% rename from examples/hello-world-server-core-streamable-http/Cargo.toml rename to examples/hello-world-server-streamable-http-core/Cargo.toml index 99d1011..a762058 100644 --- a/examples/hello-world-server-core-streamable-http/Cargo.toml +++ b/examples/hello-world-server-streamable-http-core/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "hello-world-server-core-streamable-http" -version = "0.1.22" +name = "hello-world-server-streamable-http-core" +version = "0.1.19" edition = "2021" publish = false license = "MIT" @@ -10,6 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "server", "macros", + "streamable-http", "hyper-server", "2025_06_18", ] } diff --git a/examples/hello-world-server-core-streamable-http/README.md b/examples/hello-world-server-streamable-http-core/README.md similarity index 95% rename from examples/hello-world-server-core-streamable-http/README.md rename to examples/hello-world-server-streamable-http-core/README.md index cd37623..49af2c2 100644 --- a/examples/hello-world-server-core-streamable-http/README.md +++ b/examples/hello-world-server-streamable-http-core/README.md @@ -37,7 +37,7 @@ cd rust-mcp-sdk 2. Build and start the server: ```bash -cargo run -p hello-world-server-core-streamable-http --release +cargo run -p hello-world-server-streamable-http-core --release ``` By default, both the Streamable HTTP and SSE endpoints are displayed in the terminal: @@ -65,4 +65,4 @@ Then , to test the server, visit one of the following URLs based on the desired Here you can see it in action : -![hello-world-mcp-server-sse-core](../../assets/examples/hello-world-server-core-streamable-http.gif) +![hello-world-mcp-server-sse-core](../../assets/examples/hello-world-server-streamable-http-core.gif) diff --git a/examples/hello-world-server-core-streamable-http/src/handler.rs b/examples/hello-world-server-streamable-http-core/src/handler.rs similarity index 97% rename from examples/hello-world-server-core-streamable-http/src/handler.rs rename to examples/hello-world-server-streamable-http-core/src/handler.rs index 1c69e8c..7941075 100644 --- a/examples/hello-world-server-core-streamable-http/src/handler.rs +++ b/examples/hello-world-server-streamable-http-core/src/handler.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use async_trait::async_trait; use rust_mcp_sdk::schema::{ @@ -22,7 +24,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_request( &self, request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { let method_name = &request.method().to_owned(); match request { @@ -95,7 +97,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_notification( &self, notification: NotificationFromClient, - _: &dyn McpServer, + _: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -104,7 +106,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_error( &self, error: &RpcError, - _: &dyn McpServer, + _: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } diff --git a/examples/hello-world-server-core-streamable-http/src/main.rs b/examples/hello-world-server-streamable-http-core/src/main.rs similarity index 100% rename from examples/hello-world-server-core-streamable-http/src/main.rs rename to examples/hello-world-server-streamable-http-core/src/main.rs diff --git a/examples/hello-world-server-core-streamable-http/src/tools.rs b/examples/hello-world-server-streamable-http-core/src/tools.rs similarity index 100% rename from examples/hello-world-server-core-streamable-http/src/tools.rs rename to examples/hello-world-server-streamable-http-core/src/tools.rs diff --git a/examples/hello-world-server-streamable-http/Cargo.toml b/examples/hello-world-server-streamable-http/Cargo.toml index df4296d..17a87c8 100644 --- a/examples/hello-world-server-streamable-http/Cargo.toml +++ b/examples/hello-world-server-streamable-http/Cargo.toml @@ -10,6 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "server", "macros", + "streamable-http", "hyper-server", "2025_06_18", ] } diff --git a/examples/hello-world-server-streamable-http/README.md b/examples/hello-world-server-streamable-http/README.md index ac56a86..7e3f3b6 100644 --- a/examples/hello-world-server-streamable-http/README.md +++ b/examples/hello-world-server-streamable-http/README.md @@ -66,4 +66,4 @@ Then , to test the server, visit one of the following URLs based on the desired Here you can see it in action : -![hello-world-mcp-server-sse-core](../../assets/examples/hello-world-server-core-streamable-http.gif) +![hello-world-mcp-server-sse-core](../../assets/examples/hello-world-server-streamable-http-core.gif) diff --git a/examples/hello-world-server-streamable-http/src/handler.rs b/examples/hello-world-server-streamable-http/src/handler.rs index b8ce355..3939d86 100644 --- a/examples/hello-world-server-streamable-http/src/handler.rs +++ b/examples/hello-world-server-streamable-http/src/handler.rs @@ -1,12 +1,11 @@ +use crate::tools::GreetingTools; use async_trait::async_trait; use rust_mcp_sdk::schema::{ schema_utils::CallToolError, CallToolRequest, CallToolResult, ListToolsRequest, ListToolsResult, RpcError, }; use rust_mcp_sdk::{mcp_server::ServerHandler, McpServer}; - -use crate::tools::GreetingTools; - +use std::sync::Arc; // Custom Handler to handle MCP Messages pub struct MyServerHandler; @@ -20,7 +19,7 @@ impl ServerHandler for MyServerHandler { async fn handle_list_tools_request( &self, request: ListToolsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { Ok(ListToolsResult { meta: None, @@ -33,7 +32,7 @@ impl ServerHandler for MyServerHandler { async fn handle_call_tool_request( &self, request: CallToolRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { // Attempt to convert request parameters into GreetingTools enum let tool_params: GreetingTools = @@ -45,6 +44,4 @@ impl ServerHandler for MyServerHandler { GreetingTools::SayGoodbyeTool(say_goodbye_tool) => say_goodbye_tool.call_tool(), } } - - async fn on_server_started(&self, runtime: &dyn McpServer) {} } diff --git a/examples/simple-mcp-client-core-sse/Cargo.toml b/examples/simple-mcp-client-sse-core/Cargo.toml similarity index 88% rename from examples/simple-mcp-client-core-sse/Cargo.toml rename to examples/simple-mcp-client-sse-core/Cargo.toml index 0e32790..25dcd7d 100644 --- a/examples/simple-mcp-client-core-sse/Cargo.toml +++ b/examples/simple-mcp-client-sse-core/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "simple-mcp-client-core-sse" -version = "0.1.22" +name = "simple-mcp-client-sse-core" +version = "0.1.19" edition = "2021" publish = false license = "MIT" @@ -10,6 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "client", "macros", + "sse", "2025_06_18", ] } diff --git a/examples/simple-mcp-client-core-sse/README.md b/examples/simple-mcp-client-sse-core/README.md similarity index 97% rename from examples/simple-mcp-client-core-sse/README.md rename to examples/simple-mcp-client-sse-core/README.md index e7e10d2..a0852fb 100644 --- a/examples/simple-mcp-client-core-sse/README.md +++ b/examples/simple-mcp-client-sse-core/README.md @@ -32,7 +32,7 @@ npx @modelcontextprotocol/server-everything sse 2. Open a new terminal and run the project with: ```bash -cargo run -p simple-mcp-client-core-sse +cargo run -p simple-mcp-client-sse-core ``` You can observe a sample output of the project; however, your results may vary slightly depending on the version of the MCP Server in use when you run it. diff --git a/examples/simple-mcp-client-core-sse/src/handler.rs b/examples/simple-mcp-client-sse-core/src/handler.rs similarity index 100% rename from examples/simple-mcp-client-core-sse/src/handler.rs rename to examples/simple-mcp-client-sse-core/src/handler.rs diff --git a/examples/simple-mcp-client-core-sse/src/inquiry_utils.rs b/examples/simple-mcp-client-sse-core/src/inquiry_utils.rs similarity index 100% rename from examples/simple-mcp-client-core-sse/src/inquiry_utils.rs rename to examples/simple-mcp-client-sse-core/src/inquiry_utils.rs diff --git a/examples/simple-mcp-client-core-sse/src/main.rs b/examples/simple-mcp-client-sse-core/src/main.rs similarity index 99% rename from examples/simple-mcp-client-core-sse/src/main.rs rename to examples/simple-mcp-client-sse-core/src/main.rs index 459f9ba..be8279b 100644 --- a/examples/simple-mcp-client-core-sse/src/main.rs +++ b/examples/simple-mcp-client-sse-core/src/main.rs @@ -44,6 +44,7 @@ async fn main() -> SdkResult<()> { // STEP 3: instantiate our custom handler that is responsible for handling MCP messages let handler = MyClientHandler {}; + // STEP 4: create the client let client = client_runtime_core::create_client(client_details, transport, handler); // STEP 5: start the MCP client diff --git a/examples/simple-mcp-client-sse/Cargo.toml b/examples/simple-mcp-client-sse/Cargo.toml index 14fd96b..bf7174d 100644 --- a/examples/simple-mcp-client-sse/Cargo.toml +++ b/examples/simple-mcp-client-sse/Cargo.toml @@ -9,6 +9,8 @@ license = "MIT" [dependencies] rust-mcp-sdk = { workspace = true, default-features = false, features = [ "client", + "sse", + "streamable-http", "macros", "2025_06_18", ] } diff --git a/examples/simple-mcp-client-sse/src/main.rs b/examples/simple-mcp-client-sse/src/main.rs index ce8850a..0a76caa 100644 --- a/examples/simple-mcp-client-sse/src/main.rs +++ b/examples/simple-mcp-client-sse/src/main.rs @@ -15,7 +15,9 @@ use std::sync::Arc; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; -const MCP_SERVER_URL: &str = "http://localhost:3001/sse"; +// Connect to a server started with the following command: +// npx @modelcontextprotocol/server-everything sse +const MCP_SERVER_URL: &str = "http://127.0.0.1:3001/sse"; #[tokio::main] async fn main() -> SdkResult<()> { @@ -44,6 +46,7 @@ async fn main() -> SdkResult<()> { // STEP 3: instantiate our custom handler that is responsible for handling MCP messages let handler = MyClientHandler {}; + // STEP 4: create the client let client = client_runtime::create_client(client_details, transport, handler); // STEP 5: start the MCP client @@ -57,6 +60,7 @@ async fn main() -> SdkResult<()> { let utils = InquiryUtils { client: Arc::clone(&client), }; + // Display server information (name and version) utils.print_server_info(); @@ -78,8 +82,11 @@ async fn main() -> SdkResult<()> { // Call add tool, and print the result utils.call_add_tool(100, 25).await?; - // Set the log level - utils.client.set_logging_level(LoggingLevel::Debug).await?; + // // Set the log level + match utils.client.set_logging_level(LoggingLevel::Debug).await { + Ok(_) => println!("Log level is set to \"Debug\""), + Err(err) => eprintln!("Error setting the Log level : {err}"), + } // Send 3 pings to the server, with a 2-second interval between each ping. utils.ping_n_times(3).await; diff --git a/examples/simple-mcp-client-core/Cargo.toml b/examples/simple-mcp-client-stdio-core/Cargo.toml similarity index 86% rename from examples/simple-mcp-client-core/Cargo.toml rename to examples/simple-mcp-client-stdio-core/Cargo.toml index 0dacc2d..6d95cf6 100644 --- a/examples/simple-mcp-client-core/Cargo.toml +++ b/examples/simple-mcp-client-stdio-core/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "simple-mcp-client-core" -version = "0.1.31" +name = "simple-mcp-client-stdio-core" +version = "0.1.28" edition = "2021" publish = false license = "MIT" @@ -10,6 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "client", "macros", + "stdio", "2025_06_18", ] } diff --git a/examples/simple-mcp-client-core/README.md b/examples/simple-mcp-client-stdio-core/README.md similarity index 97% rename from examples/simple-mcp-client-core/README.md rename to examples/simple-mcp-client-stdio-core/README.md index 52d8074..f3258aa 100644 --- a/examples/simple-mcp-client-core/README.md +++ b/examples/simple-mcp-client-stdio-core/README.md @@ -24,7 +24,7 @@ cd rust-mcp-sdk 2. RUn the project: ```bash -cargo run -p simple-mcp-client-core +cargo run -p simple-mcp-client-stdio-core ``` You can observe a sample output of the project; however, your results may vary slightly depending on the version of the MCP Server in use when you run it. diff --git a/examples/simple-mcp-client-core/src/handler.rs b/examples/simple-mcp-client-stdio-core/src/handler.rs similarity index 100% rename from examples/simple-mcp-client-core/src/handler.rs rename to examples/simple-mcp-client-stdio-core/src/handler.rs diff --git a/examples/simple-mcp-client-core/src/inquiry_utils.rs b/examples/simple-mcp-client-stdio-core/src/inquiry_utils.rs similarity index 100% rename from examples/simple-mcp-client-core/src/inquiry_utils.rs rename to examples/simple-mcp-client-stdio-core/src/inquiry_utils.rs diff --git a/examples/simple-mcp-client-core/src/main.rs b/examples/simple-mcp-client-stdio-core/src/main.rs similarity index 100% rename from examples/simple-mcp-client-core/src/main.rs rename to examples/simple-mcp-client-stdio-core/src/main.rs diff --git a/examples/simple-mcp-client/Cargo.toml b/examples/simple-mcp-client-stdio/Cargo.toml similarity index 87% rename from examples/simple-mcp-client/Cargo.toml rename to examples/simple-mcp-client-stdio/Cargo.toml index 9599c46..3597105 100644 --- a/examples/simple-mcp-client/Cargo.toml +++ b/examples/simple-mcp-client-stdio/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "simple-mcp-client" -version = "0.1.31" +name = "simple-mcp-client-stdio" +version = "0.1.28" edition = "2021" publish = false license = "MIT" @@ -10,6 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "client", "macros", + "stdio", "2025_06_18", ] } diff --git a/examples/simple-mcp-client/README.md b/examples/simple-mcp-client-stdio/README.md similarity index 97% rename from examples/simple-mcp-client/README.md rename to examples/simple-mcp-client-stdio/README.md index c56a933..be17f02 100644 --- a/examples/simple-mcp-client/README.md +++ b/examples/simple-mcp-client-stdio/README.md @@ -24,7 +24,7 @@ cd rust-mcp-sdk 2. RUn the project: ```bash -cargo run -p simple-mcp-client +cargo run -p simple-mcp-client-stdio ``` You can observe a sample output of the project; however, your results may vary slightly depending on the version of the MCP Server in use when you run it. diff --git a/examples/simple-mcp-client/src/handler.rs b/examples/simple-mcp-client-stdio/src/handler.rs similarity index 100% rename from examples/simple-mcp-client/src/handler.rs rename to examples/simple-mcp-client-stdio/src/handler.rs diff --git a/examples/simple-mcp-client/src/inquiry_utils.rs b/examples/simple-mcp-client-stdio/src/inquiry_utils.rs similarity index 100% rename from examples/simple-mcp-client/src/inquiry_utils.rs rename to examples/simple-mcp-client-stdio/src/inquiry_utils.rs diff --git a/examples/simple-mcp-client/src/main.rs b/examples/simple-mcp-client-stdio/src/main.rs similarity index 100% rename from examples/simple-mcp-client/src/main.rs rename to examples/simple-mcp-client-stdio/src/main.rs diff --git a/examples/simple-mcp-client-streamable-http-core/Cargo.toml b/examples/simple-mcp-client-streamable-http-core/Cargo.toml new file mode 100644 index 0000000..68356e1 --- /dev/null +++ b/examples/simple-mcp-client-streamable-http-core/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "simple-mcp-client-streamable-http-core" +version = "0.1.0" +edition = "2021" +publish = false +license = "MIT" + + +[dependencies] +rust-mcp-sdk = { workspace = true, default-features = false, features = [ + "client", + "macros", + "streamable-http", + "2025_06_18", +] } + +tokio = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +async-trait = { workspace = true } +futures = { workspace = true } +thiserror = { workspace = true } +colored = "3.0.0" +tracing-subscriber = { workspace = true } +tracing = { workspace = true } + + +[lints] +workspace = true diff --git a/examples/simple-mcp-client-streamable-http-core/README.md b/examples/simple-mcp-client-streamable-http-core/README.md new file mode 100644 index 0000000..a0852fb --- /dev/null +++ b/examples/simple-mcp-client-streamable-http-core/README.md @@ -0,0 +1,40 @@ +# Simple MCP Client Core (SSE) + +This is a simple MCP (Model Context Protocol) client implemented with the rust-mcp-sdk, dmeonstrating SSE transport, showcasing fundamental MCP client operations like fetching the MCP server's capabilities and executing a tool call. + +## Overview + +This project demonstrates a basic MCP client implementation, showcasing the features of the [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk). + +This example connects to a running instance of the [@modelcontextprotocol/server-everything](https://www.npmjs.com/package/@modelcontextprotocol/server-everything) server, which has already been started with the sse flag. + +It displays the server name and version, outlines the server's capabilities, and provides a list of available tools, prompts, templates, resources, and more offered by the server. Additionally, it will execute a tool call by utilizing the add tool from the server-everything package to sum two numbers and output the result. + +> Note that @modelcontextprotocol/server-everything is an npm package, so you must have Node.js and npm installed on your system, as this example attempts to start it. + +## Running the Example + +1. Clone the repository: + +```bash +git clone git@github.com:rust-mcp-stack/rust-mcp-sdk.git +cd rust-mcp-sdk +``` + +2- Start `@modelcontextprotocol/server-everything` with SSE argument: + +```bash +npx @modelcontextprotocol/server-everything sse +``` + +> It launches the server, making everything accessible via the SSE transport at http://localhost:3001/sse. + +2. Open a new terminal and run the project with: + +```bash +cargo run -p simple-mcp-client-sse-core +``` + +You can observe a sample output of the project; however, your results may vary slightly depending on the version of the MCP Server in use when you run it. + + diff --git a/examples/simple-mcp-client-streamable-http-core/src/handler.rs b/examples/simple-mcp-client-streamable-http-core/src/handler.rs new file mode 100644 index 0000000..ab86e9e --- /dev/null +++ b/examples/simple-mcp-client-streamable-http-core/src/handler.rs @@ -0,0 +1,72 @@ +use async_trait::async_trait; +use rust_mcp_sdk::schema::{ + self, + schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient}, + RpcError, ServerRequest, +}; +use rust_mcp_sdk::{mcp_client::ClientHandlerCore, McpClient}; +pub struct MyClientHandler; + +// To check out a list of all the methods in the trait that you can override, take a look at +// https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs + +#[async_trait] +impl ClientHandlerCore for MyClientHandler { + async fn handle_request( + &self, + request: RequestFromServer, + _runtime: &dyn McpClient, + ) -> std::result::Result { + match request { + RequestFromServer::ServerRequest(server_request) => match server_request { + ServerRequest::PingRequest(_) => { + return Ok(schema::Result::default().into()); + } + ServerRequest::CreateMessageRequest(_create_message_request) => { + Err(RpcError::internal_error().with_message( + "CreateMessageRequest handler is not implemented".to_string(), + )) + } + ServerRequest::ListRootsRequest(_list_roots_request) => { + Err(RpcError::internal_error() + .with_message("ListRootsRequest handler is not implemented".to_string())) + } + ServerRequest::ElicitRequest(_elicit_request) => Err(RpcError::internal_error() + .with_message("ElicitRequest handler is not implemented".to_string())), + }, + RequestFromServer::CustomRequest(_value) => Err(RpcError::internal_error() + .with_message("CustomRequest handler is not implemented".to_string())), + } + } + + async fn handle_notification( + &self, + notification: NotificationFromServer, + _runtime: &dyn McpClient, + ) -> std::result::Result<(), RpcError> { + if let NotificationFromServer::ServerNotification( + schema::ServerNotification::LoggingMessageNotification(logging_message_notification), + ) = notification + { + println!( + "Notification from server: {}", + logging_message_notification.params.data + ); + } else { + println!( + "A {} notification received from the server", + notification.method() + ); + }; + + Ok(()) + } + + async fn handle_error( + &self, + _error: &RpcError, + _runtime: &dyn McpClient, + ) -> std::result::Result<(), RpcError> { + Err(RpcError::internal_error().with_message("handle_error() Not implemented".to_string())) + } +} diff --git a/examples/simple-mcp-client-streamable-http-core/src/inquiry_utils.rs b/examples/simple-mcp-client-streamable-http-core/src/inquiry_utils.rs new file mode 100644 index 0000000..a8e7c9c --- /dev/null +++ b/examples/simple-mcp-client-streamable-http-core/src/inquiry_utils.rs @@ -0,0 +1,222 @@ +//! This module contains utility functions for querying and displaying server capabilities. + +use colored::Colorize; +use rust_mcp_sdk::schema::CallToolRequestParams; +use rust_mcp_sdk::McpClient; +use rust_mcp_sdk::{error::SdkResult, mcp_client::ClientRuntime}; +use serde_json::json; +use std::io::Write; +use std::sync::Arc; +use std::time::Duration; +use tokio::time::sleep; + +const GREY_COLOR: (u8, u8, u8) = (90, 90, 90); +const HEADER_SIZE: usize = 31; + +pub struct InquiryUtils { + pub client: Arc, +} + +impl InquiryUtils { + fn print_header(&self, title: &str) { + let pad = ((HEADER_SIZE as f32 / 2.0) + (title.len() as f32 / 2.0)).floor() as usize; + println!("\n{}", "=".repeat(HEADER_SIZE).custom_color(GREY_COLOR)); + println!("{:>pad$}", title.custom_color(GREY_COLOR)); + println!("{}", "=".repeat(HEADER_SIZE).custom_color(GREY_COLOR)); + } + + fn print_list(&self, list_items: Vec<(String, String)>) { + list_items.iter().enumerate().for_each(|(index, item)| { + println!("{}. {}: {}", index + 1, item.0.yellow(), item.1.cyan(),); + }); + } + + pub fn print_server_info(&self) { + self.print_header("Server info"); + let server_version = self.client.server_version().unwrap(); + println!("{} {}", "Server name:".bold(), server_version.name.cyan()); + println!( + "{} {}", + "Server version:".bold(), + server_version.version.cyan() + ); + } + + pub fn print_server_capabilities(&self) { + self.print_header("Capabilities"); + let capability_vec = [ + ("tools", self.client.server_has_tools()), + ("prompts", self.client.server_has_prompts()), + ("resources", self.client.server_has_resources()), + ("logging", self.client.server_supports_logging()), + ("experimental", self.client.server_has_experimental()), + ]; + + capability_vec.iter().for_each(|(tool_name, opt)| { + println!( + "{}: {}", + tool_name.bold(), + opt.map(|b| if b { "Yes" } else { "No" }) + .unwrap_or("Unknown") + .cyan() + ); + }); + } + + pub async fn print_tool_list(&self) -> SdkResult<()> { + // Return if the MCP server does not support tools + if !self.client.server_has_tools().unwrap_or(false) { + return Ok(()); + } + + let tools = self.client.list_tools(None).await?; + self.print_header("Tools"); + self.print_list( + tools + .tools + .iter() + .map(|item| { + ( + item.name.clone(), + item.description.clone().unwrap_or_default(), + ) + }) + .collect(), + ); + + Ok(()) + } + + pub async fn print_prompts_list(&self) -> SdkResult<()> { + // Return if the MCP server does not support prompts + if !self.client.server_has_prompts().unwrap_or(false) { + return Ok(()); + } + + let prompts = self.client.list_prompts(None).await?; + + self.print_header("Prompts"); + self.print_list( + prompts + .prompts + .iter() + .map(|item| { + ( + item.name.clone(), + item.description.clone().unwrap_or_default(), + ) + }) + .collect(), + ); + Ok(()) + } + + pub async fn print_resource_list(&self) -> SdkResult<()> { + // Return if the MCP server does not support resources + if !self.client.server_has_resources().unwrap_or(false) { + return Ok(()); + } + + let resources = self.client.list_resources(None).await?; + + self.print_header("Resources"); + + self.print_list( + resources + .resources + .iter() + .map(|item| { + ( + item.name.clone(), + format!( + "( uri: {} , mime: {}", + item.uri, + item.mime_type.as_ref().unwrap_or(&"?".to_string()), + ), + ) + }) + .collect(), + ); + + Ok(()) + } + + pub async fn print_resource_templates(&self) -> SdkResult<()> { + // Return if the MCP server does not support resources + if !self.client.server_has_resources().unwrap_or(false) { + return Ok(()); + } + + let templates = self.client.list_resource_templates(None).await?; + + self.print_header("Resource Templates"); + + self.print_list( + templates + .resource_templates + .iter() + .map(|item| { + ( + item.name.clone(), + item.description.clone().unwrap_or_default(), + ) + }) + .collect(), + ); + Ok(()) + } + + pub async fn call_add_tool(&self, a: i64, b: i64) -> SdkResult<()> { + // Invoke the "add" tool with 100 and 25 as arguments, and display the result + println!( + "{}", + format!("\nCalling the \"add\" tool with {a} and {b} ...").magenta() + ); + + // Create a `Map` to represent the tool parameters + let params = json!({ + "a": a, + "b": b + }) + .as_object() + .unwrap() + .clone(); + + // invoke the tool + let result = self + .client + .call_tool(CallToolRequestParams { + name: "add".to_string(), + arguments: Some(params), + }) + .await?; + + // Retrieve the result content and print it to the stdout + let result_content = result.content.first().unwrap().as_text_content()?; + println!("{}", result_content.text.green()); + + Ok(()) + } + + pub async fn ping_n_times(&self, n: i32) { + let max_pings = n; + println!(); + for ping_index in 1..=max_pings { + print!("Ping the server ({ping_index} out of {max_pings})..."); + std::io::stdout().flush().unwrap(); + let ping_result = self.client.ping(None).await; + print!( + "\rPing the server ({} out of {}) : {}", + ping_index, + max_pings, + if ping_result.is_ok() { + "success".bright_green() + } else { + "failed".bright_red() + } + ); + println!(); + sleep(Duration::from_secs(2)).await; + } + } +} diff --git a/examples/simple-mcp-client-streamable-http-core/src/main.rs b/examples/simple-mcp-client-streamable-http-core/src/main.rs new file mode 100644 index 0000000..e1a5849 --- /dev/null +++ b/examples/simple-mcp-client-streamable-http-core/src/main.rs @@ -0,0 +1,95 @@ +mod handler; +mod inquiry_utils; + +use handler::MyClientHandler; + +use inquiry_utils::InquiryUtils; +use rust_mcp_sdk::error::SdkResult; +use rust_mcp_sdk::mcp_client::client_runtime_core; +use rust_mcp_sdk::schema::{ + ClientCapabilities, Implementation, InitializeRequestParams, LoggingLevel, + LATEST_PROTOCOL_VERSION, +}; +use rust_mcp_sdk::{McpClient, RequestOptions, StreamableTransportOptions}; +use std::sync::Arc; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; + +// Assuming @modelcontextprotocol/server-everything is launched with streamableHttp argument and listening on port 3001 +const MCP_SERVER_URL: &str = "http://127.0.0.1:3001/mcp"; + +#[tokio::main] +async fn main() -> SdkResult<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + // Step1 : Define client details and capabilities + let client_details: InitializeRequestParams = InitializeRequestParams { + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "simple-rust-mcp-client-core-sse".to_string(), + version: "0.1.0".to_string(), + title: Some("Simple Rust MCP Client (Core,SSE)".to_string()), + }, + protocol_version: LATEST_PROTOCOL_VERSION.into(), + }; + + // Step 2: Create transport options to connect to an MCP server via Streamable HTTP. + let transport_options = StreamableTransportOptions { + mcp_url: MCP_SERVER_URL.to_string(), + request_options: RequestOptions { + ..RequestOptions::default() + }, + }; + // STEP 3: instantiate our custom handler that is responsible for handling MCP messages + let handler = MyClientHandler {}; + + // STEP 4: create the client + let client = + client_runtime_core::with_transport_options(client_details, transport_options, handler); + + // STEP 5: start the MCP client + client.clone().start().await?; + + // You can utilize the client and its methods to interact with the MCP Server. + // The following demonstrates how to use client methods to retrieve server information, + // and print them in the terminal, set the log level, invoke a tool, and more. + + // Create a struct with utility functions for demonstration purpose, to utilize different client methods and display the information. + let utils = InquiryUtils { + client: Arc::clone(&client), + }; + // Display server information (name and version) + utils.print_server_info(); + + // Display server capabilities + utils.print_server_capabilities(); + + // Display the list of tools available on the server + utils.print_tool_list().await?; + + // Display the list of prompts available on the server + utils.print_prompts_list().await?; + + // Display the list of resources available on the server + utils.print_resource_list().await?; + + // Display the list of resource templates available on the server + utils.print_resource_templates().await?; + + // Call add tool, and print the result + utils.call_add_tool(100, 25).await?; + + // Set the log level + utils.client.set_logging_level(LoggingLevel::Debug).await?; + + // Send 3 pings to the server, with a 2-second interval between each ping. + utils.ping_n_times(3).await; + client.shut_down().await?; + + Ok(()) +} diff --git a/examples/simple-mcp-client-streamable-http/Cargo.toml b/examples/simple-mcp-client-streamable-http/Cargo.toml new file mode 100644 index 0000000..0638aab --- /dev/null +++ b/examples/simple-mcp-client-streamable-http/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "simple-mcp-client-streamable-http" +version = "0.1.0" +edition = "2021" +publish = false +license = "MIT" + + +[dependencies] +rust-mcp-sdk = { workspace = true, default-features = false, features = [ + "client", + "streamable-http", + "macros", + "2025_06_18", +] } + +tokio = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +async-trait = { workspace = true } +futures = { workspace = true } +thiserror = { workspace = true } +colored = "3.0.0" +tracing-subscriber = { workspace = true } +tracing = { workspace = true } + + +[lints] +workspace = true diff --git a/examples/simple-mcp-client-streamable-http/README.md b/examples/simple-mcp-client-streamable-http/README.md new file mode 100644 index 0000000..5b4488e --- /dev/null +++ b/examples/simple-mcp-client-streamable-http/README.md @@ -0,0 +1,40 @@ +# Simple MCP Client (SSE) + +This is a simple MCP (Model Context Protocol) client implemented with the rust-mcp-sdk, dmeonstrating SSE transport, showcasing fundamental MCP client operations like fetching the MCP server's capabilities and executing a tool call. + +## Overview + +This project demonstrates a basic MCP client implementation, showcasing the features of the [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk). + +This example connects to a running instance of the [@modelcontextprotocol/server-everything](https://www.npmjs.com/package/@modelcontextprotocol/server-everything) server, which has already been started with the sse flag. + +It displays the server name and version, outlines the server's capabilities, and provides a list of available tools, prompts, templates, resources, and more offered by the server. Additionally, it will execute a tool call by utilizing the add tool from the server-everything package to sum two numbers and output the result. + +> Note that @modelcontextprotocol/server-everything is an npm package, so you must have Node.js and npm installed on your system, as this example attempts to start it. + +## Running the Example + +1. Clone the repository: + +```bash +git clone git@github.com:rust-mcp-stack/rust-mcp-sdk.git +cd rust-mcp-sdk +``` + +2- Start `@modelcontextprotocol/server-everything` with SSE argument: + +```bash +npx @modelcontextprotocol/server-everything sse +``` + +> It launches the server, making everything accessible via the SSE transport at http://localhost:3001/sse. + +2. Open a new terminal and run the project with: + +```bash +cargo run -p simple-mcp-client-sse +``` + +You can observe a sample output of the project; however, your results may vary slightly depending on the version of the MCP Server in use when you run it. + + diff --git a/examples/simple-mcp-client-streamable-http/src/handler.rs b/examples/simple-mcp-client-streamable-http/src/handler.rs new file mode 100644 index 0000000..19360f6 --- /dev/null +++ b/examples/simple-mcp-client-streamable-http/src/handler.rs @@ -0,0 +1,10 @@ +use async_trait::async_trait; +use rust_mcp_sdk::mcp_client::ClientHandler; + +pub struct MyClientHandler; + +#[async_trait] +impl ClientHandler for MyClientHandler { + // To check out a list of all the methods in the trait that you can override, take a look at + // https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs +} diff --git a/examples/simple-mcp-client-streamable-http/src/inquiry_utils.rs b/examples/simple-mcp-client-streamable-http/src/inquiry_utils.rs new file mode 100644 index 0000000..a8e7c9c --- /dev/null +++ b/examples/simple-mcp-client-streamable-http/src/inquiry_utils.rs @@ -0,0 +1,222 @@ +//! This module contains utility functions for querying and displaying server capabilities. + +use colored::Colorize; +use rust_mcp_sdk::schema::CallToolRequestParams; +use rust_mcp_sdk::McpClient; +use rust_mcp_sdk::{error::SdkResult, mcp_client::ClientRuntime}; +use serde_json::json; +use std::io::Write; +use std::sync::Arc; +use std::time::Duration; +use tokio::time::sleep; + +const GREY_COLOR: (u8, u8, u8) = (90, 90, 90); +const HEADER_SIZE: usize = 31; + +pub struct InquiryUtils { + pub client: Arc, +} + +impl InquiryUtils { + fn print_header(&self, title: &str) { + let pad = ((HEADER_SIZE as f32 / 2.0) + (title.len() as f32 / 2.0)).floor() as usize; + println!("\n{}", "=".repeat(HEADER_SIZE).custom_color(GREY_COLOR)); + println!("{:>pad$}", title.custom_color(GREY_COLOR)); + println!("{}", "=".repeat(HEADER_SIZE).custom_color(GREY_COLOR)); + } + + fn print_list(&self, list_items: Vec<(String, String)>) { + list_items.iter().enumerate().for_each(|(index, item)| { + println!("{}. {}: {}", index + 1, item.0.yellow(), item.1.cyan(),); + }); + } + + pub fn print_server_info(&self) { + self.print_header("Server info"); + let server_version = self.client.server_version().unwrap(); + println!("{} {}", "Server name:".bold(), server_version.name.cyan()); + println!( + "{} {}", + "Server version:".bold(), + server_version.version.cyan() + ); + } + + pub fn print_server_capabilities(&self) { + self.print_header("Capabilities"); + let capability_vec = [ + ("tools", self.client.server_has_tools()), + ("prompts", self.client.server_has_prompts()), + ("resources", self.client.server_has_resources()), + ("logging", self.client.server_supports_logging()), + ("experimental", self.client.server_has_experimental()), + ]; + + capability_vec.iter().for_each(|(tool_name, opt)| { + println!( + "{}: {}", + tool_name.bold(), + opt.map(|b| if b { "Yes" } else { "No" }) + .unwrap_or("Unknown") + .cyan() + ); + }); + } + + pub async fn print_tool_list(&self) -> SdkResult<()> { + // Return if the MCP server does not support tools + if !self.client.server_has_tools().unwrap_or(false) { + return Ok(()); + } + + let tools = self.client.list_tools(None).await?; + self.print_header("Tools"); + self.print_list( + tools + .tools + .iter() + .map(|item| { + ( + item.name.clone(), + item.description.clone().unwrap_or_default(), + ) + }) + .collect(), + ); + + Ok(()) + } + + pub async fn print_prompts_list(&self) -> SdkResult<()> { + // Return if the MCP server does not support prompts + if !self.client.server_has_prompts().unwrap_or(false) { + return Ok(()); + } + + let prompts = self.client.list_prompts(None).await?; + + self.print_header("Prompts"); + self.print_list( + prompts + .prompts + .iter() + .map(|item| { + ( + item.name.clone(), + item.description.clone().unwrap_or_default(), + ) + }) + .collect(), + ); + Ok(()) + } + + pub async fn print_resource_list(&self) -> SdkResult<()> { + // Return if the MCP server does not support resources + if !self.client.server_has_resources().unwrap_or(false) { + return Ok(()); + } + + let resources = self.client.list_resources(None).await?; + + self.print_header("Resources"); + + self.print_list( + resources + .resources + .iter() + .map(|item| { + ( + item.name.clone(), + format!( + "( uri: {} , mime: {}", + item.uri, + item.mime_type.as_ref().unwrap_or(&"?".to_string()), + ), + ) + }) + .collect(), + ); + + Ok(()) + } + + pub async fn print_resource_templates(&self) -> SdkResult<()> { + // Return if the MCP server does not support resources + if !self.client.server_has_resources().unwrap_or(false) { + return Ok(()); + } + + let templates = self.client.list_resource_templates(None).await?; + + self.print_header("Resource Templates"); + + self.print_list( + templates + .resource_templates + .iter() + .map(|item| { + ( + item.name.clone(), + item.description.clone().unwrap_or_default(), + ) + }) + .collect(), + ); + Ok(()) + } + + pub async fn call_add_tool(&self, a: i64, b: i64) -> SdkResult<()> { + // Invoke the "add" tool with 100 and 25 as arguments, and display the result + println!( + "{}", + format!("\nCalling the \"add\" tool with {a} and {b} ...").magenta() + ); + + // Create a `Map` to represent the tool parameters + let params = json!({ + "a": a, + "b": b + }) + .as_object() + .unwrap() + .clone(); + + // invoke the tool + let result = self + .client + .call_tool(CallToolRequestParams { + name: "add".to_string(), + arguments: Some(params), + }) + .await?; + + // Retrieve the result content and print it to the stdout + let result_content = result.content.first().unwrap().as_text_content()?; + println!("{}", result_content.text.green()); + + Ok(()) + } + + pub async fn ping_n_times(&self, n: i32) { + let max_pings = n; + println!(); + for ping_index in 1..=max_pings { + print!("Ping the server ({ping_index} out of {max_pings})..."); + std::io::stdout().flush().unwrap(); + let ping_result = self.client.ping(None).await; + print!( + "\rPing the server ({} out of {}) : {}", + ping_index, + max_pings, + if ping_result.is_ok() { + "success".bright_green() + } else { + "failed".bright_red() + } + ); + println!(); + sleep(Duration::from_secs(2)).await; + } + } +} diff --git a/examples/simple-mcp-client-streamable-http/src/main.rs b/examples/simple-mcp-client-streamable-http/src/main.rs new file mode 100644 index 0000000..ab580db --- /dev/null +++ b/examples/simple-mcp-client-streamable-http/src/main.rs @@ -0,0 +1,99 @@ +mod handler; +mod inquiry_utils; + +use handler::MyClientHandler; + +use rust_mcp_sdk::error::SdkResult; +use rust_mcp_sdk::mcp_client::client_runtime; +use rust_mcp_sdk::schema::{ + ClientCapabilities, Implementation, InitializeRequestParams, LoggingLevel, + LATEST_PROTOCOL_VERSION, +}; +use rust_mcp_sdk::{McpClient, RequestOptions, StreamableTransportOptions}; +use std::sync::Arc; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; + +use crate::inquiry_utils::InquiryUtils; + +const MCP_SERVER_URL: &str = "http://127.0.0.1:8080/mcp"; + +#[tokio::main] +async fn main() -> SdkResult<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + // Step1 : Define client details and capabilities + let client_details: InitializeRequestParams = InitializeRequestParams { + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "simple-rust-mcp-client-sse".to_string(), + version: "0.1.0".to_string(), + title: Some("Simple Rust MCP Client (SSE)".to_string()), + }, + protocol_version: LATEST_PROTOCOL_VERSION.into(), + }; + + // Step 2: Create transport options to connect to an MCP server via Streamable HTTP. + let transport_options = StreamableTransportOptions { + mcp_url: MCP_SERVER_URL.to_string(), + request_options: RequestOptions { + ..RequestOptions::default() + }, + }; + + // STEP 3: instantiate our custom handler that is responsible for handling MCP messages + let handler = MyClientHandler {}; + + // STEP 4: create the client with transport options and the handler + let client = client_runtime::with_transport_options(client_details, transport_options, handler); + + // STEP 5: start the MCP client + client.clone().start().await?; + + // You can utilize the client and its methods to interact with the MCP Server. + // The following demonstrates how to use client methods to retrieve server information, + // and print them in the terminal, set the log level, invoke a tool, and more. + + // Create a struct with utility functions for demonstration purpose, to utilize different client methods and display the information. + let utils = InquiryUtils { + client: Arc::clone(&client), + }; + + // Display server information (name and version) + utils.print_server_info(); + + // Display server capabilities + utils.print_server_capabilities(); + + // Display the list of tools available on the server + utils.print_tool_list().await?; + + // Display the list of prompts available on the server + utils.print_prompts_list().await?; + + // Display the list of resources available on the server + utils.print_resource_list().await?; + + // Display the list of resource templates available on the server + utils.print_resource_templates().await?; + + // Call add tool, and print the result + utils.call_add_tool(100, 25).await?; + + // Set the log level + match utils.client.set_logging_level(LoggingLevel::Debug).await { + Ok(_) => println!("Log level is set to \"Debug\""), + Err(err) => eprintln!("Error setting the Log level : {err}"), + } + + // Send 3 pings to the server, with a 2-second interval between each ping. + utils.ping_n_times(3).await; + client.shut_down().await?; + + Ok(()) +} From 39be611055bbe1dd95ecb2b25eb3b4878dab1cb4 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sat, 13 Sep 2025 12:28:41 -0300 Subject: [PATCH 2/4] chore: typos --- .../src/mcp_runtimes/client_runtime.rs | 6 +++--- crates/rust-mcp-sdk/tests/common/test_client.rs | 2 +- .../tests/test_streamable_http_client.rs | 14 +++++++------- crates/rust-mcp-transport/src/utils/sse_parser.rs | 2 +- .../src/utils/streamable_http_stream.rs | 4 ++-- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs index 9961b84..2093dc3 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs @@ -454,9 +454,9 @@ impl ClientRuntime { let result = transport.send_message(messages, timeout).await?; if no_session_id { - if let Some(resquest_id) = transport.session_id().await.clone() { + if let Some(request_id) = transport.session_id().await.clone() { let mut guard = self.session_id.write().await; - *guard = Some(resquest_id) + *guard = Some(request_id) } } @@ -515,7 +515,7 @@ impl ClientRuntime { // Run both tasks with cancellation logic let (send_res, _) = tokio::select! { res = &mut send_task => { - // cancel the receive_task task, to cover the case where sned_task returns with error + // cancel the receive_task task, to cover the case where send_task returns with error abort_recv_handle.abort(); (res, receive_task.await) // Wait for receive_task to finish (it should exit due to cancellation) } diff --git a/crates/rust-mcp-sdk/tests/common/test_client.rs b/crates/rust-mcp-sdk/tests/common/test_client.rs index 21678c7..46a8525 100644 --- a/crates/rust-mcp-sdk/tests/common/test_client.rs +++ b/crates/rust-mcp-sdk/tests/common/test_client.rs @@ -89,7 +89,7 @@ pub mod test_client_common { ) -> InitializedClient { let mock_server = MockServer::start().await; - // intialize response + // initialize response let mut response = create_sse_response(INITIALIZE_RESPONSE); if let Some(session_id) = session_id { diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs b/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs index a0a2804..cb82ff5 100644 --- a/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs +++ b/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs @@ -39,7 +39,7 @@ async fn should_send_json_rpc_messages_via_post() { // Start a mock server let mock_server = MockServer::start().await; - // intialize response + // initialize response let response = create_sse_response(INITIALIZE_RESPONSE); // initialize request and response @@ -137,7 +137,7 @@ async fn should_store_session_id_received_during_initialization() { // Start a mock server let mock_server = MockServer::start().await; - // intialize response + // initialize response let response = create_sse_response(INITIALIZE_RESPONSE).append_header("mcp-session-id", "test-session-id"); @@ -283,7 +283,7 @@ async fn should_handle_successful_initial_get_connection_for_sse() { // Start a mock server let mock_server = MockServer::start().await; - // intialize response + // initialize response let response = create_sse_response(INITIALIZE_RESPONSE); // initialize request and response @@ -394,7 +394,7 @@ async fn should_attempt_initial_get_connection_and_handle_405_gracefully() { // Start a mock server let mock_server = MockServer::start().await; - // intialize response + // initialize response let response = create_sse_response(INITIALIZE_RESPONSE); // initialize request and response @@ -445,7 +445,7 @@ async fn should_attempt_initial_get_connection_and_handle_405_gracefully() { assert!(get_request.is_some()); - // send a batch message, runtime should work as expected with no isse + // send a batch message, runtime should work as expected with no issue let response = create_sse_response( r#"[{"id":"id1","jsonrpc":"2.0", "result":{}},{"id":"id2","jsonrpc":"2.0", "result":{}}]"#, @@ -616,7 +616,7 @@ async fn should_reconnect_a_get_initiated_notification_stream_that_fails() { // Start a mock server let mock_server = MockServer::start().await; - // intialize response + // initialize response let response = create_sse_response(INITIALIZE_RESPONSE); // initialize request and response @@ -726,7 +726,7 @@ async fn should_pass_last_event_id_when_reconnecting() { assert!(get_requests.len() > 1); let Some(last_get_request) = get_requests.last() else { - panic!("Unable to find last GET reuest!"); + panic!("Unable to find last GET request!"); }; let last_event_id = last_get_request diff --git a/crates/rust-mcp-transport/src/utils/sse_parser.rs b/crates/rust-mcp-transport/src/utils/sse_parser.rs index 064d3c3..5933726 100644 --- a/crates/rust-mcp-transport/src/utils/sse_parser.rs +++ b/crates/rust-mcp-transport/src/utils/sse_parser.rs @@ -62,7 +62,7 @@ impl fmt::Debug for SseEvent { } /// A parser for Server-Sent Events (SSE) that processes incoming byte chunks into `SseEvent`s. -/// This Parser is specificly designed for MCP messages and with no multi-line data support +/// This Parser is specifically designed for MCP messages and with no multi-line data support /// /// This struct maintains a buffer to accumulate incoming data and parses it into SSE events /// based on the SSE protocol. It handles fields like `event`, `data`, and `id` as defined diff --git a/crates/rust-mcp-transport/src/utils/streamable_http_stream.rs b/crates/rust-mcp-transport/src/utils/streamable_http_stream.rs index ae9c69c..3362c71 100644 --- a/crates/rust-mcp-transport/src/utils/streamable_http_stream.rs +++ b/crates/rust-mcp-transport/src/utils/streamable_http_stream.rs @@ -130,7 +130,7 @@ impl StreamableHttpStream { match chunk { Some(chunk) => chunk, None => { - // stream ended, unline SSE, so no retry attempt here needed to reconnect + // stream ended, unlike SSE, so no retry attempt here needed to reconnect return Err(TransportError::Internal("Stream has ended.".to_string())); } } @@ -315,7 +315,7 @@ impl StreamableHttpStream { match chunk { Some(chunk) => chunk, None => { - // stream ended, unline SSE, so no retry attempt here needed to reconnect + // stream ended, unlike SSE, so no retry attempt here needed to reconnect return Err(TransportError::Internal("Stream has ended.".to_string())); } } From 6204d79232ad876a9050e42028b2b5b38ded7279 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sat, 13 Sep 2025 14:35:32 -0300 Subject: [PATCH 3/4] chore: update readme --- README.md | 145 +++++++++++++++++++++++++--------- crates/rust-mcp-sdk/README.md | 142 +++++++++++++++++++++++++-------- 2 files changed, 215 insertions(+), 72 deletions(-) diff --git a/README.md b/README.md index b1af670..ced3672 100644 --- a/README.md +++ b/README.md @@ -32,27 +32,15 @@ This project supports following transports: 🚀 The **rust-mcp-sdk** includes a lightweight [Axum](https://github.com/tokio-rs/axum) based server that handles all core functionality seamlessly. Switching between `stdio` and `Streamable HTTP` is straightforward, requiring minimal code changes. The server is designed to efficiently handle multiple concurrent client connections and offers built-in support for SSL. - **MCP Streamable HTTP Support** - ✅ Streamable HTTP Support for MCP Servers - ✅ DNS Rebinding Protection - ✅ Batch Messages - ✅ Streaming & non-streaming JSON response -- ⬜ Streamable HTTP Support for MCP Clients +- ✅ Streamable HTTP Support for MCP Clients - ⬜ Resumability - ⬜ Authentication / Oauth - - -**MCP Streamable HTTP Support** -- [x] Streamable HTTP Support for MCP Servers -- [x] DNS Rebinding Protection -- [x] Batch Messages -- [x] Streaming & non-streaming JSON response -- [ ] Streamable HTTP Support for MCP Clients -- [ ] Resumability -- [ ] Authentication / Oauth - **⚠️** Project is currently under development and should be used at your own risk. ## Table of Contents @@ -60,6 +48,7 @@ This project supports following transports: - [MCP Server (stdio)](#mcp-server-stdio) - [MCP Server (Streamable HTTP)](#mcp-server-streamable-http) - [MCP Client (stdio)](#mcp-client-stdio) + - [MCP Client (Streamable HTTP)](#mcp-client_streamable-http)) - [MCP Client (sse)](#mcp-client-sse) - [Getting Started](#getting-started) - [HyperServerOptions](#hyperserveroptions) @@ -202,7 +191,7 @@ impl ServerHandler for MyServerHandler { } /// Handles requests to call a specific tool. - async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: Arc, ) -> Result { + async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: Arc ) -> Result { if request.tool_name() == SayHelloTool::tool_name() { Ok( CallToolResult::text_content( vec![TextContent::from("Hello World!".to_string())] )) @@ -294,6 +283,8 @@ async fn main() -> SdkResult<()> { println!("{}",result.content.first().unwrap().as_text_content()?.text); + client.shut_down().await?; + Ok(()) } @@ -305,8 +296,82 @@ Here is the output : > your results may vary slightly depending on the version of the MCP Server in use when you run it. +### MCP Client (Streamable HTTP) +```rs + +// STEP 1: Custom Handler to handle incoming MCP Messages +pub struct MyClientHandler; + +#[async_trait] +impl ClientHandler for MyClientHandler { + // To check out a list of all the methods in the trait that you can override, take a look at https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs +} + +#[tokio::main] +async fn main() -> SdkResult<()> { + + // Step2 : Define client details and capabilities + let client_details: InitializeRequestParams = InitializeRequestParams { + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "simple-rust-mcp-client-sse".to_string(), + version: "0.1.0".to_string(), + title: Some("Simple Rust MCP Client (SSE)".to_string()), + }, + protocol_version: LATEST_PROTOCOL_VERSION.into(), + }; + + // Step 3: Create transport options to connect to an MCP server via Streamable HTTP. + let transport_options = StreamableTransportOptions { + mcp_url: MCP_SERVER_URL.to_string(), + request_options: RequestOptions { + ..RequestOptions::default() + }, + }; + + // STEP 4: instantiate the custom handler that is responsible for handling MCP messages + let handler = MyClientHandler {}; + + // STEP 5: create the client with transport options and the handler + let client = client_runtime::with_transport_options(client_details, transport_options, handler); + + // STEP 6: start the MCP client + client.clone().start().await?; + + // STEP 7: use client methods to communicate with the MCP Server as you wish + + // Retrieve and display the list of tools available on the server + let server_version = client.server_version().unwrap(); + let tools = client.list_tools(None).await?.tools; + println!("List of tools for {}@{}", server_version.name, server_version.version); + + tools.iter().enumerate().for_each(|(tool_index, tool)| { + println!(" {}. {} : {}", + tool_index + 1, + tool.name, + tool.description.clone().unwrap_or_default() + ); + }); + + println!("Call \"add\" tool with 100 and 28 ..."); + // Create a `Map` to represent the tool parameters + let params = json!({"a": 100,"b": 28}).as_object().unwrap().clone(); + let request = CallToolRequestParams { name: "add".to_string(),arguments: Some(params)}; + + // invoke the tool + let result = client.call_tool(request).await?; + + println!("{}",result.content.first().unwrap().as_text_content()?.text); + + client.shut_down().await?; + + Ok(()) +``` +👉 see [examples/simple-mcp-client-streamable-http](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-streamable-http) for a complete working example. + + ### MCP Client (sse) -Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost identical, with one exception at `step 3`. Instead of creating a `StdioTransport`, you simply create a `ClientSseTransport`. The rest of the code remains the same: +Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost identical to the [stdio example](#mcp-client-stdio) , with one exception at `step 3`. Instead of creating a `StdioTransport`, you simply create a `ClientSseTransport`. The rest of the code remains the same: ```diff - let transport = StdioTransport::create_with_server_launch( @@ -317,6 +382,8 @@ Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost + let transport = ClientSseTransport::new(MCP_SERVER_URL, ClientSseTransportOptions::default())?; ``` +👉 see [examples/simple-mcp-client-sse](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-sse) for a complete working example. + ## Getting Started @@ -355,9 +422,15 @@ pub struct HyperServerOptions { /// Hostname or IP address the server will bind to (default: "8080") pub port: u16, + /// Optional thread-safe session id generator to generate unique session IDs. + pub session_id_generator: Option>>, + /// Optional custom path for the Streamable HTTP endpoint (default: `/mcp`) pub custom_streamable_http_endpoint: Option, + /// Shared transport configuration used by the server + pub transport_options: Arc, + /// This setting only applies to streamable HTTP. /// If true, the server will return JSON responses instead of starting an SSE stream. /// This can be useful for simple request/response scenarios without streaming. @@ -367,12 +440,6 @@ pub struct HyperServerOptions { /// Interval between automatic ping messages sent to clients to detect disconnects pub ping_interval: Duration, - /// Shared transport configuration used by the server - pub transport_options: Arc, - - /// Optional thread-safe session id generator to generate unique session IDs. - pub session_id_generator: Option>, - /// Enables SSL/TLS if set to `true` pub enable_ssl: bool, @@ -384,17 +451,6 @@ pub struct HyperServerOptions { /// Required if `enable_ssl` is `true`. pub ssl_key_path: Option, - /// If set to true, the SSE transport will also be supported for backward compatibility (default: true) - pub sse_support: bool, - - /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`) - /// Applicable only if sse_support is true - pub custom_sse_endpoint: Option, - - /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`) - /// Applicable only if sse_support is true - pub custom_messages_endpoint: Option, - /// List of allowed host header values for DNS rebinding protection. /// If not specified, host validation is disabled. pub allowed_hosts: Option>, @@ -406,6 +462,17 @@ pub struct HyperServerOptions { /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). /// Default is false for backwards compatibility. pub dns_rebinding_protection: bool, + + /// If set to true, the SSE transport will also be supported for backward compatibility (default: true) + pub sse_support: bool, + + /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`) + /// Applicable only if sse_support is true + pub custom_sse_endpoint: Option, + + /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`) + /// Applicable only if sse_support is true + pub custom_messages_endpoint: Option, } ``` @@ -427,9 +494,13 @@ The `rust-mcp-sdk` crate provides several features that can be enabled or disabl - `server`: Activates MCP server capabilities in `rust-mcp-sdk`, providing modules and APIs for building and managing MCP servers. - `client`: Activates MCP client capabilities, offering modules and APIs for client development and communicating with MCP servers. -- `hyper-server`: This feature enables the **sse** transport for MCP servers, supporting multiple simultaneous client connections out of the box. -- `ssl`: This feature enables TLS/SSL support for the **sse** transport when used with the `hyper-server`. +- `hyper-server`: This feature is necessary to enable `Streamable HTTP` or `Server-Sent Events (SSE)` transports for MCP servers. It must be used alongside the server feature to support the required server functionalities. +- `ssl`: This feature enables TLS/SSL support for the `Streamable HTTP` or `Server-Sent Events (SSE)` transport when used with the `hyper-server`. - `macros`: Provides procedural macros for simplifying the creation and manipulation of MCP Tool structures. +- `sse`: Enables support for the `Server-Sent Events (SSE)` transport. +- `streamable-http`: Enables support for the `Streamable HTTP` transport. +- `stdio`: Enables support for the `standard input/output (stdio)` transport.. + #### MCP Protocol Versions with Corresponding Features @@ -460,9 +531,9 @@ If you only need the MCP Server functionality, you can disable the default featu ```toml [dependencies] -rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["server","macros"] } +rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["server","macros","stdio"] } ``` -Optionally add `hyper-server` for **sse** transport, and `ssl` feature for tls/ssl support of the `hyper-server` +Optionally add `hyper-server` and `streamable-http` for **Streamable HTTP** transport, and `ssl` feature for tls/ssl support of the `hyper-server` @@ -475,7 +546,7 @@ Add the following to your Cargo.toml: ```toml [dependencies] -rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["client","2024_11_05"] } +rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["client","2024_11_05","stdio"] } ``` diff --git a/crates/rust-mcp-sdk/README.md b/crates/rust-mcp-sdk/README.md index 9df027d..ced3672 100644 --- a/crates/rust-mcp-sdk/README.md +++ b/crates/rust-mcp-sdk/README.md @@ -37,21 +37,10 @@ This project supports following transports: - ✅ DNS Rebinding Protection - ✅ Batch Messages - ✅ Streaming & non-streaming JSON response -- ⬜ Streamable HTTP Support for MCP Clients +- ✅ Streamable HTTP Support for MCP Clients - ⬜ Resumability - ⬜ Authentication / Oauth - - -**MCP Streamable HTTP Support** -- [x] Streamable HTTP Support for MCP Servers -- [x] DNS Rebinding Protection -- [x] Batch Messages -- [x] Streaming & non-streaming JSON response -- [ ] Streamable HTTP Support for MCP Clients -- [ ] Resumability -- [ ] Authentication / Oauth - **⚠️** Project is currently under development and should be used at your own risk. ## Table of Contents @@ -59,6 +48,7 @@ This project supports following transports: - [MCP Server (stdio)](#mcp-server-stdio) - [MCP Server (Streamable HTTP)](#mcp-server-streamable-http) - [MCP Client (stdio)](#mcp-client-stdio) + - [MCP Client (Streamable HTTP)](#mcp-client_streamable-http)) - [MCP Client (sse)](#mcp-client-sse) - [Getting Started](#getting-started) - [HyperServerOptions](#hyperserveroptions) @@ -293,6 +283,8 @@ async fn main() -> SdkResult<()> { println!("{}",result.content.first().unwrap().as_text_content()?.text); + client.shut_down().await?; + Ok(()) } @@ -304,8 +296,82 @@ Here is the output : > your results may vary slightly depending on the version of the MCP Server in use when you run it. +### MCP Client (Streamable HTTP) +```rs + +// STEP 1: Custom Handler to handle incoming MCP Messages +pub struct MyClientHandler; + +#[async_trait] +impl ClientHandler for MyClientHandler { + // To check out a list of all the methods in the trait that you can override, take a look at https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs +} + +#[tokio::main] +async fn main() -> SdkResult<()> { + + // Step2 : Define client details and capabilities + let client_details: InitializeRequestParams = InitializeRequestParams { + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "simple-rust-mcp-client-sse".to_string(), + version: "0.1.0".to_string(), + title: Some("Simple Rust MCP Client (SSE)".to_string()), + }, + protocol_version: LATEST_PROTOCOL_VERSION.into(), + }; + + // Step 3: Create transport options to connect to an MCP server via Streamable HTTP. + let transport_options = StreamableTransportOptions { + mcp_url: MCP_SERVER_URL.to_string(), + request_options: RequestOptions { + ..RequestOptions::default() + }, + }; + + // STEP 4: instantiate the custom handler that is responsible for handling MCP messages + let handler = MyClientHandler {}; + + // STEP 5: create the client with transport options and the handler + let client = client_runtime::with_transport_options(client_details, transport_options, handler); + + // STEP 6: start the MCP client + client.clone().start().await?; + + // STEP 7: use client methods to communicate with the MCP Server as you wish + + // Retrieve and display the list of tools available on the server + let server_version = client.server_version().unwrap(); + let tools = client.list_tools(None).await?.tools; + println!("List of tools for {}@{}", server_version.name, server_version.version); + + tools.iter().enumerate().for_each(|(tool_index, tool)| { + println!(" {}. {} : {}", + tool_index + 1, + tool.name, + tool.description.clone().unwrap_or_default() + ); + }); + + println!("Call \"add\" tool with 100 and 28 ..."); + // Create a `Map` to represent the tool parameters + let params = json!({"a": 100,"b": 28}).as_object().unwrap().clone(); + let request = CallToolRequestParams { name: "add".to_string(),arguments: Some(params)}; + + // invoke the tool + let result = client.call_tool(request).await?; + + println!("{}",result.content.first().unwrap().as_text_content()?.text); + + client.shut_down().await?; + + Ok(()) +``` +👉 see [examples/simple-mcp-client-streamable-http](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-streamable-http) for a complete working example. + + ### MCP Client (sse) -Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost identical, with one exception at `step 3`. Instead of creating a `StdioTransport`, you simply create a `ClientSseTransport`. The rest of the code remains the same: +Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost identical to the [stdio example](#mcp-client-stdio) , with one exception at `step 3`. Instead of creating a `StdioTransport`, you simply create a `ClientSseTransport`. The rest of the code remains the same: ```diff - let transport = StdioTransport::create_with_server_launch( @@ -316,6 +382,8 @@ Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost + let transport = ClientSseTransport::new(MCP_SERVER_URL, ClientSseTransportOptions::default())?; ``` +👉 see [examples/simple-mcp-client-sse](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-sse) for a complete working example. + ## Getting Started @@ -354,9 +422,15 @@ pub struct HyperServerOptions { /// Hostname or IP address the server will bind to (default: "8080") pub port: u16, + /// Optional thread-safe session id generator to generate unique session IDs. + pub session_id_generator: Option>>, + /// Optional custom path for the Streamable HTTP endpoint (default: `/mcp`) pub custom_streamable_http_endpoint: Option, + /// Shared transport configuration used by the server + pub transport_options: Arc, + /// This setting only applies to streamable HTTP. /// If true, the server will return JSON responses instead of starting an SSE stream. /// This can be useful for simple request/response scenarios without streaming. @@ -366,12 +440,6 @@ pub struct HyperServerOptions { /// Interval between automatic ping messages sent to clients to detect disconnects pub ping_interval: Duration, - /// Shared transport configuration used by the server - pub transport_options: Arc, - - /// Optional thread-safe session id generator to generate unique session IDs. - pub session_id_generator: Option>, - /// Enables SSL/TLS if set to `true` pub enable_ssl: bool, @@ -383,17 +451,6 @@ pub struct HyperServerOptions { /// Required if `enable_ssl` is `true`. pub ssl_key_path: Option, - /// If set to true, the SSE transport will also be supported for backward compatibility (default: true) - pub sse_support: bool, - - /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`) - /// Applicable only if sse_support is true - pub custom_sse_endpoint: Option, - - /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`) - /// Applicable only if sse_support is true - pub custom_messages_endpoint: Option, - /// List of allowed host header values for DNS rebinding protection. /// If not specified, host validation is disabled. pub allowed_hosts: Option>, @@ -405,6 +462,17 @@ pub struct HyperServerOptions { /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). /// Default is false for backwards compatibility. pub dns_rebinding_protection: bool, + + /// If set to true, the SSE transport will also be supported for backward compatibility (default: true) + pub sse_support: bool, + + /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`) + /// Applicable only if sse_support is true + pub custom_sse_endpoint: Option, + + /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`) + /// Applicable only if sse_support is true + pub custom_messages_endpoint: Option, } ``` @@ -426,9 +494,13 @@ The `rust-mcp-sdk` crate provides several features that can be enabled or disabl - `server`: Activates MCP server capabilities in `rust-mcp-sdk`, providing modules and APIs for building and managing MCP servers. - `client`: Activates MCP client capabilities, offering modules and APIs for client development and communicating with MCP servers. -- `hyper-server`: This feature enables the **sse** transport for MCP servers, supporting multiple simultaneous client connections out of the box. -- `ssl`: This feature enables TLS/SSL support for the **sse** transport when used with the `hyper-server`. +- `hyper-server`: This feature is necessary to enable `Streamable HTTP` or `Server-Sent Events (SSE)` transports for MCP servers. It must be used alongside the server feature to support the required server functionalities. +- `ssl`: This feature enables TLS/SSL support for the `Streamable HTTP` or `Server-Sent Events (SSE)` transport when used with the `hyper-server`. - `macros`: Provides procedural macros for simplifying the creation and manipulation of MCP Tool structures. +- `sse`: Enables support for the `Server-Sent Events (SSE)` transport. +- `streamable-http`: Enables support for the `Streamable HTTP` transport. +- `stdio`: Enables support for the `standard input/output (stdio)` transport.. + #### MCP Protocol Versions with Corresponding Features @@ -459,9 +531,9 @@ If you only need the MCP Server functionality, you can disable the default featu ```toml [dependencies] -rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["server","macros"] } +rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["server","macros","stdio"] } ``` -Optionally add `hyper-server` for **sse** transport, and `ssl` feature for tls/ssl support of the `hyper-server` +Optionally add `hyper-server` and `streamable-http` for **Streamable HTTP** transport, and `ssl` feature for tls/ssl support of the `hyper-server` @@ -474,7 +546,7 @@ Add the following to your Cargo.toml: ```toml [dependencies] -rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["client","2024_11_05"] } +rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["client","2024_11_05","stdio"] } ``` From 5234d9cd7d657f3da9fec56cd9476d0091a82eaa Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Wed, 17 Sep 2025 20:04:30 -0300 Subject: [PATCH 4/4] merge main --- crates/rust-mcp-sdk/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/rust-mcp-sdk/README.md b/crates/rust-mcp-sdk/README.md index ced3672..8036022 100644 --- a/crates/rust-mcp-sdk/README.md +++ b/crates/rust-mcp-sdk/README.md @@ -499,8 +499,9 @@ The `rust-mcp-sdk` crate provides several features that can be enabled or disabl - `macros`: Provides procedural macros for simplifying the creation and manipulation of MCP Tool structures. - `sse`: Enables support for the `Server-Sent Events (SSE)` transport. - `streamable-http`: Enables support for the `Streamable HTTP` transport. -- `stdio`: Enables support for the `standard input/output (stdio)` transport.. +- `stdio`: Enables support for the `standard input/output (stdio)` transport. +- `tls-no-provider`: Enables TLS without a crypto provider. This is useful if you are already using a different crypto provider than the aws-lc default. #### MCP Protocol Versions with Corresponding Features