diff --git a/Cargo.lock b/Cargo.lock index b434baefb901..799ae1710554 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3444,6 +3444,7 @@ dependencies = [ "mcp-server", "reqwest 0.12.12", "rmcp", + "schemars", "serde", "serde_json", "serde_yaml", @@ -6166,6 +6167,20 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "process-wrap" +version = "8.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d35f4dc9988d1326b065b4def5e950c3ed727aa03e3151b86cc9e2aec6b03f54" +dependencies = [ + "futures", + "indexmap 2.7.1", + "nix 0.29.0", + "tokio", + "tracing", + "windows 0.59.0", +] + [[package]] name = "profiling" version = "1.0.16" @@ -6803,14 +6818,19 @@ dependencies = [ "base64 0.22.1", "chrono", "futures", + "http 1.2.0", "paste", "pin-project-lite", + "process-wrap", + "reqwest 0.12.12", "rmcp-macros", "schemars", "serde", "serde_json", + "sse-stream", "thiserror 2.0.12", "tokio", + "tokio-stream", "tokio-util", "tracing", ] @@ -7253,6 +7273,7 @@ version = "1.0.140" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" dependencies = [ + "indexmap 2.7.1", "itoa", "memchr", "ryu", @@ -7572,6 +7593,19 @@ dependencies = [ "syn 2.0.99", ] +[[package]] +name = "sse-stream" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb4dc4d33c68ec1f27d386b5610a351922656e1fdf5c05bbaad930cd1519479a" +dependencies = [ + "bytes", + "futures-util", + "http-body 1.0.1", + "http-body-util", + "pin-project-lite", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -9020,6 +9054,16 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f919aee0a93304be7f62e8e5027811bbba96bcb1de84d6618be56e43f8a32a1" +dependencies = [ + "windows-core 0.59.0", + "windows-targets 0.53.3", +] + [[package]] name = "windows-core" version = "0.52.0" @@ -9050,10 +9094,23 @@ dependencies = [ "windows-implement 0.58.0", "windows-interface 0.58.0", "windows-result 0.2.0", - "windows-strings", + "windows-strings 0.1.0", "windows-targets 0.52.6", ] +[[package]] +name = "windows-core" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "810ce18ed2112484b0d4e15d022e5f598113e220c53e373fb31e67e21670c1ce" +dependencies = [ + "windows-implement 0.59.0", + "windows-interface 0.59.1", + "windows-result 0.3.4", + "windows-strings 0.3.1", + "windows-targets 0.53.3", +] + [[package]] name = "windows-implement" version = "0.57.0" @@ -9076,6 +9133,17 @@ dependencies = [ "syn 2.0.99", ] +[[package]] +name = "windows-implement" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83577b051e2f49a058c308f17f273b570a6a758386fc291b5f6a934dd84e48c1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + [[package]] name = "windows-interface" version = "0.57.0" @@ -9098,6 +9166,23 @@ dependencies = [ "syn 2.0.99", ] +[[package]] +name = "windows-interface" +version = "0.59.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "windows-link" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" + [[package]] name = "windows-registry" version = "0.2.0" @@ -9105,7 +9190,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" dependencies = [ "windows-result 0.2.0", - "windows-strings", + "windows-strings 0.1.0", "windows-targets 0.52.6", ] @@ -9127,6 +9212,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-result" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-strings" version = "0.1.0" @@ -9137,6 +9231,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-strings" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87fa48cc5d406560701792be122a10132491cff9d0aeb23583cc2dcafc847319" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-sys" version = "0.45.0" @@ -9212,13 +9315,30 @@ dependencies = [ "windows_aarch64_gnullvm 0.52.6", "windows_aarch64_msvc 0.52.6", "windows_i686_gnu 0.52.6", - "windows_i686_gnullvm", + "windows_i686_gnullvm 0.52.6", "windows_i686_msvc 0.52.6", "windows_x86_64_gnu 0.52.6", "windows_x86_64_gnullvm 0.52.6", "windows_x86_64_msvc 0.52.6", ] +[[package]] +name = "windows-targets" +version = "0.53.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5fe6031c4041849d7c496a8ded650796e7b6ecc19df1a431c1a363342e5dc91" +dependencies = [ + "windows-link", + "windows_aarch64_gnullvm 0.53.0", + "windows_aarch64_msvc 0.53.0", + "windows_i686_gnu 0.53.0", + "windows_i686_gnullvm 0.53.0", + "windows_i686_msvc 0.53.0", + "windows_x86_64_gnu 0.53.0", + "windows_x86_64_gnullvm 0.53.0", + "windows_x86_64_msvc 0.53.0", +] + [[package]] name = "windows_aarch64_gnullvm" version = "0.42.2" @@ -9237,6 +9357,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" + [[package]] name = "windows_aarch64_msvc" version = "0.42.2" @@ -9255,6 +9381,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" + [[package]] name = "windows_i686_gnu" version = "0.42.2" @@ -9273,12 +9405,24 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" +[[package]] +name = "windows_i686_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" + [[package]] name = "windows_i686_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" + [[package]] name = "windows_i686_msvc" version = "0.42.2" @@ -9297,6 +9441,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" +[[package]] +name = "windows_i686_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" + [[package]] name = "windows_x86_64_gnu" version = "0.42.2" @@ -9315,6 +9465,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" + [[package]] name = "windows_x86_64_gnullvm" version = "0.42.2" @@ -9333,6 +9489,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" + [[package]] name = "windows_x86_64_msvc" version = "0.42.2" @@ -9351,6 +9513,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" + [[package]] name = "winnow" version = "0.7.3" diff --git a/crates/goose-cli/src/scenario_tests/mock_client.rs b/crates/goose-cli/src/scenario_tests/mock_client.rs index cdc35ef8e9ff..006795ce9d94 100644 --- a/crates/goose-cli/src/scenario_tests/mock_client.rs +++ b/crates/goose-cli/src/scenario_tests/mock_client.rs @@ -1,13 +1,15 @@ //! MockClient is a mock implementation of the McpClientTrait for testing purposes. //! add a tool you want to have around and then add the client to the extension router -use mcp_client::client::{ClientCapabilities, ClientInfo, Error, McpClientTrait}; -use mcp_core::protocol::{ - CallToolResult, Implementation, InitializeResult, ListPromptsResult, ListResourcesResult, - ListToolsResult, ReadResourceResult, ServerCapabilities, ToolsCapability, +use mcp_client::client::{Error, McpClientTrait}; +use mcp_core::ToolError; +use rmcp::{ + model::{ + CallToolResult, Content, GetPromptResult, ListPromptsResult, ListResourcesResult, + ListToolsResult, ReadResourceResult, ServerNotification, Tool, + }, + object, }; -use mcp_core::{Tool, ToolError}; -use rmcp::model::{Content, GetPromptResult, ServerNotification}; use serde_json::Value; use std::collections::HashMap; use tokio::sync::mpsc::{self, Receiver}; @@ -38,26 +40,6 @@ impl MockClient { #[async_trait::async_trait] impl McpClientTrait for MockClient { - async fn initialize( - &mut self, - _: ClientInfo, - _: ClientCapabilities, - ) -> Result { - Ok(InitializeResult { - protocol_version: "2024-11-05".to_string(), - capabilities: ServerCapabilities { - prompts: None, - resources: None, - tools: Some(ToolsCapability { list_changed: None }), - }, - server_info: Implementation { - name: "MockClient".to_string(), - version: "1.0.0".to_string(), - }, - instructions: None, - }) - } - async fn list_resources( &self, _next_cursor: Option, @@ -68,10 +50,12 @@ impl McpClientTrait for MockClient { }) } + fn get_info(&self) -> std::option::Option<&rmcp::model::InitializeResult> { + todo!() + } + async fn read_resource(&self, _uri: &str) -> Result { - Err(Error::UnexpectedResponse( - "Resources not supported by mock client".to_string(), - )) + Err(Error::UnexpectedResponse) } async fn list_tools(&self, _: Option) -> Result { @@ -79,16 +63,10 @@ impl McpClientTrait for MockClient { .tools .values() .map(|tool| { - let input_schema = if let serde_json::Value::Object(obj) = &tool.input_schema { - std::sync::Arc::new(obj.clone()) - } else { - std::sync::Arc::new(serde_json::Map::new()) - }; - rmcp::model::Tool::new( tool.name.to_string(), - tool.description.to_string(), - input_schema, + tool.description.clone().unwrap_or_default(), + tool.input_schema.clone(), ) }) .collect(); @@ -106,24 +84,22 @@ impl McpClientTrait for MockClient { content, is_error: None, }), - Err(e) => Err(Error::UnexpectedResponse(e.to_string())), + Err(e) => Err(Error::UnexpectedResponse), } } else { - Err(Error::UnexpectedResponse(format!( - "Tool '{}' not found", - name - ))) + Err(Error::UnexpectedResponse) } } async fn list_prompts(&self, _next_cursor: Option) -> Result { - Ok(ListPromptsResult { prompts: vec![] }) + Ok(ListPromptsResult { + prompts: vec![], + next_cursor: None, + }) } async fn get_prompt(&self, _name: &str, _arguments: Value) -> Result { - Err(Error::UnexpectedResponse( - "Prompts not supported by mock client".to_string(), - )) + Err(Error::UnexpectedResponse) } async fn subscribe(&self) -> Receiver { @@ -137,7 +113,7 @@ pub fn weather_client() -> MockClient { let weather_tool = Tool::new( "get_weather", "Get the weather for a location", - serde_json::json!({ + object!({ "type": "object", "required": ["location"], "properties": { @@ -147,7 +123,6 @@ pub fn weather_client() -> MockClient { } } }), - None, // ToolAnnotations ); let mock_client = MockClient::new().add_tool(weather_tool, |args| { diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index f18b886c183c..ea619dfa69a3 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -1,5 +1,4 @@ use console::style; -use goose::agents::extension::ExtensionError; use goose::agents::types::RetryConfig; use goose::agents::Agent; use goose::config::{Config, ExtensionConfig, ExtensionConfigManager}; @@ -7,7 +6,6 @@ use goose::providers::create; use goose::recipe::{Response, SubRecipe}; use goose::session; use goose::session::Identifier; -use mcp_client::transport::Error as McpClientError; use rustyline::EditMode; use std::process; use std::sync::Arc; @@ -359,10 +357,7 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { for extension in extensions_to_run { if let Err(e) = agent.add_extension(extension.clone()).await { - let err = match e { - ExtensionError::Transport(McpClientError::StdioProcessError(inner)) => inner, - _ => e.to_string(), - }; + let err = e.to_string(); eprintln!( "{}", style(format!( diff --git a/crates/goose-server/Cargo.toml b/crates/goose-server/Cargo.toml index 483746c0b751..555d64f9a308 100644 --- a/crates/goose-server/Cargo.toml +++ b/crates/goose-server/Cargo.toml @@ -16,12 +16,13 @@ mcp-core = { path = "../mcp-core" } goose-mcp = { path = "../goose-mcp" } mcp-server = { path = "../mcp-server" } rmcp = { workspace = true } +schemars = "1.0" axum = { version = "0.8.1", features = ["ws", "macros"] } tokio = { version = "1.43", features = ["full"] } chrono = "0.4" tower-http = { version = "0.5", features = ["cors"] } serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" +serde_json = { version = "1.0", features = ["preserve_order"] } futures = "0.3" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt", "json", "time"] } diff --git a/crates/goose-server/src/bin/generate_schema.rs b/crates/goose-server/src/bin/generate_schema.rs index 18de38e07695..8ddb867bef49 100644 --- a/crates/goose-server/src/bin/generate_schema.rs +++ b/crates/goose-server/src/bin/generate_schema.rs @@ -22,7 +22,7 @@ fn main() { fs::write(&output_path, &schema).unwrap(); eprintln!( "Successfully generated OpenAPI schema at {}", - output_path.display() + output_path.canonicalize().unwrap().display() ); // Output the schema to stdout for piping diff --git a/crates/goose-server/src/logging.rs b/crates/goose-server/src/logging.rs index 90db8f8a2369..7c21881038c6 100644 --- a/crates/goose-server/src/logging.rs +++ b/crates/goose-server/src/logging.rs @@ -66,6 +66,7 @@ pub fn setup_logging(name: Option<&str>) -> Result<()> { // Create console logging layer for development - INFO and above only let console_layer = fmt::layer() + .with_writer(std::io::stderr) .with_target(true) .with_level(true) .with_ansi(true) diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index ceb05db371f6..4fb67eb6dd93 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -17,7 +17,12 @@ reqwest = { version = "0.12.9", features = ["json", "rustls-tls-native-roots"], [dependencies] mcp-client = { path = "../mcp-client" } mcp-core = { path = "../mcp-core" } -rmcp = { workspace = true } +rmcp = { workspace = true, features = [ + "reqwest", + "transport-child-process", + "transport-sse-client", + "transport-streamable-http-client", +] } anyhow = "1.0" thiserror = "1.0" futures = "0.3" diff --git a/crates/goose/src/agents/extension.rs b/crates/goose/src/agents/extension.rs index fa431ed850a4..3550ac4e7086 100644 --- a/crates/goose/src/agents/extension.rs +++ b/crates/goose/src/agents/extension.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use mcp_client::client::Error as ClientError; use rmcp::model::Tool; +use rmcp::service::ClientInitializeError; use serde::{Deserialize, Serialize}; use thiserror::Error; use tracing::warn; @@ -11,25 +12,43 @@ use crate::config; use crate::config::extensions::name_to_key; use crate::config::permission::PermissionLevel; +#[derive(Error, Debug)] +#[error("process quit before initialization: stderr = {stderr}")] +pub struct ProcessExit { + stderr: String, + #[source] + source: ClientInitializeError, +} + +impl ProcessExit { + pub fn new(stderr: T, source: ClientInitializeError) -> Self + where + T: Into, + { + ProcessExit { + stderr: stderr.into(), + source, + } + } +} + /// Errors from Extension operation #[derive(Error, Debug)] pub enum ExtensionError { - #[error("Failed to start the MCP server from configuration `{0}` `{1}`")] - Initialization(Box, ClientError), - #[error("Failed a client call to an MCP server: {0}")] + #[error("failed a client call to an MCP server: {0}")] Client(#[from] ClientError), - #[error("User Message exceeded context-limit. History could not be truncated to accommodate.")] - ContextLimit, - #[error("Transport error: {0}")] - Transport(#[from] mcp_client::transport::Error), - #[error("Environment variable `{0}` is not allowed to be overridden.")] - InvalidEnvVar(String), - #[error("Error during extension setup: {0}")] + #[error("invalid config: {0}")] + ConfigError(String), + #[error("error during extension setup: {0}")] SetupError(String), - #[error("Join error occurred during task execution: {0}")] + #[error("join error occurred during task execution: {0}")] TaskJoinError(#[from] tokio::task::JoinError), #[error("IO error: {0}")] IoError(#[from] std::io::Error), + #[error("failed to initialize MCP client: {0}")] + InitializeError(#[from] ClientInitializeError), + #[error("{0}")] + ProcessExit(#[from] ProcessExit), } pub type ExtensionResult = Result; @@ -107,7 +126,10 @@ impl Envs { pub fn validate(&self) -> Result<(), Box> { for key in self.map.keys() { if Self::is_disallowed(key) { - return Err(Box::new(ExtensionError::InvalidEnvVar(key.clone()))); + return Err(Box::new(ExtensionError::ConfigError(format!( + "environment variable {} not allowed to be overwritten", + key + )))); } } Ok(()) diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index 22d4c57c16ab..72328d2c86ed 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -1,13 +1,22 @@ use anyhow::Result; +use axum::http::{HeaderMap, HeaderName}; use chrono::{DateTime, TimeZone, Utc}; use futures::stream::{FuturesUnordered, StreamExt}; use futures::{future, FutureExt}; -use rmcp::model::GetPromptResult; +use mcp_core::{ToolCall, ToolError}; +use rmcp::service::ClientInitializeError; +use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; +use rmcp::transport::{ + ConfigureCommandExt, SseClientTransport, StreamableHttpClientTransport, TokioChildProcess, +}; use std::collections::{HashMap, HashSet}; +use std::process::Stdio; use std::sync::Arc; use std::sync::LazyLock; use std::time::Duration; use tempfile::tempdir; +use tokio::io::AsyncReadExt; +use tokio::process::Command; use tokio::sync::Mutex; use tokio::task; use tokio_stream::wrappers::ReceiverStream; @@ -15,13 +24,11 @@ use tracing::{error, warn}; use super::extension::{ExtensionConfig, ExtensionError, ExtensionInfo, ExtensionResult, ToolInfo}; use super::tool_execution::ToolCallResult; -use crate::agents::extension::Envs; +use crate::agents::extension::{Envs, ProcessExit}; use crate::config::{Config, ExtensionConfigManager}; use crate::prompt_template; -use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}; -use mcp_client::transport::{SseTransport, StdioTransport, StreamableHttpTransport, Transport}; -use mcp_core::{ToolCall, ToolError}; -use rmcp::model::{Content, Prompt, Resource, ResourceContents, Tool}; +use mcp_client::client::{McpClient, McpClientTrait}; +use rmcp::model::{Content, GetPromptResult, Prompt, Resource, ResourceContents, Tool}; use serde_json::Value; // By default, we set it to Jan 1, 2020 if the resource does not have a timestamp @@ -167,7 +174,7 @@ impl ExtensionManager { error = %e, "Failed to fetch secret from config." ); - return Err(ExtensionError::SetupError(format!( + return Err(ExtensionError::ConfigError(format!( "Failed to fetch secret '{}' from config: {}", key, e ))); @@ -178,20 +185,19 @@ impl ExtensionManager { Ok(all_envs) } - let mut client: Box = match &config { - ExtensionConfig::Sse { - uri, - envs, - env_keys, - timeout, - .. - } => { - let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?; - let transport = SseTransport::new(uri, all_envs); - let handle = transport.start().await?; + let client: Box = match &config { + ExtensionConfig::Sse { uri, timeout, .. } => { + let transport = SseClientTransport::start(uri.to_string()).await.map_err( + |transport_error| { + ClientInitializeError::transport::>( + transport_error, + "connect", + ) + }, + )?; Box::new( McpClient::connect( - handle, + transport, Duration::from_secs( timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), ), @@ -201,25 +207,42 @@ impl ExtensionManager { } ExtensionConfig::StreamableHttp { uri, - envs, - env_keys, - headers, timeout, + headers, .. } => { - let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?; - let transport = - StreamableHttpTransport::with_headers(uri, all_envs, headers.clone()); - let handle = transport.start().await?; - Box::new( - McpClient::connect( - handle, - Duration::from_secs( - timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), - ), - ) - .await?, + let mut default_headers = HeaderMap::new(); + for (key, value) in headers { + default_headers.insert( + HeaderName::try_from(key).map_err(|_| { + ExtensionError::ConfigError(format!("invalid header: {}", key)) + })?, + value.parse().map_err(|_| { + ExtensionError::ConfigError(format!("invalid header value: {}", key)) + })?, + ); + } + let client = reqwest::Client::builder() + .default_headers(default_headers) + .build() + .map_err(|_| { + ExtensionError::ConfigError("could not construct http client".to_string()) + })?; + let transport = StreamableHttpClientTransport::with_client( + client, + StreamableHttpClientTransportConfig { + uri: uri.clone().into(), + ..Default::default() + }, + ); + let client = McpClient::connect( + transport, + Duration::from_secs( + timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), + ), ) + .await?; + Box::new(client) } ExtensionConfig::Stdio { cmd, @@ -230,17 +253,42 @@ impl ExtensionManager { .. } => { let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?; - let transport = StdioTransport::new(cmd, args.to_vec(), all_envs); - let handle = transport.start().await?; - Box::new( - McpClient::connect( - handle, - Duration::from_secs( - timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), - ), - ) - .await?, + let command = Command::new(cmd).configure(|command| { + command.args(args).envs(all_envs); + }); + let (transport, mut stderr) = TokioChildProcess::builder(command) + .stderr(Stdio::piped()) + .spawn()?; + let mut stderr = stderr + .take() + .expect("should have a stderr handle because it was requested"); + + let stderr_task = tokio::spawn(async move { + let mut all_stderr = Vec::new(); + stderr.read_to_end(&mut all_stderr).await?; + Ok::(String::from_utf8_lossy(&all_stderr).into()) + }); + + let client_result = McpClient::connect( + transport, + Duration::from_secs( + timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), + ), ) + .await; + + let client = match client_result { + Ok(client) => Ok(client), + Err(error) => { + let error_task_out = stderr_task.await?; + Err::(match error_task_out { + Ok(stderr_content) => ProcessExit::new(stderr_content, error).into(), + Err(e) => e.into(), + }) + } + }?; + + Box::new(client) } ExtensionConfig::Builtin { name, @@ -254,15 +302,13 @@ impl ExtensionManager { .to_str() .expect("should resolve executable to string path") .to_string(); - let transport = StdioTransport::new( - &cmd, - vec!["mcp".to_string(), name.clone()], - HashMap::new(), - ); - let handle = transport.start().await?; + + let transport = TokioChildProcess::new(Command::new(cmd).configure(|command| { + command.arg("mcp").arg(name); + }))?; Box::new( McpClient::connect( - handle, + transport, Duration::from_secs( timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), ), @@ -281,27 +327,20 @@ impl ExtensionManager { let file_path = temp_dir.path().join(format!("{}.py", name)); std::fs::write(&file_path, code)?; - let mut args = vec![]; - - let mut all_deps = vec!["mcp".to_string()]; + let command = Command::new("uvx").configure(|command| { + command.arg("--with").arg("mcp"); - if let Some(deps) = dependencies.as_ref() { - all_deps.extend(deps.iter().cloned()); - } - - for dep in all_deps { - args.push("--with".to_string()); - args.push(dep); - } + dependencies.iter().flatten().for_each(|dep| { + command.arg("--with").arg(dep); + }); - args.push("python".to_string()); - args.push(file_path.to_str().unwrap().to_string()); + command.arg("python").arg(file_path.to_str().unwrap()); + }); + let transport = TokioChildProcess::new(command)?; - let transport = StdioTransport::new("uvx", args, HashMap::new()); - let handle = transport.start().await?; let client = Box::new( McpClient::connect( - handle, + transport, Duration::from_secs( timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), ), @@ -316,24 +355,13 @@ impl ExtensionManager { _ => unreachable!(), }; - // Initialize the client with default capabilities - let info = ClientInfo { - name: "goose".to_string(), - version: env!("CARGO_PKG_VERSION").to_string(), - }; - let capabilities = ClientCapabilities::default(); - - let init_result = client - .initialize(info, capabilities) - .await - .map_err(|e| ExtensionError::Initialization(Box::new(config.clone()), e))?; - - if let Some(instructions) = init_result.instructions { + let info = client.get_info(); + if let Some(instructions) = info.and_then(|info| info.instructions.as_ref()) { self.instructions - .insert(sanitized_name.clone(), instructions); + .insert(sanitized_name.clone(), instructions.clone()); } - if init_result.capabilities.resources.is_some() { + if let Some(_resources) = info.and_then(|info| info.capabilities.resources.as_ref()) { self.resource_capable_extensions .insert(sanitized_name.clone()); } @@ -431,18 +459,13 @@ impl ExtensionManager { let mut client_tools = client_guard.list_tools(None).await?; loop { - for client_tool in client_tools.tools { - let mut tool = Tool::new( - format!("{}__{}", name, client_tool.name), - client_tool.description.unwrap_or_default(), - client_tool.input_schema, - ); - - if tool.annotations.is_some() { - tool = tool.annotate(client_tool.annotations.unwrap()) - } - - tools.push(tool); + for tool in client_tools.tools { + tools.push(Tool { + name: format!("{}__{}", name, tool.name).into(), + description: tool.description, + input_schema: tool.input_schema, + annotations: tool.annotations, + }); } // Exit loop when there are no more pages @@ -885,11 +908,14 @@ mod tests { use super::*; use mcp_client::client::Error; use mcp_client::client::McpClientTrait; - use mcp_core::protocol::{ - CallToolResult, InitializeResult, ListPromptsResult, ListResourcesResult, ListToolsResult, - ReadResourceResult, - }; - use rmcp::model::{GetPromptResult, ServerNotification}; + use rmcp::model::CallToolResult; + use rmcp::model::InitializeResult; + + use rmcp::model::ListPromptsResult; + use rmcp::model::ListResourcesResult; + use rmcp::model::ListToolsResult; + use rmcp::model::ReadResourceResult; + use rmcp::model::ServerNotification; use serde_json::json; use tokio::sync::mpsc; @@ -897,27 +923,23 @@ mod tests { #[async_trait::async_trait] impl McpClientTrait for MockClient { - async fn initialize( - &mut self, - _info: ClientInfo, - _capabilities: ClientCapabilities, - ) -> Result { - Err(Error::NotInitialized) + fn get_info(&self) -> Option<&InitializeResult> { + None } async fn list_resources( &self, _next_cursor: Option, ) -> Result { - Err(Error::NotInitialized) + Err(Error::TransportClosed) } async fn read_resource(&self, _uri: &str) -> Result { - Err(Error::NotInitialized) + Err(Error::TransportClosed) } async fn list_tools(&self, _next_cursor: Option) -> Result { - Err(Error::NotInitialized) + Err(Error::TransportClosed) } async fn call_tool(&self, name: &str, _arguments: Value) -> Result { @@ -926,7 +948,7 @@ mod tests { content: vec![], is_error: None, }), - _ => Err(Error::NotInitialized), + _ => Err(Error::TransportClosed), } } @@ -934,7 +956,7 @@ mod tests { &self, _next_cursor: Option, ) -> Result { - Err(Error::NotInitialized) + Err(Error::TransportClosed) } async fn get_prompt( @@ -942,7 +964,7 @@ mod tests { _name: &str, _arguments: Value, ) -> Result { - Err(Error::NotInitialized) + Err(Error::TransportClosed) } async fn subscribe(&self) -> mpsc::Receiver { diff --git a/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs b/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs index 18157eec73ea..da8ea0ca121b 100644 --- a/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs +++ b/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs @@ -97,11 +97,7 @@ impl TaskExecutionTracker { is_token_cancelled(&self.cancellation_token) } - fn log_notification_error( - &self, - error: &mpsc::error::TrySendError, - context: &str, - ) { + fn log_notification_error(&self, error: &mpsc::error::TrySendError, context: &str) { if !self.is_cancelled() { tracing::warn!("Failed to send {} notification: {}", context, error); } diff --git a/crates/goose/src/agents/tool_execution.rs b/crates/goose/src/agents/tool_execution.rs index 9be0d9b82a7d..f80a03ca355c 100644 --- a/crates/goose/src/agents/tool_execution.rs +++ b/crates/goose/src/agents/tool_execution.rs @@ -4,7 +4,6 @@ use std::sync::Arc; use async_stream::try_stream; use futures::stream::{self, BoxStream}; use futures::{Stream, StreamExt}; -use rmcp::model::ServerNotification; use tokio::sync::Mutex; use tokio_util::sync::CancellationToken; @@ -13,7 +12,7 @@ use crate::config::PermissionManager; use crate::message::{Message, ToolRequest}; use crate::permission::Permission; use mcp_core::ToolResult; -use rmcp::model::Content; +use rmcp::model::{Content, ServerNotification}; // ToolCallResult combines the result of a tool call with an optional notification stream that // can be used to receive notifications from the tool. diff --git a/crates/goose/src/providers/formats/google.rs b/crates/goose/src/providers/formats/google.rs index 50c45ac3c513..a487084f06b5 100644 --- a/crates/goose/src/providers/formats/google.rs +++ b/crates/goose/src/providers/formats/google.rs @@ -335,8 +335,7 @@ pub fn create_request( #[cfg(test)] mod tests { use super::*; - use rmcp::model::Content; - use rmcp::object; + use rmcp::{model::Content, object}; use serde_json::json; fn set_up_text_message(text: &str, role: Role) -> Message { @@ -680,18 +679,12 @@ mod tests { #[test] fn test_tools_to_google_spec_with_empty_properties() { - use rmcp::model::object; - use std::borrow::Cow; - use std::sync::Arc; - - let schema = json!({ - "properties": {} - }); - let tools = vec![Tool::new( - Cow::Borrowed("tool1"), - Cow::Borrowed("description1"), - Arc::new(object(schema)), + "tool1".to_string(), + "description1".to_string(), + object!({ + "properties": {} + }), )]; let result = format_tools(&tools); assert_eq!(result.len(), 1); diff --git a/crates/mcp-client/Cargo.toml b/crates/mcp-client/Cargo.toml index 94cf802fd629..993bfef48b06 100644 --- a/crates/mcp-client/Cargo.toml +++ b/crates/mcp-client/Cargo.toml @@ -11,7 +11,7 @@ mcp-core = { path = "../mcp-core" } tokio = { version = "1", features = ["full"] } tokio-util = { version = "0.7", features = ["io"] } reqwest = { version = "0.11", default-features = false, features = ["json", "stream", "rustls-tls-native-roots"] } -rmcp = { workspace = true } +rmcp = { workspace = true, features = ["client", "transport-child-process"]} eventsource-client = "0.12.0" futures = "0.3" serde = { version = "1.0", features = ["derive"] } diff --git a/crates/mcp-client/examples/clients.rs b/crates/mcp-client/examples/clients.rs deleted file mode 100644 index e36abeba1ac4..000000000000 --- a/crates/mcp-client/examples/clients.rs +++ /dev/null @@ -1,127 +0,0 @@ -use mcp_client::{ - client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}, - transport::{SseTransport, StdioTransport, Transport}, -}; -use rand::Rng; -use rand::SeedableRng; -use std::time::Duration; -use std::{collections::HashMap, sync::Arc}; -use tracing_subscriber::EnvFilter; - -#[tokio::main] -async fn main() -> Result<(), Box> { - // Initialize logging - tracing_subscriber::fmt() - .with_env_filter( - EnvFilter::from_default_env().add_directive("mcp_client=debug".parse().unwrap()), - ) - .init(); - - let transport1 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new()); - let handle1 = transport1.start().await?; - let client1 = McpClient::connect(handle1, Duration::from_secs(30)).await?; - - let transport2 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new()); - let handle2 = transport2.start().await?; - let client2 = McpClient::connect(handle2, Duration::from_secs(30)).await?; - - let transport3 = SseTransport::new("http://localhost:8000/sse", HashMap::new()); - let handle3 = transport3.start().await?; - let client3 = McpClient::connect(handle3, Duration::from_secs(10)).await?; - - // Initialize both clients - let mut clients: Vec> = - vec![Box::new(client1), Box::new(client2), Box::new(client3)]; - - // Initialize all clients - for (i, client) in clients.iter_mut().enumerate() { - let info = ClientInfo { - name: format!("example-client-{}", i + 1), - version: "1.0.0".to_string(), - }; - let capabilities = ClientCapabilities::default(); - - println!("\nInitializing client {}", i + 1); - let init_result = client.initialize(info, capabilities).await?; - println!("Client {} initialized: {:?}", i + 1, init_result); - } - - // List tools for all clients - for (i, client) in clients.iter_mut().enumerate() { - let tools = client.list_tools(None).await?; - println!("\nClient {} tools: {:?}", i + 1, tools); - } - - println!("\n\n----------------------------------\n\n"); - - // Wrap clients in Arc before spawning tasks - let clients = Arc::new(clients); - let mut handles = vec![]; - - for i in 0..20 { - let clients = Arc::clone(&clients); - let handle = tokio::spawn(async move { - // let mut rng = rand::thread_rng(); - let mut rng = rand::rngs::StdRng::from_entropy(); - tokio::time::sleep(Duration::from_millis(rng.gen_range(5..50))).await; - - // Randomly select an operation - match rng.gen_range(0..4) { - 0 => { - println!("\n{i}: Listing tools for client 1 (stdio)"); - match clients[0].list_tools(None).await { - Ok(tools) => { - println!(" {i}: -> Got tools, first one: {:?}", tools.tools.first()) - } - Err(e) => println!(" {i}: -> Error: {}", e), - } - } - 1 => { - println!("\n{i}: Calling tool for client 2 (stdio)"); - match clients[1] - .call_tool("git_status", serde_json::json!({ "repo_path": "." })) - .await - { - Ok(result) => println!( - " {i}: -> Tool execution result, is_error: {:?}", - result.is_error - ), - Err(e) => println!(" {i}: -> Error: {}", e), - } - } - 2 => { - println!("\n{i}: Listing tools for client 3 (sse)"); - match clients[2].list_tools(None).await { - Ok(tools) => { - println!(" {i}: -> Got tools, first one: {:?}", tools.tools.first()) - } - Err(e) => println!(" {i}: -> Error: {}", e), - } - } - 3 => { - println!("\n{i}: Calling tool for client 3 (sse)"); - match clients[2] - .call_tool( - "echo_tool", - serde_json::json!({ "message": "Client with SSE transport - calling a tool" }), - ) - .await - { - Ok(result) => println!(" {i}: -> Tool execution result, is_error: {:?}", result.is_error), - Err(e) => println!(" {i}: -> Error: {}", e), - } - } - _ => unreachable!(), - } - Ok::<(), Box>(()) - }); - handles.push(handle); - } - - // Wait for all tasks to complete - for handle in handles { - handle.await.unwrap().unwrap(); - } - - Ok(()) -} diff --git a/crates/mcp-client/examples/integration_test.rs b/crates/mcp-client/examples/integration_test.rs deleted file mode 100644 index d5de80abfc4f..000000000000 --- a/crates/mcp-client/examples/integration_test.rs +++ /dev/null @@ -1,167 +0,0 @@ -use anyhow::Result; -use futures::lock::Mutex; -use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}; -use mcp_client::transport::{SseTransport, StreamableHttpTransport, Transport}; -use mcp_client::StdioTransport; -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Duration; -use tracing_subscriber::EnvFilter; - -#[tokio::main] -async fn main() -> Result<()> { - // Initialize logging - tracing_subscriber::fmt() - .with_env_filter( - EnvFilter::from_default_env() - .add_directive("mcp_client=debug".parse().unwrap()) - .add_directive("eventsource_client=info".parse().unwrap()), - ) - .init(); - - test_transport(sse_transport().await?).await?; - test_transport(streamable_http_transport().await?).await?; - test_transport(stdio_transport().await?).await?; - - // Test broken transport - match test_transport(broken_stdio_transport().await?).await { - Ok(_) => panic!("Expected an error but got success"), - Err(e) => { - assert!(e - .to_string() - .contains("error: package(s) `thispackagedoesnotexist` not found in workspace")); - println!("Expected error occurred: {e}"); - } - } - - Ok(()) -} - -async fn sse_transport() -> Result { - let port = "60053"; - - tokio::process::Command::new("npx") - .env("PORT", port) - .arg("@modelcontextprotocol/server-everything") - .arg("sse") - .spawn()?; - tokio::time::sleep(Duration::from_secs(1)).await; - - Ok(SseTransport::new( - format!("http://localhost:{}/sse", port), - HashMap::new(), - )) -} - -async fn streamable_http_transport() -> Result { - let port = "60054"; - - tokio::process::Command::new("npx") - .env("PORT", port) - .arg("@modelcontextprotocol/server-everything") - .arg("streamable-http") - .spawn()?; - tokio::time::sleep(Duration::from_secs(1)).await; - - Ok(StreamableHttpTransport::new( - format!("http://localhost:{}/mcp", port), - HashMap::new(), - )) -} - -async fn stdio_transport() -> Result { - Ok(StdioTransport::new( - "npx", - vec!["@modelcontextprotocol/server-everything"] - .into_iter() - .map(|s| s.to_string()) - .collect(), - HashMap::new(), - )) -} - -async fn broken_stdio_transport() -> Result { - Ok(StdioTransport::new( - "cargo", - vec!["run", "-p", "thispackagedoesnotexist"] - .into_iter() - .map(|s| s.to_string()) - .collect(), - HashMap::new(), - )) -} - -async fn test_transport(transport: T) -> Result<()> -where - T: Transport + Send + 'static, -{ - // Start transport - let handle = transport.start().await?; - - // Create client - let mut client = McpClient::connect(handle, Duration::from_secs(10)).await?; - println!("Client created\n"); - - let mut receiver = client.subscribe().await; - let events = Arc::new(Mutex::new(Vec::new())); - let events_clone = events.clone(); - tokio::spawn(async move { - while let Some(event) = receiver.recv().await { - println!("Received event: {event:?}"); - events_clone.lock().await.push(event); - } - }); - - // Initialize - let server_info = client - .initialize( - ClientInfo { - name: "test-client".into(), - version: "1.0.0".into(), - }, - ClientCapabilities::default(), - ) - .await?; - println!("Connected to server: {server_info:?}\n"); - - // Sleep for 100ms to allow the server to start - surprisingly this is required! - tokio::time::sleep(Duration::from_millis(500)).await; - - // List tools - let tools = client.list_tools(None).await?; - println!("Available tools: {tools:#?}\n"); - - // Call tool - let tool_result = client - .call_tool("echo", serde_json::json!({ "message": "honk" })) - .await?; - println!("Tool result: {tool_result:#?}\n"); - - let collected_eventes_before = events.lock().await.len(); - let n_steps = 5; - let long_op = client - .call_tool( - "longRunningOperation", - serde_json::json!({ "duration": 3, "steps": n_steps }), - ) - .await?; - println!("Long op result: {long_op:#?}\n"); - let collected_events_after = events.lock().await.len(); - assert_eq!(collected_events_after - collected_eventes_before, n_steps); - - let error_result = client - .call_tool("add", serde_json::json!({ "a": "foo", "b": "bar" })) - .await; - assert!(error_result.is_err()); - println!("Error result: {error_result:#?}\n"); - - // List resources - let resources = client.list_resources(None).await?; - println!("Resources: {resources:#?}\n"); - - // Read resource - let resource = client.read_resource("test://static/resource/1").await?; - println!("Resource: {resource:#?}\n"); - - Ok(()) -} diff --git a/crates/mcp-client/examples/sse.rs b/crates/mcp-client/examples/sse.rs deleted file mode 100644 index 6e97a0a69f7d..000000000000 --- a/crates/mcp-client/examples/sse.rs +++ /dev/null @@ -1,66 +0,0 @@ -use anyhow::Result; -use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}; -use mcp_client::transport::{SseTransport, Transport}; -use std::collections::HashMap; -use std::time::Duration; -use tracing_subscriber::EnvFilter; - -#[tokio::main] -async fn main() -> Result<()> { - // Initialize logging - tracing_subscriber::fmt() - .with_env_filter( - EnvFilter::from_default_env() - .add_directive("mcp_client=debug".parse().unwrap()) - .add_directive("eventsource_client=info".parse().unwrap()), - ) - .init(); - - // Create the base transport - let transport = SseTransport::new("http://localhost:8000/sse", HashMap::new()); - - // Start transport - let handle = transport.start().await?; - - // Create client - let mut client = McpClient::connect(handle, Duration::from_secs(3)).await?; - println!("Client created\n"); - - // Initialize - let server_info = client - .initialize( - ClientInfo { - name: "test-client".into(), - version: "1.0.0".into(), - }, - ClientCapabilities::default(), - ) - .await?; - println!("Connected to server: {server_info:?}\n"); - - // Sleep for 100ms to allow the server to start - surprisingly this is required! - tokio::time::sleep(Duration::from_millis(500)).await; - - // List tools - let tools = client.list_tools(None).await?; - println!("Available tools: {tools:?}\n"); - - // Call tool - let tool_result = client - .call_tool( - "echo_tool", - serde_json::json!({ "message": "Client with SSE transport - calling a tool" }), - ) - .await?; - println!("Tool result: {tool_result:?}\n"); - - // List resources - let resources = client.list_resources(None).await?; - println!("Resources: {resources:?}\n"); - - // Read resource - let resource = client.read_resource("echo://fixedresource").await?; - println!("Resource: {resource:?}\n"); - - Ok(()) -} diff --git a/crates/mcp-client/examples/stdio.rs b/crates/mcp-client/examples/stdio.rs deleted file mode 100644 index 9879359791e9..000000000000 --- a/crates/mcp-client/examples/stdio.rs +++ /dev/null @@ -1,58 +0,0 @@ -use std::collections::HashMap; - -use anyhow::Result; -use mcp_client::{ - ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait, - StdioTransport, Transport, -}; -use std::time::Duration; -use tracing_subscriber::EnvFilter; - -#[tokio::main] -async fn main() -> Result<(), ClientError> { - // Initialize logging - tracing_subscriber::fmt() - .with_env_filter( - EnvFilter::from_default_env() - .add_directive("mcp_client=debug".parse().unwrap()) - .add_directive("eventsource_client=debug".parse().unwrap()), - ) - .init(); - - // 1) Create the transport - let transport = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new()); - - // 2) Start the transport to get a handle - let transport_handle = transport.start().await?; - - // 3) Create the client with the middleware-wrapped service - let mut client = McpClient::connect(transport_handle, Duration::from_secs(10)).await?; - - // Initialize - let server_info = client - .initialize( - ClientInfo { - name: "test-client".into(), - version: "1.0.0".into(), - }, - ClientCapabilities::default(), - ) - .await?; - println!("Connected to server: {server_info:?}\n"); - - // List tools - let tools = client.list_tools(None).await?; - println!("Available tools: {tools:?}\n"); - - // Call tool 'git_status' with arguments = {"repo_path": "."} - let tool_result = client - .call_tool("git_status", serde_json::json!({ "repo_path": "." })) - .await?; - println!("Tool result: {tool_result:?}\n"); - - // List resources - let resources = client.list_resources(None).await?; - println!("Available resources: {resources:?}\n"); - - Ok(()) -} diff --git a/crates/mcp-client/examples/stdio_integration.rs b/crates/mcp-client/examples/stdio_integration.rs deleted file mode 100644 index 9b367d25c15b..000000000000 --- a/crates/mcp-client/examples/stdio_integration.rs +++ /dev/null @@ -1,93 +0,0 @@ -// This example shows how to use the mcp-client crate to interact with a server that has a simple counter tool. -// The server is started by running `cargo run -p mcp-server` in the root of the mcp-server crate. -use anyhow::Result; -use mcp_client::client::{ - ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait, -}; -use mcp_client::transport::{StdioTransport, Transport}; -use std::collections::HashMap; -use std::time::Duration; -use tracing_subscriber::EnvFilter; - -#[tokio::main] -async fn main() -> Result<(), ClientError> { - // Initialize logging - tracing_subscriber::fmt() - .with_env_filter( - EnvFilter::from_default_env() - .add_directive("mcp_client=debug".parse().unwrap()) - .add_directive("eventsource_client=debug".parse().unwrap()), - ) - .init(); - - // Create the transport - let transport = StdioTransport::new( - "cargo", - vec!["run", "-p", "mcp-server"] - .into_iter() - .map(|s| s.to_string()) - .collect(), - HashMap::new(), - ); - - // Start the transport to get a handle - let transport_handle = transport.start().await.unwrap(); - - // Create client - let mut client = McpClient::connect(transport_handle, Duration::from_secs(10)).await?; - - // Initialize - let server_info = client - .initialize( - ClientInfo { - name: "test-client".into(), - version: "1.0.0".into(), - }, - ClientCapabilities::default(), - ) - .await?; - println!("Connected to server: {server_info:?}\n"); - - // List tools - let tools = client.list_tools(None).await?; - println!("Available tools: {tools:?}\n"); - - // Call tool 'increment' tool 3 times - for _ in 0..3 { - let increment_result = client.call_tool("increment", serde_json::json!({})).await?; - println!("Tool result for 'increment': {increment_result:?}\n"); - } - - // Call tool 'get_value' - let get_value_result = client.call_tool("get_value", serde_json::json!({})).await?; - println!("Tool result for 'get_value': {get_value_result:?}\n"); - - // Call tool 'decrement' once - let decrement_result = client.call_tool("decrement", serde_json::json!({})).await?; - println!("Tool result for 'decrement': {decrement_result:?}\n"); - - // Call tool 'get_value' - let get_value_result = client.call_tool("get_value", serde_json::json!({})).await?; - println!("Tool result for 'get_value': {get_value_result:?}\n"); - - // List resources - let resources = client.list_resources(None).await?; - println!("Resources: {resources:?}\n"); - - // Read resource - let resource = client.read_resource("memo://insights").await?; - println!("Resource: {resource:?}\n"); - - let prompts = client.list_prompts(None).await?; - println!("Prompts: {prompts:?}\n"); - - let prompt = client - .get_prompt( - "example_prompt", - serde_json::json!({"message": "hello there!"}), - ) - .await?; - println!("Prompt: {prompt:?}\n"); - - Ok(()) -} diff --git a/crates/mcp-client/examples/streamable_http.rs b/crates/mcp-client/examples/streamable_http.rs deleted file mode 100644 index 0fd856ba661b..000000000000 --- a/crates/mcp-client/examples/streamable_http.rs +++ /dev/null @@ -1,93 +0,0 @@ -use anyhow::Result; -use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}; -use mcp_client::transport::{StreamableHttpTransport, Transport}; -use std::collections::HashMap; -use std::time::Duration; -use tracing_subscriber::EnvFilter; - -#[tokio::main] -async fn main() -> Result<()> { - // Initialize logging - tracing_subscriber::fmt() - .with_env_filter( - EnvFilter::from_default_env() - .add_directive("mcp_client=debug".parse().unwrap()) - .add_directive("eventsource_client=info".parse().unwrap()), - ) - .init(); - - // Create example headers - let mut headers = HashMap::new(); - headers.insert("X-Custom-Header".to_string(), "example-value".to_string()); - headers.insert( - "User-Agent".to_string(), - "MCP-StreamableHttp-Client/1.0".to_string(), - ); - - // Create the Streamable HTTP transport with headers - let transport = - StreamableHttpTransport::with_headers("http://localhost:8000/mcp", HashMap::new(), headers); - - // Start transport - let handle = transport.start().await?; - - // Create client - let mut client = McpClient::connect(handle, Duration::from_secs(10)).await?; - println!("Client created with Streamable HTTP transport\n"); - - // Initialize - let server_info = client - .initialize( - ClientInfo { - name: "streamable-http-client".into(), - version: "1.0.0".into(), - }, - ClientCapabilities::default(), - ) - .await?; - println!("Connected to server: {server_info:?}\n"); - - // Give the server a moment to fully initialize - tokio::time::sleep(Duration::from_millis(500)).await; - - // List tools - let tools = client.list_tools(None).await?; - println!("Available tools: {tools:?}\n"); - - // Call tool if available - if !tools.tools.is_empty() { - let tool_result = client - .call_tool( - &tools.tools[0].name, - serde_json::json!({ "message": "Hello from Streamable HTTP transport!" }), - ) - .await?; - println!("Tool result: {tool_result:?}\n"); - } - - // List resources - let resources = client.list_resources(None).await?; - println!("Resources: {resources:?}\n"); - - // Read resource if available - if !resources.resources.is_empty() { - let resource = client.read_resource(&resources.resources[0].uri).await?; - println!("Resource content: {resource:?}\n"); - } - - // List prompts - let prompts = client.list_prompts(None).await?; - println!("Available prompts: {prompts:?}\n"); - - // Get prompt if available - if !prompts.prompts.is_empty() { - let prompt_result = client - .get_prompt(&prompts.prompts[0].name, serde_json::json!({})) - .await?; - println!("Prompt result: {prompt_result:?}\n"); - } - - println!("Streamable HTTP transport example completed successfully!"); - - Ok(()) -} diff --git a/crates/mcp-client/examples/test_auth.rs b/crates/mcp-client/examples/test_auth.rs deleted file mode 100644 index b4159d41224f..000000000000 --- a/crates/mcp-client/examples/test_auth.rs +++ /dev/null @@ -1,64 +0,0 @@ -use anyhow::Result; -use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}; -use mcp_client::transport::{StreamableHttpTransport, Transport}; -use std::collections::HashMap; -use std::time::Duration; -use tracing_subscriber::EnvFilter; - -#[tokio::main] -async fn main() -> Result<()> { - // Initialize logging - tracing_subscriber::fmt() - .with_env_filter( - EnvFilter::from_default_env() - .add_directive("mcp_client=debug".parse().unwrap()) - .add_directive("eventsource_client=info".parse().unwrap()), - ) - .init(); - - println!("Testing Streamable HTTP transport with OAuth 2.0 authentication..."); - - // Create the Streamable HTTP transport for any MCP service that supports OAuth - // This example uses a hypothetical MCP endpoint - replace with actual service - let mcp_endpoint = - std::env::var("MCP_ENDPOINT").unwrap_or_else(|_| "https://example.com/mcp".to_string()); - - println!("Using MCP endpoint: {}", mcp_endpoint); - - let transport = StreamableHttpTransport::new(&mcp_endpoint, HashMap::new()); - - // Start transport - let handle = transport.start().await?; - - // Create client - let mut client = McpClient::connect(handle, Duration::from_secs(30)).await?; - println!("Client created with Streamable HTTP transport\n"); - - // Initialize - this will trigger the OAuth flow if authentication is needed - // The implementation now includes: - // - RFC 8707 Resource Parameter support for proper token audience binding - // - Proper OAuth 2.0 discovery with multiple fallback paths - // - Dynamic client registration (RFC 7591) - // - PKCE for security (RFC 7636) - // - MCP-Protocol-Version header as required by the specification - let server_info = client - .initialize( - ClientInfo { - name: "streamable-http-auth-test".into(), - version: "1.0.0".into(), - }, - ClientCapabilities::default(), - ) - .await?; - - println!("Connected to server: {server_info:?}\n"); - println!("OAuth 2.0 authentication test completed successfully!"); - println!("\nKey improvements implemented:"); - println!("✓ RFC 8707 Resource Parameter implementation"); - println!("✓ MCP-Protocol-Version header support"); - println!("✓ Enhanced OAuth discovery with multiple fallback paths"); - println!("✓ Proper canonical resource URI generation"); - println!("✓ Full compliance with MCP Authorization specification"); - - Ok(()) -} diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index 2c2de1f483f8..e1253855162c 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -1,97 +1,30 @@ -use mcp_core::protocol::{ - CallToolResult, Implementation, InitializeResult, ListPromptsResult, ListResourcesResult, - ListToolsResult, ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND, -}; -use rmcp::model::{ - GetPromptResult, JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, - JsonRpcResponse, JsonRpcVersion2_0, Notification, NumberOrString, Request, RequestId, - ServerNotification, +use rmcp::{ + model::{ + CallToolRequest, CallToolRequestParam, CallToolResult, ClientCapabilities, ClientInfo, + ClientRequest, GetPromptRequest, GetPromptRequestParam, GetPromptResult, Implementation, + InitializeResult, ListPromptsRequest, ListPromptsResult, ListResourcesRequest, + ListResourcesResult, ListToolsRequest, ListToolsResult, LoggingMessageNotification, + LoggingMessageNotificationMethod, PaginatedRequestParam, ProgressNotification, + ProgressNotificationMethod, ProtocolVersion, ReadResourceRequest, ReadResourceRequestParam, + ReadResourceResult, ServerNotification, ServerResult, + }, + service::{ClientInitializeError, PeerRequestOptions, RunningService}, + transport::IntoTransport, + ClientHandler, RoleClient, ServiceError, ServiceExt, }; -use serde::{Deserialize, Serialize}; -use serde_json::{json, Value}; -use std::sync::{ - atomic::{AtomicU64, Ordering}, - Arc, +use serde_json::Value; +use std::sync::Arc; +use tokio::sync::{ + mpsc::{self, Sender}, + Mutex, }; -use thiserror::Error; -use tokio::sync::{mpsc, Mutex}; -use tower::{timeout::TimeoutLayer, Layer, Service, ServiceExt}; - -use crate::{McpService, TransportHandle}; pub type BoxError = Box; -/// Error type for MCP client operations. -#[derive(Debug, Error)] -pub enum Error { - #[error("Transport error: {0}")] - Transport(#[from] super::transport::Error), - - #[error("RPC error: code={code}, message={message}")] - RpcError { code: i32, message: String }, - - #[error("Serialization error: {0}")] - Serialization(#[from] serde_json::Error), - - #[error("Unexpected response from server: {0}")] - UnexpectedResponse(String), - - #[error("Not initialized")] - NotInitialized, - - #[error("Timeout or service not ready")] - NotReady, - - #[error("Request timed out")] - Timeout(#[from] tower::timeout::error::Elapsed), - - #[error("Error from mcp-server: {0}")] - ServerBoxError(BoxError), - - #[error("Call to '{server}' failed for '{method}'. {source}")] - McpServerError { - method: String, - server: String, - #[source] - source: BoxError, - }, -} - -// BoxError from mcp-server gets converted to our Error type -impl From for Error { - fn from(err: BoxError) -> Self { - Error::ServerBoxError(err) - } -} - -#[derive(Serialize, Deserialize)] -pub struct ClientInfo { - pub name: String, - pub version: String, -} - -#[derive(Serialize, Deserialize, Default)] -pub struct ClientCapabilities { - // Add fields as needed. For now, empty capabilities are fine. -} - -#[derive(Serialize, Deserialize)] -pub struct InitializeParams { - #[serde(rename = "protocolVersion")] - pub protocol_version: String, - pub capabilities: ClientCapabilities, - #[serde(rename = "clientInfo")] - pub client_info: ClientInfo, -} +pub type Error = rmcp::ServiceError; #[async_trait::async_trait] pub trait McpClientTrait: Send + Sync { - async fn initialize( - &mut self, - info: ClientInfo, - capabilities: ClientCapabilities, - ) -> Result; - async fn list_resources( &self, next_cursor: Option, @@ -108,347 +41,268 @@ pub trait McpClientTrait: Send + Sync { async fn get_prompt(&self, name: &str, arguments: Value) -> Result; async fn subscribe(&self) -> mpsc::Receiver; -} -/// The MCP client is the interface for MCP operations. -pub struct McpClient -where - T: TransportHandle + Send + Sync + 'static, -{ - service: Mutex>>, - next_id_counter: AtomicU64, // Added for atomic ID generation - server_capabilities: Option, - server_info: Option, - notification_subscribers: Arc>>>, + fn get_info(&self) -> Option<&InitializeResult>; } -impl McpClient -where - T: TransportHandle + Send + Sync + 'static, -{ - pub async fn connect(transport: T, timeout: std::time::Duration) -> Result { - let service = McpService::new(transport.clone()); - let service_ptr = service.clone(); - let notification_subscribers = - Arc::new(Mutex::new(Vec::>::new())); - let subscribers_ptr = notification_subscribers.clone(); - - tokio::spawn(async move { - loop { - match transport.receive().await { - Ok(message) => { - tracing::info!("Received message: {:?}", message); - match message { - JsonRpcMessage::Response(JsonRpcResponse { - id: NumberOrString::Number(id), - .. - }) - | JsonRpcMessage::Error(JsonRpcError { - id: NumberOrString::Number(id), - .. - }) => { - service_ptr.respond(&id.to_string(), Ok(message)).await; - } - JsonRpcMessage::Notification(JsonRpcNotification { - notification, - .. - }) => { - let mut subs = subscribers_ptr.lock().await; - if let Some(server_notification) = notification.into() { - subs.retain(|sub| { - sub.try_send(server_notification.clone()).is_ok() - }); - } - } - _ => { - tracing::warn!( - "Received unexpected received message type: {:?}", - message - ); - } - } - } - Err(e) => { - service_ptr.hangup(e).await; - subscribers_ptr.lock().await.clear(); - break; - } - } - } - }); - - let middleware = TimeoutLayer::new(timeout); +pub struct GooseClient { + notification_handlers: Arc>>>, +} - Ok(Self { - service: Mutex::new(middleware.layer(service)), - next_id_counter: AtomicU64::new(1), - server_capabilities: None, - server_info: None, - notification_subscribers, - }) +impl GooseClient { + pub fn new(handlers: Arc>>>) -> Self { + GooseClient { + notification_handlers: handlers, + } } +} - /// Send a JSON-RPC request and check we don't get an error response. - async fn send_request(&self, method: &str, params: Value) -> Result - where - R: for<'de> Deserialize<'de>, - { - let mut service = self.service.lock().await; - service.ready().await.map_err(|_| Error::NotReady)?; - let id_num = self.next_id_counter.fetch_add(1, Ordering::SeqCst); - let id = RequestId::Number(id_num as u32); - - let mut params = params.clone(); - params["_meta"] = json!({ - "progressToken": format!("prog-{}", id), - }); - - let request = JsonRpcMessage::Request(JsonRpcRequest { - jsonrpc: JsonRpcVersion2_0, - id, - request: Request { - method: method.to_string(), - params: params.as_object().unwrap().clone(), - extensions: Default::default(), - }, - }); +impl ClientHandler for GooseClient { + async fn on_progress( + &self, + params: rmcp::model::ProgressNotificationParam, + context: rmcp::service::NotificationContext, + ) -> () { + self.notification_handlers + .lock() + .await + .iter() + .for_each(|handler| { + let _ = handler.try_send(ServerNotification::ProgressNotification( + ProgressNotification { + params: params.clone(), + method: ProgressNotificationMethod, + extensions: context.extensions.clone(), + }, + )); + }); + } - let response_msg = service - .call(request) + async fn on_logging_message( + &self, + params: rmcp::model::LoggingMessageNotificationParam, + context: rmcp::service::NotificationContext, + ) -> () { + self.notification_handlers + .lock() .await - .map_err(|e| Error::McpServerError { - server: self - .server_info - .as_ref() - .map(|s| s.name.clone()) - .unwrap_or("".to_string()), - method: method.to_string(), - // we don't need include params because it can be really large - source: Box::::new(e.into()), - })?; + .iter() + .for_each(|handler| { + let _ = handler.try_send(ServerNotification::LoggingMessageNotification( + LoggingMessageNotification { + params: params.clone(), + method: LoggingMessageNotificationMethod, + extensions: context.extensions.clone(), + }, + )); + }); + } - match response_msg { - JsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => { - // Verify id matches - convert current id to match expected format - let expected_id = RequestId::Number((id_num) as u32); - if id != expected_id { - return Err(Error::UnexpectedResponse( - "id mismatch for JsonRpcResponse".to_string(), - )); - } - Ok(serde_json::from_value(serde_json::to_value(result)?)?) - } - JsonRpcMessage::Error(JsonRpcError { id, error, .. }) => { - let expected_id = RequestId::Number((id_num) as u32); - if id != expected_id { - return Err(Error::UnexpectedResponse( - "id mismatch for JsonRpcError".to_string(), - )); - } - Err(Error::RpcError { - code: error.code.0, // Extract the i32 from ErrorCode - message: error.message.to_string(), // Convert Cow to String - }) - } - _ => { - // Requests/notifications not expected as a response - Err(Error::UnexpectedResponse( - "unexpected message type".to_string(), - )) - } + fn get_info(&self) -> ClientInfo { + ClientInfo { + protocol_version: ProtocolVersion::V_2025_03_26, + capabilities: ClientCapabilities::builder().build(), + client_info: Implementation { + name: "goose".to_string(), + version: env!("CARGO_PKG_VERSION").to_owned(), + }, } } +} - /// Send a JSON-RPC notification. - async fn send_notification(&self, method: &str, params: Value) -> Result<(), Error> { - let mut service = self.service.lock().await; - service.ready().await.map_err(|_| Error::NotReady)?; +/// The MCP client is the interface for MCP operations. +pub struct McpClient { + client: Mutex>, + notification_subscribers: Arc>>>, + server_info: Option, + timeout: std::time::Duration, +} - let notification = JsonRpcMessage::Notification(JsonRpcNotification { - jsonrpc: JsonRpcVersion2_0, - notification: Notification { - method: method.to_string(), - params: params.as_object().unwrap().clone(), - extensions: Default::default(), - }, - }); +impl McpClient { + pub async fn connect( + transport: T, + timeout: std::time::Duration, + ) -> Result + where + T: IntoTransport, + E: std::error::Error + From + Send + Sync + 'static, + { + let notification_subscribers = + Arc::new(Mutex::new(Vec::>::new())); - service - .call(notification) - .await - .map_err(|e| Error::McpServerError { - server: self - .server_info - .as_ref() - .map(|s| s.name.clone()) - .unwrap_or("".to_string()), - method: method.to_string(), - // we don't need include params because it can be really large - source: Box::::new(e.into()), - })?; + let client = GooseClient::new(notification_subscribers.clone()); + let client: rmcp::service::RunningService = + client.serve(transport).await?; + let server_info = client.peer_info().cloned(); - Ok(()) + Ok(Self { + client: Mutex::new(client), + notification_subscribers, + server_info, + timeout, + }) } - // Check if the client has completed initialization - fn completed_initialization(&self) -> bool { - self.server_capabilities.is_some() + fn get_request_options(&self) -> PeerRequestOptions { + PeerRequestOptions { + timeout: Some(self.timeout), + meta: None, + } } } #[async_trait::async_trait] -impl McpClientTrait for McpClient -where - T: TransportHandle + Send + Sync + 'static, -{ - async fn initialize( - &mut self, - info: ClientInfo, - capabilities: ClientCapabilities, - ) -> Result { - let params = InitializeParams { - protocol_version: "2025-03-26".to_string(), - client_info: info, - capabilities, - }; - let result: InitializeResult = self - .send_request("initialize", serde_json::to_value(params)?) - .await?; - - self.send_notification("notifications/initialized", serde_json::json!({})) - .await?; - - self.server_capabilities = Some(result.capabilities.clone()); - - self.server_info = Some(result.server_info.clone()); - - Ok(result) +impl McpClientTrait for McpClient { + fn get_info(&self) -> Option<&InitializeResult> { + self.server_info.as_ref() } - async fn list_resources( - &self, - next_cursor: Option, - ) -> Result { - if !self.completed_initialization() { - return Err(Error::NotInitialized); - } - // If resources is not supported, return an empty list - if self - .server_capabilities - .as_ref() - .unwrap() - .resources - .is_none() - { - return Ok(ListResourcesResult { - resources: vec![], - next_cursor: None, - }); + async fn list_resources(&self, cursor: Option) -> Result { + let res = self + .client + .lock() + .await + .send_request_with_option( + ClientRequest::ListResourcesRequest(ListResourcesRequest { + params: Some(PaginatedRequestParam { cursor }), + method: Default::default(), + extensions: Default::default(), + }), + self.get_request_options(), + ) + .await? + .await_response() + .await?; + match res { + ServerResult::ListResourcesResult(result) => Ok(result), + _ => Err(ServiceError::UnexpectedResponse), } - - let payload = next_cursor - .map(|cursor| serde_json::json!({"cursor": cursor})) - .unwrap_or_else(|| serde_json::json!({})); - - self.send_request("resources/list", payload).await } async fn read_resource(&self, uri: &str) -> Result { - if !self.completed_initialization() { - return Err(Error::NotInitialized); - } - // If resources is not supported, return an error - if self - .server_capabilities - .as_ref() - .unwrap() - .resources - .is_none() - { - return Err(Error::RpcError { - code: METHOD_NOT_FOUND, - message: "Server does not support 'resources' capability".to_string(), - }); + let res = self + .client + .lock() + .await + .send_request_with_option( + ClientRequest::ReadResourceRequest(ReadResourceRequest { + params: ReadResourceRequestParam { + uri: uri.to_string(), + }, + method: Default::default(), + extensions: Default::default(), + }), + self.get_request_options(), + ) + .await? + .await_response() + .await?; + match res { + ServerResult::ReadResourceResult(result) => Ok(result), + _ => Err(ServiceError::UnexpectedResponse), } - - let params = serde_json::json!({ "uri": uri }); - self.send_request("resources/read", params).await } - async fn list_tools(&self, next_cursor: Option) -> Result { - if !self.completed_initialization() { - return Err(Error::NotInitialized); - } - // If tools is not supported, return an empty list - if self.server_capabilities.as_ref().unwrap().tools.is_none() { - return Ok(ListToolsResult { - tools: vec![], - next_cursor: None, - }); + async fn list_tools(&self, cursor: Option) -> Result { + let res = self + .client + .lock() + .await + .send_request_with_option( + ClientRequest::ListToolsRequest(ListToolsRequest { + params: Some(PaginatedRequestParam { cursor }), + method: Default::default(), + extensions: Default::default(), + }), + self.get_request_options(), + ) + .await? + .await_response() + .await?; + match res { + ServerResult::ListToolsResult(result) => Ok(result), + _ => Err(ServiceError::UnexpectedResponse), } - - let payload = next_cursor - .map(|cursor| serde_json::json!({"cursor": cursor})) - .unwrap_or_else(|| serde_json::json!({})); - - self.send_request("tools/list", payload).await } async fn call_tool(&self, name: &str, arguments: Value) -> Result { - if !self.completed_initialization() { - return Err(Error::NotInitialized); - } - // If tools is not supported, return an error - if self.server_capabilities.as_ref().unwrap().tools.is_none() { - return Err(Error::RpcError { - code: METHOD_NOT_FOUND, - message: "Server does not support 'tools' capability".to_string(), - }); + let arguments = match arguments { + Value::Object(map) => Some(map), + _ => None, + }; + let res = self + .client + .lock() + .await + .send_request_with_option( + ClientRequest::CallToolRequest(CallToolRequest { + params: CallToolRequestParam { + name: name.to_string().into(), + arguments, + }, + method: Default::default(), + extensions: Default::default(), + }), + self.get_request_options(), + ) + .await? + .await_response() + .await?; + match res { + ServerResult::CallToolResult(result) => Ok(result), + _ => Err(ServiceError::UnexpectedResponse), } - - let params = serde_json::json!({ "name": name, "arguments": arguments }); - - // TODO ERROR: check that if there is an error, we send back is_error: true with msg - // https://modelcontextprotocol.io/docs/concepts/tools#error-handling-2 - self.send_request("tools/call", params).await } - async fn list_prompts(&self, next_cursor: Option) -> Result { - if !self.completed_initialization() { - return Err(Error::NotInitialized); - } - - // If prompts is not supported, return an error - if self.server_capabilities.as_ref().unwrap().prompts.is_none() { - return Err(Error::RpcError { - code: METHOD_NOT_FOUND, - message: "Server does not support 'prompts' capability".to_string(), - }); + async fn list_prompts(&self, cursor: Option) -> Result { + let res = self + .client + .lock() + .await + .send_request_with_option( + ClientRequest::ListPromptsRequest(ListPromptsRequest { + params: Some(PaginatedRequestParam { cursor }), + method: Default::default(), + extensions: Default::default(), + }), + self.get_request_options(), + ) + .await? + .await_response() + .await?; + match res { + ServerResult::ListPromptsResult(result) => Ok(result), + _ => Err(ServiceError::UnexpectedResponse), } - - let payload = next_cursor - .map(|cursor| serde_json::json!({"cursor": cursor})) - .unwrap_or_else(|| serde_json::json!({})); - - self.send_request("prompts/list", payload).await } async fn get_prompt(&self, name: &str, arguments: Value) -> Result { - if !self.completed_initialization() { - return Err(Error::NotInitialized); - } - - // If prompts is not supported, return an error - if self.server_capabilities.as_ref().unwrap().prompts.is_none() { - return Err(Error::RpcError { - code: METHOD_NOT_FOUND, - message: "Server does not support 'prompts' capability".to_string(), - }); + let arguments = match arguments { + Value::Object(map) => Some(map), + _ => None, + }; + let res = self + .client + .lock() + .await + .send_request_with_option( + ClientRequest::GetPromptRequest(GetPromptRequest { + params: GetPromptRequestParam { + name: name.to_string(), + arguments, + }, + method: Default::default(), + extensions: Default::default(), + }), + self.get_request_options(), + ) + .await? + .await_response() + .await?; + match res { + ServerResult::GetPromptResult(result) => Ok(result), + _ => Err(ServiceError::UnexpectedResponse), } - - let params = serde_json::json!({ "name": name, "arguments": arguments }); - - self.send_request("prompts/get", params).await } async fn subscribe(&self) -> mpsc::Receiver { diff --git a/crates/mcp-client/src/lib.rs b/crates/mcp-client/src/lib.rs index b659ac3753a1..6757b984be10 100644 --- a/crates/mcp-client/src/lib.rs +++ b/crates/mcp-client/src/lib.rs @@ -1,14 +1,8 @@ pub mod client; pub mod oauth; -pub mod service; -pub mod transport; #[cfg(test)] mod oauth_tests; -pub use client::{ClientCapabilities, ClientInfo, Error, McpClient, McpClientTrait}; +pub use client::{Error, McpClient, McpClientTrait}; pub use oauth::{authenticate_service, ServiceConfig}; -pub use service::McpService; -pub use transport::{ - SseTransport, StdioTransport, StreamableHttpTransport, Transport, TransportHandle, -}; diff --git a/crates/mcp-client/src/service.rs b/crates/mcp-client/src/service.rs deleted file mode 100644 index 0a995c5db17b..000000000000 --- a/crates/mcp-client/src/service.rs +++ /dev/null @@ -1,145 +0,0 @@ -use futures::future::BoxFuture; -use rmcp::model::{JsonRpcMessage, JsonRpcRequest}; -use std::collections::HashMap; -use std::sync::Arc; -use std::task::{Context, Poll}; -use tokio::sync::{oneshot, RwLock}; -use tower::{timeout::Timeout, Service, ServiceBuilder}; - -use crate::transport::{Error, TransportHandle, TransportMessageRecv}; - -/// A wrapper service that implements Tower's Service trait for MCP transport -#[derive(Clone)] -pub struct McpService { - inner: Arc, - pending_requests: Arc, -} - -impl McpService { - pub fn new(transport: T) -> Self { - Self { - inner: Arc::new(transport), - pending_requests: Arc::new(PendingRequests::default()), - } - } - - pub async fn respond(&self, id: &str, response: Result) { - self.pending_requests.respond(id, response).await - } - - pub async fn hangup(&self, error: Error) { - self.pending_requests.broadcast_close(error).await - } -} - -impl Service for McpService -where - T: TransportHandle + Send + Sync + 'static, -{ - type Response = TransportMessageRecv; - type Error = Error; - type Future = BoxFuture<'static, Result>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - // Most transports are always ready, but this could be customized if needed - Poll::Ready(Ok(())) - } - - fn call(&mut self, request: JsonRpcMessage) -> Self::Future { - let transport = self.inner.clone(); - let pending_requests = self.pending_requests.clone(); - - Box::pin(async move { - match &request { - JsonRpcMessage::Request(JsonRpcRequest { id, .. }) => { - // Create a channel to receive the response - let (sender, receiver) = oneshot::channel(); - pending_requests.insert(id.to_string(), sender).await; - - transport.send(request).await?; - receiver.await.map_err(|_| Error::ChannelClosed)? - } - JsonRpcMessage::Notification(_) => { - // Handle notifications without waiting for a response - transport.send(request).await?; - // Return a dummy response for notifications - let dummy_response: Self::Response = - JsonRpcMessage::Response(rmcp::model::JsonRpcResponse { - jsonrpc: rmcp::model::JsonRpcVersion2_0, - id: rmcp::model::RequestId::Number(0), - result: serde_json::Map::new(), - }); - Ok(dummy_response) - } - _ => Err(Error::UnsupportedMessage), - } - }) - } -} - -// Add a convenience constructor for creating a service with timeout -impl McpService -where - T: TransportHandle, -{ - pub fn with_timeout(transport: T, timeout: std::time::Duration) -> Timeout> { - ServiceBuilder::new() - .timeout(timeout) - .service(McpService::new(transport)) - } -} - -// A data structure to store pending requests and their response channels -pub struct PendingRequests { - requests: RwLock>>>, -} - -impl Default for PendingRequests { - fn default() -> Self { - Self::new() - } -} - -impl PendingRequests { - pub fn new() -> Self { - Self { - requests: RwLock::new(HashMap::new()), - } - } - - pub async fn insert( - &self, - id: String, - sender: oneshot::Sender>, - ) { - self.requests.write().await.insert(id, sender); - } - - pub async fn respond(&self, id: &str, response: Result) { - if let Some(tx) = self.requests.write().await.remove(id) { - let _ = tx.send(response); - } - } - - pub async fn broadcast_close(&self, error: Error) { - for (_, tx) in self.requests.write().await.drain() { - let err = match &error { - Error::StdioProcessError(s) => Error::StdioProcessError(s.clone()), - _ => Error::ChannelClosed, - }; - let _ = tx.send(Err(err)); - } - } - - pub async fn clear(&self) { - self.requests.write().await.clear(); - } - - pub async fn len(&self) -> usize { - self.requests.read().await.len() - } - - pub async fn is_empty(&self) -> bool { - self.len().await == 0 - } -} diff --git a/crates/mcp-client/src/transport/mod.rs b/crates/mcp-client/src/transport/mod.rs deleted file mode 100644 index c36cb07912f4..000000000000 --- a/crates/mcp-client/src/transport/mod.rs +++ /dev/null @@ -1,82 +0,0 @@ -use async_trait::async_trait; -use rmcp::model::{JsonObject, JsonRpcMessage, Request, ServerNotification}; -use thiserror::Error; -use tokio::sync::mpsc; - -pub type BoxError = Box; -/// A generic error type for transport operations. -#[derive(Debug, Error)] -pub enum Error { - #[error("I/O error: {0}")] - Io(#[from] std::io::Error), - - #[error("Transport was not connected or is already closed")] - NotConnected, - - #[error("Channel closed")] - ChannelClosed, - - #[error("Serialization error: {0}")] - Serialization(#[from] serde_json::Error), - - #[error("Unsupported message type. JsonRpcMessage can only be Request or Notification.")] - UnsupportedMessage, - - #[error("Stdio process error: {0}")] - StdioProcessError(String), - - #[error("SSE connection error: {0}")] - SseConnection(String), - - #[error("HTTP error: {status} - {message}")] - HttpError { status: u16, message: String }, - - #[error("Streamable HTTP error: {0}")] - StreamableHttpError(String), - - #[error("Session error: {0}")] - SessionError(String), -} - -/// A generic asynchronous transport trait with channel-based communication -#[async_trait] -pub trait Transport { - type Handle: TransportHandle; - - /// Start the transport and establish the underlying connection. - /// Returns the transport handle for sending messages. - async fn start(&self) -> Result; - - /// Close the transport and free any resources. - async fn close(&self) -> Result<(), Error>; -} - -pub type TransportMessageRecv = JsonRpcMessage; - -#[async_trait] -pub trait TransportHandle: Send + Sync + Clone + 'static { - async fn send(&self, message: JsonRpcMessage) -> Result<(), Error>; - async fn receive(&self) -> Result; -} - -pub async fn serialize_and_send( - sender: &mpsc::Sender, - message: JsonRpcMessage, -) -> Result<(), Error> { - match serde_json::to_string(&message).map_err(Error::Serialization) { - Ok(msg) => sender.send(msg).await.map_err(|_| Error::ChannelClosed), - Err(e) => { - tracing::error!(error = ?e, "Error serializing message"); - Err(e) - } - } -} - -pub mod stdio; -pub use stdio::StdioTransport; - -pub mod sse; -pub use sse::SseTransport; - -pub mod streamable_http; -pub use streamable_http::StreamableHttpTransport; diff --git a/crates/mcp-client/src/transport/sse.rs b/crates/mcp-client/src/transport/sse.rs deleted file mode 100644 index 6ef2f7d8f7d6..000000000000 --- a/crates/mcp-client/src/transport/sse.rs +++ /dev/null @@ -1,280 +0,0 @@ -use crate::transport::{Error, TransportMessageRecv}; -use async_trait::async_trait; -use eventsource_client::{Client, SSE}; -use futures::TryStreamExt; -use reqwest::Client as HttpClient; -use rmcp::model::JsonRpcMessage; -use std::collections::HashMap; -use std::sync::Arc; -use tokio::sync::{mpsc, Mutex, RwLock}; -use tokio::time::{timeout, Duration}; -use tracing::warn; -use url::Url; - -use super::{serialize_and_send, Transport, TransportHandle}; - -// Timeout for the endpoint discovery -const ENDPOINT_TIMEOUT_SECS: u64 = 5; - -/// The SSE-based actor that continuously: -/// - Reads incoming events from the SSE stream. -/// - Sends outgoing messages via HTTP POST (once the post endpoint is known). -pub struct SseActor { - /// Receives messages (requests/notifications) from the handle - receiver: mpsc::Receiver, - /// Sends messages (responses) back to the handle - sender: mpsc::Sender, - /// Base SSE URL - sse_url: String, - /// For sending HTTP POST requests - http_client: HttpClient, - /// The discovered endpoint for POST requests (once "endpoint" SSE event arrives) - post_endpoint: Arc>>, -} - -impl SseActor { - pub fn new( - receiver: mpsc::Receiver, - sender: mpsc::Sender, - sse_url: String, - post_endpoint: Arc>>, - ) -> Self { - Self { - receiver, - sender, - sse_url, - post_endpoint, - http_client: HttpClient::new(), - } - } - - /// The main entry point for the actor. Spawns two concurrent loops: - /// 1) handle_incoming_messages (SSE events) - /// 2) handle_outgoing_messages (sending messages via POST) - pub async fn run(self) { - tokio::join!( - Self::handle_incoming_messages( - self.sender, - self.sse_url.clone(), - Arc::clone(&self.post_endpoint) - ), - Self::handle_outgoing_messages( - self.receiver, - self.http_client.clone(), - Arc::clone(&self.post_endpoint), - ) - ); - } - - /// Continuously reads SSE events from `sse_url`. - /// - If an `endpoint` event is received, store it in `post_endpoint`. - /// - If a `message` event is received, parse it as `JsonRpcMessage` - /// and respond to pending requests if it's a `Response`. - async fn handle_incoming_messages( - sender: mpsc::Sender, - sse_url: String, - post_endpoint: Arc>>, - ) { - let client = match eventsource_client::ClientBuilder::for_url(&sse_url) { - Ok(builder) => builder.build(), - Err(e) => { - warn!("Failed to connect SSE client: {}", e); - return; - } - }; - let mut stream = client.stream(); - - // First, wait for the "endpoint" event - while let Ok(Some(event)) = stream.try_next().await { - match event { - SSE::Event(e) if e.event_type == "endpoint" => { - // SSE server uses the "endpoint" event to tell us the POST URL - let base_url = Url::parse(&sse_url).expect("Invalid base URL"); - let post_url = base_url - .join(&e.data) - .expect("Failed to resolve endpoint URL"); - - tracing::debug!("Discovered SSE POST endpoint: {}", post_url); - *post_endpoint.write().await = Some(post_url.to_string()); - break; - } - _ => continue, - } - } - - // Now handle subsequent events - loop { - match stream.try_next().await { - Ok(Some(event)) => { - match event { - SSE::Event(e) if e.event_type == "message" => { - // Attempt to parse the SSE data as a JsonRpcMessage - match serde_json::from_str::(&e.data) { - Ok(message) => { - let _ = sender.send(message).await; - } - Err(err) => { - warn!("Failed to parse SSE message: {err}"); - } - } - } - _ => { /* ignore other events */ } - } - } - Ok(None) => { - // Stream ended - tracing::info!("SSE stream ended."); - break; - } - Err(e) => { - warn!("Error reading SSE stream: {e}"); - break; - } - } - } - - tracing::error!("SSE stream ended or encountered an error."); - } - - async fn handle_outgoing_messages( - mut receiver: mpsc::Receiver, - http_client: HttpClient, - post_endpoint: Arc>>, - ) { - while let Some(message_str) = receiver.recv().await { - let post_url = match post_endpoint.read().await.as_ref() { - Some(url) => url.clone(), - None => { - // TODO: the endpoint isn't discovered yet. This shouldn't happen -- we only return the handle - // after the endpoint is set. - continue; - } - }; - - // Perform the HTTP POST - match http_client - .post(&post_url) - .header("Content-Type", "application/json") - .body(message_str) - .send() - .await - { - Ok(resp) => { - if !resp.status().is_success() { - let err = Error::HttpError { - status: resp.status().as_u16(), - message: resp.status().to_string(), - }; - warn!("HTTP request returned error: {err}"); - // This doesn't directly fail the request, - // because we rely on SSE to deliver the error response - } - } - Err(e) => { - warn!("HTTP POST failed: {e}"); - // Similarly, SSE might eventually reveal the error - } - } - } - - tracing::info!("SseActor shut down."); - } -} - -#[derive(Clone)] -pub struct SseTransportHandle { - sender: mpsc::Sender, - receiver: Arc>>, -} - -#[async_trait::async_trait] -impl TransportHandle for SseTransportHandle { - async fn send(&self, message: JsonRpcMessage) -> Result<(), Error> { - serialize_and_send(&self.sender, message).await - } - - async fn receive(&self) -> Result { - let mut receiver = self.receiver.lock().await; - receiver.recv().await.ok_or(Error::ChannelClosed) - } -} - -#[derive(Clone)] -pub struct SseTransport { - sse_url: String, - env: HashMap, -} - -/// The SSE transport spawns an `SseActor` on `start()`. -impl SseTransport { - pub fn new>(sse_url: S, env: HashMap) -> Self { - Self { - sse_url: sse_url.into(), - env, - } - } - - /// Waits for the endpoint to be set, up to 10 attempts. - async fn wait_for_endpoint( - post_endpoint: Arc>>, - ) -> Result { - // Check every 100ms for the endpoint, for up to 10 attempts - let check_interval = Duration::from_millis(100); - let mut attempts = 0; - let max_attempts = 10; - - while attempts < max_attempts { - if let Some(url) = post_endpoint.read().await.clone() { - return Ok(url); - } - tokio::time::sleep(check_interval).await; - attempts += 1; - } - Err(Error::SseConnection("No endpoint discovered".to_string())) - } -} - -#[async_trait] -impl Transport for SseTransport { - type Handle = SseTransportHandle; - - async fn start(&self) -> Result { - // Set environment variables - for (key, value) in &self.env { - std::env::set_var(key, value); - } - - // Create a channel for outgoing TransportMessages - let (tx, rx) = mpsc::channel(32); - let (otx, orx) = mpsc::channel(32); - - let post_endpoint: Arc>> = Arc::new(RwLock::new(None)); - let post_endpoint_clone = Arc::clone(&post_endpoint); - - // Build the actor - let actor = SseActor::new(rx, otx, self.sse_url.clone(), post_endpoint); - - // Spawn the actor task - tokio::spawn(actor.run()); - - // Wait for the endpoint to be discovered before returning the handle - match timeout( - Duration::from_secs(ENDPOINT_TIMEOUT_SECS), - Self::wait_for_endpoint(post_endpoint_clone), - ) - .await - { - Ok(_) => Ok(SseTransportHandle { - sender: tx, - receiver: Arc::new(Mutex::new(orx)), - }), - Err(e) => Err(Error::SseConnection(e.to_string())), - } - } - - async fn close(&self) -> Result<(), Error> { - // For SSE, you might close the stream or send a shutdown signal to the actor. - // Here, we do nothing special. - Ok(()) - } -} diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs deleted file mode 100644 index 225f245bbcd6..000000000000 --- a/crates/mcp-client/src/transport/stdio.rs +++ /dev/null @@ -1,319 +0,0 @@ -use std::collections::HashMap; -use std::sync::atomic::{AtomicI32, Ordering}; -use std::sync::Arc; -use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command}; - -use async_trait::async_trait; -use rmcp::model::JsonRpcMessage; -use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}; -use tokio::sync::{mpsc, Mutex}; - -// Import nix crate components instead of libc -#[cfg(unix)] -use nix::sys::signal::{kill, Signal}; -#[cfg(unix)] -use nix::unistd::{getpgid, Pid}; - -use crate::transport::TransportMessageRecv; - -use super::{serialize_and_send, Error, Transport, TransportHandle}; - -// Global to track process groups we've created -static PROCESS_GROUP: AtomicI32 = AtomicI32::new(-1); - -/// A `StdioTransport` uses a child process's stdin/stdout as a communication channel. -/// -/// It uses channels for message passing and handles responses asynchronously through a background task. -pub struct StdioActor { - receiver: Option>, - sender: Option>, - process: Child, // we store the process to keep it alive - error_sender: mpsc::Sender, - stdin: Option, - stdout: Option, - stderr: Option, -} - -impl Drop for StdioActor { - fn drop(&mut self) { - // Get the process group ID before attempting cleanup - #[cfg(unix)] - if let Some(pid) = self.process.id() { - if let Ok(pgid) = getpgid(Some(Pid::from_raw(pid as i32))) { - // Send SIGTERM to the entire process group - let _ = kill(Pid::from_raw(-pgid.as_raw()), Signal::SIGTERM); - // Give processes a moment to cleanup - std::thread::sleep(std::time::Duration::from_millis(100)); - // Force kill if still running - let _ = kill(Pid::from_raw(-pgid.as_raw()), Signal::SIGKILL); - } - } - } -} - -impl StdioActor { - pub async fn run(mut self) { - use tokio::pin; - - let stdout = self.stdout.take().expect("stdout should be available"); - let stdin = self.stdin.take().expect("stdin should be available"); - let msg_inbox = self.receiver.take().expect("receiver should be available"); - let msg_outbox = self.sender.take().expect("sender should be available"); - - let incoming = Self::handle_proc_output(stdout, msg_outbox); - let outgoing = Self::handle_proc_input(stdin, msg_inbox); - - // take ownership of futures for tokio::select - pin!(incoming); - pin!(outgoing); - - // Use select! to wait for either I/O completion or process exit - tokio::select! { - result = &mut incoming => { - tracing::debug!("Stdin handler completed: {:?}", result); - } - result = &mut outgoing => { - tracing::debug!("Stdout handler completed: {:?}", result); - } - // capture the status so we don't need to wait for a timeout - status = self.process.wait() => { - tracing::debug!("Process exited with status: {:?}", status); - } - } - - // Then always try to read stderr before cleaning up - let mut stderr_buffer = Vec::new(); - if let Some(mut stderr) = self.stderr.take() { - if let Ok(bytes) = stderr.read_to_end(&mut stderr_buffer).await { - let err_msg = if bytes > 0 { - String::from_utf8_lossy(&stderr_buffer).to_string() - } else { - "Process ended unexpectedly".to_string() - }; - - tracing::info!("Process stderr: {}", err_msg); - let _ = self - .error_sender - .send(Error::StdioProcessError(err_msg)) - .await; - } - } - } - - async fn handle_proc_output(stdout: ChildStdout, sender: mpsc::Sender) { - let mut reader = BufReader::new(stdout); - let mut line = String::new(); - loop { - match reader.read_line(&mut line).await { - Ok(0) => { - tracing::error!("Child process ended (EOF on stdout)"); - break; - } // EOF - Ok(_) => { - if let Ok(message) = serde_json::from_str::(&line) { - tracing::debug!( - message = ?message, - "Received incoming message" - ); - let _ = sender.send(message).await; - } else { - tracing::warn!( - message = ?line, - "Failed to parse incoming message" - ); - } - line.clear(); - } - Err(e) => { - tracing::error!(error = ?e, "Error reading line"); - break; - } - } - } - } - - async fn handle_proc_input(mut stdin: ChildStdin, mut receiver: mpsc::Receiver) { - while let Some(message_str) = receiver.recv().await { - tracing::debug!(message = ?message_str, "Sending outgoing message"); - - if let Err(e) = stdin.write_all(format!("{message_str}\n").as_bytes()).await { - tracing::error!(error = ?e, "Error writing message to child process"); - break; - } - - if let Err(e) = stdin.flush().await { - tracing::error!(error = ?e, "Error flushing message to child process"); - break; - } - } - } -} - -#[derive(Clone)] -pub struct StdioTransportHandle { - sender: mpsc::Sender, // to process - receiver: Arc>>, // from process - error_receiver: Arc>>, -} - -#[async_trait::async_trait] -impl TransportHandle for StdioTransportHandle { - async fn send(&self, message: JsonRpcMessage) -> Result<(), Error> { - let result = serialize_and_send(&self.sender, message).await; - // Check for any pending errors even if send is successful - self.check_for_errors().await?; - result - } - - async fn receive(&self) -> Result { - let mut receiver = self.receiver.lock().await; - match receiver.recv().await { - Some(message) => Ok(message), - None => { - self.check_for_errors().await?; - Err(Error::ChannelClosed) - } - } - } -} - -impl StdioTransportHandle { - /// Check if there are any process errors - pub async fn check_for_errors(&self) -> Result<(), Error> { - match self.error_receiver.lock().await.try_recv() { - Ok(error) => { - tracing::debug!("Found error: {:?}", error); - Err(error) - } - Err(_) => Ok(()), - } - } -} - -pub struct StdioTransport { - command: String, - args: Vec, - env: HashMap, -} - -impl StdioTransport { - pub fn new>( - command: S, - args: Vec, - env: HashMap, - ) -> Self { - Self { - command: command.into(), - args, - env, - } - } - - async fn spawn_process(&self) -> Result<(Child, ChildStdin, ChildStdout, ChildStderr), Error> { - let mut command = Command::new(&self.command); - command - .envs(&self.env) - .args(&self.args) - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::piped()) - .kill_on_drop(true); - - // Set process group and ensure signal handling on Unix systems - #[cfg(unix)] - command.process_group(0); - - // Hide console window on Windows - #[cfg(windows)] - command.creation_flags(0x08000000); // CREATE_NO_WINDOW flag - - let mut process = command.spawn().map_err(|e| { - let command = command.into_std(); - Error::StdioProcessError(format!( - "Could not run extension command (`{} {}`): {}", - command - .get_program() - .to_str() - .unwrap_or("[invalid command]"), - command - .get_args() - .map(|arg| arg.to_str().unwrap_or("[invalid arg]")) - .collect::>() - .join(" "), - e - )) - })?; - - let stdin = process - .stdin - .take() - .ok_or_else(|| Error::StdioProcessError("Failed to get stdin".into()))?; - - let stdout = process - .stdout - .take() - .ok_or_else(|| Error::StdioProcessError("Failed to get stdout".into()))?; - - let stderr = process - .stderr - .take() - .ok_or_else(|| Error::StdioProcessError("Failed to get stderr".into()))?; - - // Store the process group ID for cleanup - #[cfg(unix)] - if let Some(pid) = process.id() { - // Use nix instead of unsafe libc calls - if let Ok(pgid) = getpgid(Some(Pid::from_raw(pid as i32))) { - PROCESS_GROUP.store(pgid.as_raw(), Ordering::SeqCst); - } - } - - Ok((process, stdin, stdout, stderr)) - } -} - -#[async_trait] -impl Transport for StdioTransport { - type Handle = StdioTransportHandle; - - async fn start(&self) -> Result { - let (process, stdin, stdout, stderr) = self.spawn_process().await?; - let (outbox_tx, outbox_rx) = mpsc::channel(32); - let (inbox_tx, inbox_rx) = mpsc::channel(32); - let (error_tx, error_rx) = mpsc::channel(1); - - let actor = StdioActor { - receiver: Some(outbox_rx), // client to process - sender: Some(inbox_tx), // process to client - process, - error_sender: error_tx, - stdin: Some(stdin), - stdout: Some(stdout), - stderr: Some(stderr), - }; - - tokio::spawn(actor.run()); - - let handle = StdioTransportHandle { - sender: outbox_tx, // client to process - receiver: Arc::new(Mutex::new(inbox_rx)), // process to client - error_receiver: Arc::new(Mutex::new(error_rx)), - }; - Ok(handle) - } - - async fn close(&self) -> Result<(), Error> { - // Attempt to clean up the process group on close - #[cfg(unix)] - if let Some(pgid) = PROCESS_GROUP.load(Ordering::SeqCst).checked_abs() { - // Use nix instead of unsafe libc calls - // Try SIGTERM first - let _ = kill(Pid::from_raw(-pgid), Signal::SIGTERM); - // Give processes a moment to cleanup - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - // Force kill if still running - let _ = kill(Pid::from_raw(-pgid), Signal::SIGKILL); - } - Ok(()) - } -} diff --git a/crates/mcp-client/src/transport/streamable_http.rs b/crates/mcp-client/src/transport/streamable_http.rs deleted file mode 100644 index 77c534fc0071..000000000000 --- a/crates/mcp-client/src/transport/streamable_http.rs +++ /dev/null @@ -1,1001 +0,0 @@ -use crate::oauth::{authenticate_service, ServiceConfig}; -use crate::transport::{Error, TransportMessageRecv}; -use async_trait::async_trait; -use eventsource_client::{Client, SSE}; -use futures::TryStreamExt; -use reqwest::Client as HttpClient; -use rmcp::model::{JsonRpcMessage, JsonRpcRequest, NumberOrString::Number}; -use std::collections::HashMap; -use std::sync::Arc; -use tokio::sync::{mpsc, Mutex, RwLock}; -use tokio::time::Duration; -use tracing::{debug, error, info, warn}; -use url::Url; - -use super::{serialize_and_send, Transport, TransportHandle}; - -// Default timeout for HTTP requests -const HTTP_TIMEOUT_SECS: u64 = 30; - -/// The Streamable HTTP transport actor that handles: -/// - HTTP POST requests to send messages to the server -/// - Optional streaming responses for receiving multiple responses and server-initiated messages -/// - Session management with session IDs -pub struct StreamableHttpActor { - /// Receives messages (requests/notifications) from the handle - receiver: mpsc::Receiver, - /// Sends messages (responses) back to the handle - sender: mpsc::Sender, - /// MCP endpoint URL - mcp_endpoint: String, - /// HTTP client for sending requests - http_client: HttpClient, - /// Optional session ID for stateful connections - session_id: Arc>>, - /// Environment variables to set - env: HashMap, - /// Custom headers to include in requests - headers: HashMap, -} - -impl StreamableHttpActor { - pub fn new( - receiver: mpsc::Receiver, - sender: mpsc::Sender, - mcp_endpoint: String, - session_id: Arc>>, - env: HashMap, - headers: HashMap, - ) -> Self { - Self { - receiver, - sender, - mcp_endpoint, - http_client: HttpClient::builder() - .timeout(Duration::from_secs(HTTP_TIMEOUT_SECS)) - .build() - .unwrap(), - session_id, - env, - headers, - } - } - - /// Main entry point for the actor - pub async fn run(mut self) { - // Set environment variables - for (key, value) in &self.env { - std::env::set_var(key, value); - } - - // Handle outgoing messages - while let Some(message_str) = self.receiver.recv().await { - if let Err(e) = self.handle_outgoing_message(message_str).await { - error!("Error handling outgoing message: {}", e); - break; - } - } - - debug!("StreamableHttpActor shut down"); - } - - /// Handle an outgoing message by sending it via HTTP POST - async fn handle_outgoing_message(&mut self, message_str: String) -> Result<(), Error> { - debug!("Sending message to MCP endpoint: {}", message_str); - - // Parse the message to determine if it's a request that expects a response - let parsed_message = - serde_json::from_str::(&message_str).map_err(Error::Serialization)?; - - let expects_response = matches!( - parsed_message, - JsonRpcMessage::Request(JsonRpcRequest { id: Number(_), .. }) - ); - - // Try to send the request - match self.send_request(&message_str, expects_response).await { - Ok(()) => Ok(()), - Err(Error::HttpError { status, .. }) if status == 401 || status == 403 => { - // Authentication challenge - try to authenticate and retry - info!( - "Received authentication challenge ({}), attempting OAuth flow...", - status - ); - - if let Some(token) = self.attempt_authentication().await? { - info!("Authentication successful, retrying request..."); - self.headers - .insert("Authorization".to_string(), format!("Bearer {}", token)); - self.send_request(&message_str, expects_response).await - } else { - Err(Error::StreamableHttpError( - "Authentication failed - service not supported or OAuth flow failed" - .to_string(), - )) - } - } - Err(e) => Err(e), - } - } - - /// Send an HTTP request to the MCP endpoint - async fn send_request( - &mut self, - message_str: &str, - expects_response: bool, - ) -> Result<(), Error> { - // Build the HTTP request - let mut request = self - .http_client - .post(&self.mcp_endpoint) - .header("Content-Type", "application/json") - .header("Accept", "application/json, text/event-stream") - .header("MCP-Protocol-Version", "2025-06-18") // Required protocol version header - .body(message_str.to_string()); - - // Add session ID header if we have one - if let Some(session_id) = self.session_id.read().await.as_ref() { - request = request.header("Mcp-Session-Id", session_id); - } - - // Add custom headers - for (key, value) in &self.headers { - request = request.header(key, value); - } - - // Send the request - let response = request - .send() - .await - .map_err(|e| Error::StreamableHttpError(format!("HTTP request failed: {}", e)))?; - - // Handle HTTP error status codes - if !response.status().is_success() { - let status = response.status(); - if status.as_u16() == 404 { - // Session not found - clear our session ID - *self.session_id.write().await = None; - return Err(Error::SessionError( - "Session expired or not found".to_string(), - )); - } - let error_text = response - .text() - .await - .unwrap_or_else(|_| "Unknown error".to_string()); - return Err(Error::HttpError { - status: status.as_u16(), - message: error_text, - }); - } - - // Check for session ID in response headers - if let Some(session_id_header) = response.headers().get("Mcp-Session-Id") { - if let Ok(session_id) = session_id_header.to_str() { - debug!("Received session ID: {}", session_id); - *self.session_id.write().await = Some(session_id.to_string()); - } - } - - // Handle the response based on content type - let content_type = response - .headers() - .get("content-type") - .and_then(|h| h.to_str().ok()) - .unwrap_or(""); - - if content_type.starts_with("text/event-stream") { - // Handle streaming HTTP response (server chose to stream multiple messages back) - if expects_response { - self.handle_streaming_response(response).await?; - } - } else if content_type.starts_with("application/json") || expects_response { - // Handle single JSON response - let response_text = response.text().await.map_err(|e| { - Error::StreamableHttpError(format!("Failed to read response: {}", e)) - })?; - - if !response_text.is_empty() { - let json_message = serde_json::from_str::(&response_text) - .map_err(Error::Serialization)?; - - let _ = self.sender.send(json_message).await; - } - } - // For notifications and responses, we get 202 Accepted with no body - - Ok(()) - } - - /// Attempt to authenticate with the service - async fn attempt_authentication(&self) -> Result, Error> { - info!("Attempting to authenticate with service..."); - - // Create a generic OAuth configuration from the MCP endpoint - match ServiceConfig::from_mcp_endpoint(&self.mcp_endpoint) { - Ok(config) => { - info!("Created OAuth config for endpoint: {}", self.mcp_endpoint); - - match authenticate_service(config, &self.mcp_endpoint).await { - Ok(token) => { - info!("OAuth authentication successful!"); - Ok(Some(token)) - } - Err(e) => { - warn!("OAuth authentication failed: {}", e); - Err(Error::StreamableHttpError(format!("OAuth failed: {}", e))) - } - } - } - Err(e) => { - warn!( - "Could not create OAuth config from MCP endpoint {}: {}", - self.mcp_endpoint, e - ); - Ok(None) - } - } - } - - /// Handle streaming HTTP response that uses Server-Sent Events format - /// - /// This is called when the server responds to an HTTP POST with `text/event-stream` - /// content-type, indicating it wants to stream multiple JSON-RPC messages back - /// rather than sending a single response. This is part of the Streamable HTTP - /// specification, not a separate SSE transport. - async fn handle_streaming_response( - &mut self, - response: reqwest::Response, - ) -> Result<(), Error> { - use futures::StreamExt; - use tokio::io::AsyncBufReadExt; - use tokio_util::io::StreamReader; - - // Convert the response body to a stream reader - let stream = response - .bytes_stream() - .map(|result| result.map_err(std::io::Error::other)); - let reader = StreamReader::new(stream); - let mut lines = tokio::io::BufReader::new(reader).lines(); - - let mut event_type = String::new(); - let mut event_data = String::new(); - let mut event_id = String::new(); - - while let Ok(Some(line)) = lines.next_line().await { - if line.is_empty() { - // Empty line indicates end of event - if !event_data.is_empty() { - // Parse the streamed data as JSON-RPC message - match serde_json::from_str::(&event_data) { - Ok(message) => { - debug!("Received streaming HTTP response message: {:?}", message); - let _ = self.sender.send(message).await; - } - Err(err) => { - warn!("Failed to parse streaming HTTP response message: {}", err); - } - } - } - // Reset for next event - event_type.clear(); - event_data.clear(); - event_id.clear(); - } else if let Some(field_data) = line.strip_prefix("data: ") { - if !event_data.is_empty() { - event_data.push('\n'); - } - event_data.push_str(field_data); - } else if let Some(field_data) = line.strip_prefix("event: ") { - event_type = field_data.to_string(); - } else if let Some(field_data) = line.strip_prefix("id: ") { - event_id = field_data.to_string(); - } - // Ignore other fields (retry, etc.) - we only care about data - } - - Ok(()) - } -} - -#[derive(Clone)] -pub struct StreamableHttpTransportHandle { - sender: mpsc::Sender, - receiver: Arc>>, - session_id: Arc>>, - mcp_endpoint: String, - http_client: HttpClient, - headers: HashMap, -} - -#[async_trait::async_trait] -impl TransportHandle for StreamableHttpTransportHandle { - async fn send(&self, message: JsonRpcMessage) -> Result<(), Error> { - serialize_and_send(&self.sender, message).await - } - - async fn receive(&self) -> Result { - let mut receiver = self.receiver.lock().await; - receiver.recv().await.ok_or(Error::ChannelClosed) - } -} - -impl StreamableHttpTransportHandle { - /// Manually terminate the session by sending HTTP DELETE - pub async fn terminate_session(&self) -> Result<(), Error> { - if let Some(session_id) = self.session_id.read().await.as_ref() { - let mut request = self - .http_client - .delete(&self.mcp_endpoint) - .header("Mcp-Session-Id", session_id) - .header("MCP-Protocol-Version", "2025-06-18"); // Required protocol version header - - // Add custom headers - for (key, value) in &self.headers { - request = request.header(key, value); - } - - match request.send().await { - Ok(response) => { - if response.status().as_u16() == 405 { - // Method not allowed - server doesn't support session termination - debug!("Server doesn't support session termination"); - } - } - Err(e) => { - warn!("Failed to terminate session: {}", e); - } - } - } - Ok(()) - } - - /// Create a GET request to establish a streaming connection for server-initiated messages - pub async fn listen_for_server_messages(&self) -> Result<(), Error> { - let mut request = self - .http_client - .get(&self.mcp_endpoint) - .header("Accept", "text/event-stream") - .header("MCP-Protocol-Version", "2025-06-18"); // Required protocol version header - - // Add session ID header if we have one - if let Some(session_id) = self.session_id.read().await.as_ref() { - request = request.header("Mcp-Session-Id", session_id); - } - - // Add custom headers - for (key, value) in &self.headers { - request = request.header(key, value); - } - - let response = request.send().await.map_err(|e| { - Error::StreamableHttpError(format!("Failed to start GET streaming connection: {}", e)) - })?; - - if !response.status().is_success() { - if response.status().as_u16() == 405 { - // Method not allowed - server doesn't support GET streaming connections - debug!("Server doesn't support GET streaming connections"); - return Ok(()); - } - return Err(Error::HttpError { - status: response.status().as_u16(), - message: "Failed to establish GET streaming connection".to_string(), - }); - } - - // Handle the streaming connection in a separate task - let receiver = self.receiver.clone(); - let url = response.url().clone(); - - tokio::spawn(async move { - let client = match eventsource_client::ClientBuilder::for_url(url.as_str()) { - Ok(builder) => builder.build(), - Err(e) => { - error!( - "Failed to create streaming client for GET connection: {}", - e - ); - return; - } - }; - - let mut stream = client.stream(); - while let Ok(Some(event)) = stream.try_next().await { - match event { - SSE::Event(e) if e.event_type == "message" || e.event_type.is_empty() => { - match serde_json::from_str::(&e.data) { - Ok(message) => { - debug!("Received GET streaming message: {:?}", message); - let receiver_guard = receiver.lock().await; - // We can't send through the receiver since it's for outbound messages - // This would need a different channel for server-initiated messages - drop(receiver_guard); - } - Err(err) => { - warn!("Failed to parse GET streaming message: {}", err); - } - } - } - _ => {} - } - } - }); - - Ok(()) - } -} - -#[derive(Clone)] -pub struct StreamableHttpTransport { - mcp_endpoint: String, - env: HashMap, - headers: HashMap, -} - -impl StreamableHttpTransport { - pub fn new>(mcp_endpoint: S, env: HashMap) -> Self { - Self { - mcp_endpoint: mcp_endpoint.into(), - env, - headers: HashMap::new(), - } - } - - pub fn with_headers>( - mcp_endpoint: S, - env: HashMap, - headers: HashMap, - ) -> Self { - Self { - mcp_endpoint: mcp_endpoint.into(), - env, - headers, - } - } - - /// Validate that the URL is a valid MCP endpoint - pub fn validate_endpoint(endpoint: &str) -> Result<(), Error> { - Url::parse(endpoint) - .map_err(|e| Error::StreamableHttpError(format!("Invalid MCP endpoint URL: {}", e)))?; - Ok(()) - } -} - -#[async_trait] -impl Transport for StreamableHttpTransport { - type Handle = StreamableHttpTransportHandle; - - async fn start(&self) -> Result { - // Validate the endpoint URL - Self::validate_endpoint(&self.mcp_endpoint)?; - - // Create channels for communication - let (tx, rx) = mpsc::channel(32); - let (otx, orx) = mpsc::channel(32); - - let session_id: Arc>> = Arc::new(RwLock::new(None)); - let session_id_clone = Arc::clone(&session_id); - - // Create and spawn the actor - let actor = StreamableHttpActor::new( - rx, - otx, - self.mcp_endpoint.clone(), - session_id, - self.env.clone(), - self.headers.clone(), - ); - - tokio::spawn(actor.run()); - - // Create the handle - let handle = StreamableHttpTransportHandle { - sender: tx, - receiver: Arc::new(Mutex::new(orx)), - session_id: session_id_clone, - mcp_endpoint: self.mcp_endpoint.clone(), - http_client: HttpClient::builder() - .timeout(Duration::from_secs(HTTP_TIMEOUT_SECS)) - .build() - .unwrap(), - headers: self.headers.clone(), - }; - - Ok(handle) - } - - async fn close(&self) -> Result<(), Error> { - // The transport is closed when the actor task completes - // No additional cleanup needed - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use mockito::Server; - use serde_json::json; - use std::collections::HashMap; - use std::sync::Arc; - use tokio::sync::mpsc; - use tokio::sync::RwLock; - - #[test] - fn test_message_parsing_request() { - // Test that we can parse a JSON-RPC request message using the mcp-core types - let request_json = json!({ - "jsonrpc": "2.0", - "id": 1, - "method": "initialize", - "params": { - "capabilities": {} - } - }); - - let message_str = serde_json::to_string(&request_json).unwrap(); - let parsed_message = serde_json::from_str::(&message_str); - assert!( - parsed_message.is_ok(), - "Should be able to parse JSON-RPC request message" - ); - } - - #[test] - fn test_message_parsing_response() { - // Test that we can parse a JSON-RPC response message - let response_json = json!({ - "jsonrpc": "2.0", - "id": 1, - "result": { - "capabilities": {} - } - }); - - let message_str = serde_json::to_string(&response_json).unwrap(); - let parsed_message = serde_json::from_str::(&message_str); - assert!( - parsed_message.is_ok(), - "Should be able to parse JSON-RPC response message" - ); - } - - #[test] - fn test_message_parsing_notification() { - // Test that we can parse a JSON-RPC notification message - let notification_json = json!({ - "jsonrpc": "2.0", - "method": "initialized", - "params": {} - }); - - let message_str = serde_json::to_string(¬ification_json).unwrap(); - let parsed_message = serde_json::from_str::(&message_str); - assert!( - parsed_message.is_ok(), - "Should be able to parse JSON-RPC notification message" - ); - } - - #[test] - fn test_message_parsing_error() { - // Test that we can parse a JSON-RPC error message - let error_json = json!({ - "jsonrpc": "2.0", - "id": 1, - "error": { - "code": -32600, - "message": "Invalid Request" - } - }); - - let message_str = serde_json::to_string(&error_json).unwrap(); - let parsed_message = serde_json::from_str::(&message_str); - assert!( - parsed_message.is_ok(), - "Should be able to parse JSON-RPC error message" - ); - } - - #[test] - fn test_message_parsing_invalid_json() { - let invalid_json = "{ invalid json }"; - let parsed_message = serde_json::from_str::(invalid_json); - assert!(parsed_message.is_err(), "Invalid JSON should fail to parse"); - } - - #[test] - fn test_transport_message_recv_parsing() { - // Test that we can parse messages as TransportMessageRecv (the type used for incoming messages) - let response_json = json!({ - "jsonrpc": "2.0", - "id": 1, - "result": { - "capabilities": {} - } - }); - - let message_str = serde_json::to_string(&response_json).unwrap(); - - // For incoming messages - let parsed_message = serde_json::from_str::(&message_str); - assert!( - parsed_message.is_ok(), - "Should be able to parse response as TransportMessageRecv" - ); - } - - #[test] - fn test_untagged_enum_serialization_issue() { - let request_json = json!({ - "jsonrpc": "2.0", - "id": 1, - "method": "tools/list", - "params": {} - }); - - let message_str = serde_json::to_string(&request_json).unwrap(); - - let parsed_as_jsonrpc = serde_json::from_str::(&message_str); - assert!( - parsed_as_jsonrpc.is_ok(), - "Should be able to parse request as JsonRpcMessage" - ); - } - - #[test] - fn test_expects_response_logic_with_number_id() { - // Check if a message expects a response - let request_json = json!({ - "jsonrpc": "2.0", - "id": 1, - "method": "initialize", - "params": {} - }); - - let message_str = serde_json::to_string(&request_json).unwrap(); - let parsed_message = serde_json::from_str::(&message_str).unwrap(); - - // This should match the logic in handle_outgoing_message after the fix - // The original code used: JsonRpcMessage::Request(JsonRpcRequest { id: Number(_), .. }) - let expects_response = match parsed_message { - JsonRpcMessage::Request(_) => true, - _ => false, - }; - - assert!(expects_response, "Request with ID should expect a response"); - - // Test notification (should not expect response) - let notification_json = json!({ - "jsonrpc": "2.0", - "method": "initialized", - "params": {} - }); - - let message_str = serde_json::to_string(¬ification_json).unwrap(); - let parsed_message = serde_json::from_str::(&message_str).unwrap(); - - let expects_response = match parsed_message { - JsonRpcMessage::Request(_) => true, - _ => false, - }; - - assert!( - !expects_response, - "Notification should not expect a response" - ); - } - - #[tokio::test] - async fn test_handle_outgoing_message_successful_request() { - // Set up a mock HTTP server - let mut server = Server::new_async().await; - let mock = server - .mock("POST", "/") - .with_status(200) - .with_header("content-type", "application/json") - .with_body(r#"{"jsonrpc":"2.0","id":1,"result":{"capabilities":{}}}"#) - .create_async() - .await; - - // Create channels for the actor - let (_tx, rx) = mpsc::channel(32); - let (otx, mut orx) = mpsc::channel(32); - - // Create the actor - let session_id = Arc::new(RwLock::new(None)); - let mut actor = StreamableHttpActor::new( - rx, - otx, - server.url(), - session_id, - HashMap::new(), - HashMap::new(), - ); - - // Create a JSON-RPC request message - let request_json = json!({ - "jsonrpc": "2.0", - "id": 1, - "method": "initialize", - "params": { - "capabilities": {} - } - }); - let message_str = serde_json::to_string(&request_json).unwrap(); - - // Test handle_outgoing_message - let result = actor.handle_outgoing_message(message_str).await; - assert!(result.is_ok(), "handle_outgoing_message should succeed"); - - // Verify the mock was called - mock.assert_async().await; - - // Check that a response was received - let response = - tokio::time::timeout(std::time::Duration::from_millis(100), orx.recv()).await; - assert!(response.is_ok(), "Should receive a response"); - assert!(response.unwrap().is_some(), "Response should not be None"); - } - - #[tokio::test] - async fn test_handle_outgoing_message_notification() { - // Set up a mock HTTP server for notifications (202 Accepted, no body) - let mut server = Server::new_async().await; - let mock = server - .mock("POST", "/") - .with_status(202) - .create_async() - .await; - - // Create channels for the actor - let (_tx, rx) = mpsc::channel(32); - let (otx, mut orx) = mpsc::channel(32); - - // Create the actor - let session_id = Arc::new(RwLock::new(None)); - let mut actor = StreamableHttpActor::new( - rx, - otx, - server.url(), - session_id, - HashMap::new(), - HashMap::new(), - ); - - // Create a JSON-RPC notification message (no id) - let notification_json = json!({ - "jsonrpc": "2.0", - "method": "initialized", - "params": {} - }); - let message_str = serde_json::to_string(¬ification_json).unwrap(); - - // Test handle_outgoing_message - let result = actor.handle_outgoing_message(message_str).await; - assert!( - result.is_ok(), - "handle_outgoing_message should succeed for notification" - ); - - // Verify the mock was called - mock.assert_async().await; - - // For notifications, we shouldn't receive a response - let response = - tokio::time::timeout(std::time::Duration::from_millis(100), orx.recv()).await; - assert!( - response.is_err(), - "Should not receive a response for notification" - ); - } - - #[tokio::test] - async fn test_handle_outgoing_message_http_error() { - // Set up a mock HTTP server that returns an error - let mut server = Server::new_async().await; - let mock = server - .mock("POST", "/") - .with_status(500) - .with_body("Internal Server Error") - .create_async() - .await; - - // Create channels for the actor - let (_tx, rx) = mpsc::channel(32); - let (otx, _orx) = mpsc::channel(32); - - // Create the actor - let session_id = Arc::new(RwLock::new(None)); - let mut actor = StreamableHttpActor::new( - rx, - otx, - server.url(), - session_id, - HashMap::new(), - HashMap::new(), - ); - - // Create a JSON-RPC request message - let request_json = json!({ - "jsonrpc": "2.0", - "id": 1, - "method": "test", - "params": {} - }); - let message_str = serde_json::to_string(&request_json).unwrap(); - - // Test handle_outgoing_message - let result = actor.handle_outgoing_message(message_str).await; - assert!( - result.is_err(), - "handle_outgoing_message should fail with HTTP error" - ); - - // Verify it's an HTTP error - match result.unwrap_err() { - Error::HttpError { status, .. } => { - assert_eq!(status, 500, "Should return HTTP 500 error"); - } - _ => panic!("Expected HttpError"), - } - - // Verify the mock was called - mock.assert_async().await; - } - - #[tokio::test] - async fn test_handle_outgoing_message_session_id_handling() { - // Set up a mock HTTP server that returns a session ID - let mut server = Server::new_async().await; - let mock = server - .mock("POST", "/") - .with_status(200) - .with_header("content-type", "application/json") - .with_header("Mcp-Session-Id", "test-session-123") - .with_body(r#"{"jsonrpc":"2.0","id":1,"result":{}}"#) - .create_async() - .await; - - // Create channels for the actor - let (_tx, rx) = mpsc::channel(32); - let (otx, _orx) = mpsc::channel(32); - - // Create the actor - let session_id = Arc::new(RwLock::new(None)); - let session_id_clone = Arc::clone(&session_id); - let mut actor = StreamableHttpActor::new( - rx, - otx, - server.url(), - session_id, - HashMap::new(), - HashMap::new(), - ); - - // Create a JSON-RPC request message - let request_json = json!({ - "jsonrpc": "2.0", - "id": 1, - "method": "initialize", - "params": {} - }); - let message_str = serde_json::to_string(&request_json).unwrap(); - - // Test handle_outgoing_message - let result = actor.handle_outgoing_message(message_str).await; - assert!(result.is_ok(), "handle_outgoing_message should succeed"); - - // Verify the session ID was stored - let stored_session_id = session_id_clone.read().await; - assert_eq!( - stored_session_id.as_ref(), - Some(&"test-session-123".to_string()), - "Session ID should be stored" - ); - - // Verify the mock was called - mock.assert_async().await; - } - - #[tokio::test] - async fn test_handle_outgoing_message_invalid_json() { - // Create channels for the actor - let (_tx, rx) = mpsc::channel(32); - let (otx, _orx) = mpsc::channel(32); - - // Create the actor - let session_id = Arc::new(RwLock::new(None)); - let mut actor = StreamableHttpActor::new( - rx, - otx, - "http://localhost:8080".to_string(), - session_id, - HashMap::new(), - HashMap::new(), - ); - - // Test with invalid JSON - let invalid_json = "{ invalid json }"; - - // Test handle_outgoing_message - let result = actor - .handle_outgoing_message(invalid_json.to_string()) - .await; - assert!( - result.is_err(), - "handle_outgoing_message should fail with invalid JSON" - ); - - // Verify it's a serialization error - match result.unwrap_err() { - Error::Serialization(_) => { - // Expected error type - } - _ => panic!("Expected Serialization error"), - } - } - - #[tokio::test] - async fn test_handle_outgoing_message_session_not_found() { - // Set up a mock HTTP server that returns 404 (session not found) - let mut server = Server::new_async().await; - let mock = server - .mock("POST", "/") - .with_status(404) - .with_body("Session not found") - .create_async() - .await; - - // Create channels for the actor - let (_tx, rx) = mpsc::channel(32); - let (otx, _orx) = mpsc::channel(32); - - // Create the actor with an existing session ID - let session_id = Arc::new(RwLock::new(Some("old-session".to_string()))); - let session_id_clone = Arc::clone(&session_id); - let mut actor = StreamableHttpActor::new( - rx, - otx, - server.url(), - session_id, - HashMap::new(), - HashMap::new(), - ); - - // Create a JSON-RPC request message - let request_json = json!({ - "jsonrpc": "2.0", - "id": 1, - "method": "test", - "params": {} - }); - let message_str = serde_json::to_string(&request_json).unwrap(); - - // Test handle_outgoing_message - let result = actor.handle_outgoing_message(message_str).await; - assert!( - result.is_err(), - "handle_outgoing_message should fail with 404" - ); - - // Verify it's a session error and the session ID was cleared - match result.unwrap_err() { - Error::SessionError(_) => { - // Expected error type - } - _ => panic!("Expected SessionError"), - } - - // Verify the session ID was cleared - let stored_session_id = session_id_clone.read().await; - assert!( - stored_session_id.is_none(), - "Session ID should be cleared on 404" - ); - - // Verify the mock was called - mock.assert_async().await; - } -} diff --git a/crates/mcp-core/src/protocol.rs b/crates/mcp-core/src/protocol.rs index 031a880f221c..fa6ca09eb4f5 100644 --- a/crates/mcp-core/src/protocol.rs +++ b/crates/mcp-core/src/protocol.rs @@ -1,124 +1,7 @@ /// The protocol messages exchanged between client and server use rmcp::model::Tool; -use rmcp::model::{Content, ErrorData, Prompt, PromptMessage, Resource, ResourceContents}; +use rmcp::model::{Content, Prompt, PromptMessage, Resource, ResourceContents}; use serde::{Deserialize, Serialize}; -use serde_json::Value; - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct JsonRpcRequest { - pub jsonrpc: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub id: Option, - pub method: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub params: Option, -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct JsonRpcResponse { - pub jsonrpc: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub result: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct JsonRpcNotification { - pub jsonrpc: String, - pub method: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub params: Option, -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct JsonRpcError { - pub jsonrpc: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub id: Option, - pub error: ErrorData, -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -#[serde(untagged, try_from = "JsonRpcRaw")] -pub enum JsonRpcMessage { - Request(JsonRpcRequest), - Response(JsonRpcResponse), - Notification(JsonRpcNotification), - Error(JsonRpcError), - Nil, // used to respond to notifications -} - -#[derive(Debug, Serialize, Deserialize)] -struct JsonRpcRaw { - jsonrpc: String, - #[serde(skip_serializing_if = "Option::is_none")] - id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - method: Option, - #[serde(skip_serializing_if = "Option::is_none")] - params: Option, - #[serde(skip_serializing_if = "Option::is_none")] - result: Option, - #[serde(skip_serializing_if = "Option::is_none")] - error: Option, -} - -impl TryFrom for JsonRpcMessage { - type Error = String; - - fn try_from(raw: JsonRpcRaw) -> Result>::Error> { - // If it has an error field, it's an error response - if raw.error.is_some() { - return Ok(JsonRpcMessage::Error(JsonRpcError { - jsonrpc: raw.jsonrpc, - id: raw.id, - error: raw.error.unwrap(), - })); - } - - // If it has a result field, it's a response - if raw.result.is_some() { - return Ok(JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: raw.jsonrpc, - id: raw.id, - result: raw.result, - error: None, - })); - } - - // If we have a method, it's either a notification or request - if let Some(method) = raw.method { - if raw.id.is_none() { - return Ok(JsonRpcMessage::Notification(JsonRpcNotification { - jsonrpc: raw.jsonrpc, - method, - params: raw.params, - })); - } - - return Ok(JsonRpcMessage::Request(JsonRpcRequest { - jsonrpc: raw.jsonrpc, - id: raw.id, - method, - params: raw.params, - })); - } - - // If we have no method and no result/error, it's a nil response - if raw.id.is_none() && raw.result.is_none() && raw.error.is_none() { - return Ok(JsonRpcMessage::Nil); - } - - // If we get here, something is wrong with the message - Err(format!( - "Invalid JSON-RPC message format: id={:?}, method={:?}, result={:?}, error={:?}", - raw.id, raw.method, raw.result, raw.error - )) - } -} // Standard JSON-RPC error codes pub const PARSE_ERROR: i32 = -32700; @@ -216,54 +99,3 @@ pub struct GetPromptResult { #[derive(Debug, Serialize, Deserialize)] pub struct EmptyResult {} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - #[test] - fn test_notification_conversion() { - let raw = JsonRpcRaw { - jsonrpc: "2.0".to_string(), - id: None, - method: Some("notify".to_string()), - params: Some(json!({"key": "value"})), - result: None, - error: None, - }; - - let message = JsonRpcMessage::try_from(raw).unwrap(); - match message { - JsonRpcMessage::Notification(n) => { - assert_eq!(n.jsonrpc, "2.0"); - assert_eq!(n.method, "notify"); - assert_eq!(n.params.unwrap(), json!({"key": "value"})); - } - _ => panic!("Expected Notification"), - } - } - - #[test] - fn test_request_conversion() { - let raw = JsonRpcRaw { - jsonrpc: "2.0".to_string(), - id: Some(1), - method: Some("request".to_string()), - params: Some(json!({"key": "value"})), - result: None, - error: None, - }; - - let message = JsonRpcMessage::try_from(raw).unwrap(); - match message { - JsonRpcMessage::Request(r) => { - assert_eq!(r.jsonrpc, "2.0"); - assert_eq!(r.id, Some(1)); - assert_eq!(r.method, "request"); - assert_eq!(r.params.unwrap(), json!({"key": "value"})); - } - _ => panic!("Expected Request"), - } - } -} diff --git a/ui/desktop/src/bin/npx b/ui/desktop/src/bin/npx index 5c1d1bef0b65..92c361e55c65 100755 --- a/ui/desktop/src/bin/npx +++ b/ui/desktop/src/bin/npx @@ -12,7 +12,7 @@ LOG_FILE="/tmp/mcp.log" # Function for logging log() { local MESSAGE="$1" - echo "$(date +'%Y-%m-%d %H:%M:%S') - $MESSAGE" | tee -a "$LOG_FILE" + echo "$(date +'%Y-%m-%d %H:%M:%S') - $MESSAGE" | tee -a "$LOG_FILE" >&2 } # Trap errors and log them before exiting diff --git a/ui/desktop/src/bin/uvx b/ui/desktop/src/bin/uvx index 8a1eec121345..b0f1bdbed643 100755 --- a/ui/desktop/src/bin/uvx +++ b/ui/desktop/src/bin/uvx @@ -12,7 +12,7 @@ LOG_FILE="/tmp/mcp.log" # Function for logging log() { local MESSAGE="$1" - echo "$(date +'%Y-%m-%d %H:%M:%S') - $MESSAGE" | tee -a "$LOG_FILE" + echo "$(date +'%Y-%m-%d %H:%M:%S') - $MESSAGE" | tee -a "$LOG_FILE" >&2 } # Trap errors and log them before exiting