diff --git a/Cargo.lock b/Cargo.lock index 16253974ce18..c8ccc6d7c33c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -38,7 +38,7 @@ dependencies = [ "getrandom 0.2.15", "once_cell", "version_check", - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -1300,6 +1300,15 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" +[[package]] +name = "colored" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fde0e0ec90c9dfb3b4b1a0891a7dcd0e2bffde2f7efed5fe7c9bb00e5bfb915e" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "combine" version = "4.6.7" @@ -1641,7 +1650,7 @@ dependencies = [ "futures-util", "num", "once_cell", - "rand", + "rand 0.8.5", ] [[package]] @@ -1916,7 +1925,7 @@ dependencies = [ "hyper-timeout", "log", "pin-project", - "rand", + "rand 0.8.5", "tokio", ] @@ -2313,10 +2322,11 @@ dependencies = [ "mcp-core", "minijinja", "mockall", + "mockito", "nanoid", "once_cell", "paste", - "rand", + "rand 0.8.5", "regex", "reqwest 0.12.12", "serde", @@ -2381,7 +2391,7 @@ dependencies = [ "mcp-core", "mcp-server", "once_cell", - "rand", + "rand 0.8.5", "reqwest 0.12.12", "rustyline", "serde", @@ -3527,7 +3537,7 @@ dependencies = [ "eventsource-client", "futures", "mcp-core", - "rand", + "rand 0.8.5", "reqwest 0.11.27", "serde", "serde_json", @@ -3678,6 +3688,30 @@ dependencies = [ "syn 2.0.99", ] +[[package]] +name = "mockito" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7760e0e418d9b7e5777c0374009ca4c93861b9066f18cb334a20ce50ab63aa48" +dependencies = [ + "assert-json-diff", + "bytes", + "colored", + "futures-util", + "http 1.2.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "log", + "rand 0.9.0", + "regex", + "serde_json", + "serde_urlencoded", + "similar", + "tokio", +] + [[package]] name = "monostate" version = "0.1.14" @@ -3705,7 +3739,7 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3ffa00dec017b5b1a8b7cf5e2c008bfda1aa7e0697ac1508b491fdf2622fb4d8" dependencies = [ - "rand", + "rand 0.8.5", ] [[package]] @@ -3932,7 +3966,7 @@ dependencies = [ "chrono", "getrandom 0.2.15", "http 1.2.0", - "rand", + "rand 0.8.5", "reqwest 0.12.12", "serde", "serde_json", @@ -4243,7 +4277,7 @@ version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" dependencies = [ - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -4419,7 +4453,7 @@ checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" dependencies = [ "bytes", "getrandom 0.2.15", - "rand", + "rand 0.8.5", "ring", "rustc-hash 2.1.1", "rustls 0.23.23", @@ -4471,8 +4505,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha", - "rand_core", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.3", + "zerocopy 0.8.24", ] [[package]] @@ -4482,7 +4527,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.3", ] [[package]] @@ -4494,6 +4549,15 @@ dependencies = [ "getrandom 0.2.15", ] +[[package]] +name = "rand_core" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +dependencies = [ + "getrandom 0.3.1", +] + [[package]] name = "rangemap" version = "1.5.1" @@ -4526,8 +4590,8 @@ dependencies = [ "once_cell", "paste", "profiling", - "rand", - "rand_chacha", + "rand 0.8.5", + "rand_chacha 0.3.1", "simd_helpers", "system-deps", "thiserror 1.0.69", @@ -4723,6 +4787,7 @@ dependencies = [ "cookie", "cookie_store", "encoding_rs", + "futures-channel", "futures-core", "futures-util", "h2 0.4.8", @@ -5323,6 +5388,12 @@ dependencies = [ "quote", ] +[[package]] +name = "similar" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" + [[package]] name = "simple_asn1" version = "0.6.3" @@ -5800,7 +5871,7 @@ dependencies = [ "monostate", "onig", "paste", - "rand", + "rand 0.8.5", "rayon", "rayon-cond", "regex", @@ -6109,7 +6180,7 @@ dependencies = [ "http 1.2.0", "httparse", "log", - "rand", + "rand 0.8.5", "sha1", "thiserror 1.0.69", "utf-8", @@ -7108,7 +7179,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ "byteorder", - "zerocopy-derive", + "zerocopy-derive 0.7.35", +] + +[[package]] +name = "zerocopy" +version = "0.8.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2586fea28e186957ef732a5f8b3be2da217d65c5969d4b1e17f973ebbe876879" +dependencies = [ + "zerocopy-derive 0.8.24", ] [[package]] @@ -7122,6 +7202,17 @@ dependencies = [ "syn 2.0.99", ] +[[package]] +name = "zerocopy-derive" +version = "0.8.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a996a8f63c5c4448cd959ac1bab0aaa3306ccfd060472f85943ee0750f0169be" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + [[package]] name = "zerofrom" version = "0.1.6" diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index d7ab4b565fd7..f44376d53780 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -27,7 +27,8 @@ reqwest = { version = "0.12.9", features = [ "zstd", "charset", "http2", - "stream" + "stream", + "blocking" ], default-features = false } tokio = { version = "1.0", features = ["full"] } serde = { version = "1.0", features = ["derive"] } @@ -82,6 +83,7 @@ tempfile = "3.15.0" serial_test = "3.2.0" mockall = "0.13.1" wiremock = "0.6.0" +mockito = "1.4.0" tokio = { version = "1.0", features = ["full"] } [[example]] diff --git a/crates/goose/src/agents/capabilities.rs b/crates/goose/src/agents/capabilities.rs index 682724288db7..248952bf6ad2 100644 --- a/crates/goose/src/agents/capabilities.rs +++ b/crates/goose/src/agents/capabilities.rs @@ -111,6 +111,9 @@ impl Capabilities { /// Add a new MCP extension based on the provided client type // TODO IMPORTANT need to ensure this times out if the extension command is broken! pub async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()> { + // Validate the command before creating the transport + config.validate_command().map_err(|e| *e)?; + let mut client: Box = match &config { ExtensionConfig::Sse { uri, envs, timeout, .. diff --git a/crates/goose/src/agents/extension.rs b/crates/goose/src/agents/extension.rs index d44de5a36110..01c532c1a1d5 100644 --- a/crates/goose/src/agents/extension.rs +++ b/crates/goose/src/agents/extension.rs @@ -1,9 +1,12 @@ use std::collections::HashMap; +use std::fs; +use etcetera::{choose_app_strategy, AppStrategy, AppStrategyArgs}; use mcp_client::client::Error as ClientError; +use reqwest; use serde::{Deserialize, Serialize}; use thiserror::Error; -use tracing::warn; +use tracing::{info, warn}; use utoipa::ToSchema; use crate::config; @@ -22,6 +25,10 @@ pub enum ExtensionError { Transport(#[from] mcp_client::transport::Error), #[error("Environment variable `{0}` is not allowed to be overridden.")] InvalidEnvVar(String), + #[error("Command `{0}` is not in the allowed extensions list")] + UnauthorizedCommand(String), + #[error("Allowlist error: {0}")] + AllowlistError(String), } pub type ExtensionResult = Result; @@ -48,7 +55,7 @@ impl Envs { "LD_AUDIT", // Loads a monitoring library that can intercept execution "LD_DEBUG", // Enables verbose linker logging (information disclosure risk) "LD_BIND_NOW", // Forces immediate symbol resolution, affecting ASLR - "LD_ASSUME_KERNEL", // Tricks linker into thinking itโ€™s running on an older kernel + "LD_ASSUME_KERNEL", // Tricks linker into thinking it's running on an older kernel // ๐ŸŽ macOS dynamic linker variables "DYLD_LIBRARY_PATH", // Same as LD_LIBRARY_PATH but for macOS "DYLD_INSERT_LIBRARIES", // macOS equivalent of LD_PRELOAD @@ -161,6 +168,172 @@ impl Default for ExtensionConfig { } } +/// Check if a command is in the allowed extensions list +/// +/// This function checks if the command is allowed according to the allowlist. +/// If GOOSE_MCP_ALLOWLIST_URL is set, it will download the allowlist from that URL +/// and save it to ~/.config/goose/mcp_allowlist.yaml. +/// +/// The function will then check if the command is allowed according to the downloaded +/// allowlist file. +/// +/// If GOOSE_MCP_ALLOWLIST_URL is not set, all commands are allowed. +pub fn is_command_allowed(cmd: &str) -> Result<(), Box> { + // Check if GOOSE_MCP_ALLOWLIST_URL is set + if let Ok(url) = std::env::var("GOOSE_MCP_ALLOWLIST_URL") { + // Get the path where the allowlist would be stored + let app_strategy = AppStrategyArgs { + top_level_domain: "Block".to_string(), + author: "Block".to_string(), + app_name: "goose".to_string(), + }; + + let allowlist_path = match choose_app_strategy(app_strategy) { + Ok(strategy) => strategy.config_dir().join("mcp_allowlist.yaml"), + Err(e) => { + warn!("Failed to determine allowlist path: {}", e); + return Ok(()); // Allow the command if we can't even determine the path + } + }; + + let path_str = allowlist_path.to_string_lossy().to_string(); + + // Always try to download the allowlist if URL is set + let download_result = download_allowlist(&url); + + // Whether download succeeded or failed, try to use the file if it exists + if let Ok(content) = std::fs::read_to_string(&allowlist_path) { + // Parse the YAML file + if let Ok(yaml) = serde_yaml::from_str::(&content) { + // Extract the extensions list + if let Some(extensions) = yaml.get("extensions") { + if let Some(extensions_array) = extensions.as_sequence() { + // Create a list of allowed commands + let allowed_commands: Vec = extensions_array + .iter() + .filter_map(|v| { + v.get("command") + .and_then(|c| c.as_str()) + .map(|command| command.trim().to_string()) + }) + .collect(); + + // Require exact match for security + if !allowed_commands.contains(&cmd.to_string()) { + return Err(Box::new(ExtensionError::UnauthorizedCommand( + cmd.to_string(), + ))); + } + } + } + } + } else if download_result.is_err() { + // Only log a warning if both download failed AND we couldn't read the file + warn!( + "Failed to download allowlist AND couldn't read existing file at {}: {:?}", + path_str, + download_result.err() + ); + } + } + + // If no URL is set or everything passed, allow the command + Ok(()) +} + +/// Download the allowlist from a URL and save it to the config directory +/// +/// This function downloads the allowlist from the specified URL and saves it to +/// ~/.config/goose/mcp_allowlist.yaml. It will create the directory if it doesn't exist. +/// +/// Returns the path to the downloaded file. +pub fn download_allowlist(url: &str) -> Result> { + // Define app strategy for consistent config paths + let app_strategy = AppStrategyArgs { + top_level_domain: "Block".to_string(), + author: "Block".to_string(), + app_name: "goose".to_string(), + }; + + // Get the config directory (~/.config/goose/ on macOS/Linux) + let config_dir = choose_app_strategy(app_strategy) + .map_err(|e| { + Box::new(ExtensionError::AllowlistError(format!( + "Failed to get config directory: {}", + e + ))) + })? + .config_dir(); + + // Create the directory if it doesn't exist + fs::create_dir_all(&config_dir).map_err(|e| { + Box::new(ExtensionError::AllowlistError(format!( + "Failed to create directory: {}", + e + ))) + })?; + + // Define the path for the allowlist file + let allowlist_path = config_dir.join("mcp_allowlist.yaml"); + let path_str = allowlist_path.to_string_lossy().to_string(); + + // Download the allowlist file + info!("Downloading allowlist from {}", url); + + // Create a client with a timeout + let client = reqwest::blocking::Client::builder() + .timeout(std::time::Duration::from_secs(10)) // 10 second timeout + .build() + .map_err(|e| { + Box::new(ExtensionError::AllowlistError(format!( + "Failed to create HTTP client: {}", + e + ))) + })?; + + // Make the request + let response = client.get(url).send().map_err(|e| { + Box::new(ExtensionError::AllowlistError(format!( + "HTTP request failed: {}", + e + ))) + })?; + + if !response.status().is_success() { + return Err(Box::new(ExtensionError::AllowlistError(format!( + "HTTP error: {}", + response.status() + )))); + } + + let content = response.text().map_err(|e| { + Box::new(ExtensionError::AllowlistError(format!( + "Failed to read response body: {}", + e + ))) + })?; + + // Validate the YAML format + serde_yaml::from_str::(&content).map_err(|e| { + Box::new(ExtensionError::AllowlistError(format!( + "Invalid YAML: {}", + e + ))) + })?; + + // Write the content to the file + fs::write(&allowlist_path, content).map_err(|e| { + Box::new(ExtensionError::AllowlistError(format!( + "Failed to write file: {}", + e + ))) + })?; + + info!("Allowlist downloaded and saved to {}", path_str); + + Ok(path_str) +} + impl ExtensionConfig { pub fn sse, T: Into>(name: S, uri: S, description: S, timeout: T) -> Self { Self::Sse { @@ -227,6 +400,14 @@ impl ExtensionConfig { } .to_string() } + + /// Check if this extension's command is allowed + pub fn validate_command(&self) -> Result<(), Box> { + if let Self::Stdio { cmd, .. } = self { + is_command_allowed(cmd)?; + } + Ok(()) + } } impl std::fmt::Display for ExtensionConfig { @@ -278,3 +459,146 @@ impl ToolInfo { } } } + +#[cfg(test)] +mod tests { + use super::*; + use mockito; + use std::env; + use std::fs::File; + use std::io::Write; + use tempfile::tempdir; + + #[test] + fn test_no_allowlist_url() { + // Make sure the environment variable is not set + env::remove_var("GOOSE_MCP_ALLOWLIST_URL"); + + // Without an allowlist URL, all commands should be allowed + assert!(is_command_allowed("any-command").is_ok()); + } + + #[test] + fn test_allowlist_with_new_format() { + // This test manually creates a file and checks the command validation logic + // Create a temporary directory + let temp_dir = tempdir().expect("Failed to create temp dir"); + let file_path = temp_dir.path().join("allowed_extensions.yaml"); + + // Create a whitelist file with the new format that includes id and command + let mut file = File::create(&file_path).expect("Failed to create allowlist file"); + writeln!(file, "extensions:").expect("Failed to write to allowlist file"); + writeln!(file, " - id: slack").expect("Failed to write to allowlist file"); + writeln!(file, " command: uvx mcp_slack").expect("Failed to write to allowlist file"); + writeln!(file, " - id: python").expect("Failed to write to allowlist file"); + writeln!(file, " command: python").expect("Failed to write to allowlist file"); + file.flush().expect("Failed to flush allowlist file"); + + // Test with allowed commands (using our mock function) + let allowed_commands = ["uvx mcp_slack", "python"]; + for cmd in allowed_commands.iter() { + // Read the file and check if command is allowed + let content = std::fs::read_to_string(&file_path).expect("Failed to read file"); + let yaml = + serde_yaml::from_str::(&content).expect("Failed to parse YAML"); + let extensions = yaml.get("extensions").expect("No extensions found"); + let extensions_array = extensions + .as_sequence() + .expect("Extensions is not an array"); + + let allowed_commands: Vec = extensions_array + .iter() + .filter_map(|v| { + if let Some(command) = v.get("command").and_then(|c| c.as_str()) { + Some(command.trim().to_string()) + } else { + None + } + }) + .collect(); + + assert!(allowed_commands.contains(&cmd.to_string())); + } + + // Test with a command not in the allowlist + let content = std::fs::read_to_string(&file_path).expect("Failed to read file"); + let yaml = + serde_yaml::from_str::(&content).expect("Failed to parse YAML"); + let extensions = yaml.get("extensions").expect("No extensions found"); + let extensions_array = extensions + .as_sequence() + .expect("Extensions is not an array"); + + let allowed_commands: Vec = extensions_array + .iter() + .filter_map(|v| { + if let Some(command) = v.get("command").and_then(|c| c.as_str()) { + Some(command.trim().to_string()) + } else { + None + } + }) + .collect(); + + assert!(!allowed_commands.contains(&"not-in-allowlist".to_string())); + } + + #[test] + #[ignore] // This test requires network access, so we ignore it by default + fn test_download_allowlist() { + // Setup a mock server + let mut server = mockito::Server::new(); + + // Mock with any number of calls + let _mock = server + .mock("GET", "/allowlist.yaml") + .with_status(200) + .with_body( + "extensions: + - id: slack + command: uvx mcp_slack + - id: python + command: python", + ) + .create(); + + // Set the URL environment variable to point to our mock server + env::set_var( + "GOOSE_MCP_ALLOWLIST_URL", + format!("{}/allowlist.yaml", server.url()), + ); + + // Test that a command is allowed after downloading the allowlist + assert!(is_command_allowed("uvx mcp_slack").is_ok()); + + // Test that a command not in the allowlist is rejected + assert!(is_command_allowed("not-in-allowlist").is_err()); + + // Clean up + env::remove_var("GOOSE_MCP_ALLOWLIST_URL"); + } + + #[test] + #[ignore] // This test requires network access, so we ignore it by default + fn test_download_allowlist_failure() { + // Setup a mock server that returns an error + let mut server = mockito::Server::new(); + let _mock = server + .mock("GET", "/allowlist.yaml") + .with_status(404) + .with_body("Not Found") + .create(); + + // Set the URL environment variable to point to our mock server + env::set_var( + "GOOSE_MCP_ALLOWLIST_URL", + format!("{}/allowlist.yaml", server.url()), + ); + + // Test that command validation fails when download fails + assert!(is_command_allowed("any-command").is_err()); + + // Clean up + env::remove_var("GOOSE_MCP_ALLOWLIST_URL"); + } +} diff --git a/crates/goose/src/config/extensions.rs b/crates/goose/src/config/extensions.rs index 9c31add82287..b7ba230b0cc4 100644 --- a/crates/goose/src/config/extensions.rs +++ b/crates/goose/src/config/extensions.rs @@ -1,5 +1,5 @@ use super::base::Config; -use crate::agents::ExtensionConfig; +use crate::agents::extension::{ExtensionConfig, ExtensionError}; use anyhow::Result; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -65,6 +65,14 @@ impl ExtensionManager { /// Set or update an extension configuration pub fn set(entry: ExtensionEntry) -> Result<()> { + // Validate the command before saving + entry.config.validate_command().map_err(|e| match *e { + ExtensionError::UnauthorizedCommand(cmd) => { + anyhow::anyhow!("Command '{}' is not in the allowed extensions list", cmd) + } + _ => anyhow::anyhow!("Failed to validate command: {}", e), + })?; + let config = Config::global(); let mut extensions: HashMap = config