diff --git a/crates/goose-cli/src/scenario_tests/scenario_runner.rs b/crates/goose-cli/src/scenario_tests/scenario_runner.rs index 45970afc792c..431626a4303c 100644 --- a/crates/goose-cli/src/scenario_tests/scenario_runner.rs +++ b/crates/goose-cli/src/scenario_tests/scenario_runner.rs @@ -142,7 +142,6 @@ where F: Fn(&ScenarioResult) -> Result<()>, { use goose::config::ExtensionConfig; - use tokio::sync::Mutex; goose::agents::moim::SKIP.with(|f| f.set(true)); @@ -229,7 +228,7 @@ where bundled: None, available_tools: vec![], }, - Arc::new(Mutex::new(Box::new(mock_client))), + Arc::new(mock_client), None, None, ) diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index de31d0131ffb..5cd53bb36e15 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -49,7 +49,7 @@ use rmcp::transport::auth::AuthClient; use schemars::_private::NoSerialize; use serde_json::Value; -type McpClientBox = Arc>>; +type McpClientBox = Arc; struct Extension { pub config: ExtensionConfig, @@ -686,7 +686,7 @@ impl ExtensionManager { let mut extensions = self.extensions.lock().await; extensions.insert( sanitized_name, - Extension::new(config, Arc::new(Mutex::new(client)), server_info, temp_dir), + Extension::new(config, Arc::from(client), server_info, temp_dir), ); drop(extensions); self.invalidate_tools_cache_and_bump_version().await; @@ -857,8 +857,7 @@ impl ExtensionManager { let ext_name = name.clone(); async move { let mut tools = Vec::new(); - let client_guard = client.lock().await; - let mut client_tools = match client_guard + let mut client_tools = match client .list_tools(session_id, None, cancel_token.clone()) .await { @@ -908,7 +907,7 @@ impl ExtensionManager { break; } - client_tools = match client_guard + client_tools = match client .list_tools(session_id, client_tools.next_cursor, cancel_token.clone()) .await { @@ -1060,8 +1059,7 @@ impl ExtensionManager { .await .ok_or(ErrorData::new(ErrorCode::INVALID_PARAMS, error_msg, None))?; - let client_guard = client.lock().await; - client_guard + client .read_resource(session_id, uri, cancellation_token) .await .map_err(|_| { @@ -1088,9 +1086,7 @@ impl ExtensionManager { }; for (extension_name, client) in extensions_to_check { - let client_guard = client.lock().await; - - match client_guard + match client .list_resources(session_id, None, CancellationToken::default()) .await { @@ -1127,8 +1123,7 @@ impl ExtensionManager { ) })?; - let client_guard = client.lock().await; - client_guard + client .list_resources(session_id, None, cancellation_token) .await .map_err(|e| { @@ -1302,7 +1297,7 @@ impl ExtensionManager { let arguments = tool_call.arguments.clone(); let client = resolved.client.clone(); - let notifications_receiver = client.lock().await.subscribe().await; + let notifications_receiver = client.subscribe().await; let session_id = session_id.to_string(); let actual_tool_name = resolved.actual_tool_name; let working_dir_str = working_dir.map(|p| p.to_string_lossy().to_string()); @@ -1314,8 +1309,7 @@ impl ExtensionManager { session_id, working_dir_str ); - let client_guard = client.lock().await; - client_guard + client .call_tool( &session_id, &actual_tool_name, @@ -1355,8 +1349,7 @@ impl ExtensionManager { ) })?; - let client_guard = client.lock().await; - client_guard + client .list_prompts(session_id, None, cancellation_token) .await .map_err(|e| { @@ -1430,8 +1423,7 @@ impl ExtensionManager { .await .ok_or_else(|| anyhow::anyhow!("Extension {} not found", extension_name))?; - let client_guard = client.lock().await; - client_guard + client .get_prompt(session_id, name, arguments, cancellation_token) .await .map_err(|e| anyhow::anyhow!("Failed to get prompt: {}", e)) @@ -1535,8 +1527,7 @@ impl ExtensionManager { }; for (name, client) in platform_clients { - let client_guard = client.lock().await; - if let Some(moim_content) = client_guard.get_moim(session_id).await { + if let Some(moim_content) = client.get_moim(session_id).await { tracing::debug!("MOIM content from {}: {} chars", name, moim_content.len()); content.push('\n'); content.push_str(&moim_content); @@ -1702,24 +1693,15 @@ mod tests { // Add some mock clients using the helper method extension_manager - .add_mock_extension( - "test_client".to_string(), - Arc::new(Mutex::new(Box::new(MockClient {}))), - ) + .add_mock_extension("test_client".to_string(), Arc::new(MockClient {})) .await; extension_manager - .add_mock_extension( - "__cli__ent__".to_string(), - Arc::new(Mutex::new(Box::new(MockClient {}))), - ) + .add_mock_extension("__cli__ent__".to_string(), Arc::new(MockClient {})) .await; extension_manager - .add_mock_extension( - "client 🚀".to_string(), - Arc::new(Mutex::new(Box::new(MockClient {}))), - ) + .add_mock_extension("client 🚀".to_string(), Arc::new(MockClient {})) .await; let tool_call = CallToolRequestParams { @@ -1847,7 +1829,7 @@ mod tests { extension_manager .add_mock_extension_with_tools( "test_extension".to_string(), - Arc::new(Mutex::new(Box::new(MockClient {}))), + Arc::new(MockClient {}), available_tools, ) .await; @@ -1877,7 +1859,7 @@ mod tests { extension_manager .add_mock_extension_with_tools( "test_extension".to_string(), - Arc::new(Mutex::new(Box::new(MockClient {}))), + Arc::new(MockClient {}), vec![], // Empty available_tools means all tools are available by default ) .await; @@ -1909,7 +1891,7 @@ mod tests { extension_manager .add_mock_extension_with_tools( "test_extension".to_string(), - Arc::new(Mutex::new(Box::new(MockClient {}))), + Arc::new(MockClient {}), available_tools, ) .await; @@ -2013,10 +1995,7 @@ mod tests { ExtensionManager::new_without_provider(temp_dir.path().to_path_buf()); extension_manager - .add_mock_extension( - "ext_a".to_string(), - Arc::new(Mutex::new(Box::new(MockClient {}))), - ) + .add_mock_extension("ext_a".to_string(), Arc::new(MockClient {})) .await; let tools_after_first = extension_manager @@ -2031,10 +2010,7 @@ mod tests { assert!(!tool_names.iter().any(|n| n.starts_with("ext_b__"))); extension_manager - .add_mock_extension( - "ext_b".to_string(), - Arc::new(Mutex::new(Box::new(MockClient {}))), - ) + .add_mock_extension("ext_b".to_string(), Arc::new(MockClient {})) .await; let tools_after_second = extension_manager @@ -2056,16 +2032,10 @@ mod tests { ExtensionManager::new_without_provider(temp_dir.path().to_path_buf()); extension_manager - .add_mock_extension( - "ext_a".to_string(), - Arc::new(Mutex::new(Box::new(MockClient {}))), - ) + .add_mock_extension("ext_a".to_string(), Arc::new(MockClient {})) .await; extension_manager - .add_mock_extension( - "ext_b".to_string(), - Arc::new(Mutex::new(Box::new(MockClient {}))), - ) + .add_mock_extension("ext_b".to_string(), Arc::new(MockClient {})) .await; let tools_before = extension_manager @@ -2094,16 +2064,10 @@ mod tests { ExtensionManager::new_without_provider(temp_dir.path().to_path_buf()); extension_manager - .add_mock_extension( - "ext_a".to_string(), - Arc::new(Mutex::new(Box::new(MockClient {}))), - ) + .add_mock_extension("ext_a".to_string(), Arc::new(MockClient {})) .await; extension_manager - .add_mock_extension( - "ext_b".to_string(), - Arc::new(Mutex::new(Box::new(MockClient {}))), - ) + .add_mock_extension("ext_b".to_string(), Arc::new(MockClient {})) .await; let tools = extension_manager @@ -2123,16 +2087,10 @@ mod tests { ExtensionManager::new_without_provider(temp_dir.path().to_path_buf()); extension_manager - .add_mock_extension( - "ext_a".to_string(), - Arc::new(Mutex::new(Box::new(MockClient {}))), - ) + .add_mock_extension("ext_a".to_string(), Arc::new(MockClient {})) .await; extension_manager - .add_mock_extension( - "ext_b".to_string(), - Arc::new(Mutex::new(Box::new(MockClient {}))), - ) + .add_mock_extension("ext_b".to_string(), Arc::new(MockClient {})) .await; let tools = extension_manager diff --git a/crates/goose/src/agents/mcp_client.rs b/crates/goose/src/agents/mcp_client.rs index 4c43c531263e..89c963da88c3 100644 --- a/crates/goose/src/agents/mcp_client.rs +++ b/crates/goose/src/agents/mcp_client.rs @@ -106,8 +106,10 @@ pub trait McpClientTrait: Send + Sync { pub struct GooseClient { notification_handlers: Arc>>>, provider: SharedProvider, - // Single-slot because calls are serialized per MCP client. - current_session_id: Arc>>, + /// Fallback session_id for server-initiated callbacks (e.g. sampling/createMessage) + /// that don't include the session_id in their MCP extensions metadata. + /// Set once on first request; never cleared (the id is invariant per McpClient). + session_id: Mutex>, } impl GooseClient { @@ -118,23 +120,21 @@ impl GooseClient { GooseClient { notification_handlers: handlers, provider, - current_session_id: Arc::new(Mutex::new(None)), + session_id: Mutex::new(None), } } - async fn set_current_session_id(&self, session_id: &str) { - let mut slot = self.current_session_id.lock().await; + async fn set_session_id(&self, session_id: &str) { + let mut slot = self.session_id.lock().await; + assert!( + slot.as_deref().is_none_or(|s| s == session_id), + "McpClient received requests from different sessions" + ); *slot = Some(session_id.to_string()); } - async fn clear_current_session_id(&self) { - let mut slot = self.current_session_id.lock().await; - *slot = None; - } - async fn current_session_id(&self) -> Option { - let slot = self.current_session_id.lock().await; - slot.clone() + self.session_id.lock().await.clone() } async fn resolve_session_id(&self, extensions: &Extensions) -> Option { @@ -416,31 +416,17 @@ impl McpClient { cancel_token: CancellationToken, ) -> Result { let request = inject_session_context_into_request(request, Some(session_id), working_dir); - // ExtensionManager serializes calls per MCP connection, so one current_session_id slot - // is sufficient for mapping callbacks to the active request session. + // The inner mutex is held only for the send; the actual response wait + // happens outside the lock so concurrent calls can overlap. let handle = { let client = self.client.lock().await; - client.service().set_current_session_id(session_id).await; + client.service().set_session_id(session_id).await; client .send_cancellable_request(request, PeerRequestOptions::no_options()) .await - }; - - let handle = match handle { - Ok(handle) => handle, - Err(err) => { - let client = self.client.lock().await; - client.service().clear_current_session_id().await; - return Err(err); - } - }; - - let result = await_response(handle, self.timeout, &cancel_token).await; - - let client = self.client.lock().await; - client.service().clear_current_session_id().await; + }?; - result + await_response(handle, self.timeout, &cancel_token).await } } @@ -857,8 +843,7 @@ mod tests { runtime.block_on(async { let client = new_client(); if let Some(session_id) = current_session { - let mut slot = client.current_session_id.lock().await; - *slot = Some(session_id.to_string()); + client.set_session_id(session_id).await; } let extensions =