diff --git a/Cargo.lock b/Cargo.lock index fe8397c39f60..9b1824f18dc5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4298,12 +4298,14 @@ dependencies = [ "shlex", "tar", "tempfile", + "test-case", "tokio", "tokio-util", "tower-http", "tracing", "tracing-appender", "tracing-subscriber", + "url", "urlencoding", "uuid", "webbrowser", diff --git a/crates/goose-acp/src/server.rs b/crates/goose-acp/src/server.rs index 43715089f832..47908231727c 100644 --- a/crates/goose-acp/src/server.rs +++ b/crates/goose-acp/src/server.rs @@ -740,7 +740,7 @@ impl GooseAcpAgent { goose::model::ModelConfig::new(&model_id)? } }; - let provider = (self.provider_factory)(model_config).await?; + let provider = (self.provider_factory)(model_config, Vec::new()).await?; agent.update_provider(provider.clone(), &session.id).await?; Ok(provider) } @@ -958,9 +958,11 @@ impl GooseAcpAgent { let model_config = goose::model::ModelConfig::new(model_id).map_err(|e| { sacp::Error::invalid_params().data(format!("Invalid model config: {}", e)) })?; - let provider = (self.provider_factory)(model_config).await.map_err(|e| { - sacp::Error::internal_error().data(format!("Failed to create provider: {}", e)) - })?; + let provider = (self.provider_factory)(model_config, Vec::new()) + .await + .map_err(|e| { + sacp::Error::internal_error().data(format!("Failed to create provider: {}", e)) + })?; let agent = { let sessions = self.sessions.lock().await; diff --git a/crates/goose-acp/src/server_factory.rs b/crates/goose-acp/src/server_factory.rs index 5f7a86e4ba09..96b94559406a 100644 --- a/crates/goose-acp/src/server_factory.rs +++ b/crates/goose-acp/src/server_factory.rs @@ -33,7 +33,7 @@ impl AcpServer { let disable_session_naming = config.get_goose_disable_session_naming().unwrap_or(false); let config_dir = self.config.config_dir.clone(); - let provider_factory: ProviderConstructor = Arc::new(move |model_config| { + let provider_factory: ProviderConstructor = Arc::new(move |model_config, extensions| { let config_dir = config_dir.clone(); Box::pin(async move { let config_path = config_dir.join(goose::config::base::CONFIG_YAML_NAME); @@ -41,7 +41,7 @@ impl AcpServer { let provider_name = config .get_goose_provider() .map_err(|_| anyhow::anyhow!("No provider configured"))?; - goose::providers::create(&provider_name, model_config).await + goose::providers::create(&provider_name, model_config, extensions).await }) }); diff --git a/crates/goose-acp/tests/common_tests/mod.rs b/crates/goose-acp/tests/common_tests/mod.rs index 92ff4de11787..679d0f455ae1 100644 --- a/crates/goose-acp/tests/common_tests/mod.rs +++ b/crates/goose-acp/tests/common_tests/mod.rs @@ -63,7 +63,7 @@ pub async fn run_initialize_without_provider() { let temp_dir = tempfile::tempdir().unwrap(); let provider_factory: ProviderConstructor = - Arc::new(|_| Box::pin(async { Err(anyhow::anyhow!("no provider configured")) })); + Arc::new(|_, _| Box::pin(async { Err(anyhow::anyhow!("no provider configured")) })); let agent = Arc::new( GooseAcpAgent::new( diff --git a/crates/goose-acp/tests/fixtures/mod.rs b/crates/goose-acp/tests/fixtures/mod.rs index 2f35c42f050c..a4a295af8869 100644 --- a/crates/goose-acp/tests/fixtures/mod.rs +++ b/crates/goose-acp/tests/fixtures/mod.rs @@ -203,7 +203,7 @@ pub async fn spawn_acp_server_in_process( } let provider_factory = provider_factory.unwrap_or_else(|| { let base_url = openai_base_url.to_string(); - Arc::new(move |model_config| { + Arc::new(move |model_config, _extensions| { let base_url = base_url.clone(); Box::pin(async move { let api_client = diff --git a/crates/goose-cli/Cargo.toml b/crates/goose-cli/Cargo.toml index a1bdd1fb8d47..dbbdcacc2612 100644 --- a/crates/goose-cli/Cargo.toml +++ b/crates/goose-cli/Cargo.toml @@ -58,6 +58,7 @@ indicatif = "0.18.1" tokio-util = { workspace = true, features = ["compat", "rt"] } anstream = "0.6.18" open = "5.3.2" +url = { workspace = true } urlencoding = { workspace = true } clap_complete = "4.5.62" @@ -70,4 +71,5 @@ disable-update = [] [dev-dependencies] tempfile = { workspace = true } +test-case = { workspace = true } tokio = { workspace = true } diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index 9bde5c451b42..aa13ccb99ac5 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -328,7 +328,7 @@ async fn handle_oauth_configuration(provider_name: &str, key_name: &str) -> anyh // Create a temporary provider instance to handle OAuth let temp_model = ModelConfig::new("temp")?; - match create(provider_name, temp_model).await { + match create(provider_name, temp_model, Vec::new()).await { Ok(provider) => match provider.configure_oauth().await { Ok(_) => { let _ = cliclack::log::success("OAuth authentication completed successfully!"); @@ -683,7 +683,7 @@ pub async fn configure_provider_dialog() -> anyhow::Result { spin.start("Attempting to fetch supported models..."); let models_res = { let temp_model_config = ModelConfig::new(&provider_meta.default_model)?; - let temp_provider = create(provider_name, temp_model_config).await?; + let temp_provider = create(provider_name, temp_model_config, Vec::new()).await?; retry_operation(&RetryConfig::default(), || async { temp_provider.fetch_recommended_models().await }) @@ -1445,7 +1445,6 @@ pub async fn configure_tool_permissions_dialog() -> anyhow::Result<()> { let model_config = ModelConfig::new(&model)?; let agent = Agent::new(); - let new_provider = create(&provider_name, model_config).await?; let session = agent .config @@ -1457,8 +1456,8 @@ pub async fn configure_tool_permissions_dialog() -> anyhow::Result<()> { ) .await?; - agent.update_provider(new_provider, &session.id).await?; - if let Some(config) = get_extension_by_name(&selected_extension_name) { + let extension_config = get_extension_by_name(&selected_extension_name); + if let Some(config) = extension_config.as_ref() { agent .add_extension(config.clone(), &session.id) .await @@ -1478,6 +1477,10 @@ pub async fn configure_tool_permissions_dialog() -> anyhow::Result<()> { return Ok(()); } + let extensions = extension_config.into_iter().collect::>(); + let new_provider = create(&provider_name, model_config, extensions).await?; + agent.update_provider(new_provider, &session.id).await?; + let permission_manager = PermissionManager::instance(); let selected_tools = agent .list_tools(&session.id, Some(selected_extension_name.clone())) @@ -1667,7 +1670,7 @@ pub async fn handle_openrouter_auth() -> anyhow::Result<()> { } }; - match create("openrouter", model_config).await { + match create("openrouter", model_config, Vec::new()).await { Ok(provider) => { let model_config = provider.get_model_config(); let test_result = provider @@ -1747,7 +1750,7 @@ pub async fn handle_tetrate_auth() -> anyhow::Result<()> { } }; - match create("tetrate", model_config).await { + match create("tetrate", model_config, Vec::new()).await { Ok(provider) => { let test_result = provider.fetch_supported_models().await; diff --git a/crates/goose-cli/src/commands/web.rs b/crates/goose-cli/src/commands/web.rs index ba75f808c180..c3dedac3af31 100644 --- a/crates/goose-cli/src/commands/web.rs +++ b/crates/goose-cli/src/commands/web.rs @@ -181,16 +181,16 @@ async fn create_agent(provider_name: &str, model: &str) -> Result { ) .await?; - let provider = goose::providers::create(provider_name, model_config).await?; - agent.update_provider(provider, &init_session.id).await?; - let enabled_configs = goose::config::get_enabled_extensions(); - for config in enabled_configs { + for config in &enabled_configs { if let Err(e) = agent.add_extension(config.clone(), &init_session.id).await { eprintln!("Warning: Failed to load extension {}: {}", config.name(), e); } } + let provider = goose::providers::create(provider_name, model_config, enabled_configs).await?; + agent.update_provider(provider, &init_session.id).await?; + Ok(agent) } diff --git a/crates/goose-cli/src/scenario_tests/scenario_runner.rs b/crates/goose-cli/src/scenario_tests/scenario_runner.rs index b8805187700a..45970afc792c 100644 --- a/crates/goose-cli/src/scenario_tests/scenario_runner.rs +++ b/crates/goose-cli/src/scenario_tests/scenario_runner.rs @@ -187,7 +187,12 @@ where let original_env = setup_environment(config)?; - let inner_provider = create(&factory_name, ModelConfig::new(config.model_name)?).await?; + let inner_provider = create( + &factory_name, + ModelConfig::new(config.model_name)?, + Vec::new(), + ) + .await?; let test_provider = Arc::new(TestProvider::new_recording(inner_provider, &file_path)); ( diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index 447e4cbd780a..9807a4babf56 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -3,14 +3,13 @@ use crate::cli::StreamableHttpOptions; use super::output; use super::CliSession; use console::style; -use goose::agents::{Agent, Container}; -use goose::config::get_enabled_extensions; +use goose::agents::{Agent, Container, ExtensionError}; use goose::config::resolve_extensions_for_new_session; use goose::config::{get_all_extensions, Config, ExtensionConfig}; use goose::providers::create; use goose::recipe::Recipe; use goose::session::session_manager::SessionType; -use goose::session::{EnabledExtensionsState, ExtensionState}; +use goose::session::EnabledExtensionsState; use rustyline::EditMode; use std::collections::BTreeSet; use std::process; @@ -490,27 +489,19 @@ async fn handle_resumed_session_workdir(agent: &Agent, session_id: &str, interac } } -async fn resolve_and_load_extensions( - agent: Agent, +async fn collect_extension_configs( + agent: &Agent, session_config: &SessionBuilderConfig, recipe: Option<&Recipe>, session_id: &str, - provider_for_debug: Arc, -) -> Arc { - for warning in goose::config::get_warnings() { - eprintln!("{}", style(format!("Warning: {}", warning)).yellow()); - } - +) -> Result, ExtensionError> { let configured_extensions: Vec = if session_config.resume { - agent - .config - .session_manager - .get_session(session_id, false) - .await - .ok() - .and_then(|s| EnabledExtensionsState::from_extension_data(&s.extension_data)) - .map(|state| state.extensions) - .unwrap_or_else(get_enabled_extensions) + EnabledExtensionsState::for_session( + &agent.config.session_manager, + session_id, + Config::global(), + ) + .await } else if session_config.no_profile { Vec::new() } else { @@ -523,17 +514,33 @@ async fn resolve_and_load_extensions( &session_config.builtins, ); - let mut extensions_to_load: Vec<(String, ExtensionConfig)> = configured_extensions - .iter() - .map(|cfg| (cfg.name(), cfg.clone())) + let mut all: Vec = configured_extensions; + all.extend(cli_flag_extensions.into_iter().map(|(_, cfg)| cfg)); + + Ok(all) +} + +async fn resolve_and_load_extensions( + agent: Agent, + extensions: Vec, + provider_for_debug: Arc, + interactive: bool, + session_id: &str, +) -> Arc { + for warning in goose::config::get_warnings() { + eprintln!("{}", style(format!("Warning: {}", warning)).yellow()); + } + + let extensions_to_load: Vec<(String, ExtensionConfig)> = extensions + .into_iter() + .map(|cfg| (cfg.name(), cfg)) .collect(); - extensions_to_load.extend(cli_flag_extensions); load_extensions( agent, extensions_to_load, provider_for_debug, - session_config.interactive, + interactive, session_id, ) .await @@ -598,7 +605,28 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { .apply_recipe_components(recipe.and_then(|r| r.response.clone()), true) .await; - let new_provider = match create(&resolved.provider_name, resolved.model_config).await { + let session_id = resolve_session_id(&session_config, &session_manager).await; + + if session_config.resume { + handle_resumed_session_workdir(&agent, &session_id, session_config.interactive).await; + } + + let extensions_for_provider = + match collect_extension_configs(&agent, &session_config, recipe, &session_id).await { + Ok(exts) => exts, + Err(e) => { + output::render_error(&format!("Failed to collect extensions: {}", e)); + process::exit(1); + } + }; + + let new_provider = match create( + &resolved.provider_name, + resolved.model_config, + extensions_for_provider.clone(), + ) + .await + { Ok(provider) => provider, Err(e) => { output::render_error(&format!( @@ -624,8 +652,6 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { tracing::info!("🤖 Using model: {}", resolved.model_name); } - let session_id = resolve_session_id(&session_config, &session_manager).await; - agent .update_provider(new_provider, &session_id) .await @@ -645,17 +671,13 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { } } - if session_config.resume { - handle_resumed_session_workdir(&agent, &session_id, session_config.interactive).await; - } - // Extensions are loaded after session creation because we may change directory when resuming let agent_ptr = resolve_and_load_extensions( agent, - &session_config, - recipe, - &session_id, + extensions_for_provider, Arc::clone(&provider_for_display), + session_config.interactive, + &session_id, ) .await; diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 4472ae139b26..fec6ceadc1e0 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -286,9 +286,14 @@ impl CliSession { } let cmd = parts.remove(0).to_string(); + let name = std::path::Path::new(&cmd) + .file_name() + .and_then(|f| f.to_str()) + .unwrap_or("unnamed") + .to_string(); Ok(ExtensionConfig::Stdio { - name: String::new(), + name, cmd, args: parts.iter().map(|s| s.to_string()).collect(), envs: Envs::new(envs), @@ -301,8 +306,29 @@ impl CliSession { } pub fn parse_streamable_http_extension(extension_url: &str, timeout: u64) -> ExtensionConfig { + let name = url::Url::parse(extension_url) + .ok() + .map(|u| { + let mut s = String::new(); + if let Some(host) = u.host_str() { + s.push_str(host); + } + if let Some(port) = u.port() { + s.push('_'); + s.push_str(&port.to_string()); + } + let path = u.path().trim_matches('/'); + if !path.is_empty() { + s.push('_'); + s.push_str(path); + } + s + }) + .filter(|s| !s.is_empty()) + .unwrap_or_else(|| "unnamed".to_string()); + ExtensionConfig::StreamableHttp { - name: String::new(), + name, uri: extension_url.to_string(), envs: Envs::new(HashMap::new()), env_keys: Vec::new(), @@ -1841,7 +1867,8 @@ async fn get_reasoner() -> Result, anyhow::Error> { let model_config = ModelConfig::new_with_context_env(model, Some("GOOSE_PLANNER_CONTEXT_LIMIT"))?; - let reasoner = create(&provider, model_config).await?; + let extensions = goose::config::extensions::get_enabled_extensions_with_config(config); + let reasoner = create(&provider, model_config, extensions).await?; Ok(reasoner) } @@ -1862,7 +1889,10 @@ fn format_elapsed_time(duration: std::time::Duration) -> String { #[cfg(test)] mod tests { use super::*; + use goose::agents::extension::Envs; + use goose::config::ExtensionConfig; use std::time::Duration; + use test_case::test_case; #[test] fn test_format_elapsed_time_under_60_seconds() { @@ -1929,4 +1959,95 @@ mod tests { let duration = Duration::from_millis(60500); assert_eq!(format_elapsed_time(duration), "1m 00s"); } + + #[test_case( + "/usr/bin/my-server", + ExtensionConfig::Stdio { + name: "my-server".into(), + cmd: "/usr/bin/my-server".into(), + args: vec![], + envs: Envs::default(), + env_keys: vec![], + description: goose::config::DEFAULT_EXTENSION_DESCRIPTION.to_string(), + timeout: Some(goose::config::DEFAULT_EXTENSION_TIMEOUT), + bundled: None, + available_tools: vec![], + } + ; "name_from_cmd_basename" + )] + #[test_case( + "MY_SECRET=s3cret npx -y @modelcontextprotocol/server-everything", + ExtensionConfig::Stdio { + name: "npx".into(), + cmd: "npx".into(), + args: vec!["-y".into(), "@modelcontextprotocol/server-everything".into()], + envs: Envs::new([("MY_SECRET".into(), "s3cret".into())].into()), + env_keys: vec![], + description: goose::config::DEFAULT_EXTENSION_DESCRIPTION.to_string(), + timeout: Some(goose::config::DEFAULT_EXTENSION_TIMEOUT), + bundled: None, + available_tools: vec![], + } + ; "env_prefix_name_from_cmd" + )] + fn test_parse_stdio_extension(input: &str, expected: ExtensionConfig) { + assert_eq!(CliSession::parse_stdio_extension(input).unwrap(), expected); + } + + #[test] + fn test_parse_stdio_extension_no_command() { + assert!(CliSession::parse_stdio_extension("").is_err()); + } + + #[test_case( + "https://mcp.kiwi.com", 300, + ExtensionConfig::StreamableHttp { + name: "mcp.kiwi.com".into(), + uri: "https://mcp.kiwi.com".into(), + envs: Envs::default(), + env_keys: vec![], + headers: HashMap::new(), + description: goose::config::DEFAULT_EXTENSION_DESCRIPTION.to_string(), + timeout: Some(300), + bundled: None, + available_tools: vec![], + } + ; "name_from_host" + )] + #[test_case( + "http://localhost:8080/api", 300, + ExtensionConfig::StreamableHttp { + name: "localhost_8080_api".into(), + uri: "http://localhost:8080/api".into(), + envs: Envs::default(), + env_keys: vec![], + headers: HashMap::new(), + description: goose::config::DEFAULT_EXTENSION_DESCRIPTION.to_string(), + timeout: Some(300), + bundled: None, + available_tools: vec![], + } + ; "port_and_path" + )] + #[test_case( + "http://localhost:9090/other", 300, + ExtensionConfig::StreamableHttp { + name: "localhost_9090_other".into(), + uri: "http://localhost:9090/other".into(), + envs: Envs::default(), + env_keys: vec![], + headers: HashMap::new(), + description: goose::config::DEFAULT_EXTENSION_DESCRIPTION.to_string(), + timeout: Some(300), + bundled: None, + available_tools: vec![], + } + ; "different_port_and_path" + )] + fn test_parse_streamable_http_extension(url: &str, timeout: u64, expected: ExtensionConfig) { + assert_eq!( + CliSession::parse_streamable_http_extension(url, timeout), + expected + ); + } } diff --git a/crates/goose-server/src/routes/agent.rs b/crates/goose-server/src/routes/agent.rs index c8ceb98dd4d4..1c5b9a8952e6 100644 --- a/crates/goose-server/src/routes/agent.rs +++ b/crates/goose-server/src/routes/agent.rs @@ -21,9 +21,8 @@ use goose::model::ModelConfig; use goose::providers::create; use goose::recipe::Recipe; use goose::recipe_deeplink; -use goose::session::extension_data::ExtensionState; use goose::session::session_manager::SessionType; -use goose::session::{EnabledExtensionsState, Session}; +use goose::session::{EnabledExtensionsState, ExtensionState, Session}; use goose::{ agents::{extension::ToolInfo, extension_manager::get_parameter_names}, config::permission::PermissionLevel, @@ -553,12 +552,18 @@ async fn update_agent_provider( .with_context_limit(payload.context_limit) .with_request_params(payload.request_params); - let new_provider = create(&payload.provider, model_config).await.map_err(|e| { - ( - StatusCode::BAD_REQUEST, - format!("Failed to create {} provider: {}", &payload.provider, e), - ) - })?; + let extensions = + EnabledExtensionsState::for_session(state.session_manager(), &payload.session_id, config) + .await; + + let new_provider = create(&payload.provider, model_config, extensions) + .await + .map_err(|e| { + ( + StatusCode::BAD_REQUEST, + format!("Failed to create {} provider: {}", &payload.provider, e), + ) + })?; agent .update_provider(new_provider, &payload.session_id) diff --git a/crates/goose-server/src/routes/config_management.rs b/crates/goose-server/src/routes/config_management.rs index 79e62aa28f53..45562d79791f 100644 --- a/crates/goose-server/src/routes/config_management.rs +++ b/crates/goose-server/src/routes/config_management.rs @@ -387,7 +387,7 @@ pub async fn get_provider_models( } let model_config = ModelConfig::new(&metadata.default_model)?; - let provider = goose::providers::create(&name, model_config).await?; + let provider = goose::providers::create(&name, model_config, Vec::new()).await?; let models_result = provider.fetch_recommended_models().await; @@ -747,9 +747,12 @@ pub async fn update_custom_provider( pub async fn check_provider( Json(CheckProviderRequest { provider }): Json, ) -> Result<(), ErrorResponse> { - create_with_default_model(&provider).await.map_err(|err| { - ErrorResponse::bad_request(format!("Provider '{}' check failed: {}", provider, err)) - })?; + // Provider check does not use extensions. + create_with_default_model(&provider, Vec::new()) + .await + .map_err(|err| { + ErrorResponse::bad_request(format!("Provider '{}' check failed: {}", provider, err)) + })?; Ok(()) } @@ -761,7 +764,8 @@ pub async fn check_provider( pub async fn set_config_provider( Json(SetProviderRequest { provider, model }): Json, ) -> Result<(), ErrorResponse> { - create_with_default_model(&provider) + // Provider validation does not use extensions. + create_with_default_model(&provider, Vec::new()) .await .and_then(|_| { let config = Config::global(); @@ -807,12 +811,15 @@ pub async fn configure_provider_oauth( ErrorResponse::bad_request(format!("Failed to create temporary model config: {}", e)) })?; - let provider = create(&provider_name, temp_model).await.map_err(|e| { - ErrorResponse::bad_request(format!( - "Failed to create provider '{}': {}", - provider_name, e - )) - })?; + // OAuth configuration does not use extensions. + let provider = create(&provider_name, temp_model, Vec::new()) + .await + .map_err(|e| { + ErrorResponse::bad_request(format!( + "Failed to create provider '{}': {}", + provider_name, e + )) + })?; provider.configure_oauth().await.map_err(|e| { ErrorResponse::bad_request(format!( diff --git a/crates/goose-server/src/routes/session.rs b/crates/goose-server/src/routes/session.rs index e4ff70ae3dd6..b6a61c21fc2c 100644 --- a/crates/goose-server/src/routes/session.rs +++ b/crates/goose-server/src/routes/session.rs @@ -11,7 +11,6 @@ use axum::{ }; use goose::agents::ExtensionConfig; use goose::recipe::Recipe; -use goose::session::extension_data::ExtensionState; use goose::session::session_manager::SessionInsights; use goose::session::{EnabledExtensionsState, Session}; use serde::{Deserialize, Serialize}; @@ -481,10 +480,10 @@ async fn get_session_extensions( .await .map_err(|_| StatusCode::NOT_FOUND)?; - // Try to get session-specific extensions, fall back to global config - let extensions = EnabledExtensionsState::from_extension_data(&session.extension_data) - .map(|state| state.extensions) - .unwrap_or_else(goose::config::get_enabled_extensions); + let extensions = EnabledExtensionsState::extensions_or_default( + Some(&session.extension_data), + goose::config::Config::global(), + ); Ok(Json(SessionExtensionsResponse { extensions })) } diff --git a/crates/goose/examples/agent.rs b/crates/goose/examples/agent.rs index 41c61408c675..34e2a30ae824 100644 --- a/crates/goose/examples/agent.rs +++ b/crates/goose/examples/agent.rs @@ -12,7 +12,8 @@ use std::path::PathBuf; async fn main() -> anyhow::Result<()> { let _ = dotenv(); - let provider = create_with_named_model("databricks", DATABRICKS_DEFAULT_MODEL).await?; + let provider = + create_with_named_model("databricks", DATABRICKS_DEFAULT_MODEL, Vec::new()).await?; let agent = Agent::new(); diff --git a/crates/goose/examples/databricks_oauth.rs b/crates/goose/examples/databricks_oauth.rs index 9d602f77da96..a263fa8d1d69 100644 --- a/crates/goose/examples/databricks_oauth.rs +++ b/crates/goose/examples/databricks_oauth.rs @@ -12,7 +12,8 @@ async fn main() -> Result<()> { std::env::remove_var("DATABRICKS_TOKEN"); // Create the provider - let provider = create_with_named_model("databricks", DATABRICKS_DEFAULT_MODEL).await?; + let provider = + create_with_named_model("databricks", DATABRICKS_DEFAULT_MODEL, Vec::new()).await?; // Create a simple message let message = Message::user().with_text("Tell me a short joke about programming."); diff --git a/crates/goose/examples/image_tool.rs b/crates/goose/examples/image_tool.rs index 96eb4c48078c..1fe66f92e576 100644 --- a/crates/goose/examples/image_tool.rs +++ b/crates/goose/examples/image_tool.rs @@ -18,9 +18,9 @@ async fn main() -> Result<()> { // Create providers let providers: Vec> = vec![ - create_with_named_model("databricks", DATABRICKS_DEFAULT_MODEL).await?, - create_with_named_model("openai", OPEN_AI_DEFAULT_MODEL).await?, - create_with_named_model("anthropic", ANTHROPIC_DEFAULT_MODEL).await?, + create_with_named_model("databricks", DATABRICKS_DEFAULT_MODEL, Vec::new()).await?, + create_with_named_model("openai", OPEN_AI_DEFAULT_MODEL, Vec::new()).await?, + create_with_named_model("anthropic", ANTHROPIC_DEFAULT_MODEL, Vec::new()).await?, ]; for provider in providers { // Read and encode test image diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 3d476ead469e..a862f830273d 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -1584,7 +1584,10 @@ impl Agent { } }; - let provider = crate::providers::create(&provider_name, model_config) + let extensions = + EnabledExtensionsState::extensions_or_default(Some(&session.extension_data), config); + + let provider = crate::providers::create(&provider_name, model_config, extensions) .await .map_err(|e| anyhow!("Could not create provider: {}", e))?; diff --git a/crates/goose/src/agents/extension.rs b/crates/goose/src/agents/extension.rs index 3a34617d718d..ccd80324c339 100644 --- a/crates/goose/src/agents/extension.rs +++ b/crates/goose/src/agents/extension.rs @@ -11,6 +11,7 @@ use crate::agents::mcp_client::McpClientTrait; use crate::config; use crate::config::extensions::name_to_key; use crate::config::permission::PermissionLevel; +use crate::config::Config; use once_cell::sync::Lazy; use rmcp::model::Tool; use rmcp::service::ClientInitializeError; @@ -532,11 +533,9 @@ impl ExtensionConfig { } pub fn key(&self) -> String { - let name = self.name(); - name_to_key(&name) + name_to_key(&self.name()) } - /// Get the extension name regardless of variant pub fn name(&self) -> String { match self { Self::Sse { name, .. } => name, @@ -578,6 +577,69 @@ impl ExtensionConfig { // If tools are specified, only those tools are available available_tools.is_empty() || available_tools.contains(&tool_name.to_string()) } + + pub async fn resolve(self, config: &Config) -> ExtensionResult { + use crate::agents::extension_manager::{merge_environments, substitute_env_vars}; + + match self { + Self::Stdio { + name, + description, + cmd, + args, + envs, + env_keys, + timeout, + bundled, + available_tools, + } => { + let merged = merge_environments(&envs, &env_keys, &name, config).await?; + Ok(Self::Stdio { + name, + description, + cmd, + args, + envs: Envs::new(merged), + env_keys: vec![], + timeout, + bundled, + available_tools, + }) + } + Self::StreamableHttp { + name, + description, + uri, + envs, + env_keys, + headers, + timeout, + bundled, + available_tools, + } => { + let merged = merge_environments(&envs, &env_keys, &name, config).await?; + let headers = headers + .into_iter() + .map(|(k, v)| { + let v = substitute_env_vars(&v, &merged); + (k, v) + }) + .collect(); + Ok(Self::StreamableHttp { + name, + description, + uri, + envs: Envs::new(merged), + env_keys: vec![], + headers, + timeout, + bundled, + available_tools, + }) + } + other => Ok(other), + } + } } impl std::fmt::Display for ExtensionConfig { @@ -661,6 +723,8 @@ impl ToolInfo { #[cfg(test)] mod tests { use crate::agents::*; + use crate::config; + use test_case::test_case; #[test] fn test_deserialize_missing_description() { @@ -722,4 +786,201 @@ available_tools: [] panic!("unexpected result of deserialization: {}", config) } } + + #[test_case( + ExtensionConfig::Builtin { + name: "developer".into(), + description: "dev".into(), + display_name: None, + timeout: None, + bundled: None, + available_tools: vec![], + }, + ExtensionConfig::Builtin { + name: "developer".into(), + description: "dev".into(), + display_name: None, + timeout: None, + bundled: None, + available_tools: vec![], + } + ; "builtin_unchanged" + )] + #[test_case( + ExtensionConfig::StreamableHttp { + name: "test".into(), + description: String::new(), + uri: "https://example.com".into(), + envs: extension::Envs::new({ + let mut m = std::collections::HashMap::new(); + m.insert("AUTH_TOKEN".to_string(), "secret".to_string()); + m + }), + env_keys: vec![], + headers: [( + "Authorization".to_string(), + "Bearer $AUTH_TOKEN".to_string(), + )] + .into_iter() + .collect(), + timeout: None, + bundled: None, + available_tools: vec![], + }, + ExtensionConfig::StreamableHttp { + name: "test".into(), + description: String::new(), + uri: "https://example.com".into(), + envs: extension::Envs::new({ + let mut m = std::collections::HashMap::new(); + m.insert("AUTH_TOKEN".to_string(), "secret".to_string()); + m + }), + env_keys: vec![], + headers: [( + "Authorization".to_string(), + "Bearer secret".to_string(), + )] + .into_iter() + .collect(), + timeout: None, + bundled: None, + available_tools: vec![], + } + ; "header_substitution" + )] + #[test_case( + ExtensionConfig::Stdio { + name: "test".into(), + description: String::new(), + cmd: "echo".into(), + args: vec![], + envs: extension::Envs::default(), + env_keys: vec![], + timeout: None, + bundled: None, + available_tools: vec![], + }, + ExtensionConfig::Stdio { + name: "test".into(), + description: String::new(), + cmd: "echo".into(), + args: vec![], + envs: extension::Envs::default(), + env_keys: vec![], + timeout: None, + bundled: None, + available_tools: vec![], + } + ; "env_keys_cleared" + )] + #[test_case( + ExtensionConfig::Stdio { + name: "test".into(), + description: String::new(), + cmd: "echo".into(), + args: vec![], + envs: extension::Envs::default(), + env_keys: vec!["MY_SECRET".into()], + timeout: None, + bundled: None, + available_tools: vec![], + }, + ExtensionConfig::Stdio { + name: "test".into(), + description: String::new(), + cmd: "echo".into(), + args: vec![], + envs: extension::Envs::new({ + let mut m = std::collections::HashMap::new(); + m.insert("MY_SECRET".to_string(), "secret_value".to_string()); + m + }), + env_keys: vec![], + timeout: None, + bundled: None, + available_tools: vec![], + } + ; "env_key_resolved" + )] + #[test_case( + ExtensionConfig::StreamableHttp { + name: "test".into(), + description: String::new(), + uri: "https://example.com".into(), + envs: extension::Envs::default(), + env_keys: vec!["MY_SECRET".into()], + headers: [( + "Authorization".to_string(), + "Bearer $MY_SECRET".to_string(), + )] + .into_iter() + .collect(), + timeout: None, + bundled: None, + available_tools: vec![], + }, + ExtensionConfig::StreamableHttp { + name: "test".into(), + description: String::new(), + uri: "https://example.com".into(), + envs: extension::Envs::new({ + let mut m = std::collections::HashMap::new(); + m.insert("MY_SECRET".to_string(), "secret_value".to_string()); + m + }), + env_keys: vec![], + headers: [("Authorization".to_string(), "Bearer secret_value".to_string())] + .into_iter() + .collect(), + timeout: None, + bundled: None, + available_tools: vec![], + } + ; "http_env_key_and_header_substitution" + )] + #[test_case( + ExtensionConfig::Stdio { + name: "test".into(), + description: String::new(), + cmd: "echo".into(), + args: vec![], + envs: extension::Envs::new({ + let mut m = std::collections::HashMap::new(); + m.insert("MY_SECRET".to_string(), "original".to_string()); + m + }), + env_keys: vec!["MY_SECRET".into()], + timeout: None, + bundled: None, + available_tools: vec![], + }, + ExtensionConfig::Stdio { + name: "test".into(), + description: String::new(), + cmd: "echo".into(), + args: vec![], + envs: extension::Envs::new({ + let mut m = std::collections::HashMap::new(); + m.insert("MY_SECRET".to_string(), "original".to_string()); + m + }), + env_keys: vec![], + timeout: None, + bundled: None, + available_tools: vec![], + } + ; "env_key_skipped_when_already_in_envs" + )] + #[tokio::test] + async fn test_resolve(config: ExtensionConfig, expected: ExtensionConfig) { + let dir = tempfile::tempdir().unwrap(); + let cfg = config::Config::new_with_file_secrets( + dir.path().join("config.yaml"), + dir.path().join("secrets.yaml"), + ) + .unwrap(); + cfg.set("MY_SECRET", &"secret_value", true).unwrap(); + assert_eq!(config.resolve(&cfg).await.unwrap(), expected); + } } diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index dcdf098c0b13..de31d0131ffb 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -3,7 +3,6 @@ use axum::http::{HeaderMap, HeaderName}; use chrono::{DateTime, Utc}; use futures::stream::{FuturesUnordered, StreamExt}; use futures::{future, FutureExt}; -use rand::{distributions::Alphanumeric, Rng}; use rmcp::service::{ClientInitializeError, ServiceError}; use rmcp::transport::streamable_http_client::{ AuthRequiredError, StreamableHttpClientTransportConfig, StreamableHttpError, @@ -135,31 +134,6 @@ impl ResourceItem { } } -/// Generates extension name from server info; adds random suffix on collision. -fn generate_extension_name( - server_info: Option<&ServerInfo>, - name_exists: impl Fn(&str) -> bool, -) -> String { - let base = server_info - .and_then(|info| { - let name = info.server_info.name.as_str(); - (!name.is_empty()).then(|| name_to_key(name)) - }) - .unwrap_or_else(|| "unnamed".to_string()); - - if !name_exists(&base) { - return base; - } - - let suffix: String = rand::thread_rng() - .sample_iter(Alphanumeric) - .take(6) - .map(char::from) - .collect(); - - format!("{base}_{suffix}") -} - fn resolve_command(cmd: &str) -> PathBuf { SearchPaths::builder() .with_npm() @@ -315,20 +289,20 @@ fn extract_auth_error( } /// Merge environment variables from direct envs and keychain-stored env_keys -async fn merge_environments( +pub(crate) async fn merge_environments( envs: &Envs, env_keys: &[String], ext_name: &str, + config: &Config, ) -> Result, ExtensionError> { let mut all_envs = envs.get_env(); - let config_instance = Config::global(); for key in env_keys { if all_envs.contains_key(key) { continue; } - match config_instance.get(key, true) { + match config.get(key, true) { Ok(value) => { if value.is_null() { warn!( @@ -369,7 +343,7 @@ async fn merge_environments( } /// Substitute environment variables in a string. Supports both ${VAR} and $VAR syntax. -fn substitute_env_vars(value: &str, env_map: &HashMap) -> String { +pub(crate) fn substitute_env_vars(value: &str, env_map: &HashMap) -> String { let mut result = value.to_string(); let re_braces = @@ -404,7 +378,6 @@ async fn create_streamable_http_client( timeout: Option, headers: &HashMap, name: &str, - all_envs: &HashMap, provider: SharedProvider, ) -> ExtensionResult> { let mut default_headers = HeaderMap::new(); @@ -412,11 +385,10 @@ async fn create_streamable_http_client( default_headers.insert(reqwest::header::USER_AGENT, GOOSE_USER_AGENT); for (key, value) in headers { - let substituted_value = substitute_env_vars(value, all_envs); default_headers.insert( HeaderName::try_from(key) .map_err(|_| ExtensionError::ConfigError(format!("invalid header: {}", key)))?, - substituted_value.parse().map_err(|_| { + value.parse().map_err(|_| { ExtensionError::ConfigError(format!("invalid header value: {}", key)) })?, ); @@ -517,8 +489,7 @@ impl ExtensionManager { container: Option<&Container>, session_id: Option<&str>, ) -> ExtensionResult<()> { - let config_name = config.key().to_string(); - let sanitized_name = name_to_key(&config_name); + let sanitized_name = config.key(); if self.extensions.lock().await.contains_key(&sanitized_name) { return Ok(()); @@ -545,13 +516,17 @@ impl ExtensionManager { env_keys, .. } => { - let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?; + let config = Config::global(); + let all_envs = merge_environments(envs, env_keys, &sanitized_name, config).await?; + let resolved_headers = headers + .iter() + .map(|(k, v)| (k.clone(), substitute_env_vars(v, &all_envs))) + .collect(); create_streamable_http_client( uri, *timeout, - headers, + &resolved_headers, name, - &all_envs, self.provider.clone(), ) .await? @@ -564,7 +539,9 @@ impl ExtensionManager { timeout, .. } => { - let mut all_envs = merge_environments(envs, env_keys, &sanitized_name).await?; + let config = Config::global(); + let mut all_envs = + merge_environments(envs, env_keys, &sanitized_name, config).await?; if let Some(sid) = session_id { all_envs.insert("AGENT_SESSION_ID".to_string(), sid.to_string()); @@ -706,15 +683,9 @@ impl ExtensionManager { let server_info = client.get_info().cloned(); - // Only generate name from server info when config has no name (e.g., CLI --with-*-extension args) let mut extensions = self.extensions.lock().await; - let final_name = if sanitized_name.is_empty() { - generate_extension_name(server_info.as_ref(), |n| extensions.contains_key(n)) - } else { - sanitized_name - }; extensions.insert( - final_name, + sanitized_name, Extension::new(config, Arc::new(Mutex::new(client)), server_info, temp_dir), ); drop(extensions); @@ -2020,35 +1991,6 @@ mod tests { assert_eq!(result, "Authorization: Bearer secret123 and API key456"); } - mod generate_extension_name_tests { - use super::*; - use rmcp::model::Implementation; - use test_case::test_case; - - fn make_info(name: &str) -> ServerInfo { - ServerInfo { - server_info: Implementation { - name: name.into(), - ..Default::default() - }, - ..Default::default() - } - } - - #[test_case(Some("kiwi-mcp-server"), None, "^kiwi-mcp-server$" ; "already normalized server name")] - #[test_case(Some("Context7"), None, "^context7$" ; "mixed case normalized")] - #[test_case(Some("@huggingface/mcp-services"), None, "^_huggingface_mcp-services$" ; "special chars normalized")] - #[test_case(None, None, "^unnamed$" ; "no server info falls back")] - #[test_case(Some(""), None, "^unnamed$" ; "empty server name falls back")] - #[test_case(Some("github-mcp-server"), Some("github-mcp-server"), r"^github-mcp-server_[A-Za-z0-9]{6}$" ; "duplicate adds suffix")] - fn test_generate_name(server_name: Option<&str>, collision: Option<&str>, expected: &str) { - let info = server_name.map(make_info); - let result = generate_extension_name(info.as_ref(), |n| collision == Some(n)); - let re = regex::Regex::new(expected).unwrap(); - assert!(re.is_match(&result)); - } - } - #[tokio::test] async fn test_collect_moim_uses_minute_granularity() { let temp_dir = tempfile::tempdir().unwrap(); diff --git a/crates/goose/src/agents/mod.rs b/crates/goose/src/agents/mod.rs index 0efc3ec22274..07cd370ce7a5 100644 --- a/crates/goose/src/agents/mod.rs +++ b/crates/goose/src/agents/mod.rs @@ -30,7 +30,7 @@ pub mod types; pub use agent::{Agent, AgentConfig, AgentEvent, ExtensionLoadResult}; pub use container::Container; pub use execute_commands::COMPACT_TRIGGERS; -pub use extension::ExtensionConfig; +pub use extension::{ExtensionConfig, ExtensionError}; pub use extension_manager::ExtensionManager; pub use prompt_manager::PromptManager; pub use subagent_handler::SUBAGENT_TOOL_REQUEST_TYPE; diff --git a/crates/goose/src/agents/summon_extension.rs b/crates/goose/src/agents/summon_extension.rs index 4e0439b28d01..05b207b59329 100644 --- a/crates/goose/src/agents/summon_extension.rs +++ b/crates/goose/src/agents/summon_extension.rs @@ -13,11 +13,12 @@ use crate::agents::subagent_handler::{ use crate::agents::subagent_task_config::{TaskConfig, DEFAULT_SUBAGENT_MAX_TURNS}; use crate::agents::AgentConfig; use crate::config::paths::Paths; +use crate::config::Config; use crate::providers; use crate::recipe::build_recipe::build_recipe_from_template; use crate::recipe::local_recipes::load_local_recipe_file; use crate::recipe::{Recipe, Settings, RECIPE_FILE_EXTENSIONS}; -use crate::session::extension_data::{EnabledExtensionsState, ExtensionState}; +use crate::session::extension_data::EnabledExtensionsState; use crate::session::SessionType; use anyhow::Result; use async_trait::async_trait; @@ -1290,7 +1291,10 @@ impl SummonClient { ) -> Result { let provider = self.resolve_provider(params, recipe, session).await?; - let mut extensions = self.resolve_extensions(session)?; + let mut extensions = EnabledExtensionsState::extensions_or_default( + Some(&session.extension_data), + Config::global(), + ); if let Some(filter) = ¶ms.extensions { if filter.is_empty() { @@ -1349,18 +1353,7 @@ impl SummonClient { model_config = model_config.with_temperature(Some(temp)); } - providers::create(&provider_name, model_config).await - } - - fn resolve_extensions( - &self, - session: &crate::session::Session, - ) -> Result, anyhow::Error> { - let extensions = EnabledExtensionsState::from_extension_data(&session.extension_data) - .map(|s| s.extensions) - .unwrap_or_else(crate::config::get_enabled_extensions); - - Ok(extensions) + providers::create(&provider_name, model_config, Vec::new()).await } fn resolve_max_turns(&self, session: &crate::session::Session) -> usize { diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index 8a48bc218521..c6dfa1002bab 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -205,7 +205,10 @@ impl ProviderDef for AnthropicProvider { ) } - fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + fn from_env( + model: ModelConfig, + _extensions: Vec, + ) -> BoxFuture<'static, Result> { Box::pin(Self::from_env(model)) } } diff --git a/crates/goose/src/providers/auto_detect.rs b/crates/goose/src/providers/auto_detect.rs index 43f13ea21e4c..39a9e54118e6 100644 --- a/crates/goose/src/providers/auto_detect.rs +++ b/crates/goose/src/providers/auto_detect.rs @@ -22,6 +22,7 @@ pub async fn detect_provider_from_api_key(api_key: &str) -> Option<(String, Vec< let result = match crate::providers::create( provider_name, ModelConfig::new_or_fail("default"), + Vec::new(), ) .await { diff --git a/crates/goose/src/providers/azure.rs b/crates/goose/src/providers/azure.rs index cef3f20c176a..e9fdc2a105c2 100644 --- a/crates/goose/src/providers/azure.rs +++ b/crates/goose/src/providers/azure.rs @@ -63,7 +63,10 @@ impl ProviderDef for AzureProvider { ) } - fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + fn from_env( + model: ModelConfig, + _extensions: Vec, + ) -> BoxFuture<'static, Result> { Box::pin(async move { let config = crate::config::Config::global(); let endpoint: String = config.get_param("AZURE_OPENAI_ENDPOINT")?; diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index 67ac0e5f9f2e..05447494bcd0 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -1,4 +1,5 @@ use anyhow::Result; +use async_trait::async_trait; use futures::future::BoxFuture; use futures::Stream; use serde::{Deserialize, Serialize}; @@ -7,6 +8,7 @@ use super::canonical::{map_to_canonical_model, CanonicalModelRegistry}; use super::errors::ProviderError; use super::retry::RetryConfig; use crate::config::base::ConfigValue; +use crate::config::ExtensionConfig; use crate::conversation::message::Message; use crate::conversation::Conversation; use crate::model::ModelConfig; @@ -342,8 +344,6 @@ impl Usage { } } -use async_trait::async_trait; - pub trait ProviderDef: Send + Sync { type Provider: Provider + 'static; @@ -351,7 +351,10 @@ pub trait ProviderDef: Send + Sync { where Self: Sized; - fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> + fn from_env( + model: ModelConfig, + extensions: Vec, + ) -> BoxFuture<'static, Result> where Self: Sized; } diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index 898112eba658..e8889fbefc24 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -282,7 +282,10 @@ impl ProviderDef for BedrockProvider { ) } - fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + fn from_env( + model: ModelConfig, + _extensions: Vec, + ) -> BoxFuture<'static, Result> { Box::pin(Self::from_env(model)) } } diff --git a/crates/goose/src/providers/canonical/build_canonical_models.rs b/crates/goose/src/providers/canonical/build_canonical_models.rs index 0fe11dffec59..a955c1335315 100644 --- a/crates/goose/src/providers/canonical/build_canonical_models.rs +++ b/crates/goose/src/providers/canonical/build_canonical_models.rs @@ -483,7 +483,7 @@ async fn check_provider( ) -> Result<(Vec, Vec, Vec)> { println!("Checking provider: {}", provider_name); - let provider = match create_with_named_model(provider_name, model_for_init).await { + let provider = match create_with_named_model(provider_name, model_for_init, Vec::new()).await { Ok(p) => p, Err(e) => { println!(" âš  Failed to create provider: {}", e); diff --git a/crates/goose/src/providers/chatgpt_codex.rs b/crates/goose/src/providers/chatgpt_codex.rs index a1d47ff0b6c6..fcd63a59a67f 100644 --- a/crates/goose/src/providers/chatgpt_codex.rs +++ b/crates/goose/src/providers/chatgpt_codex.rs @@ -868,7 +868,10 @@ impl ProviderDef for ChatGptCodexProvider { .with_unlisted_models() } - fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + fn from_env( + model: ModelConfig, + _extensions: Vec, + ) -> BoxFuture<'static, Result> { Box::pin(Self::from_env(model)) } } diff --git a/crates/goose/src/providers/claude_code.rs b/crates/goose/src/providers/claude_code.rs index 72a1de7b6d28..67c1d5d18f14 100644 --- a/crates/goose/src/providers/claude_code.rs +++ b/crates/goose/src/providers/claude_code.rs @@ -3,8 +3,10 @@ use async_trait::async_trait; use futures::future::BoxFuture; use rmcp::model::Role; use serde_json::{json, Value}; -use std::path::PathBuf; +use std::io::Write; +use std::path::{Path, PathBuf}; use std::process::Stdio; +use tempfile::NamedTempFile; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use tokio::process::Command; @@ -12,8 +14,9 @@ use super::base::{ConfigKey, Provider, ProviderDef, ProviderMetadata, ProviderUs use super::errors::ProviderError; use super::utils::{filter_extensions_from_system_prompt, RequestLog}; use crate::config::base::ClaudeCodeCommand; +use crate::config::paths::Paths; use crate::config::search_path::SearchPaths; -use crate::config::{Config, GooseMode}; +use crate::config::{Config, ExtensionConfig, GooseMode}; use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; use crate::subprocess::configure_subprocess; @@ -145,6 +148,9 @@ pub struct ClaudeCodeProvider { model: ModelConfig, #[serde(skip)] name: String, + /// Temp file holding MCP config JSON (auto-deleted on drop). + #[serde(skip)] + mcp_config_file: Option, #[serde(skip)] cli_process: tokio::sync::OnceCell>, } @@ -374,6 +380,10 @@ impl ClaudeCodeProvider { .cli_process .get_or_try_init(|| async { let mut cmd = self.build_stream_json_command(); + if let Some(f) = &self.mcp_config_file { + cmd.arg("--mcp-config").arg(f.path()); + cmd.arg("--strict-mcp-config"); + } // System prompt is set once at process start and cannot be updated at runtime. cmd.arg("--system-prompt").arg(&filtered_system); @@ -577,7 +587,71 @@ fn build_stream_json_input(content_blocks: &[Value], session_id: &str) -> String serde_json::to_string(&msg).expect("serializing JSON content blocks cannot fail") } -#[async_trait] +fn claude_mcp_config_json(extensions: &[ExtensionConfig]) -> Option { + let mut mcp_servers = serde_json::Map::new(); + + for extension in extensions { + match extension { + ExtensionConfig::StreamableHttp { uri, headers, .. } => { + let key = extension.key(); + let mut config = serde_json::Map::new(); + config.insert("type".to_string(), json!("http")); + config.insert("url".to_string(), json!(uri)); + if !headers.is_empty() { + config.insert("headers".to_string(), json!(headers)); + } + mcp_servers.insert(key, Value::Object(config)); + } + ExtensionConfig::Stdio { + cmd, args, envs, .. + } => { + let key = extension.key(); + let mut config = serde_json::Map::new(); + config.insert("type".to_string(), json!("stdio")); + config.insert("command".to_string(), json!(cmd)); + if !args.is_empty() { + config.insert("args".to_string(), json!(args)); + } + let env_map = envs.get_env(); + if !env_map.is_empty() { + config.insert("env".to_string(), json!(env_map)); + } + mcp_servers.insert(key, Value::Object(config)); + } + ExtensionConfig::Sse { name, .. } => { + tracing::debug!(name, "skipping SSE extension, migrate to streamable_http"); + } + _ => {} + } + } + + if mcp_servers.is_empty() { + return None; + } + + serde_json::to_string(&json!({ "mcpServers": mcp_servers })).ok() +} + +/// Write the MCP config JSON to a temp file with restricted permissions +/// so secrets (headers, env vars) are not leaked via process argv. +fn write_mcp_config_file(state_dir: &Path, json: &str) -> Result { + let dir = state_dir.join("claude-code"); + std::fs::create_dir_all(&dir)?; + let prefix = format!("mcp-config-{}_", chrono::Utc::now().format("%Y%m%d")); + let mut tmp = tempfile::Builder::new() + .prefix(&prefix) + .suffix(".json") + .tempfile_in(&dir)?; + tmp.write_all(json.as_bytes())?; + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + tmp.as_file() + .set_permissions(std::fs::Permissions::from_mode(0o600))?; + } + Ok(tmp) +} + impl ProviderDef for ClaudeCodeProvider { type Provider = Self; @@ -598,16 +672,29 @@ impl ProviderDef for ClaudeCodeProvider { .with_unlisted_models() } - fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + fn from_env( + model: ModelConfig, + extensions: Vec, + ) -> BoxFuture<'static, Result> { Box::pin(async move { let config = crate::config::Config::global(); let command: String = config.get_claude_code_command().unwrap_or_default().into(); let resolved_command = SearchPaths::builder().with_npm().resolve(command)?; + let mut resolved = Vec::with_capacity(extensions.len()); + for ext in extensions { + resolved.push(ext.resolve(config).await?); + } + + let mcp_config_file = claude_mcp_config_json(&resolved) + .map(|json| write_mcp_config_file(&Paths::state_dir(), &json)) + .transpose()?; + Ok(Self { command: resolved_command, model, name: CLAUDE_CODE_PROVIDER_NAME.to_string(), + mcp_config_file, cli_process: tokio::sync::OnceCell::new(), }) }) @@ -737,8 +824,13 @@ impl Provider for ClaudeCodeProvider { #[cfg(test)] mod tests { use super::*; + use crate::agents::extension::Envs; + use chrono::Utc; use goose_test_support::session::TEST_SESSION_ID; use serde_json::json; + use std::collections::HashMap; + use std::fs; + use tempfile::tempdir; use test_case::test_case; /// (role, text, optional (image_data, mime_type)) @@ -940,11 +1032,115 @@ mod tests { assert_eq!(result, expected); } + #[test_case( + vec![], + None + ; "empty_extensions_returns_none" + )] + #[test_case( + vec![ExtensionConfig::Sse { + name: "legacy".into(), + description: String::new(), + uri: Some("http://localhost/sse".into()), + }], + None + ; "sse_only_returns_none" + )] + #[test_case( + vec![ExtensionConfig::Stdio { + name: "lookup".into(), + description: String::new(), + cmd: "node".into(), + args: vec!["server.js".into()], + envs: Envs::new([("API_KEY".into(), "secret".into())].into()), + env_keys: vec![], + timeout: None, + bundled: Some(false), + available_tools: vec![], + }], + Some(json!({ "mcpServers": { + "lookup": { + "type": "stdio", + "command": "node", + "args": ["server.js"], + "env": { "API_KEY": "secret" } + } + }})) + ; "stdio_converts_to_mcp_config_json" + )] + #[test_case( + vec![ExtensionConfig::StreamableHttp { + name: "lookup".into(), + description: String::new(), + uri: "http://localhost/mcp".into(), + envs: Envs::default(), + env_keys: vec![], + headers: HashMap::from([("Authorization".into(), "Bearer token".into())]), + timeout: None, + bundled: Some(false), + available_tools: vec![], + }], + Some(json!({ "mcpServers": { + "lookup": { + "type": "http", + "url": "http://localhost/mcp", + "headers": { "Authorization": "Bearer token" } + } + }})) + ; "streamable_http_converts_to_mcp_config_json" + )] + #[test_case( + vec![ExtensionConfig::StreamableHttp { + name: "mcp_kiwi_com".into(), + description: String::new(), + uri: "https://mcp.kiwi.com".into(), + envs: Envs::default(), + env_keys: vec![], + headers: HashMap::new(), + timeout: None, + bundled: None, + available_tools: vec![], + }], + Some(json!({ "mcpServers": { + "mcp_kiwi_com": { + "type": "http", + "url": "https://mcp.kiwi.com" + } + }})) + ; "resolved_name_used_as_key" + )] + fn test_claude_mcp_config_json(extensions: Vec, expected: Option) { + let result = claude_mcp_config_json(&extensions) + .map(|json| serde_json::from_str::(&json).unwrap()); + assert_eq!(result, expected); + } + + #[test] + fn test_write_mcp_config_file() { + let state_dir = tempdir().unwrap(); + let json = r#"{"mcpServers":{}}"#; + + let tmp = write_mcp_config_file(state_dir.path(), json).unwrap(); + + assert_eq!(fs::read_to_string(tmp.path()).unwrap(), json); + + let norm_path = tmp.path().to_string_lossy().replace('\\', "/"); + let expected_prefix = format!("claude-code/mcp-config-{}_", Utc::now().format("%Y%m%d")); + assert!(norm_path.contains(&expected_prefix)); + assert!(norm_path.ends_with(".json")); + } + + #[test] + fn test_write_mcp_config_file_invalid_state_dir() { + assert!(write_mcp_config_file(Path::new("/dev/null"), "{}").is_err()); + } + fn make_provider() -> ClaudeCodeProvider { ClaudeCodeProvider { command: PathBuf::from("claude"), model: ModelConfig::new(CLAUDE_CODE_DEFAULT_MODEL).unwrap(), name: "claude-code".to_string(), + mcp_config_file: None, cli_process: tokio::sync::OnceCell::new(), } } diff --git a/crates/goose/src/providers/codex.rs b/crates/goose/src/providers/codex.rs index 8f9a3e119477..453d99641a78 100644 --- a/crates/goose/src/providers/codex.rs +++ b/crates/goose/src/providers/codex.rs @@ -16,7 +16,7 @@ use super::utils::{filter_extensions_from_system_prompt, RequestLog}; use crate::config::base::{CodexCommand, CodexReasoningEffort, CodexSkipGitCheck}; use crate::config::paths::Paths; use crate::config::search_path::SearchPaths; -use crate::config::{Config, GooseMode}; +use crate::config::{Config, ExtensionConfig, GooseMode}; use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; use crate::subprocess::configure_subprocess; @@ -50,6 +50,8 @@ pub struct CodexProvider { reasoning_effort: String, /// Whether to skip git repo check skip_git_check: bool, + /// CLI config overrides for MCP servers + mcp_config_overrides: Vec, } impl CodexProvider { @@ -140,6 +142,10 @@ impl CodexProvider { self.reasoning_effort )); + for override_config in &self.mcp_config_overrides { + cmd.arg("-c").arg(override_config); + } + // JSON output format for structured parsing cmd.arg("--json"); @@ -539,7 +545,92 @@ fn prepare_input( Ok((prompt, temp_files)) } -#[async_trait] +fn toml_quote(s: &str) -> String { + let mut out = String::with_capacity(s.len() + 2); + out.push('"'); + for ch in s.chars() { + match ch { + '\\' => out.push_str("\\\\"), + '"' => out.push_str("\\\""), + '\n' => out.push_str("\\n"), + '\r' => out.push_str("\\r"), + '\t' => out.push_str("\\t"), + '\u{0008}' => out.push_str("\\b"), + '\u{000C}' => out.push_str("\\f"), + c if c.is_control() => { + // TOML \uXXXX for other control characters + for unit in c.encode_utf16(&mut [0; 2]) { + out.push_str(&format!("\\u{:04X}", unit)); + } + } + c => out.push(c), + } + } + out.push('"'); + out +} + +// Codex CLI only supports inline `-c key=value` TOML overrides — no file-based +// config merging. Resolved secrets (from env_keys/keystore) in envs/headers end +// up in process argv, visible via `ps`. Claude Code avoids this by writing to a +// temp file with 0o600 permissions. +// Tracking: https://github.com/openai/codex/issues/2628 +fn codex_mcp_config_overrides(extensions: &[ExtensionConfig]) -> Vec { + let mut overrides = Vec::new(); + for extension in extensions { + match extension { + ExtensionConfig::StreamableHttp { uri, headers, .. } => { + let key = extension.key(); + overrides.push(format!("mcp_servers.{}.url={}", key, toml_quote(uri))); + if !headers.is_empty() { + let mut hkeys: Vec<_> = headers.keys().collect(); + hkeys.sort(); + let entries: Vec<_> = hkeys + .iter() + .map(|k| format!("{} = {}", toml_quote(k), toml_quote(&headers[*k]))) + .collect(); + overrides.push(format!( + "mcp_servers.{}.http_headers={{{}}}", + key, + entries.join(", ") + )); + } + } + ExtensionConfig::Stdio { + cmd, args, envs, .. + } => { + let key = extension.key(); + overrides.push(format!("mcp_servers.{}.command={}", key, toml_quote(cmd))); + if !args.is_empty() { + let items: Vec<_> = args.iter().map(|a| toml_quote(a)).collect(); + overrides.push(format!("mcp_servers.{}.args=[{}]", key, items.join(", "))); + } + let env_map = envs.get_env(); + if !env_map.is_empty() { + let mut ekeys: Vec<_> = env_map.keys().collect(); + ekeys.sort(); + let entries: Vec<_> = ekeys + .iter() + .map(|k| { + format!("{} = {}", toml_quote(k), toml_quote(&env_map[k.as_str()])) + }) + .collect(); + overrides.push(format!( + "mcp_servers.{}.env={{{}}}", + key, + entries.join(", ") + )); + } + } + ExtensionConfig::Sse { name, .. } => { + tracing::debug!(name, "skipping SSE extension, migrate to streamable_http"); + } + _ => {} + } + } + overrides +} + impl ProviderDef for CodexProvider { type Provider = Self; @@ -560,7 +651,10 @@ impl ProviderDef for CodexProvider { .with_unlisted_models() } - fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + fn from_env( + model: ModelConfig, + extensions: Vec, + ) -> BoxFuture<'static, Result> { Box::pin(async move { let config = Config::global(); let command: String = config.get_codex_command().unwrap_or_default().into(); @@ -591,12 +685,18 @@ impl ProviderDef for CodexProvider { .map(|s| s.to_lowercase() == "true") .unwrap_or(false); + let mut resolved = Vec::with_capacity(extensions.len()); + for ext in extensions { + resolved.push(ext.resolve(config).await?); + } + Ok(Self { command: resolved_command, model, name: CODEX_PROVIDER_NAME.to_string(), reasoning_effort, skip_git_check, + mcp_config_overrides: codex_mcp_config_overrides(&resolved), }) }) } @@ -669,7 +769,9 @@ impl Provider for CodexProvider { #[cfg(test)] mod tests { use super::*; + use crate::agents::extension::Envs; use goose_test_support::TEST_IMAGE_B64; + use std::collections::HashMap; use test_case::test_case; #[test] @@ -685,6 +787,96 @@ mod tests { .any(|m| m.name == CODEX_DEFAULT_MODEL)); } + #[test_case( + ExtensionConfig::Stdio { + name: "lookup".into(), + cmd: "node".into(), + args: vec!["server.js".into()], + envs: Envs::new([("API_KEY".into(), "secret".into())].into()), + env_keys: vec![], + description: "Lookup".into(), + timeout: Some(30), + bundled: None, + available_tools: vec![], + }, + &[ + r#"mcp_servers.lookup.command="node""#, + r#"mcp_servers.lookup.args=["server.js"]"#, + r#"mcp_servers.lookup.env={"API_KEY" = "secret"}"#, + ] + ; "stdio_converts_to_mcp_overrides" + )] + #[test_case( + ExtensionConfig::StreamableHttp { + name: "lookup".into(), + description: String::new(), + uri: "http://localhost/mcp".into(), + envs: Envs::default(), + env_keys: vec![], + headers: HashMap::from([("Authorization".into(), "Bearer token".into())]), + timeout: None, + bundled: Some(false), + available_tools: vec![], + }, + &[ + r#"mcp_servers.lookup.url="http://localhost/mcp""#, + r#"mcp_servers.lookup.http_headers={"Authorization" = "Bearer token"}"#, + ] + ; "streamable_http_converts_to_mcp_overrides" + )] + #[test_case( + ExtensionConfig::StreamableHttp { + name: "mcp_kiwi_com".into(), + description: String::new(), + uri: "https://mcp.kiwi.com".into(), + envs: Envs::default(), + env_keys: vec![], + headers: HashMap::new(), + timeout: None, + bundled: None, + available_tools: vec![], + }, + &[ + r#"mcp_servers.mcp_kiwi_com.url="https://mcp.kiwi.com""#, + ] + ; "resolved_name_used_as_key_http" + )] + #[test_case( + ExtensionConfig::Stdio { + name: "my-server".into(), + cmd: "/usr/bin/my-server".into(), + args: vec![], + envs: Envs::default(), + env_keys: vec![], + description: String::new(), + timeout: None, + bundled: None, + available_tools: vec![], + }, + &[ + r#"mcp_servers.my-server.command="/usr/bin/my-server""#, + ] + ; "resolved_name_used_as_key_stdio" + )] + fn test_codex_mcp_overrides(config: ExtensionConfig, expected: &[&str]) { + let overrides = codex_mcp_config_overrides(&[config]); + let expected: Vec = expected.iter().map(|s| s.to_string()).collect(); + assert_eq!(overrides, expected); + } + + #[test_case("simple", r#""simple""# ; "no_special_chars")] + #[test_case(r#"back\slash"#, r#""back\\slash""# ; "backslash")] + #[test_case(r#"has"quote"#, r#""has\"quote""# ; "double_quote")] + #[test_case("line\nbreak", r#""line\nbreak""# ; "newline")] + #[test_case("tab\there", r#""tab\there""# ; "tab")] + #[test_case("cr\rhere", r#""cr\rhere""# ; "carriage_return")] + #[test_case("bell\u{0008}here", r#""bell\bhere""# ; "backspace")] + #[test_case("ff\u{000C}here", r#""ff\fhere""# ; "form_feed")] + #[test_case("null\u{0000}here", r#""null\u0000here""# ; "null_control_char")] + fn test_toml_quote(input: &str, expected: &str) { + assert_eq!(toml_quote(input), expected); + } + #[test_case("image/png", ".png" ; "png image")] #[test_case("image/jpeg", ".jpg" ; "jpeg image")] fn test_prepare_input_image(mime: &str, expected_ext: &str) { @@ -765,6 +957,7 @@ mod tests { name: "codex".to_string(), reasoning_effort: "high".to_string(), skip_git_check: false, + mcp_config_overrides: Vec::new(), }; let lines = vec!["Hello, world!".to_string()]; @@ -784,6 +977,7 @@ mod tests { name: "codex".to_string(), reasoning_effort: "high".to_string(), skip_git_check: false, + mcp_config_overrides: Vec::new(), }; // Test with actual Codex CLI output format @@ -816,6 +1010,7 @@ mod tests { name: "codex".to_string(), reasoning_effort: "high".to_string(), skip_git_check: false, + mcp_config_overrides: Vec::new(), }; let lines: Vec = vec![]; @@ -863,6 +1058,7 @@ mod tests { name: "codex".to_string(), reasoning_effort: "high".to_string(), skip_git_check: false, + mcp_config_overrides: Vec::new(), }; let lines = vec![ @@ -887,6 +1083,7 @@ mod tests { name: "codex".to_string(), reasoning_effort: "high".to_string(), skip_git_check: false, + mcp_config_overrides: Vec::new(), }; let lines = vec![ @@ -958,6 +1155,7 @@ mod tests { name: "codex".to_string(), reasoning_effort: "high".to_string(), skip_git_check: false, + mcp_config_overrides: Vec::new(), }; let lines: Vec = lines.iter().map(|s| s.to_string()).collect(); @@ -973,6 +1171,7 @@ mod tests { name: "codex".to_string(), reasoning_effort: "high".to_string(), skip_git_check: false, + mcp_config_overrides: Vec::new(), }; let lines = vec![ @@ -999,6 +1198,7 @@ mod tests { name: "codex".to_string(), reasoning_effort: "high".to_string(), skip_git_check: false, + mcp_config_overrides: Vec::new(), }; let messages = vec![Message::new( @@ -1030,6 +1230,7 @@ mod tests { name: "codex".to_string(), reasoning_effort: "high".to_string(), skip_git_check: false, + mcp_config_overrides: Vec::new(), }; let messages: Vec = vec![]; @@ -1072,6 +1273,7 @@ mod tests { name: "codex".to_string(), reasoning_effort: "high".to_string(), skip_git_check: false, + mcp_config_overrides: Vec::new(), }; let lines = vec![ diff --git a/crates/goose/src/providers/cursor_agent.rs b/crates/goose/src/providers/cursor_agent.rs index 0f60ad4a9fcc..500904e0dd4b 100644 --- a/crates/goose/src/providers/cursor_agent.rs +++ b/crates/goose/src/providers/cursor_agent.rs @@ -340,7 +340,10 @@ impl ProviderDef for CursorAgentProvider { .with_unlisted_models() } - fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + fn from_env( + model: ModelConfig, + _extensions: Vec, + ) -> BoxFuture<'static, Result> { Box::pin(Self::from_env(model)) } } diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index 5f6f2b823759..49709acbdb9a 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -254,7 +254,10 @@ impl ProviderDef for DatabricksProvider { ) } - fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + fn from_env( + model: ModelConfig, + _extensions: Vec, + ) -> BoxFuture<'static, Result> { Box::pin(Self::from_env(model)) } } diff --git a/crates/goose/src/providers/gcpvertexai.rs b/crates/goose/src/providers/gcpvertexai.rs index ac30055bc5e8..b7d2a626d106 100644 --- a/crates/goose/src/providers/gcpvertexai.rs +++ b/crates/goose/src/providers/gcpvertexai.rs @@ -593,7 +593,10 @@ impl ProviderDef for GcpVertexAIProvider { .with_unlisted_models() } - fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + fn from_env( + model: ModelConfig, + _extensions: Vec, + ) -> BoxFuture<'static, Result> { Box::pin(Self::from_env(model)) } } diff --git a/crates/goose/src/providers/gemini_cli.rs b/crates/goose/src/providers/gemini_cli.rs index e198ff2215b4..7d8a28797d45 100644 --- a/crates/goose/src/providers/gemini_cli.rs +++ b/crates/goose/src/providers/gemini_cli.rs @@ -252,7 +252,10 @@ impl ProviderDef for GeminiCliProvider { .with_unlisted_models() } - fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + fn from_env( + model: ModelConfig, + _extensions: Vec, + ) -> BoxFuture<'static, Result> { Box::pin(Self::from_env(model)) } } diff --git a/crates/goose/src/providers/githubcopilot.rs b/crates/goose/src/providers/githubcopilot.rs index 9c91dcd8575c..d5a51d51b4a2 100644 --- a/crates/goose/src/providers/githubcopilot.rs +++ b/crates/goose/src/providers/githubcopilot.rs @@ -399,7 +399,10 @@ impl ProviderDef for GithubCopilotProvider { ) } - fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + fn from_env( + model: ModelConfig, + _extensions: Vec, + ) -> BoxFuture<'static, Result> { Box::pin(Self::from_env(model)) } } diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index 68a9bd28a750..31fda4c22eee 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -139,7 +139,10 @@ impl ProviderDef for GoogleProvider { ) } - fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + fn from_env( + model: ModelConfig, + _extensions: Vec, + ) -> BoxFuture<'static, Result> { Box::pin(Self::from_env(model)) } } diff --git a/crates/goose/src/providers/init.rs b/crates/goose/src/providers/init.rs index 62344c3d5d56..333362ae05e2 100644 --- a/crates/goose/src/providers/init.rs +++ b/crates/goose/src/providers/init.rs @@ -26,6 +26,7 @@ use super::{ venice::VeniceProvider, xai::XaiProvider, }; +use crate::config::ExtensionConfig; use crate::model::ModelConfig; use crate::providers::base::ProviderType; use crate::{ @@ -109,37 +110,45 @@ async fn get_from_registry(name: &str) -> Result { .cloned() } -pub async fn create(name: &str, model: ModelConfig) -> Result> { +pub async fn create( + name: &str, + model: ModelConfig, + extensions: Vec, +) -> Result> { let config = crate::config::Config::global(); if let Ok(lead_model_name) = config.get_param::("GOOSE_LEAD_MODEL") { tracing::info!("Creating lead/worker provider from environment variables"); - return create_lead_worker_from_env(name, &model, &lead_model_name).await; + return create_lead_worker_from_env(name, &model, &lead_model_name, extensions).await; } let constructor = get_from_registry(name).await?.constructor.clone(); - constructor(model).await + constructor(model, extensions).await } -pub async fn create_with_default_model(name: impl AsRef) -> Result> { +pub async fn create_with_default_model( + name: impl AsRef, + extensions: Vec, +) -> Result> { get_from_registry(name.as_ref()) .await? - .create_with_default_model() + .create_with_default_model(extensions) .await } pub async fn create_with_named_model( provider_name: &str, model_name: &str, + extensions: Vec, ) -> Result> { - let config = ModelConfig::new(model_name)?; - create(provider_name, config).await + create(provider_name, ModelConfig::new(model_name)?, extensions).await } async fn create_lead_worker_from_env( default_provider_name: &str, default_model: &ModelConfig, lead_model_name: &str, + extensions: Vec, ) -> Result> { let config = crate::config::Config::global(); @@ -186,8 +195,8 @@ async fn create_lead_worker_from_env( .clone() }; - let lead_provider = lead_constructor(lead_model_config).await?; - let worker_provider = worker_constructor(worker_model_config).await?; + let lead_provider = lead_constructor(lead_model_config, extensions.clone()).await?; + let worker_provider = worker_constructor(worker_model_config, extensions).await?; Ok(Arc::new(LeadWorkerProvider::new_with_settings( lead_provider, @@ -242,9 +251,13 @@ mod tests { ("OPENAI_CUSTOM_HEADERS", Some("")), ]); - let provider = create("openai", ModelConfig::new_or_fail("gpt-4o-mini")) - .await - .unwrap(); + let provider = create( + "openai", + ModelConfig::new_or_fail("gpt-4o-mini"), + Vec::new(), + ) + .await + .unwrap(); let lw = provider.as_lead_worker().unwrap(); let (lead, worker) = lw.get_model_info(); assert_eq!(lead, "gpt-4o"); @@ -267,9 +280,13 @@ mod tests { ("OPENAI_CUSTOM_HEADERS", Some("")), ]); - let provider = create("openai", ModelConfig::new_or_fail("gpt-4o-mini")) - .await - .unwrap(); + let provider = create( + "openai", + ModelConfig::new_or_fail("gpt-4o-mini"), + Vec::new(), + ) + .await + .unwrap(); assert!(provider.as_lead_worker().is_none()); assert_eq!(provider.get_model_config().model_name, "gpt-4o-mini"); } diff --git a/crates/goose/src/providers/lead_worker.rs b/crates/goose/src/providers/lead_worker.rs index 7e72ddf0e819..ed1fddbd284c 100644 --- a/crates/goose/src/providers/lead_worker.rs +++ b/crates/goose/src/providers/lead_worker.rs @@ -335,7 +335,10 @@ impl ProviderDef for LeadWorkerProvider { ) } - fn from_env(_model: ModelConfig) -> BoxFuture<'static, Result> { + fn from_env( + _model: ModelConfig, + _extensions: Vec, + ) -> BoxFuture<'static, Result> { Box::pin(async { Err(anyhow!("LeadWorkerProvider must be constructed explicitly")) }) } } diff --git a/crates/goose/src/providers/litellm.rs b/crates/goose/src/providers/litellm.rs index 32e9a93be817..1df8b5c290db 100644 --- a/crates/goose/src/providers/litellm.rs +++ b/crates/goose/src/providers/litellm.rs @@ -157,7 +157,10 @@ impl ProviderDef for LiteLLMProvider { ) } - fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + fn from_env( + model: ModelConfig, + _extensions: Vec, + ) -> BoxFuture<'static, Result> { Box::pin(Self::from_env(model)) } } diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 7dd87e4ac202..ff61091b3ce4 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -176,7 +176,10 @@ impl ProviderDef for OllamaProvider { ) } - fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + fn from_env( + model: ModelConfig, + _extensions: Vec, + ) -> BoxFuture<'static, Result> { Box::pin(Self::from_env(model)) } } diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 07c8e3fa7a19..04e8c8851d94 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -263,7 +263,10 @@ impl ProviderDef for OpenAiProvider { ) } - fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + fn from_env( + model: ModelConfig, + _extensions: Vec, + ) -> BoxFuture<'static, Result> { Box::pin(Self::from_env(model)) } } diff --git a/crates/goose/src/providers/openrouter.rs b/crates/goose/src/providers/openrouter.rs index a54fc46af8a0..52466c6a34b5 100644 --- a/crates/goose/src/providers/openrouter.rs +++ b/crates/goose/src/providers/openrouter.rs @@ -255,7 +255,10 @@ impl ProviderDef for OpenRouterProvider { .with_unlisted_models() } - fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + fn from_env( + model: ModelConfig, + _extensions: Vec, + ) -> BoxFuture<'static, Result> { Box::pin(Self::from_env(model)) } } diff --git a/crates/goose/src/providers/provider_registry.rs b/crates/goose/src/providers/provider_registry.rs index 542b0a627345..45f94d594731 100644 --- a/crates/goose/src/providers/provider_registry.rs +++ b/crates/goose/src/providers/provider_registry.rs @@ -1,13 +1,16 @@ use super::base::{ModelInfo, Provider, ProviderDef, ProviderMetadata, ProviderType}; -use crate::config::DeclarativeProviderConfig; +use crate::config::{DeclarativeProviderConfig, ExtensionConfig}; use crate::model::ModelConfig; use anyhow::Result; use futures::future::BoxFuture; use std::collections::HashMap; use std::sync::Arc; -pub type ProviderConstructor = - Arc BoxFuture<'static, Result>> + Send + Sync>; +pub type ProviderConstructor = Arc< + dyn Fn(ModelConfig, Vec) -> BoxFuture<'static, Result>> + + Send + + Sync, +>; #[derive(Clone)] pub struct ProviderEntry { @@ -17,10 +20,13 @@ pub struct ProviderEntry { } impl ProviderEntry { - pub async fn create_with_default_model(&self) -> Result> { + pub async fn create_with_default_model( + &self, + extensions: Vec, + ) -> Result> { let default_model = &self.metadata.default_model; let model_config = ModelConfig::new(default_model.as_str())?; - (self.constructor)(model_config).await + (self.constructor)(model_config, extensions).await } } @@ -47,9 +53,9 @@ impl ProviderRegistry { name, ProviderEntry { metadata, - constructor: Arc::new(|model| { + constructor: Arc::new(|model, extensions| { Box::pin(async move { - let provider = F::from_env(model).await?; + let provider = F::from_env(model, extensions).await?; Ok(Arc::new(provider) as Arc) }) }), @@ -121,7 +127,7 @@ impl ProviderRegistry { config.name.clone(), ProviderEntry { metadata: custom_metadata, - constructor: Arc::new(move |model| { + constructor: Arc::new(move |model, _extensions| { let result = constructor(model); Box::pin(async move { let provider = result?; @@ -141,13 +147,18 @@ impl ProviderRegistry { self } - pub async fn create(&self, name: &str, model: ModelConfig) -> Result> { + pub async fn create( + &self, + name: &str, + model: ModelConfig, + extensions: Vec, + ) -> Result> { let entry = self .entries .get(name) .ok_or_else(|| anyhow::anyhow!("Unknown provider: {}", name))?; - (entry.constructor)(model).await + (entry.constructor)(model, extensions).await } pub fn all_metadata_with_types(&self) -> Vec<(ProviderMetadata, ProviderType)> { diff --git a/crates/goose/src/providers/provider_test.rs b/crates/goose/src/providers/provider_test.rs index 990e562e9f9b..b47b59d62bea 100644 --- a/crates/goose/src/providers/provider_test.rs +++ b/crates/goose/src/providers/provider_test.rs @@ -14,7 +14,7 @@ pub async fn test_provider_configuration( .with_toolshim(toolshim_enabled) .with_toolshim_model(toolshim_model); - let provider = create(provider_name, model_config).await?; + let provider = create(provider_name, model_config, Vec::new()).await?; let messages = vec![Message::user().with_text("What is the weather like in San Francisco today?")]; diff --git a/crates/goose/src/providers/sagemaker_tgi.rs b/crates/goose/src/providers/sagemaker_tgi.rs index 5981d5f2fc41..e75d89f05248 100644 --- a/crates/goose/src/providers/sagemaker_tgi.rs +++ b/crates/goose/src/providers/sagemaker_tgi.rs @@ -289,7 +289,10 @@ impl ProviderDef for SageMakerTgiProvider { ) } - fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + fn from_env( + model: ModelConfig, + _extensions: Vec, + ) -> BoxFuture<'static, Result> { Box::pin(Self::from_env(model)) } } diff --git a/crates/goose/src/providers/snowflake.rs b/crates/goose/src/providers/snowflake.rs index 03c236904d69..16ca9a4ffcd8 100644 --- a/crates/goose/src/providers/snowflake.rs +++ b/crates/goose/src/providers/snowflake.rs @@ -313,7 +313,10 @@ impl ProviderDef for SnowflakeProvider { ) } - fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + fn from_env( + model: ModelConfig, + _extensions: Vec, + ) -> BoxFuture<'static, Result> { Box::pin(Self::from_env(model)) } } diff --git a/crates/goose/src/providers/testprovider.rs b/crates/goose/src/providers/testprovider.rs index 32a017eacc94..8ffdac628235 100644 --- a/crates/goose/src/providers/testprovider.rs +++ b/crates/goose/src/providers/testprovider.rs @@ -139,7 +139,10 @@ impl ProviderDef for TestProvider { ) } - fn from_env(_model: ModelConfig) -> BoxFuture<'static, Result> { + fn from_env( + _model: ModelConfig, + _extensions: Vec, + ) -> BoxFuture<'static, Result> { Box::pin(async { Err(anyhow!("TestProvider must be constructed explicitly")) }) } } diff --git a/crates/goose/src/providers/tetrate.rs b/crates/goose/src/providers/tetrate.rs index cb0d0fad790d..b8ad1803c04f 100644 --- a/crates/goose/src/providers/tetrate.rs +++ b/crates/goose/src/providers/tetrate.rs @@ -153,7 +153,10 @@ impl ProviderDef for TetrateProvider { ) } - fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + fn from_env( + model: ModelConfig, + _extensions: Vec, + ) -> BoxFuture<'static, Result> { Box::pin(Self::from_env(model)) } } diff --git a/crates/goose/src/providers/venice.rs b/crates/goose/src/providers/venice.rs index b8187c76cb97..3b4530944b86 100644 --- a/crates/goose/src/providers/venice.rs +++ b/crates/goose/src/providers/venice.rs @@ -224,7 +224,10 @@ impl ProviderDef for VeniceProvider { ) } - fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + fn from_env( + model: ModelConfig, + _extensions: Vec, + ) -> BoxFuture<'static, Result> { Box::pin(Self::from_env(model)) } } diff --git a/crates/goose/src/providers/xai.rs b/crates/goose/src/providers/xai.rs index 49a6389e3b13..011cedaac9bd 100644 --- a/crates/goose/src/providers/xai.rs +++ b/crates/goose/src/providers/xai.rs @@ -51,7 +51,10 @@ impl ProviderDef for XaiProvider { ) } - fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + fn from_env( + model: ModelConfig, + _extensions: Vec, + ) -> BoxFuture<'static, Result> { Box::pin(async move { let config = crate::config::Config::global(); let api_key: String = config.get_secret("XAI_API_KEY")?; diff --git a/crates/goose/src/scheduler.rs b/crates/goose/src/scheduler.rs index 31c41ad88a78..815f535b387c 100644 --- a/crates/goose/src/scheduler.rs +++ b/crates/goose/src/scheduler.rs @@ -741,8 +741,6 @@ async fn execute_job( let model_name = config.get_goose_model()?; let model_config = crate::model::ModelConfig::new(&model_name)?; - let agent_provider = create(&provider_name, model_config).await?; - let session = agent .config .session_manager @@ -753,13 +751,14 @@ async fn execute_job( ) .await?; - agent.update_provider(agent_provider, &session.id).await?; - let extensions = resolve_extensions_for_new_session(recipe.extensions.as_deref(), None); - for ext in extensions { + for ext in &extensions { agent.add_extension(ext.clone(), &session.id).await?; } + let agent_provider = create(&provider_name, model_config, extensions).await?; + agent.update_provider(agent_provider, &session.id).await?; + let mut jobs_guard = jobs.lock().await; if let Some((_, job_def)) = jobs_guard.get_mut(job_id.as_str()) { job_def.current_session_id = Some(session.id.clone()); diff --git a/crates/goose/src/session/extension_data.rs b/crates/goose/src/session/extension_data.rs index ff548ef6b196..243004c3c7f4 100644 --- a/crates/goose/src/session/extension_data.rs +++ b/crates/goose/src/session/extension_data.rs @@ -1,7 +1,9 @@ // Extension data management for sessions // Provides a simple way to store extension-specific data with versioned keys +use crate::config::base::Config; use crate::config::ExtensionConfig; +use crate::session::SessionManager; use anyhow::Result; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -111,12 +113,81 @@ impl EnabledExtensionsState { pub fn new(extensions: Vec) -> Self { Self { extensions } } + + pub fn extensions_or_default( + extension_data: Option<&ExtensionData>, + config: &Config, + ) -> Vec { + extension_data + .and_then(Self::from_extension_data) + .map(|state| state.extensions) + .unwrap_or_else(|| { + crate::config::extensions::get_enabled_extensions_with_config(config) + }) + } + + pub async fn for_session( + session_manager: &SessionManager, + session_id: &str, + config: &Config, + ) -> Vec { + let session = session_manager.get_session(session_id, false).await.ok(); + Self::extensions_or_default(session.as_ref().map(|s| &s.extension_data), config) + } } #[cfg(test)] mod tests { use super::*; use serde_json::json; + use tempfile::NamedTempFile; + use test_case::test_case; + + fn test_config() -> Config { + let config_file = NamedTempFile::new().unwrap(); + let secrets_file = NamedTempFile::new().unwrap(); + Config::new_with_file_secrets(config_file.path(), secrets_file.path()).unwrap() + } + + fn test_extension() -> ExtensionConfig { + ExtensionConfig::Builtin { + name: "developer".into(), + description: "dev".into(), + display_name: None, + timeout: None, + bundled: None, + available_tools: vec![], + } + } + + fn extension_data_with(extensions: Vec) -> ExtensionData { + let mut data = ExtensionData::new(); + EnabledExtensionsState::new(extensions) + .to_extension_data(&mut data) + .unwrap(); + data + } + + #[test_case( + Some(extension_data_with(vec![test_extension()])), + Some(vec![test_extension()]) + ; "prefers_session_data" + )] + #[test_case(None, None ; "no_session_falls_back_to_config")] + #[test_case(Some(ExtensionData::default()), None ; "empty_session_data_falls_back_to_config")] + fn test_extensions_or_default( + extension_data: Option, + expected: Option>, + ) { + let config = test_config(); + let expected = expected.unwrap_or_else(|| { + crate::config::extensions::get_enabled_extensions_with_config(&config) + }); + assert_eq!( + EnabledExtensionsState::extensions_or_default(extension_data.as_ref(), &config), + expected, + ); + } #[test] fn test_extension_data_basic_operations() { diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index de21e37f7bab..8db987786c18 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -374,6 +374,7 @@ mod tests { fn from_env( _model: ModelConfig, + _extensions: Vec, ) -> futures::future::BoxFuture<'static, anyhow::Result> { Box::pin(async { Ok(Self::new()) }) } diff --git a/crates/goose/tests/compaction.rs b/crates/goose/tests/compaction.rs index 49efc2c9dfd1..29755c81c0f1 100644 --- a/crates/goose/tests/compaction.rs +++ b/crates/goose/tests/compaction.rs @@ -191,7 +191,10 @@ impl ProviderDef for MockCompactionProvider { } } - fn from_env(_model: ModelConfig) -> futures::future::BoxFuture<'static, anyhow::Result> { + fn from_env( + _model: ModelConfig, + _extensions: Vec, + ) -> futures::future::BoxFuture<'static, anyhow::Result> { Box::pin(async { Ok(Self::new()) }) } } diff --git a/crates/goose/tests/mcp_integration_test.rs b/crates/goose/tests/mcp_integration_test.rs index 1f0b724d052a..b5343b6d818a 100644 --- a/crates/goose/tests/mcp_integration_test.rs +++ b/crates/goose/tests/mcp_integration_test.rs @@ -54,7 +54,10 @@ impl ProviderDef for MockProvider { ProviderMetadata::empty() } - fn from_env(model: ModelConfig) -> futures::future::BoxFuture<'static, anyhow::Result> { + fn from_env( + model: ModelConfig, + _extensions: Vec, + ) -> futures::future::BoxFuture<'static, anyhow::Result> { Box::pin(async move { Ok(Self::new(model)) }) } } diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index 9eb65818f524..bf515c6a31a1 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -351,13 +351,10 @@ impl ProviderTester { self.test_model_listing().await?; self.test_basic_response(&self.session_id_for_test("basic_response")) .await?; - // TODO: remove skip in https://github.com/block/goose/pull/6972 - if !self.is_cli_provider { - self.test_tool_usage(&self.session_id_for_test("tool_usage")) - .await?; - self.test_image_content_support(&self.session_id_for_test("image_content")) - .await?; - } + self.test_tool_usage(&self.session_id_for_test("tool_usage")) + .await?; + self.test_image_content_support(&self.session_id_for_test("image_content")) + .await?; if self.model_switch_name.is_some() { self.test_model_switch(&self.session_id_for_test("model_switch")) .await?; @@ -443,7 +440,13 @@ async fn test_provider( let mcp_extension = ExtensionConfig::streamable_http("mcp-fixture", &mcp.url, "MCP fixture", 30_u64); - let provider = match create_with_named_model(&provider_name, model_name).await { + let provider = match create_with_named_model( + &provider_name, + model_name, + vec![mcp_extension.clone()], + ) + .await + { Ok(p) => p, Err(e) => { println!("Skipping {} tests - failed to create provider: {}", name, e);