diff --git a/Cargo.lock b/Cargo.lock index e14e129..3c18cf2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -392,6 +392,16 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + [[package]] name = "event-listener" version = "2.5.3" @@ -407,6 +417,12 @@ dependencies = [ "instant", ] +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "find-msvc-tools" version = "0.1.4" @@ -498,7 +514,7 @@ version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49a9d51ce47660b1e808d3c990b4709f2f415d928835a17dfd16991515c46bce" dependencies = [ - "fastrand", + "fastrand 1.9.0", "futures-core", "futures-io", "memchr", @@ -1085,6 +1101,12 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" +[[package]] +name = "linux-raw-sys" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" + [[package]] name = "litemap" version = "0.8.1" @@ -1674,6 +1696,7 @@ dependencies = [ "rust-mcp-transport", "serde", "serde_json", + "tempfile", "thiserror 2.0.17", "tokio", "tokio-stream", @@ -1707,6 +1730,19 @@ version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" +[[package]] +name = "rustix" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + [[package]] name = "rustls" version = "0.23.35" @@ -2039,6 +2075,19 @@ dependencies = [ "syn", ] +[[package]] +name = "tempfile" +version = "3.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" +dependencies = [ + "fastrand 2.3.0", + "getrandom 0.3.4", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + [[package]] name = "thiserror" version = "1.0.69" diff --git a/crates/rust-mcp-sdk/Cargo.toml b/crates/rust-mcp-sdk/Cargo.toml index 6d105d5..4f4238a 100644 --- a/crates/rust-mcp-sdk/Cargo.toml +++ b/crates/rust-mcp-sdk/Cargo.toml @@ -47,6 +47,7 @@ reqwest = { workspace = true, default-features = false, features = [ "cookies", "multipart", ] } +tempfile = "3.23.0" tracing-subscriber = { workspace = true, features = [ "env-filter", "std", diff --git a/crates/rust-mcp-sdk/src/hyper_servers/server.rs b/crates/rust-mcp-sdk/src/hyper_servers/server.rs index 74a3d77..0b30da6 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/server.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/server.rs @@ -200,7 +200,7 @@ impl HyperServerOptions { } pub fn streamable_http_endpoint(&self) -> &str { - self.custom_messages_endpoint + self.custom_streamable_http_endpoint .as_deref() .unwrap_or(DEFAULT_STREAMABLE_HTTP_ENDPOINT) } @@ -490,3 +490,170 @@ async fn shutdown_signal(handle: Handle, state: Arc) { // Trigger graceful shutdown with a timeout handle.graceful_shutdown(Some(Duration::from_secs(GRACEFUL_SHUTDOWN_TMEOUT_SECS))); } + +#[cfg(test)] +mod tests { + use super::*; + + use tempfile::NamedTempFile; + + #[test] + fn test_server_options_base_url_custom() { + let options = HyperServerOptions { + host: String::from("127.0.0.1"), + port: 8081, + enable_ssl: true, + ..Default::default() + }; + assert_eq!(options.base_url(), "https://127.0.0.1:8081"); + } + + #[test] + fn test_server_options_streamable_http_custom() { + let options = HyperServerOptions { + custom_streamable_http_endpoint: Some(String::from("/abcd/mcp")), + host: String::from("127.0.0.1"), + port: 8081, + enable_ssl: true, + ..Default::default() + }; + assert_eq!( + options.streamable_http_url(), + "https://127.0.0.1:8081/abcd/mcp" + ); + assert_eq!(options.streamable_http_endpoint(), "/abcd/mcp"); + } + + #[test] + fn test_server_options_sse_custom() { + let options = HyperServerOptions { + custom_sse_endpoint: Some(String::from("/abcd/sse")), + host: String::from("127.0.0.1"), + port: 8081, + enable_ssl: true, + ..Default::default() + }; + assert_eq!(options.sse_url(), "https://127.0.0.1:8081/abcd/sse"); + assert_eq!(options.sse_endpoint(), "/abcd/sse"); + } + + #[test] + fn test_server_options_sse_messages_custom() { + let options = HyperServerOptions { + custom_messages_endpoint: Some(String::from("/abcd/messages")), + ..Default::default() + }; + assert_eq!( + options.sse_message_url(), + "http://127.0.0.1:8080/abcd/messages" + ); + assert_eq!(options.sse_messages_endpoint(), "/abcd/messages"); + } + + #[test] + fn test_server_options_needs_dns_protection() { + let options = HyperServerOptions::default(); + + // should be false by default + assert!(!options.needs_dns_protection()); + + // should still be false unless allowed_hosts or allowed_origins are also provided + let options = HyperServerOptions { + dns_rebinding_protection: true, + ..Default::default() + }; + assert!(!options.needs_dns_protection()); + + // should be true when dns_rebinding_protection is true and allowed_hosts is provided + let options = HyperServerOptions { + dns_rebinding_protection: true, + allowed_hosts: Some(vec![String::from("127.0.0.1")]), + ..Default::default() + }; + assert!(options.needs_dns_protection()); + + // should be true when dns_rebinding_protection is true and allowed_origins is provided + let options = HyperServerOptions { + dns_rebinding_protection: true, + allowed_origins: Some(vec![String::from("http://127.0.0.1:8080")]), + ..Default::default() + }; + assert!(options.needs_dns_protection()); + } + + #[test] + fn test_server_options_validate() { + let options = HyperServerOptions::default(); + assert!(options.validate().is_ok()); + + // with ssl enabled but no cert or key provided, validate should fail + let options = HyperServerOptions { + enable_ssl: true, + ..Default::default() + }; + assert!(options.validate().is_err()); + + // with ssl enabled and invalid cert/key paths, validate should fail + let options = HyperServerOptions { + enable_ssl: true, + ssl_cert_path: Some(String::from("/invalid/path/to/cert.pem")), + ssl_key_path: Some(String::from("/invalid/path/to/key.pem")), + ..Default::default() + }; + assert!(options.validate().is_err()); + + // with ssl enabled and valid cert/key paths, validate should succeed + let cert_file = + NamedTempFile::with_suffix(".pem").expect("Expected to create test cert file"); + let ssl_cert_path = cert_file + .path() + .to_str() + .expect("Expected to get cert path") + .to_string(); + let key_file = + NamedTempFile::with_suffix(".pem").expect("Expected to create test key file"); + let ssl_key_path = key_file + .path() + .to_str() + .expect("Expected to get key path") + .to_string(); + + let options = HyperServerOptions { + enable_ssl: true, + ssl_cert_path: Some(ssl_cert_path), + ssl_key_path: Some(ssl_key_path), + ..Default::default() + }; + assert!(options.validate().is_ok()); + } + + #[tokio::test] + async fn test_server_options_resolve_server_address() { + let options = HyperServerOptions::default(); + assert!(options.resolve_server_address().await.is_ok()); + + // valid host should still work + let options = HyperServerOptions { + host: String::from("8.6.7.5"), + port: 309, + ..Default::default() + }; + assert!(options.resolve_server_address().await.is_ok()); + + // valid host (prepended with http://) should still work + let options = HyperServerOptions { + host: String::from("http://8.6.7.5"), + port: 309, + ..Default::default() + }; + assert!(options.resolve_server_address().await.is_ok()); + + // invalid host should raise an error + let options = HyperServerOptions { + host: String::from("invalid-host"), + port: 309, + ..Default::default() + }; + assert!(options.resolve_server_address().await.is_err()); + } +}