Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions crates/goose-cli/src/scenario_tests/scenario_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ where
F: Fn(&ScenarioResult) -> Result<()>,
{
use goose::config::ExtensionConfig;
use tokio::sync::Mutex;

goose::agents::moim::SKIP.with(|f| f.set(true));

Expand Down Expand Up @@ -229,7 +228,7 @@ where
bundled: None,
available_tools: vec![],
},
Arc::new(Mutex::new(Box::new(mock_client))),
Arc::new(mock_client),
None,
None,
)
Expand Down
94 changes: 26 additions & 68 deletions crates/goose/src/agents/extension_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ use rmcp::transport::auth::AuthClient;
use schemars::_private::NoSerialize;
use serde_json::Value;

type McpClientBox = Arc<Mutex<Box<dyn McpClientTrait>>>;
type McpClientBox = Arc<dyn McpClientTrait>;

struct Extension {
pub config: ExtensionConfig,
Expand Down Expand Up @@ -686,7 +686,7 @@ impl ExtensionManager {
let mut extensions = self.extensions.lock().await;
extensions.insert(
sanitized_name,
Extension::new(config, Arc::new(Mutex::new(client)), server_info, temp_dir),
Extension::new(config, Arc::from(client), server_info, temp_dir),
);
drop(extensions);
self.invalidate_tools_cache_and_bump_version().await;
Expand Down Expand Up @@ -857,8 +857,7 @@ impl ExtensionManager {
let ext_name = name.clone();
async move {
let mut tools = Vec::new();
let client_guard = client.lock().await;
let mut client_tools = match client_guard
let mut client_tools = match client
.list_tools(session_id, None, cancel_token.clone())
.await
{
Expand Down Expand Up @@ -908,7 +907,7 @@ impl ExtensionManager {
break;
}

client_tools = match client_guard
client_tools = match client
.list_tools(session_id, client_tools.next_cursor, cancel_token.clone())
.await
{
Expand Down Expand Up @@ -1060,8 +1059,7 @@ impl ExtensionManager {
.await
.ok_or(ErrorData::new(ErrorCode::INVALID_PARAMS, error_msg, None))?;

let client_guard = client.lock().await;
client_guard
client
.read_resource(session_id, uri, cancellation_token)
.await
.map_err(|_| {
Expand All @@ -1088,9 +1086,7 @@ impl ExtensionManager {
};

for (extension_name, client) in extensions_to_check {
let client_guard = client.lock().await;

match client_guard
match client
.list_resources(session_id, None, CancellationToken::default())
.await
{
Expand Down Expand Up @@ -1127,8 +1123,7 @@ impl ExtensionManager {
)
})?;

let client_guard = client.lock().await;
client_guard
client
.list_resources(session_id, None, cancellation_token)
.await
.map_err(|e| {
Expand Down Expand Up @@ -1302,7 +1297,7 @@ impl ExtensionManager {

let arguments = tool_call.arguments.clone();
let client = resolved.client.clone();
let notifications_receiver = client.lock().await.subscribe().await;
let notifications_receiver = client.subscribe().await;
let session_id = session_id.to_string();
let actual_tool_name = resolved.actual_tool_name;
let working_dir_str = working_dir.map(|p| p.to_string_lossy().to_string());
Expand All @@ -1314,8 +1309,7 @@ impl ExtensionManager {
session_id,
working_dir_str
);
let client_guard = client.lock().await;
client_guard
client
.call_tool(
&session_id,
&actual_tool_name,
Expand Down Expand Up @@ -1355,8 +1349,7 @@ impl ExtensionManager {
)
})?;

let client_guard = client.lock().await;
client_guard
client
.list_prompts(session_id, None, cancellation_token)
.await
.map_err(|e| {
Expand Down Expand Up @@ -1430,8 +1423,7 @@ impl ExtensionManager {
.await
.ok_or_else(|| anyhow::anyhow!("Extension {} not found", extension_name))?;

let client_guard = client.lock().await;
client_guard
client
.get_prompt(session_id, name, arguments, cancellation_token)
.await
.map_err(|e| anyhow::anyhow!("Failed to get prompt: {}", e))
Expand Down Expand Up @@ -1535,8 +1527,7 @@ impl ExtensionManager {
};

for (name, client) in platform_clients {
let client_guard = client.lock().await;
if let Some(moim_content) = client_guard.get_moim(session_id).await {
if let Some(moim_content) = client.get_moim(session_id).await {
tracing::debug!("MOIM content from {}: {} chars", name, moim_content.len());
content.push('\n');
content.push_str(&moim_content);
Expand Down Expand Up @@ -1702,24 +1693,15 @@ mod tests {

// Add some mock clients using the helper method
extension_manager
.add_mock_extension(
"test_client".to_string(),
Arc::new(Mutex::new(Box::new(MockClient {}))),
)
.add_mock_extension("test_client".to_string(), Arc::new(MockClient {}))
.await;

extension_manager
.add_mock_extension(
"__cli__ent__".to_string(),
Arc::new(Mutex::new(Box::new(MockClient {}))),
)
.add_mock_extension("__cli__ent__".to_string(), Arc::new(MockClient {}))
.await;

extension_manager
.add_mock_extension(
"client 🚀".to_string(),
Arc::new(Mutex::new(Box::new(MockClient {}))),
)
.add_mock_extension("client 🚀".to_string(), Arc::new(MockClient {}))
.await;

let tool_call = CallToolRequestParams {
Expand Down Expand Up @@ -1847,7 +1829,7 @@ mod tests {
extension_manager
.add_mock_extension_with_tools(
"test_extension".to_string(),
Arc::new(Mutex::new(Box::new(MockClient {}))),
Arc::new(MockClient {}),
available_tools,
)
.await;
Expand Down Expand Up @@ -1877,7 +1859,7 @@ mod tests {
extension_manager
.add_mock_extension_with_tools(
"test_extension".to_string(),
Arc::new(Mutex::new(Box::new(MockClient {}))),
Arc::new(MockClient {}),
vec![], // Empty available_tools means all tools are available by default
)
.await;
Expand Down Expand Up @@ -1909,7 +1891,7 @@ mod tests {
extension_manager
.add_mock_extension_with_tools(
"test_extension".to_string(),
Arc::new(Mutex::new(Box::new(MockClient {}))),
Arc::new(MockClient {}),
available_tools,
)
.await;
Expand Down Expand Up @@ -2013,10 +1995,7 @@ mod tests {
ExtensionManager::new_without_provider(temp_dir.path().to_path_buf());

extension_manager
.add_mock_extension(
"ext_a".to_string(),
Arc::new(Mutex::new(Box::new(MockClient {}))),
)
.add_mock_extension("ext_a".to_string(), Arc::new(MockClient {}))
.await;

let tools_after_first = extension_manager
Expand All @@ -2031,10 +2010,7 @@ mod tests {
assert!(!tool_names.iter().any(|n| n.starts_with("ext_b__")));

extension_manager
.add_mock_extension(
"ext_b".to_string(),
Arc::new(Mutex::new(Box::new(MockClient {}))),
)
.add_mock_extension("ext_b".to_string(), Arc::new(MockClient {}))
.await;

let tools_after_second = extension_manager
Expand All @@ -2056,16 +2032,10 @@ mod tests {
ExtensionManager::new_without_provider(temp_dir.path().to_path_buf());

extension_manager
.add_mock_extension(
"ext_a".to_string(),
Arc::new(Mutex::new(Box::new(MockClient {}))),
)
.add_mock_extension("ext_a".to_string(), Arc::new(MockClient {}))
.await;
extension_manager
.add_mock_extension(
"ext_b".to_string(),
Arc::new(Mutex::new(Box::new(MockClient {}))),
)
.add_mock_extension("ext_b".to_string(), Arc::new(MockClient {}))
.await;

let tools_before = extension_manager
Expand Down Expand Up @@ -2094,16 +2064,10 @@ mod tests {
ExtensionManager::new_without_provider(temp_dir.path().to_path_buf());

extension_manager
.add_mock_extension(
"ext_a".to_string(),
Arc::new(Mutex::new(Box::new(MockClient {}))),
)
.add_mock_extension("ext_a".to_string(), Arc::new(MockClient {}))
.await;
extension_manager
.add_mock_extension(
"ext_b".to_string(),
Arc::new(Mutex::new(Box::new(MockClient {}))),
)
.add_mock_extension("ext_b".to_string(), Arc::new(MockClient {}))
.await;

let tools = extension_manager
Expand All @@ -2123,16 +2087,10 @@ mod tests {
ExtensionManager::new_without_provider(temp_dir.path().to_path_buf());

extension_manager
.add_mock_extension(
"ext_a".to_string(),
Arc::new(Mutex::new(Box::new(MockClient {}))),
)
.add_mock_extension("ext_a".to_string(), Arc::new(MockClient {}))
.await;
extension_manager
.add_mock_extension(
"ext_b".to_string(),
Arc::new(Mutex::new(Box::new(MockClient {}))),
)
.add_mock_extension("ext_b".to_string(), Arc::new(MockClient {}))
.await;

let tools = extension_manager
Expand Down
51 changes: 18 additions & 33 deletions crates/goose/src/agents/mcp_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ pub trait McpClientTrait: Send + Sync {
pub struct GooseClient {
notification_handlers: Arc<Mutex<Vec<Sender<ServerNotification>>>>,
provider: SharedProvider,
// Single-slot because calls are serialized per MCP client.
current_session_id: Arc<Mutex<Option<String>>>,
/// Fallback session_id for server-initiated callbacks (e.g. sampling/createMessage)
/// that don't include the session_id in their MCP extensions metadata.
/// Set once on first request; never cleared (the id is invariant per McpClient).
session_id: Mutex<Option<String>>,
}

impl GooseClient {
Expand All @@ -118,23 +120,21 @@ impl GooseClient {
GooseClient {
notification_handlers: handlers,
provider,
current_session_id: Arc::new(Mutex::new(None)),
session_id: Mutex::new(None),
}
}

async fn set_current_session_id(&self, session_id: &str) {
let mut slot = self.current_session_id.lock().await;
async fn set_session_id(&self, session_id: &str) {
let mut slot = self.session_id.lock().await;
assert!(
slot.as_deref().is_none_or(|s| s == session_id),
"McpClient received requests from different sessions"
);
*slot = Some(session_id.to_string());
}

async fn clear_current_session_id(&self) {
let mut slot = self.current_session_id.lock().await;
*slot = None;
}

async fn current_session_id(&self) -> Option<String> {
let slot = self.current_session_id.lock().await;
slot.clone()
self.session_id.lock().await.clone()
}

async fn resolve_session_id(&self, extensions: &Extensions) -> Option<String> {
Expand Down Expand Up @@ -416,31 +416,17 @@ impl McpClient {
cancel_token: CancellationToken,
) -> Result<ServerResult, Error> {
let request = inject_session_context_into_request(request, Some(session_id), working_dir);
// ExtensionManager serializes calls per MCP connection, so one current_session_id slot
// is sufficient for mapping callbacks to the active request session.
// The inner mutex is held only for the send; the actual response wait
// happens outside the lock so concurrent calls can overlap.
let handle = {
let client = self.client.lock().await;
client.service().set_current_session_id(session_id).await;
client.service().set_session_id(session_id).await;
client
.send_cancellable_request(request, PeerRequestOptions::no_options())
.await
};

let handle = match handle {
Ok(handle) => handle,
Err(err) => {
let client = self.client.lock().await;
client.service().clear_current_session_id().await;
return Err(err);
}
};

let result = await_response(handle, self.timeout, &cancel_token).await;

let client = self.client.lock().await;
client.service().clear_current_session_id().await;
}?;

result
await_response(handle, self.timeout, &cancel_token).await
}
}

Expand Down Expand Up @@ -857,8 +843,7 @@ mod tests {
runtime.block_on(async {
let client = new_client();
if let Some(session_id) = current_session {
let mut slot = client.current_session_id.lock().await;
*slot = Some(session_id.to_string());
client.set_session_id(session_id).await;
}

let extensions =
Expand Down