From 9e2cadc3723476985f1500570d6659b8859bb32d Mon Sep 17 00:00:00 2001 From: Jack Amadeo Date: Thu, 31 Jul 2025 15:24:00 -0400 Subject: [PATCH 1/7] Add cancellation to mcp methods --- crates/mcp-client/src/client.rs | 167 +++++++++++++++++++++----------- 1 file changed, 109 insertions(+), 58 deletions(-) diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index e1253855162c..816562ff6a74 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -8,7 +8,9 @@ use rmcp::{ ProgressNotificationMethod, ProtocolVersion, ReadResourceRequest, ReadResourceRequestParam, ReadResourceResult, ServerNotification, ServerResult, }, - service::{ClientInitializeError, PeerRequestOptions, RunningService}, + service::{ + ClientInitializeError, PeerRequestOptions, RequestHandle, RunningService, ServiceRole, + }, transport::IntoTransport, ClientHandler, RoleClient, ServiceError, ServiceExt, }; @@ -18,6 +20,7 @@ use tokio::sync::{ mpsc::{self, Sender}, Mutex, }; +use tokio_util::sync::CancellationToken; pub type BoxError = Box; @@ -28,17 +31,40 @@ pub trait McpClientTrait: Send + Sync { async fn list_resources( &self, next_cursor: Option, + cancel_token: CancellationToken, ) -> Result; - async fn read_resource(&self, uri: &str) -> Result; + async fn read_resource( + &self, + uri: &str, + cancel_token: CancellationToken, + ) -> Result; - async fn list_tools(&self, next_cursor: Option) -> Result; + async fn list_tools( + &self, + next_cursor: Option, + cancel_token: CancellationToken, + ) -> Result; - async fn call_tool(&self, name: &str, arguments: Value) -> Result; + async fn call_tool( + &self, + name: &str, + arguments: Value, + cancel_token: CancellationToken, + ) -> Result; - async fn list_prompts(&self, next_cursor: Option) -> Result; + async fn list_prompts( + &self, + next_cursor: Option, + cancel_token: CancellationToken, + ) -> Result; - async fn get_prompt(&self, name: &str, arguments: Value) -> Result; + async fn get_prompt( + &self, + name: &str, + arguments: Value, + cancel_token: CancellationToken, + ) -> Result; async fn subscribe(&self) -> mpsc::Receiver; @@ -143,10 +169,33 @@ impl McpClient { }) } - fn get_request_options(&self) -> PeerRequestOptions { - PeerRequestOptions { - timeout: Some(self.timeout), - meta: None, + async fn send_request( + &self, + request: ClientRequest, + cancel_token: CancellationToken, + ) -> Result { + let handle = self + .client + .lock() + .await + .send_request_with_option( + request, + PeerRequestOptions { + timeout: Some(self.timeout), + meta: None, + }, + ) + .await?; + + let cancel_token = cancel_token.clone(); + + tokio::select! { + res = handle.await_response() => { + Ok(res?) + } + _ = cancel_token.cancelled() => { + Err(Error::Cancelled{reason: None}) + } } } } @@ -157,34 +206,35 @@ impl McpClientTrait for McpClient { self.server_info.as_ref() } - async fn list_resources(&self, cursor: Option) -> Result { + async fn list_resources( + &self, + cursor: Option, + cancel_token: CancellationToken, + ) -> Result { let res = self - .client - .lock() - .await - .send_request_with_option( + .send_request( ClientRequest::ListResourcesRequest(ListResourcesRequest { params: Some(PaginatedRequestParam { cursor }), method: Default::default(), extensions: Default::default(), }), - self.get_request_options(), + cancel_token, ) - .await? - .await_response() .await?; + match res { ServerResult::ListResourcesResult(result) => Ok(result), _ => Err(ServiceError::UnexpectedResponse), } } - async fn read_resource(&self, uri: &str) -> Result { + async fn read_resource( + &self, + uri: &str, + cancel_token: CancellationToken, + ) -> Result { let res = self - .client - .lock() - .await - .send_request_with_option( + .send_request( ClientRequest::ReadResourceRequest(ReadResourceRequest { params: ReadResourceRequestParam { uri: uri.to_string(), @@ -192,49 +242,50 @@ impl McpClientTrait for McpClient { method: Default::default(), extensions: Default::default(), }), - self.get_request_options(), + cancel_token, ) - .await? - .await_response() .await?; + match res { ServerResult::ReadResourceResult(result) => Ok(result), _ => Err(ServiceError::UnexpectedResponse), } } - async fn list_tools(&self, cursor: Option) -> Result { + async fn list_tools( + &self, + cursor: Option, + cancel_token: CancellationToken, + ) -> Result { let res = self - .client - .lock() - .await - .send_request_with_option( + .send_request( ClientRequest::ListToolsRequest(ListToolsRequest { params: Some(PaginatedRequestParam { cursor }), method: Default::default(), extensions: Default::default(), }), - self.get_request_options(), + cancel_token, ) - .await? - .await_response() .await?; + match res { ServerResult::ListToolsResult(result) => Ok(result), _ => Err(ServiceError::UnexpectedResponse), } } - async fn call_tool(&self, name: &str, arguments: Value) -> Result { + async fn call_tool( + &self, + name: &str, + arguments: Value, + cancel_token: CancellationToken, + ) -> Result { let arguments = match arguments { Value::Object(map) => Some(map), _ => None, }; let res = self - .client - .lock() - .await - .send_request_with_option( + .send_request( ClientRequest::CallToolRequest(CallToolRequest { params: CallToolRequestParam { name: name.to_string().into(), @@ -243,49 +294,50 @@ impl McpClientTrait for McpClient { method: Default::default(), extensions: Default::default(), }), - self.get_request_options(), + cancel_token, ) - .await? - .await_response() .await?; + match res { ServerResult::CallToolResult(result) => Ok(result), _ => Err(ServiceError::UnexpectedResponse), } } - async fn list_prompts(&self, cursor: Option) -> Result { + async fn list_prompts( + &self, + cursor: Option, + cancel_token: CancellationToken, + ) -> Result { let res = self - .client - .lock() - .await - .send_request_with_option( + .send_request( ClientRequest::ListPromptsRequest(ListPromptsRequest { params: Some(PaginatedRequestParam { cursor }), method: Default::default(), extensions: Default::default(), }), - self.get_request_options(), + cancel_token, ) - .await? - .await_response() .await?; + match res { ServerResult::ListPromptsResult(result) => Ok(result), _ => Err(ServiceError::UnexpectedResponse), } } - async fn get_prompt(&self, name: &str, arguments: Value) -> Result { + async fn get_prompt( + &self, + name: &str, + arguments: Value, + cancel_token: CancellationToken, + ) -> Result { let arguments = match arguments { Value::Object(map) => Some(map), _ => None, }; let res = self - .client - .lock() - .await - .send_request_with_option( + .send_request( ClientRequest::GetPromptRequest(GetPromptRequest { params: GetPromptRequestParam { name: name.to_string(), @@ -294,11 +346,10 @@ impl McpClientTrait for McpClient { method: Default::default(), extensions: Default::default(), }), - self.get_request_options(), + cancel_token, ) - .await? - .await_response() .await?; + match res { ServerResult::GetPromptResult(result) => Ok(result), _ => Err(ServiceError::UnexpectedResponse), From 242f86253f47d421ff3127bb731e1a3203acbaad Mon Sep 17 00:00:00 2001 From: Jack Amadeo Date: Thu, 31 Jul 2025 20:42:45 -0400 Subject: [PATCH 2/7] A WIP --- crates/goose-cli/src/session/mod.rs | 33 +++-- crates/goose-server/src/routes/reply.rs | 129 ++++++++----------- crates/goose/src/agents/agent.rs | 18 ++- crates/goose/src/agents/extension_manager.rs | 112 ++++++---------- crates/goose/src/agents/subagent.rs | 3 +- crates/mcp-client/src/client.rs | 2 + 6 files changed, 135 insertions(+), 162 deletions(-) diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 106bc0a0744d..d30e5149e5b2 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -364,7 +364,12 @@ impl Session { } /// Process a single message and get the response - pub(crate) async fn process_message(&mut self, message: Message) -> Result<()> { + pub(crate) async fn process_message( + &mut self, + message: Message, + cancel_token: CancellationToken, + ) -> Result<()> { + let cancel_token = cancel_token.clone(); let message_text = message.as_concat_text(); self.push_message(message); @@ -405,7 +410,7 @@ impl Session { ); } - self.process_agent_response(false).await?; + self.process_agent_response(false, cancel_token).await?; Ok(()) } @@ -414,7 +419,8 @@ impl Session { // Process initial message if provided if let Some(prompt) = prompt { let msg = Message::user().with_text(&prompt); - self.process_message(msg).await?; + self.process_message(msg, CancellationToken::default()) + .await?; } // Initialize the completion cache @@ -514,7 +520,8 @@ impl Session { } output::show_thinking(); - self.process_agent_response(true).await?; + self.process_agent_response(true, CancellationToken::default()) + .await?; output::hide_thinking(); } RunMode::Plan => { @@ -814,7 +821,8 @@ impl Session { self.push_message(plan_message); // act on the plan output::show_thinking(); - self.process_agent_response(true).await?; + self.process_agent_response(true, CancellationToken::default()) + .await?; output::hide_thinking(); // Reset run & goose mode @@ -842,12 +850,15 @@ impl Session { /// Process a single message and exit pub async fn headless(&mut self, prompt: String) -> Result<()> { let message = Message::user().with_text(&prompt); - self.process_message(message).await + self.process_message(message, CancellationToken::default()) + .await } - async fn process_agent_response(&mut self, interactive: bool) -> Result<()> { - // Messages will be auto-compacted in agent.reply() if needed - let cancel_token = CancellationToken::new(); + async fn process_agent_response( + &mut self, + interactive: bool, + cancel_token: CancellationToken, + ) -> Result<()> { let cancel_token_clone = cancel_token.clone(); let session_config = self.session_file.as_ref().map(|s| { @@ -1191,6 +1202,7 @@ impl Session { } } _ = tokio::signal::ctrl_c() => { + eprintln!("caught ctrl-c"); cancel_token_clone.cancel(); drop(stream); if let Err(e) = self.handle_interrupted_messages(true).await { @@ -1511,7 +1523,8 @@ impl Session { if valid { output::show_thinking(); - self.process_agent_response(true).await?; + self.process_agent_response(true, CancellationToken::default()) + .await?; output::hide_thinking(); } } diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 05f5eaa8af15..8b5ec6a189dc 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -104,14 +104,17 @@ enum MessageEvent { async fn stream_event( event: MessageEvent, tx: &mpsc::Sender, -) -> Result<(), mpsc::error::SendError> { + cancel_token: &CancellationToken, +) { let json = serde_json::to_string(&event).unwrap_or_else(|e| { format!( r#"{{"type":"Error","error":"Failed to serialize event: {}"}}"#, e ) }); - tx.send(format!("data: {}\n\n", json)).await + if let Err(_) = tx.send(format!("data: {}\n\n", json)).await { + cancel_token.cancel(); + } } async fn reply_handler( @@ -144,6 +147,7 @@ async fn reply_handler( error: "No agent configured".to_string(), }, &task_tx, + &cancel_token, ) .await; return; @@ -173,11 +177,12 @@ async fn reply_handler( Ok(stream) => stream, Err(e) => { tracing::error!("Failed to start reply stream: {:?}", e); - let _ = stream_event( + stream_event( MessageEvent::Error { error: e.to_string(), }, &task_tx, + &cancel_token, ) .await; return; @@ -194,6 +199,7 @@ async fn reply_handler( error: format!("Failed to get session path: {}", e), }, &task_tx, + &cancel_token, ) .await; return; @@ -203,79 +209,55 @@ async fn reply_handler( loop { tokio::select! { - _ = task_cancel.cancelled() => { - tracing::info!("Agent task cancelled"); + _ = task_cancel.cancelled() => { + tracing::info!("Agent task cancelled"); + break; + } + response = timeout(Duration::from_millis(500), stream.next()) => { + match response { + Ok(Some(Ok(AgentEvent::Message(message)))) => { + push_message(&mut all_messages, message.clone()); + stream_event(MessageEvent::Message { message }, &tx, &cancel_token).await; + } + Ok(Some(Ok(AgentEvent::HistoryReplaced(new_messages)))) => { + // Replace the message history with the compacted messages + all_messages = new_messages; + // Note: We don't send this as a stream event since it's an internal operation + // The client will see the compaction notification message that was sent before this event + } + Ok(Some(Ok(AgentEvent::ModelChange { model, mode }))) => { + stream_event(MessageEvent::ModelChange { model, mode }, &tx, &cancel_token).await; + } + Ok(Some(Ok(AgentEvent::McpNotification((request_id, n))))) => { + stream_event(MessageEvent::Notification{ + request_id: request_id.clone(), + message: n, + }, &tx, &cancel_token).await; + } + + Ok(Some(Err(e))) => { + tracing::error!("Error processing message: {}", e); + stream_event( + MessageEvent::Error { + error: e.to_string(), + }, + &tx, + &cancel_token, + ).await; + break; + } + Ok(None) => { + break; + } + Err(_) => { + if tx.is_closed() { break; } - response = timeout(Duration::from_millis(500), stream.next()) => { - match response { - Ok(Some(Ok(AgentEvent::Message(message)))) => { - push_message(&mut all_messages, message.clone()); - if let Err(e) = stream_event(MessageEvent::Message { message }, &tx).await { - tracing::error!("Error sending message through channel: {}", e); - let _ = stream_event( - MessageEvent::Error { - error: e.to_string(), - }, - &tx, - ).await; - break; - } - } - Ok(Some(Ok(AgentEvent::HistoryReplaced(new_messages)))) => { - // Replace the message history with the compacted messages - all_messages = new_messages; - // Note: We don't send this as a stream event since it's an internal operation - // The client will see the compaction notification message that was sent before this event - } - Ok(Some(Ok(AgentEvent::ModelChange { model, mode }))) => { - if let Err(e) = stream_event(MessageEvent::ModelChange { model, mode }, &tx).await { - tracing::error!("Error sending model change through channel: {}", e); - let _ = stream_event( - MessageEvent::Error { - error: e.to_string(), - }, - &tx, - ).await; - } - } - Ok(Some(Ok(AgentEvent::McpNotification((request_id, n))))) => { - if let Err(e) = stream_event(MessageEvent::Notification{ - request_id: request_id.clone(), - message: n, - }, &tx).await { - tracing::error!("Error sending message through channel: {}", e); - let _ = stream_event( - MessageEvent::Error { - error: e.to_string(), - }, - &tx, - ).await; - } - } - - Ok(Some(Err(e))) => { - tracing::error!("Error processing message: {}", e); - let _ = stream_event( - MessageEvent::Error { - error: e.to_string(), - }, - &tx, - ).await; - break; - } - Ok(None) => { - break; - } - Err(_) => { - if tx.is_closed() { - break; - } - continue; - } - } - } + continue; } + } + } + } } if all_messages.len() > saved_message_count { @@ -301,6 +283,7 @@ async fn reply_handler( reason: "stop".to_string(), }, &task_tx, + &cancel_token, ) .await; })); diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index c106f43f6be1..ba9212a3b80c 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -439,13 +439,19 @@ impl Agent { // Check if the tool is read_resource and handle it separately ToolCallResult::from( extension_manager - .read_resource(tool_call.arguments.clone()) + .read_resource( + tool_call.arguments.clone(), + cancellation_token.unwrap_or_default(), + ) .await, ) } else if tool_call.name == PLATFORM_LIST_RESOURCES_TOOL_NAME { ToolCallResult::from( extension_manager - .list_resources(tool_call.arguments.clone()) + .list_resources( + tool_call.arguments.clone(), + cancellation_token.unwrap_or_default(), + ) .await, ) } else if tool_call.name == PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME { @@ -469,7 +475,7 @@ impl Agent { } else { // Clone the result to ensure no references to extension_manager are returned let result = extension_manager - .dispatch_tool_call(tool_call.clone()) + .dispatch_tool_call(tool_call.clone(), cancellation_token.unwrap_or_default()) .await; result.unwrap_or_else(|e| { ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))) @@ -1161,7 +1167,7 @@ impl Agent { pub async fn list_extension_prompts(&self) -> HashMap> { let extension_manager = self.extension_manager.read().await; extension_manager - .list_prompts() + .list_prompts(CancellationToken::default()) .await .expect("Failed to list prompts") } @@ -1171,7 +1177,7 @@ impl Agent { // First find which extension has this prompt let prompts = extension_manager - .list_prompts() + .list_prompts(CancellationToken::default()) .await .map_err(|e| anyhow!("Failed to list prompts: {}", e))?; @@ -1181,7 +1187,7 @@ impl Agent { .map(|(extension, _)| extension) { return extension_manager - .get_prompt(extension, name, arguments) + .get_prompt(extension, name, arguments, CancellationToken::default()) .await .map_err(|e| anyhow!("Failed to get prompt: {}", e)); } diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index 72328d2c86ed..3247125206f9 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -20,6 +20,7 @@ use tokio::process::Command; use tokio::sync::Mutex; use tokio::task; use tokio_stream::wrappers::ReceiverStream; +use tokio_util::sync::CancellationToken; use tracing::{error, warn}; use super::extension::{ExtensionConfig, ExtensionError, ExtensionInfo, ExtensionResult, ToolInfo}; @@ -31,11 +32,6 @@ use mcp_client::client::{McpClient, McpClientTrait}; use rmcp::model::{Content, GetPromptResult, Prompt, Resource, ResourceContents, Tool}; use serde_json::Value; -// By default, we set it to Jan 1, 2020 if the resource does not have a timestamp -// This is to ensure that the resource is considered less important than resources with a more recent timestamp -static DEFAULT_TIMESTAMP: LazyLock> = - LazyLock::new(|| Utc.with_ymd_and_hms(2020, 1, 1, 0, 0, 0).unwrap()); - type McpClientBox = Arc>>; /// Manages Goose extensions / MCP clients and their interactions @@ -456,7 +452,9 @@ impl ExtensionManager { task::spawn(async move { let mut tools = Vec::new(); let client_guard = client.lock().await; - let mut client_tools = client_guard.list_tools(None).await?; + let mut client_tools = client_guard + .list_tools(None, CancellationToken::default()) + .await?; loop { for tool in client_tools.tools { @@ -473,7 +471,9 @@ impl ExtensionManager { break; } - client_tools = client_guard.list_tools(client_tools.next_cursor).await?; + client_tools = client_guard + .list_tools(client_tools.next_cursor, CancellationToken::default()) + .await?; } Ok::, ExtensionError>(tools) @@ -496,43 +496,6 @@ impl ExtensionManager { Ok(tools) } - /// Get client resources and their contents - pub async fn get_resources(&self) -> ExtensionResult> { - let mut result: Vec = Vec::new(); - - for (name, client) in &self.clients { - let client_guard = client.lock().await; - let resources = client_guard.list_resources(None).await?; - - for resource in resources.resources { - // Skip reading the resource if it's not marked active - // This avoids blowing up the context with inactive resources - if !resource_is_active(&resource) { - continue; - } - - if let Ok(contents) = client_guard.read_resource(&resource.uri).await { - for content in contents.contents { - let (uri, content_str) = match content { - ResourceContents::TextResourceContents { uri, text, .. } => (uri, text), - ResourceContents::BlobResourceContents { uri, blob, .. } => (uri, blob), - }; - - result.push(ResourceItem::new( - name.clone(), - uri, - resource.name.clone(), - content_str, - resource.timestamp().unwrap_or(*DEFAULT_TIMESTAMP), - resource.priority().unwrap_or(0.0), - )); - } - } - } - } - Ok(result) - } - /// Get the extension prompt including client instructions pub async fn get_planning_prompt(&self, tools_info: Vec) -> String { let mut context: HashMap<&str, Value> = HashMap::new(); @@ -550,7 +513,7 @@ impl ExtensionManager { } // Function that gets executed for read_resource tool - pub async fn read_resource(&self, params: Value) -> Result, ToolError> { + pub async fn read_resource(&self, params: Value, cancellation_token: CancellationToken) -> Result, ToolError> { let uri = params .get("uri") .and_then(|v| v.as_str()) @@ -561,7 +524,7 @@ impl ExtensionManager { // If extension name is provided, we can just look it up if extension_name.is_some() { let result = self - .read_resource_from_extension(uri, extension_name.unwrap()) + .read_resource_from_extension(uri, extension_name.unwrap(), cancellation_token.clone()) .await?; return Ok(result); } @@ -571,7 +534,7 @@ impl ExtensionManager { // TODO: do we want to find if a provided uri is in multiple extensions? // currently it will return the first match and skip any others for extension_name in self.resource_capable_extensions.iter() { - let result = self.read_resource_from_extension(uri, extension_name).await; + let result = self.read_resource_from_extension(uri, extension_name, cancellation_token.clone()).await; match result { Ok(result) => return Ok(result), Err(_) => continue, @@ -597,6 +560,7 @@ impl ExtensionManager { &self, uri: &str, extension_name: &str, + cancellation_token: CancellationToken, ) -> Result, ToolError> { let available_extensions = self .clients @@ -615,7 +579,7 @@ impl ExtensionManager { .ok_or(ToolError::InvalidParameters(error_msg))?; let client_guard = client.lock().await; - let read_result = client_guard.read_resource(uri).await.map_err(|_| { + let read_result = client_guard.read_resource(uri, cancellation_token).await.map_err(|_| { ToolError::ExecutionError(format!("Could not read resource with uri: {}", uri)) })?; @@ -634,6 +598,7 @@ impl ExtensionManager { async fn list_resources_from_extension( &self, extension_name: &str, + cancellation_token: CancellationToken, ) -> Result, ToolError> { let client = self.clients.get(extension_name).ok_or_else(|| { ToolError::InvalidParameters(format!("Extension {} is not valid", extension_name)) @@ -641,7 +606,7 @@ impl ExtensionManager { let client_guard = client.lock().await; client_guard - .list_resources(None) + .list_resources(None, cancellation_token) .await .map_err(|e| { ToolError::ExecutionError(format!( @@ -661,13 +626,13 @@ impl ExtensionManager { }) } - pub async fn list_resources(&self, params: Value) -> Result, ToolError> { + pub async fn list_resources(&self, params: Value, cancellation_token: CancellationToken) -> Result, ToolError> { let extension = params.get("extension").and_then(|v| v.as_str()); match extension { Some(extension_name) => { // Handle single extension case - self.list_resources_from_extension(extension_name).await + self.list_resources_from_extension(extension_name, cancellation_token).await } None => { // Handle all extensions case using FuturesUnordered @@ -675,8 +640,9 @@ impl ExtensionManager { // Create futures for each resource_capable_extension for extension_name in &self.resource_capable_extensions { + let token = cancellation_token.clone(); futures.push(async move { - self.list_resources_from_extension(extension_name).await + self.list_resources_from_extension(extension_name, token).await }); } @@ -711,7 +677,7 @@ impl ExtensionManager { } } - pub async fn dispatch_tool_call(&self, tool_call: ToolCall) -> Result { + pub async fn dispatch_tool_call(&self, tool_call: ToolCall, cancellation_token: CancellationToken) -> Result { // Dispatch tool call based on the prefix naming convention let (client_name, client) = self .get_client_for_tool(&tool_call.name) @@ -732,7 +698,7 @@ impl ExtensionManager { let fut = async move { let client_guard = client.lock().await; client_guard - .call_tool(&tool_name, arguments) + .call_tool(&tool_name, arguments, cancellation_token) .await .map(|call| call.content) .map_err(|e| ToolError::ExecutionError(e.to_string())) @@ -747,6 +713,7 @@ impl ExtensionManager { pub async fn list_prompts_from_extension( &self, extension_name: &str, + cancellation_token: CancellationToken, ) -> Result, ToolError> { let client = self.clients.get(extension_name).ok_or_else(|| { ToolError::InvalidParameters(format!("Extension {} is not valid", extension_name)) @@ -754,7 +721,7 @@ impl ExtensionManager { let client_guard = client.lock().await; client_guard - .list_prompts(None) + .list_prompts(None, cancellation_token) .await .map_err(|e| { ToolError::ExecutionError(format!( @@ -765,14 +732,15 @@ impl ExtensionManager { .map(|lp| lp.prompts) } - pub async fn list_prompts(&self) -> Result>, ToolError> { + pub async fn list_prompts(&self, cancellation_token: CancellationToken) -> Result>, ToolError> { let mut futures = FuturesUnordered::new(); for extension_name in self.clients.keys() { + let token = cancellation_token.clone(); futures.push(async move { ( extension_name, - self.list_prompts_from_extension(extension_name).await, + self.list_prompts_from_extension(extension_name, token).await, ) }); } @@ -812,6 +780,7 @@ impl ExtensionManager { extension_name: &str, name: &str, arguments: Value, + cancellation_token: CancellationToken, ) -> Result { let client = self .clients @@ -820,7 +789,7 @@ impl ExtensionManager { let client_guard = client.lock().await; client_guard - .get_prompt(name, arguments) + .get_prompt(name, arguments, cancellation_token) .await .map_err(|e| anyhow::anyhow!("Failed to get prompt: {}", e)) } @@ -899,10 +868,6 @@ impl ExtensionManager { } } -fn resource_is_active(resource: &Resource) -> bool { - resource.priority().is_some_and(|p| (p - 1.0).abs() < 1e-6) -} - #[cfg(test)] mod tests { use super::*; @@ -930,19 +895,20 @@ mod tests { async fn list_resources( &self, _next_cursor: Option, + _cancellation_token: CancellationToken, ) -> Result { Err(Error::TransportClosed) } - async fn read_resource(&self, _uri: &str) -> Result { + async fn read_resource(&self, _uri: &str, _cancellation_token: CancellationToken) -> Result { Err(Error::TransportClosed) } - async fn list_tools(&self, _next_cursor: Option) -> Result { + async fn list_tools(&self, _next_cursor: Option, _cancellation_token: CancellationToken) -> Result { Err(Error::TransportClosed) } - async fn call_tool(&self, name: &str, _arguments: Value) -> Result { + async fn call_tool(&self, name: &str, _arguments: Value, _cancellation_token: CancellationToken) -> Result { match name { "tool" | "test__tool" => Ok(CallToolResult { content: vec![], @@ -955,6 +921,7 @@ mod tests { async fn list_prompts( &self, _next_cursor: Option, + _cancellation_token: CancellationToken, ) -> Result { Err(Error::TransportClosed) } @@ -963,6 +930,7 @@ mod tests { &self, _name: &str, _arguments: Value, + _cancellation_token: CancellationToken, ) -> Result { Err(Error::TransportClosed) } @@ -1046,7 +1014,7 @@ mod tests { arguments: json!({}), }; - let result = extension_manager.dispatch_tool_call(tool_call).await; + let result = extension_manager.dispatch_tool_call(tool_call, CancellationToken::default()).await; assert!(result.is_ok()); let tool_call = ToolCall { @@ -1054,7 +1022,7 @@ mod tests { arguments: json!({}), }; - let result = extension_manager.dispatch_tool_call(tool_call).await; + let result = extension_manager.dispatch_tool_call(tool_call, CancellationToken::default()).await; assert!(result.is_ok()); // verify a multiple underscores dispatch @@ -1063,7 +1031,7 @@ mod tests { arguments: json!({}), }; - let result = extension_manager.dispatch_tool_call(tool_call).await; + let result = extension_manager.dispatch_tool_call(tool_call, CancellationToken::default()).await; assert!(result.is_ok()); // Test unicode in tool name, "client 🚀" should become "client_" @@ -1072,7 +1040,7 @@ mod tests { arguments: json!({}), }; - let result = extension_manager.dispatch_tool_call(tool_call).await; + let result = extension_manager.dispatch_tool_call(tool_call, CancellationToken::default()).await; assert!(result.is_ok()); let tool_call = ToolCall { @@ -1080,7 +1048,7 @@ mod tests { arguments: json!({}), }; - let result = extension_manager.dispatch_tool_call(tool_call).await; + let result = extension_manager.dispatch_tool_call(tool_call, CancellationToken::default()).await; assert!(result.is_ok()); // this should error out, specifically for an ToolError::ExecutionError @@ -1090,7 +1058,7 @@ mod tests { }; let result = extension_manager - .dispatch_tool_call(invalid_tool_call) + .dispatch_tool_call(invalid_tool_call, CancellationToken::default()) .await .unwrap() .result @@ -1108,7 +1076,7 @@ mod tests { }; let result = extension_manager - .dispatch_tool_call(invalid_tool_call) + .dispatch_tool_call(invalid_tool_call, CancellationToken::default()) .await; if let Err(err) = result { let tool_err = err.downcast_ref::().expect("Expected ToolError"); diff --git a/crates/goose/src/agents/subagent.rs b/crates/goose/src/agents/subagent.rs index 939916e0afeb..701d4aac60c0 100644 --- a/crates/goose/src/agents/subagent.rs +++ b/crates/goose/src/agents/subagent.rs @@ -15,6 +15,7 @@ use serde::{Deserialize, Serialize}; // use serde_json::{self}; use std::{collections::HashMap, sync::Arc}; use tokio::sync::{Mutex, RwLock}; +use tokio_util::sync::CancellationToken; use tracing::{debug, error, instrument}; /// Status of a subagent @@ -197,7 +198,7 @@ impl SubAgent { .extension_manager .read() .await - .dispatch_tool_call(tool_call.clone()) + .dispatch_tool_call(tool_call.clone(), CancellationToken::default()) .await { Ok(result) => result.result.await, diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index 816562ff6a74..2248f579a20d 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -174,6 +174,7 @@ impl McpClient { request: ClientRequest, cancel_token: CancellationToken, ) -> Result { + eprintln!("request: {:?}", request); let handle = self .client .lock() @@ -187,6 +188,7 @@ impl McpClient { ) .await?; + eprintln!("request handle: {:?}", handle); let cancel_token = cancel_token.clone(); tokio::select! { From 407539661a2060f157d52585d2c74a2664eb75a9 Mon Sep 17 00:00:00 2001 From: Jack Amadeo Date: Fri, 1 Aug 2025 09:06:08 -0400 Subject: [PATCH 3/7] detect disconnects --- crates/goose-server/src/routes/reply.rs | 6 ++++++ crates/goose/src/agents/extension_manager.rs | 5 ++--- crates/mcp-client/src/client.rs | 2 +- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 8b5ec6a189dc..ffc3755fdb1f 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -99,6 +99,7 @@ enum MessageEvent { request_id: String, message: ServerNotification, }, + Ping, } async fn stream_event( @@ -113,6 +114,7 @@ async fn stream_event( ) }); if let Err(_) = tx.send(format!("data: {}\n\n", json)).await { + tracing::info!("client hung up"); cancel_token.cancel(); } } @@ -207,12 +209,16 @@ async fn reply_handler( }; let saved_message_count = all_messages.len(); + let mut heartbeat_interval = tokio::time::interval(Duration::from_millis(500)); loop { tokio::select! { _ = task_cancel.cancelled() => { tracing::info!("Agent task cancelled"); break; } + _ = heartbeat_interval.tick() => { + stream_event(MessageEvent::Ping, &tx, &cancel_token).await; + } response = timeout(Duration::from_millis(500), stream.next()) => { match response { Ok(Some(Ok(AgentEvent::Message(message)))) => { diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index 3247125206f9..a3c6502e5e28 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -1,6 +1,6 @@ use anyhow::Result; use axum::http::{HeaderMap, HeaderName}; -use chrono::{DateTime, TimeZone, Utc}; +use chrono::{DateTime, Utc}; use futures::stream::{FuturesUnordered, StreamExt}; use futures::{future, FutureExt}; use mcp_core::{ToolCall, ToolError}; @@ -12,7 +12,6 @@ use rmcp::transport::{ use std::collections::{HashMap, HashSet}; use std::process::Stdio; use std::sync::Arc; -use std::sync::LazyLock; use std::time::Duration; use tempfile::tempdir; use tokio::io::AsyncReadExt; @@ -29,7 +28,7 @@ use crate::agents::extension::{Envs, ProcessExit}; use crate::config::{Config, ExtensionConfigManager}; use crate::prompt_template; use mcp_client::client::{McpClient, McpClientTrait}; -use rmcp::model::{Content, GetPromptResult, Prompt, Resource, ResourceContents, Tool}; +use rmcp::model::{Content, GetPromptResult, Prompt, ResourceContents, Tool}; use serde_json::Value; type McpClientBox = Arc>>; diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index 2248f579a20d..b6c7f74ae406 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -9,7 +9,7 @@ use rmcp::{ ReadResourceResult, ServerNotification, ServerResult, }, service::{ - ClientInitializeError, PeerRequestOptions, RequestHandle, RunningService, ServiceRole, + ClientInitializeError, PeerRequestOptions, RunningService }, transport::IntoTransport, ClientHandler, RoleClient, ServiceError, ServiceExt, From 08dde09352599faf9832b24e0309a59faaa62b7b Mon Sep 17 00:00:00 2001 From: Jack Amadeo Date: Fri, 1 Aug 2025 09:14:14 -0400 Subject: [PATCH 4/7] fmt --- crates/goose/src/agents/extension_manager.rs | 90 +++++++++++++++----- crates/mcp-client/src/client.rs | 4 +- 2 files changed, 71 insertions(+), 23 deletions(-) diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index a3c6502e5e28..b4b329a4751b 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -512,7 +512,11 @@ impl ExtensionManager { } // Function that gets executed for read_resource tool - pub async fn read_resource(&self, params: Value, cancellation_token: CancellationToken) -> Result, ToolError> { + pub async fn read_resource( + &self, + params: Value, + cancellation_token: CancellationToken, + ) -> Result, ToolError> { let uri = params .get("uri") .and_then(|v| v.as_str()) @@ -523,7 +527,11 @@ impl ExtensionManager { // If extension name is provided, we can just look it up if extension_name.is_some() { let result = self - .read_resource_from_extension(uri, extension_name.unwrap(), cancellation_token.clone()) + .read_resource_from_extension( + uri, + extension_name.unwrap(), + cancellation_token.clone(), + ) .await?; return Ok(result); } @@ -533,7 +541,9 @@ impl ExtensionManager { // TODO: do we want to find if a provided uri is in multiple extensions? // currently it will return the first match and skip any others for extension_name in self.resource_capable_extensions.iter() { - let result = self.read_resource_from_extension(uri, extension_name, cancellation_token.clone()).await; + let result = self + .read_resource_from_extension(uri, extension_name, cancellation_token.clone()) + .await; match result { Ok(result) => return Ok(result), Err(_) => continue, @@ -578,9 +588,12 @@ impl ExtensionManager { .ok_or(ToolError::InvalidParameters(error_msg))?; let client_guard = client.lock().await; - let read_result = client_guard.read_resource(uri, cancellation_token).await.map_err(|_| { - ToolError::ExecutionError(format!("Could not read resource with uri: {}", uri)) - })?; + let read_result = client_guard + .read_resource(uri, cancellation_token) + .await + .map_err(|_| { + ToolError::ExecutionError(format!("Could not read resource with uri: {}", uri)) + })?; let mut result = Vec::new(); for content in read_result.contents { @@ -625,13 +638,18 @@ impl ExtensionManager { }) } - pub async fn list_resources(&self, params: Value, cancellation_token: CancellationToken) -> Result, ToolError> { + pub async fn list_resources( + &self, + params: Value, + cancellation_token: CancellationToken, + ) -> Result, ToolError> { let extension = params.get("extension").and_then(|v| v.as_str()); match extension { Some(extension_name) => { // Handle single extension case - self.list_resources_from_extension(extension_name, cancellation_token).await + self.list_resources_from_extension(extension_name, cancellation_token) + .await } None => { // Handle all extensions case using FuturesUnordered @@ -641,7 +659,8 @@ impl ExtensionManager { for extension_name in &self.resource_capable_extensions { let token = cancellation_token.clone(); futures.push(async move { - self.list_resources_from_extension(extension_name, token).await + self.list_resources_from_extension(extension_name, token) + .await }); } @@ -676,7 +695,11 @@ impl ExtensionManager { } } - pub async fn dispatch_tool_call(&self, tool_call: ToolCall, cancellation_token: CancellationToken) -> Result { + pub async fn dispatch_tool_call( + &self, + tool_call: ToolCall, + cancellation_token: CancellationToken, + ) -> Result { // Dispatch tool call based on the prefix naming convention let (client_name, client) = self .get_client_for_tool(&tool_call.name) @@ -731,7 +754,10 @@ impl ExtensionManager { .map(|lp| lp.prompts) } - pub async fn list_prompts(&self, cancellation_token: CancellationToken) -> Result>, ToolError> { + pub async fn list_prompts( + &self, + cancellation_token: CancellationToken, + ) -> Result>, ToolError> { let mut futures = FuturesUnordered::new(); for extension_name in self.clients.keys() { @@ -739,7 +765,8 @@ impl ExtensionManager { futures.push(async move { ( extension_name, - self.list_prompts_from_extension(extension_name, token).await, + self.list_prompts_from_extension(extension_name, token) + .await, ) }); } @@ -899,15 +926,28 @@ mod tests { Err(Error::TransportClosed) } - async fn read_resource(&self, _uri: &str, _cancellation_token: CancellationToken) -> Result { + async fn read_resource( + &self, + _uri: &str, + _cancellation_token: CancellationToken, + ) -> Result { Err(Error::TransportClosed) } - async fn list_tools(&self, _next_cursor: Option, _cancellation_token: CancellationToken) -> Result { + async fn list_tools( + &self, + _next_cursor: Option, + _cancellation_token: CancellationToken, + ) -> Result { Err(Error::TransportClosed) } - async fn call_tool(&self, name: &str, _arguments: Value, _cancellation_token: CancellationToken) -> Result { + async fn call_tool( + &self, + name: &str, + _arguments: Value, + _cancellation_token: CancellationToken, + ) -> Result { match name { "tool" | "test__tool" => Ok(CallToolResult { content: vec![], @@ -1013,7 +1053,9 @@ mod tests { arguments: json!({}), }; - let result = extension_manager.dispatch_tool_call(tool_call, CancellationToken::default()).await; + let result = extension_manager + .dispatch_tool_call(tool_call, CancellationToken::default()) + .await; assert!(result.is_ok()); let tool_call = ToolCall { @@ -1021,7 +1063,9 @@ mod tests { arguments: json!({}), }; - let result = extension_manager.dispatch_tool_call(tool_call, CancellationToken::default()).await; + let result = extension_manager + .dispatch_tool_call(tool_call, CancellationToken::default()) + .await; assert!(result.is_ok()); // verify a multiple underscores dispatch @@ -1030,7 +1074,9 @@ mod tests { arguments: json!({}), }; - let result = extension_manager.dispatch_tool_call(tool_call, CancellationToken::default()).await; + let result = extension_manager + .dispatch_tool_call(tool_call, CancellationToken::default()) + .await; assert!(result.is_ok()); // Test unicode in tool name, "client 🚀" should become "client_" @@ -1039,7 +1085,9 @@ mod tests { arguments: json!({}), }; - let result = extension_manager.dispatch_tool_call(tool_call, CancellationToken::default()).await; + let result = extension_manager + .dispatch_tool_call(tool_call, CancellationToken::default()) + .await; assert!(result.is_ok()); let tool_call = ToolCall { @@ -1047,7 +1095,9 @@ mod tests { arguments: json!({}), }; - let result = extension_manager.dispatch_tool_call(tool_call, CancellationToken::default()).await; + let result = extension_manager + .dispatch_tool_call(tool_call, CancellationToken::default()) + .await; assert!(result.is_ok()); // this should error out, specifically for an ToolError::ExecutionError diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index b6c7f74ae406..1f0f796c7a0e 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -8,9 +8,7 @@ use rmcp::{ ProgressNotificationMethod, ProtocolVersion, ReadResourceRequest, ReadResourceRequestParam, ReadResourceResult, ServerNotification, ServerResult, }, - service::{ - ClientInitializeError, PeerRequestOptions, RunningService - }, + service::{ClientInitializeError, PeerRequestOptions, RunningService}, transport::IntoTransport, ClientHandler, RoleClient, ServiceError, ServiceExt, }; From 3dc08467d8c75422bb8cae9987bba63c9a6749bc Mon Sep 17 00:00:00 2001 From: Jack Amadeo Date: Fri, 1 Aug 2025 09:35:05 -0400 Subject: [PATCH 5/7] test fixes --- .../src/scenario_tests/mock_client.rs | 34 ++++++++++++++++--- .../src/scenario_tests/scenario_runner.rs | 6 +++- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/crates/goose-cli/src/scenario_tests/mock_client.rs b/crates/goose-cli/src/scenario_tests/mock_client.rs index 006795ce9d94..38e74ca44b01 100644 --- a/crates/goose-cli/src/scenario_tests/mock_client.rs +++ b/crates/goose-cli/src/scenario_tests/mock_client.rs @@ -13,6 +13,7 @@ use rmcp::{ use serde_json::Value; use std::collections::HashMap; use tokio::sync::mpsc::{self, Receiver}; +use tokio_util::sync::CancellationToken; pub struct MockClient { tools: HashMap, @@ -43,6 +44,7 @@ impl McpClientTrait for MockClient { async fn list_resources( &self, _next_cursor: Option, + _cancel_token: CancellationToken, ) -> Result { Ok(ListResourcesResult { resources: vec![], @@ -54,11 +56,19 @@ impl McpClientTrait for MockClient { todo!() } - async fn read_resource(&self, _uri: &str) -> Result { + async fn read_resource( + &self, + _uri: &str, + _cancel_token: CancellationToken, + ) -> Result { Err(Error::UnexpectedResponse) } - async fn list_tools(&self, _: Option) -> Result { + async fn list_tools( + &self, + _: Option, + _cancel_token: CancellationToken, + ) -> Result { let rmcp_tools: Vec = self .tools .values() @@ -77,7 +87,12 @@ impl McpClientTrait for MockClient { }) } - async fn call_tool(&self, name: &str, arguments: Value) -> Result { + async fn call_tool( + &self, + name: &str, + arguments: Value, + _cancel_token: CancellationToken, + ) -> Result { if let Some(handler) = self.handlers.get(name) { match handler(&arguments) { Ok(content) => Ok(CallToolResult { @@ -91,14 +106,23 @@ impl McpClientTrait for MockClient { } } - async fn list_prompts(&self, _next_cursor: Option) -> Result { + async fn list_prompts( + &self, + _next_cursor: Option, + _cancel_token: CancellationToken, + ) -> Result { Ok(ListPromptsResult { prompts: vec![], next_cursor: None, }) } - async fn get_prompt(&self, _name: &str, _arguments: Value) -> Result { + async fn get_prompt( + &self, + _name: &str, + _arguments: Value, + _cancel_token: CancellationToken, + ) -> Result { Err(Error::UnexpectedResponse) } diff --git a/crates/goose-cli/src/scenario_tests/scenario_runner.rs b/crates/goose-cli/src/scenario_tests/scenario_runner.rs index 687d6e413a80..5ab1cd97a4df 100644 --- a/crates/goose-cli/src/scenario_tests/scenario_runner.rs +++ b/crates/goose-cli/src/scenario_tests/scenario_runner.rs @@ -12,6 +12,7 @@ use goose::providers::{create, testprovider::TestProvider}; use std::collections::{HashMap, HashSet}; use std::path::Path; use std::sync::Arc; +use tokio_util::sync::CancellationToken; pub const SCENARIO_TESTS_DIR: &str = "src/scenario_tests"; @@ -205,7 +206,10 @@ where let mut error = None; for message in &messages { - if let Err(e) = session.process_message(message.clone()).await { + if let Err(e) = session + .process_message(message.clone(), CancellationToken::default()) + .await + { error = Some(e.to_string()); break; } From d330d80125f7cf15b02f56cd16fe91ee60bb4e96 Mon Sep 17 00:00:00 2001 From: Jack Amadeo Date: Fri, 1 Aug 2025 13:02:55 -0400 Subject: [PATCH 6/7] rm logs --- crates/goose-cli/src/session/mod.rs | 1 - crates/mcp-client/src/client.rs | 3 --- 2 files changed, 4 deletions(-) diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index d30e5149e5b2..0cdda04bcf31 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -1202,7 +1202,6 @@ impl Session { } } _ = tokio::signal::ctrl_c() => { - eprintln!("caught ctrl-c"); cancel_token_clone.cancel(); drop(stream); if let Err(e) = self.handle_interrupted_messages(true).await { diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index 1f0f796c7a0e..0da233abf95b 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -172,7 +172,6 @@ impl McpClient { request: ClientRequest, cancel_token: CancellationToken, ) -> Result { - eprintln!("request: {:?}", request); let handle = self .client .lock() @@ -186,9 +185,7 @@ impl McpClient { ) .await?; - eprintln!("request handle: {:?}", handle); let cancel_token = cancel_token.clone(); - tokio::select! { res = handle.await_response() => { Ok(res?) From a7b8acfd661587e5dd445b12122f142c8e3c5044 Mon Sep 17 00:00:00 2001 From: Jack Amadeo Date: Fri, 1 Aug 2025 22:20:08 -0400 Subject: [PATCH 7/7] lint fix --- crates/goose-server/src/routes/reply.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index ffc3755fdb1f..6a3b91ca8c42 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -113,7 +113,7 @@ async fn stream_event( e ) }); - if let Err(_) = tx.send(format!("data: {}\n\n", json)).await { + if tx.send(format!("data: {}\n\n", json)).await.is_err() { tracing::info!("client hung up"); cancel_token.cancel(); }