diff --git a/Cargo.lock b/Cargo.lock index b4f3c4aae229..e174d8c0c438 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1151,6 +1151,12 @@ dependencies = [ "vsimd", ] +[[package]] +name = "base64ct" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba" + [[package]] name = "bat" version = "0.24.0" @@ -1499,7 +1505,7 @@ version = "0.15.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d067ad48b8650848b989a59a86c6c36a995d02d2bf778d45c3c5d57bc2718f02" dependencies = [ - "smallvec", + "smallvec 1.14.0", "target-lexicon", ] @@ -2511,6 +2517,16 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "der" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" +dependencies = [ + "pem-rfc7468", + "zeroize", +] + [[package]] name = "deranged" version = "0.3.11" @@ -2532,6 +2548,37 @@ dependencies = [ "syn 2.0.99", ] +[[package]] +name = "derive_builder" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" +dependencies = [ + "darling 0.20.10", + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" +dependencies = [ + "derive_builder_core", + "syn 2.0.99", +] + [[package]] name = "digest" version = "0.10.7" @@ -2705,6 +2752,12 @@ version = "3.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a5d9305ccc6942a704f4335694ecd3de2ea531b114ac2d51f5f843750787a92f" +[[package]] +name = "esaxx-rs" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6" + [[package]] name = "etcetera" version = "0.8.0" @@ -2760,7 +2813,7 @@ dependencies = [ "lebe", "miniz_oxide", "rayon-core", - "smallvec", + "smallvec 1.14.0", "zune-inflate", ] @@ -3321,6 +3374,7 @@ dependencies = [ "opentelemetry", "opentelemetry-otlp", "opentelemetry_sdk", + "ort", "rand 0.8.5", "regex", "reqwest 0.12.12", @@ -3336,6 +3390,7 @@ dependencies = [ "test-case", "thiserror 1.0.69", "tiktoken-rs", + "tokenizers", "tokio", "tokio-cron-scheduler", "tokio-stream", @@ -3813,7 +3868,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "smallvec", + "smallvec 1.14.0", "tokio", "want", ] @@ -3985,7 +4040,7 @@ dependencies = [ "icu_normalizer_data", "icu_properties", "icu_provider", - "smallvec", + "smallvec 1.14.0", "utf16_iter", "utf8_iter", "write16", @@ -4060,7 +4115,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" dependencies = [ "idna_adapter", - "smallvec", + "smallvec 1.14.0", "utf8_iter", ] @@ -4283,6 +4338,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.12.1" @@ -5126,6 +5190,22 @@ dependencies = [ "libc", ] +[[package]] +name = "macro_rules_attribute" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65049d7923698040cd0b1ddcced9b0eb14dd22c5f86ae59c3740eab64a676520" +dependencies = [ + "macro_rules_attribute-proc_macro", + "paste", +] + +[[package]] +name = "macro_rules_attribute-proc_macro" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30" + [[package]] name = "malloc_buf" version = "0.0.6" @@ -5156,6 +5236,16 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "maybe-rayon" version = "0.1.1" @@ -5401,13 +5491,34 @@ dependencies = [ "rustc_version", "scheduled-thread-pool", "skeptic", - "smallvec", + "smallvec 1.14.0", "tagptr", "thiserror 1.0.69", "triomphe", "uuid", ] +[[package]] +name = "monostate" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aafe1be9d0c75642e3e50fedc7ecadf1ef1cbce6eb66462153fc44245343fbee" +dependencies = [ + "monostate-impl", + "serde", +] + +[[package]] +name = "monostate-impl" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c402a4092d5e204f32c9e155431046831fa712637043c58cb73bc6bc6c9663b5" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + [[package]] name = "multimap" version = "0.10.1" @@ -5429,6 +5540,38 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "native-tls" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework 2.11.1", + "security-framework-sys", + "tempfile", +] + +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "ndk-context" version = "0.1.1" @@ -5447,7 +5590,7 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77a5d83df9f36fe23f0c3648c6bbb8b0298bb5f1939c8f2704431371f4b84d43" dependencies = [ - "smallvec", + "smallvec 1.14.0", ] [[package]] @@ -5936,6 +6079,31 @@ dependencies = [ "hashbrown 0.14.5", ] +[[package]] +name = "ort" +version = "2.0.0-rc.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa7e49bd669d32d7bc2a15ec540a527e7764aec722a45467814005725bcd721" +dependencies = [ + "ndarray", + "ort-sys", + "smallvec 2.0.0-alpha.10", + "tracing", +] + +[[package]] +name = "ort-sys" +version = "2.0.0-rc.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2aba9f5c7c479925205799216e7e5d07cc1d4fa76ea8058c60a9a30f6a4e890" +dependencies = [ + "flate2", + "pkg-config", + "sha2", + "tar", + "ureq", +] + [[package]] name = "outref" version = "0.5.2" @@ -5982,7 +6150,7 @@ dependencies = [ "cfg-if", "libc", "redox_syscall", - "smallvec", + "smallvec 1.14.0", "windows-targets 0.52.6", ] @@ -6029,6 +6197,15 @@ dependencies = [ "serde", ] +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -6248,6 +6425,15 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -6754,6 +6940,12 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2ff9a1f06a88b01621b7ae906ef0211290d1c8a168a15542486a8f61c0833b9" +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.10.0" @@ -6764,6 +6956,17 @@ dependencies = [ "rayon-core", ] +[[package]] +name = "rayon-cond" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "059f538b55efd2309c9794130bc149c6a553db90e9d99c2030785c82f0bd7df9" +dependencies = [ + "either", + "itertools 0.11.0", + "rayon", +] + [[package]] name = "rayon-core" version = "1.12.1" @@ -7714,6 +7917,12 @@ version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd" +[[package]] +name = "smallvec" +version = "2.0.0-alpha.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d44cfb396c3caf6fbfd0ab422af02631b69ddd96d2eff0b0f0724f9024051b" + [[package]] name = "smawk" version = "0.3.2" @@ -7762,6 +7971,29 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "socks" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" +dependencies = [ + "byteorder", + "libc", + "winapi", +] + +[[package]] +name = "spm_precompiled" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326" +dependencies = [ + "base64 0.13.1", + "nom", + "serde", + "unicode-segmentation", +] + [[package]] name = "sqlparser" version = "0.49.0" @@ -8028,7 +8260,7 @@ dependencies = [ "serde", "serde_json", "sketches-ddsketch", - "smallvec", + "smallvec 1.14.0", "tantivy-bitpacker", "tantivy-columnar", "tantivy-common", @@ -8415,6 +8647,37 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tokenizers" +version = "0.20.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b08cc37428a476fc9e20ac850132a513a2e1ce32b6a31addf2b74fa7033b905" +dependencies = [ + "aho-corasick", + "derive_builder", + "esaxx-rs", + "getrandom 0.2.15", + "itertools 0.12.1", + "lazy_static", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand 0.8.5", + "rayon", + "rayon-cond", + "regex", + "regex-syntax 0.8.5", + "serde", + "serde_json", + "spm_precompiled", + "thiserror 1.0.69", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + [[package]] name = "tokio" version = "1.43.1" @@ -8728,7 +8991,7 @@ dependencies = [ "once_cell", "opentelemetry", "opentelemetry_sdk", - "smallvec", + "smallvec 1.14.0", "tracing", "tracing-core", "tracing-log", @@ -8759,7 +9022,7 @@ dependencies = [ "serde", "serde_json", "sharded-slab", - "smallvec", + "smallvec 1.14.0", "thread_local", "time", "tracing", @@ -8875,6 +9138,15 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-normalization-alignments" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de" +dependencies = [ + "smallvec 1.14.0", +] + [[package]] name = "unicode-segmentation" version = "1.12.0" @@ -8893,6 +9165,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" +[[package]] +name = "unicode_categories" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" + [[package]] name = "unsafe-libyaml" version = "0.2.11" @@ -8905,6 +9183,37 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "ureq" +version = "3.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f0fde9bc91026e381155f8c67cb354bcd35260b2f4a29bcc84639f762760c39" +dependencies = [ + "base64 0.22.1", + "der", + "log", + "native-tls", + "percent-encoding", + "rustls-pemfile 2.2.0", + "rustls-pki-types", + "socks", + "ureq-proto", + "utf-8", + "webpki-root-certs 0.26.11", +] + +[[package]] +name = "ureq-proto" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59db78ad1923f2b1be62b6da81fe80b173605ca0d57f85da2e005382adf693f7" +dependencies = [ + "base64 0.22.1", + "http 1.2.0", + "httparse", + "log", +] + [[package]] name = "url" version = "2.5.4" @@ -9218,6 +9527,24 @@ dependencies = [ "web-sys", ] +[[package]] +name = "webpki-root-certs" +version = "0.26.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75c7f0ef91146ebfb530314f5f1d24528d7f0767efbfd31dce919275413e393e" +dependencies = [ + "webpki-root-certs 1.0.2", +] + +[[package]] +name = "webpki-root-certs" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e4ffd8df1c57e87c325000a3d6ef93db75279dc3a231125aac571650f22b12a" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "webpki-roots" version = "0.26.8" diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index dbc74d009c32..f6ba27319a28 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -101,6 +101,10 @@ unicode-normalization = "0.1" lancedb = "0.13" arrow = "52.2" +# ML inference backends for security scanning +ort = "2.0.0-rc.10" # ONNX Runtime - use latest RC +tokenizers = { version = "0.20.4", default-features = false, features = ["onig"] } # HuggingFace tokenizers + [target.'cfg(target_os = "windows")'.dependencies] winapi = { version = "0.3", features = ["wincred"] } diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 927e011e3f01..0f7bc8eed065 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -62,6 +62,7 @@ use crate::agents::todo_tools::{ todo_read_tool, todo_write_tool, TODO_READ_TOOL_NAME, TODO_WRITE_TOOL_NAME, }; use crate::conversation::message::{Message, ToolRequest}; +use crate::security::SecurityManager; const DEFAULT_MAX_TURNS: u32 = 1000; @@ -102,6 +103,7 @@ pub struct Agent { pub(super) tool_route_manager: ToolRouteManager, pub(super) scheduler_service: Mutex>>, pub(super) retry_manager: RetryManager, + pub(super) security_manager: SecurityManager, pub(super) todo_list: Arc>, } @@ -188,6 +190,7 @@ impl Agent { tool_route_manager: ToolRouteManager::new(), scheduler_service: Mutex::new(None), retry_manager, + security_manager: SecurityManager::new(), todo_list: Arc::new(Mutex::new(String::new())), } } @@ -1078,6 +1081,29 @@ impl Agent { ); } } else { + // Check if we need to show model download status before security scanning + if let Some(download_message) = self.security_manager.check_model_download_status().await { + yield AgentEvent::Message(Message::assistant().with_text(download_message)); + } + + // SECURITY FIX: Scan tools for prompt injection BEFORE permission checking + // This ensures security results can override auto-mode approvals + let initial_permission_result = PermissionCheckResult { + approved: remaining_requests.clone(), + needs_approval: vec![], + denied: vec![], + }; + + println!("πŸ” DEBUG: About to call security manager with {} total tools", remaining_requests.len()); + let security_results = self.security_manager + .filter_malicious_tool_calls(messages.messages(), &initial_permission_result, Some(&system_prompt)) + .await + .unwrap_or_else(|e| { + tracing::warn!("Security scanning failed: {}", e); + vec![] + }); + + // Now run permission checking with security context let mut permission_manager = PermissionManager::default(); let (permission_check_result, enable_extension_request_ids) = check_tool_permissions( @@ -1089,21 +1115,45 @@ impl Agent { self.provider().await?, ).await; + // Apply security results to override permission decisions + let final_permission_result = self.apply_security_results_to_permissions( + permission_check_result, + &security_results + ).await; + + println!("πŸ” DEBUG: After security integration - {} approved, {} need approval, {} denied", + final_permission_result.approved.len(), + final_permission_result.needs_approval.len(), + final_permission_result.denied.len()); + let mut tool_futures = self.handle_approved_and_denied_tools( - &permission_check_result, + &final_permission_result, message_tool_response.clone(), cancel_token.clone() ).await?; let tool_futures_arc = Arc::new(Mutex::new(tool_futures)); - // Process tools requiring approval - let mut tool_approval_stream = self.handle_approval_tool_requests( - &permission_check_result.needs_approval, + // Process tools requiring approval (including security-flagged tools) + // Create a mapping of security results for tools that need approval + let mut security_results_for_approval: Vec> = Vec::new(); + for _approval_request in &final_permission_result.needs_approval { + // Find the corresponding security result for this tool request + let security_result = security_results.iter().find(|result| { + // Match by checking if this tool was flagged as malicious + // This is a simplified matching - ideally we'd have better tool request tracking + result.is_malicious + }); + security_results_for_approval.push(security_result); + } + + let mut tool_approval_stream = self.handle_approval_tool_requests_with_security( + &final_permission_result.needs_approval, tool_futures_arc.clone(), &mut permission_manager, message_tool_response.clone(), cancel_token.clone(), + Some(&security_results_for_approval), ); while let Some(msg) = tool_approval_stream.try_next().await? { @@ -1230,6 +1280,98 @@ impl Agent { } } + /// Apply security scan results to permission check results + /// This integrates security scanning with the existing tool approval system + async fn apply_security_results_to_permissions( + &self, + mut permission_result: PermissionCheckResult, + security_results: &[crate::security::SecurityResult], + ) -> PermissionCheckResult { + if security_results.is_empty() { + return permission_result; + } + + // Create a map of tool requests by ID for easy lookup + let mut all_requests: std::collections::HashMap = + std::collections::HashMap::new(); + + // Collect all tool requests + for req in &permission_result.approved { + all_requests.insert(req.id.clone(), req.clone()); + } + for req in &permission_result.needs_approval { + all_requests.insert(req.id.clone(), req.clone()); + } + for req in &permission_result.denied { + all_requests.insert(req.id.clone(), req.clone()); + } + + // Collect the combined requests first to avoid borrowing issues + let combined_requests: Vec = permission_result + .approved + .iter() + .chain(permission_result.needs_approval.iter()) + .cloned() + .collect(); + + // Process security results + for (i, security_result) in security_results.iter().enumerate() { + if !security_result.is_malicious { + continue; + } + + // Find the corresponding tool request by index + if let Some(tool_request) = combined_requests.get(i) { + let request_id = &tool_request.id; + + tracing::warn!( + tool_request_id = %request_id, + confidence = security_result.confidence, + explanation = %security_result.explanation, + finding_id = %security_result.finding_id, + "πŸ”’ Security threat detected - modifying tool approval status" + ); + + // Remove from approved if present + permission_result + .approved + .retain(|req| req.id != *request_id); + + if security_result.should_ask_user { + // Move to needs_approval with security context + if let Some(request) = all_requests.get(request_id) { + // Only add if not already in needs_approval + if !permission_result + .needs_approval + .iter() + .any(|req| req.id == *request_id) + { + permission_result.needs_approval.push(request.clone()); + } + } + } else { + // High confidence threat - move to denied + permission_result + .needs_approval + .retain(|req| req.id != *request_id); + + if let Some(request) = all_requests.get(request_id) { + // Only add if not already in denied + if !permission_result + .denied + .iter() + .any(|req| req.id == *request_id) + { + permission_result.denied.push(request.clone()); + } + } + } + } + } + + permission_result + } + /// Extend the system prompt with one line of additional instruction pub async fn extend_system_prompt(&self, instruction: String) { let mut prompt_manager = self.prompt_manager.lock().await; diff --git a/crates/goose/src/agents/tool_execution.rs b/crates/goose/src/agents/tool_execution.rs index 045cdd0229dd..cf668bfb38c2 100644 --- a/crates/goose/src/agents/tool_execution.rs +++ b/crates/goose/src/agents/tool_execution.rs @@ -54,21 +54,99 @@ impl Agent { permission_manager: &'a mut PermissionManager, message_tool_response: Arc>, cancellation_token: Option, + ) -> BoxStream<'a, anyhow::Result> { + self.handle_approval_tool_requests_with_security( + tool_requests, + tool_futures, + permission_manager, + message_tool_response, + cancellation_token, + None, // No security context by default + ) + } + + pub(crate) fn handle_approval_tool_requests_with_security<'a>( + &'a self, + tool_requests: &'a [ToolRequest], + tool_futures: Arc>>, + permission_manager: &'a mut PermissionManager, + message_tool_response: Arc>, + cancellation_token: Option, + security_results: Option<&'a [Option<&'a crate::security::SecurityResult>]>, ) -> BoxStream<'a, anyhow::Result> { try_stream! { - for request in tool_requests { + for (i, request) in tool_requests.iter().enumerate() { if let Ok(tool_call) = request.tool_call.clone() { + // Check if this tool has security concerns + // Match by index since security results are provided in the same order as tool requests + let security_context = security_results + .and_then(|results| results.get(i)) + .and_then(|result| *result) + .filter(|result| result.is_malicious); + + let confirmation_prompt = if let Some(security_result) = security_context { + format!( + "🚨 SECURITY WARNING: This tool call has been flagged as potentially malicious.\n\ + Finding ID: {}\n\ + Confidence: {:.1}%\n\ + Reason: {}\n\n\ + Goose would still like to call the above tool. \n\ + Please review carefully. Allow? (y/n):", + security_result.finding_id, + security_result.confidence * 100.0, + security_result.explanation + ) + } else { + "Goose would like to call the above tool. Allow? (y/n):".to_string() + }; + let confirmation = Message::user().with_tool_confirmation_request( request.id.clone(), tool_call.name.clone(), tool_call.arguments.clone(), - Some("Goose would like to call the above tool. Allow? (y/n):".to_string()), + Some(confirmation_prompt), ); yield confirmation; let mut rx = self.confirmation_rx.lock().await; while let Some((req_id, confirmation)) = rx.recv().await { if req_id == request.id { + // Log user decision, especially for security-flagged tools + if let Some(security_result) = security_context { + match confirmation.permission { + Permission::AllowOnce | Permission::AlwaysAllow => { + tracing::warn!( + tool_name = %tool_call.name, + request_id = %request.id, + permission = ?confirmation.permission, + security_confidence = %format!("{:.1}%", security_result.confidence * 100.0), + security_reason = %security_result.explanation, + finding_id = %security_result.finding_id, + "πŸ”’ USER APPROVED security-flagged tool despite warning" + ); + } + _ => { + tracing::info!( + tool_name = %tool_call.name, + request_id = %request.id, + permission = ?confirmation.permission, + security_confidence = %format!("{:.1}%", security_result.confidence * 100.0), + security_reason = %security_result.explanation, + finding_id = %security_result.finding_id, + "πŸ”’ USER DENIED security-flagged tool" + ); + } + } + } else { + // Log regular tool decisions at debug level + tracing::debug!( + tool_name = %tool_call.name, + request_id = %request.id, + permission = ?confirmation.permission, + "πŸ”’ User decision for tool execution" + ); + } + if confirmation.permission == Permission::AllowOnce || confirmation.permission == Permission::AlwaysAllow { let (req_id, tool_result) = self.dispatch_tool_call(tool_call.clone(), request.id.clone(), cancellation_token.clone()).await; let mut futures = tool_futures.lock().await; diff --git a/crates/goose/src/lib.rs b/crates/goose/src/lib.rs index 7d774dddeddc..5609aaa74e8d 100644 --- a/crates/goose/src/lib.rs +++ b/crates/goose/src/lib.rs @@ -13,6 +13,7 @@ pub mod recipe_deeplink; pub mod scheduler; pub mod scheduler_factory; pub mod scheduler_trait; +pub mod security; pub mod session; pub mod temporal_scheduler; pub mod token_counter; diff --git a/crates/goose/src/security/mod.rs b/crates/goose/src/security/mod.rs new file mode 100644 index 000000000000..8ade755d417a --- /dev/null +++ b/crates/goose/src/security/mod.rs @@ -0,0 +1,291 @@ +pub mod model_downloader; +pub mod scanner; + +use crate::conversation::message::Message; +use crate::permission::permission_judge::PermissionCheckResult; +use anyhow::Result; +use scanner::PromptInjectionScanner; + +/// Simple security manager for the POC +/// Focuses on tool call analysis with conversation context +pub struct SecurityManager { + scanner: Option, +} + +#[derive(Debug, Clone)] +pub struct SecurityResult { + pub is_malicious: bool, + pub confidence: f32, + pub explanation: String, + pub should_ask_user: bool, + pub finding_id: String, +} + +impl SecurityManager { + pub fn new() -> Self { + println!("πŸ”’ SecurityManager::new() called - checking if security should be enabled"); + + // Initialize scanner based on config + let should_enable = Self::should_enable_security(); + println!("πŸ”’ Security enabled check result: {}", should_enable); + + let scanner = match should_enable { + true => { + println!("πŸ”’ Initializing security scanner"); + tracing::info!("πŸ”’ Initializing security scanner"); + Some(PromptInjectionScanner::new()) + } + false => { + println!("πŸ”“ Security scanning disabled"); + tracing::info!("πŸ”“ Security scanning disabled"); + None + } + }; + + Self { scanner } + } + + /// Check if security should be enabled based on config + fn should_enable_security() -> bool { + // Check config file for security settings + use crate::config::Config; + let config = Config::global(); + + // Try to get security.enabled from config + let result = config + .get_param::("security") + .ok() + .and_then(|security_config| security_config.get("enabled")?.as_bool()) + .unwrap_or(false); + + println!( + "πŸ”’ Config check - security config result: {:?}", + config.get_param::("security") + ); + println!("πŸ”’ Final security enabled result: {}", result); + + result + } + + /// Main security check function - called from reply_internal + /// Uses the proper two-step security analysis process + /// Scans ALL tools (approved + needs_approval) for security threats + /// Also scans system prompt if provided for persistent injection attacks + pub async fn filter_malicious_tool_calls( + &self, + messages: &[Message], + permission_check_result: &PermissionCheckResult, + system_prompt: Option<&str>, + ) -> Result> { + let Some(scanner) = &self.scanner else { + // Security disabled, return empty results + return Ok(vec![]); + }; + + let mut results = Vec::new(); + + // First, scan system prompt if provided for persistent injection attacks + if let Some(system_prompt) = system_prompt { + tracing::info!("πŸ” Scanning system prompt for persistent injection attacks"); + + let system_prompt_result = scanner.scan_system_prompt(system_prompt).await?; + + if system_prompt_result.is_malicious { + let finding_id = format!("SYS-{}", uuid::Uuid::new_v4().simple().to_string().to_uppercase()[..8].to_string()); + + tracing::warn!( + confidence = system_prompt_result.confidence, + explanation = %system_prompt_result.explanation, + finding_id = %finding_id, + "πŸ”’ System prompt contains persistent injection attack" + ); + + let config_threshold = scanner.get_threshold_from_config(); + + results.push(SecurityResult { + is_malicious: system_prompt_result.is_malicious, + confidence: system_prompt_result.confidence, + explanation: format!("System prompt injection: {}", system_prompt_result.explanation), + should_ask_user: system_prompt_result.confidence > config_threshold, + finding_id, + }); + } else { + tracing::debug!("βœ… System prompt passed security analysis"); + } + } + + // Check ALL tools (approved + needs_approval) for potential security issues + for tool_request in permission_check_result + .approved + .iter() + .chain(permission_check_result.needs_approval.iter()) + { + if let Ok(tool_call) = &tool_request.tool_call { + tracing::info!( + tool_name = %tool_call.name, + "πŸ” Starting two-step security analysis for tool call" + ); + + // Use the new two-step analysis method + let analysis_result = scanner + .analyze_tool_call_with_context(tool_call, messages) + .await?; + + if analysis_result.is_malicious { + // Generate a unique finding ID for this security detection + let finding_id = format!("SEC-{}", uuid::Uuid::new_v4().simple().to_string().to_uppercase()[..8].to_string()); + + tracing::warn!( + tool_name = %tool_call.name, + confidence = analysis_result.confidence, + explanation = %analysis_result.explanation, + finding_id = %finding_id, + "πŸ”’ Tool call flagged as malicious after two-step analysis" + ); + + // Get threshold from config - if confidence > threshold, ask user + let config_threshold = scanner.get_threshold_from_config(); + + results.push(SecurityResult { + is_malicious: analysis_result.is_malicious, + confidence: analysis_result.confidence, + explanation: analysis_result.explanation, + should_ask_user: analysis_result.confidence > config_threshold, + finding_id, + }); + } else { + tracing::debug!( + tool_name = %tool_call.name, + confidence = analysis_result.confidence, + explanation = %analysis_result.explanation, + "βœ… Tool call passed two-step security analysis" + ); + } + } + } + + Ok(results) + } + + /// Check if models need to be downloaded and return appropriate user message + pub async fn check_model_download_status(&self) -> Option { + let Some(_scanner) = &self.scanner else { + return None; + }; + + // Check if models are already available in memory + if let Some(_model) = scanner::get_model_if_available().await { + return None; // Models ready, no message needed + } + + // Check if models exist on disk but aren't loaded + if Self::models_exist_on_disk() { + return Some("πŸ”’ Loading security models...".to_string()); + } + + // Models need to be downloaded + Some( + "πŸ”’ Setting up security scanning for first time use - this could take a minute..." + .to_string(), + ) + } + + /// Scan recipe components for security threats + /// This should be called when loading/applying recipes + pub async fn scan_recipe_components(&self, recipe: &crate::recipe::Recipe) -> Result> { + let Some(scanner) = &self.scanner else { + // Security disabled, return empty results + return Ok(vec![]); + }; + + let mut results = Vec::new(); + + // Scan recipe prompt (becomes initial user message) + if let Some(prompt) = &recipe.prompt { + if !prompt.trim().is_empty() { + tracing::info!("πŸ” Scanning recipe prompt for injection attacks"); + + let prompt_result = scanner.scan_with_prompt_injection_model(prompt).await?; + + if prompt_result.is_malicious { + let finding_id = format!("RCP-{}", uuid::Uuid::new_v4().simple().to_string().to_uppercase()[..8].to_string()); + + tracing::warn!( + confidence = prompt_result.confidence, + explanation = %prompt_result.explanation, + finding_id = %finding_id, + "πŸ”’ Recipe prompt contains malicious content" + ); + + let config_threshold = scanner.get_threshold_from_config(); + + results.push(SecurityResult { + is_malicious: prompt_result.is_malicious, + confidence: prompt_result.confidence, + explanation: format!("Recipe prompt injection: {}", prompt_result.explanation), + should_ask_user: prompt_result.confidence > config_threshold, + finding_id, + }); + } + } + } + + // Scan recipe context (additional context data) + if let Some(context_items) = &recipe.context { + for (i, context_item) in context_items.iter().enumerate() { + if !context_item.trim().is_empty() { + tracing::info!("πŸ” Scanning recipe context item {} for injection attacks", i); + + let context_result = scanner.scan_with_prompt_injection_model(context_item).await?; + + if context_result.is_malicious { + let finding_id = format!("RCC-{}", uuid::Uuid::new_v4().simple().to_string().to_uppercase()[..8].to_string()); + + tracing::warn!( + context_index = i, + confidence = context_result.confidence, + explanation = %context_result.explanation, + finding_id = %finding_id, + "πŸ”’ Recipe context contains malicious content" + ); + + let config_threshold = scanner.get_threshold_from_config(); + + results.push(SecurityResult { + is_malicious: context_result.is_malicious, + confidence: context_result.confidence, + explanation: format!("Recipe context[{}] injection: {}", i, context_result.explanation), + should_ask_user: context_result.confidence > config_threshold, + finding_id, + }); + } + } + } + } + + Ok(results) + } + + /// Check if model files exist on disk + fn models_exist_on_disk() -> bool { + use crate::security::scanner::PromptInjectionScanner; + + let model_info = PromptInjectionScanner::get_model_info_from_config(); + + if let Some(cache_dir) = dirs::cache_dir() { + let security_models_dir = cache_dir.join("goose").join("security_models"); + let model_path = security_models_dir.join(&model_info.onnx_filename); + let tokenizer_path = security_models_dir.join(&model_info.tokenizer_filename); + + return model_path.exists() && tokenizer_path.exists(); + } + + false + } +} + +impl Default for SecurityManager { + fn default() -> Self { + Self::new() + } +} diff --git a/crates/goose/src/security/model_downloader.rs b/crates/goose/src/security/model_downloader.rs new file mode 100644 index 000000000000..3a97181f859d --- /dev/null +++ b/crates/goose/src/security/model_downloader.rs @@ -0,0 +1,621 @@ +use anyhow::anyhow; +use serde::Deserialize; +use std::path::{Path, PathBuf}; +use std::process::Command; +use tokio::fs; +use tokio::sync::OnceCell; + +pub struct ModelDownloader { + cache_dir: PathBuf, +} + +#[derive(Debug)] +pub enum ModelFormat { + OnnxDirect { + model_path: String, // e.g., "onnx/model.onnx" + tokenizer_path: String, // e.g., "onnx/tokenizer.json" + }, + OnnxCustomPaths { + model_path: String, // e.g., "model.onnx" (root level) + tokenizer_path: String, // e.g., "tokenizer.json" + }, + ConvertToOnnx, // Fallback: convert PyTorch + Unsupported, +} + +#[derive(Deserialize)] +struct RepoInfo { + siblings: Vec, +} + +#[derive(Deserialize)] +struct FileInfo { + rfilename: String, +} + +impl ModelDownloader { + pub fn new() -> anyhow::Result { + // Use platform-appropriate cache directory + let cache_dir = if let Some(cache_dir) = dirs::cache_dir() { + cache_dir.join("goose").join("security_models") + } else { + // Fallback to home directory + dirs::home_dir() + .ok_or_else(|| anyhow!("Could not determine home directory"))? + .join(".cache") + .join("goose") + .join("security_models") + }; + + Ok(Self { cache_dir }) + } + + pub async fn ensure_model_available( + &self, + model_info: &ModelInfo, + ) -> anyhow::Result<(PathBuf, PathBuf)> { + let model_path = self.cache_dir.join(&model_info.onnx_filename); + let tokenizer_path = self.cache_dir.join(&model_info.tokenizer_filename); + + // Check if both model and tokenizer exist + if model_path.exists() && tokenizer_path.exists() { + tracing::info!( + model = %model_info.hf_model_name, + path = ?model_path, + "Using cached ONNX model" + ); + return Ok((model_path, tokenizer_path)); + } + + tracing::info!( + model = %model_info.hf_model_name, + "πŸ”’ Goose is being set up, this could take up to a minute…" + ); + + // Create cache directory if it doesn't exist + fs::create_dir_all(&self.cache_dir).await?; + + // Use smart model loading - try ONNX direct first, fallback to conversion + self.load_model_smart(model_info).await?; + + // Verify the files were created + if !model_path.exists() || !tokenizer_path.exists() { + return Err(anyhow!( + "Model download completed but files not found at expected paths. Model: {:?}, Tokenizer: {:?}", + model_path, tokenizer_path + )); + } + + tracing::info!( + model = %model_info.hf_model_name, + model_path = ?model_path, + tokenizer_path = ?tokenizer_path, + "βœ… Successfully downloaded model" + ); + + Ok((model_path, tokenizer_path)) + } + + /// Smart model loading - tries ONNX direct download first, falls back to conversion + async fn load_model_smart(&self, model_info: &ModelInfo) -> anyhow::Result<()> { + let format = self + .discover_model_format(&model_info.hf_model_name) + .await?; + + match format { + ModelFormat::OnnxDirect { + model_path, + tokenizer_path, + } => { + tracing::info!( + "πŸ” Found ONNX files in standard location for {}", + model_info.hf_model_name + ); + self.download_onnx_files( + &model_info.hf_model_name, + &model_path, + &tokenizer_path, + model_info, + ) + .await + } + + ModelFormat::OnnxCustomPaths { + model_path, + tokenizer_path, + } => { + tracing::info!( + "πŸ” Found ONNX files in custom location for {}", + model_info.hf_model_name + ); + self.download_onnx_files( + &model_info.hf_model_name, + &model_path, + &tokenizer_path, + model_info, + ) + .await + } + + ModelFormat::ConvertToOnnx => { + tracing::info!( + "πŸ”„ No ONNX files found, will convert PyTorch model for {}", + model_info.hf_model_name + ); + self.download_and_convert_model(model_info).await // Existing approach + } + + ModelFormat::Unsupported => Err(anyhow!( + "Model {} has no supported format (no ONNX or PyTorch files)", + model_info.hf_model_name + )), + } + } + + /// Discover what format a model is available in + async fn discover_model_format(&self, repo: &str) -> anyhow::Result { + let files = self.get_repo_files(repo).await?; + + // Strategy 1: Look for standard onnx/ folder (like protectai model) + if files.iter().any(|f| f.starts_with("onnx/")) { + return Ok(ModelFormat::OnnxDirect { + model_path: "onnx/model.onnx".to_string(), + tokenizer_path: "onnx/tokenizer.json".to_string(), + }); + } + + // Strategy 2: Look for ONNX files in root or custom locations + let onnx_files: Vec<_> = files.iter().filter(|f| f.ends_with(".onnx")).collect(); + + let tokenizer_files: Vec<_> = files + .iter() + .filter(|f| f.contains("tokenizer") && f.ends_with(".json")) + .collect(); + + if !onnx_files.is_empty() && !tokenizer_files.is_empty() { + return Ok(ModelFormat::OnnxCustomPaths { + model_path: onnx_files[0].clone(), + tokenizer_path: tokenizer_files[0].clone(), + }); + } + + // Strategy 3: Check if we can convert PyTorch model + if files + .iter() + .any(|f| f == "pytorch_model.bin" || f == "model.safetensors") + { + return Ok(ModelFormat::ConvertToOnnx); + } + + Ok(ModelFormat::Unsupported) + } + + /// Get list of files in a HuggingFace repository + async fn get_repo_files(&self, repo: &str) -> anyhow::Result> { + let api_url = format!("https://huggingface.co/api/models/{}", repo); + let client = reqwest::Client::new(); + + let mut request = client.get(&api_url); + + // Optional authentication + if let Ok(token) = std::env::var("HUGGINGFACE_TOKEN") { + request = request.header("Authorization", format!("Bearer {}", token)); + } + + let response = request.send().await?; + + if response.status() == 404 { + return Err(anyhow!("Model repository '{}' not found", repo)); + } + + if response.status() == 401 { + return Err(anyhow!( + "Model '{}' requires authentication. Set HUGGINGFACE_TOKEN environment variable.\n\ + Get a token from: https://huggingface.co/settings/tokens", + repo + )); + } + + let repo_info: RepoInfo = response.json().await?; + Ok(repo_info + .siblings + .into_iter() + .map(|f| f.rfilename) + .collect()) + } + + /// Download ONNX files directly from HuggingFace + async fn download_onnx_files( + &self, + model_name: &str, + model_path: &str, + tokenizer_path: &str, + model_info: &ModelInfo, + ) -> anyhow::Result<()> { + let base_url = format!("https://huggingface.co/{}/resolve/main/", model_name); + + // Download model file + let model_url = format!("{}{}", base_url, model_path); + let local_model_path = self.cache_dir.join(&model_info.onnx_filename); + tracing::info!("πŸ“₯ Downloading ONNX model from: {}", model_url); + self.download_file_with_auth(&model_url, &local_model_path) + .await?; + + // Download tokenizer file + let tokenizer_url = format!("{}{}", base_url, tokenizer_path); + let local_tokenizer_path = self.cache_dir.join(&model_info.tokenizer_filename); + tracing::info!("πŸ“₯ Downloading tokenizer from: {}", tokenizer_url); + self.download_file_with_auth(&tokenizer_url, &local_tokenizer_path) + .await?; + + Ok(()) + } + + /// Download a file with optional authentication + async fn download_file_with_auth(&self, url: &str, local_path: &PathBuf) -> anyhow::Result<()> { + let client = reqwest::Client::new(); + let mut request = client.get(url); + + // Use HF token if available + if let Ok(token) = std::env::var("HUGGINGFACE_TOKEN") { + request = request.header("Authorization", format!("Bearer {}", token)); + } + + let response = request.send().await?; + + if response.status() == 401 { + return Err(anyhow!( + "File requires authentication. Set HUGGINGFACE_TOKEN environment variable." + )); + } + + if !response.status().is_success() { + return Err(anyhow!( + "Failed to download file from {}: HTTP {}", + url, + response.status() + )); + } + + let bytes = response.bytes().await?; + fs::write(local_path, bytes).await?; + + tracing::info!( + "βœ… Downloaded: {} ({} bytes)", + local_path.display(), + fs::metadata(local_path).await?.len() + ); + + Ok(()) + } + + async fn download_and_convert_model(&self, model_info: &ModelInfo) -> anyhow::Result<()> { + // Set up Python virtual environment with required dependencies + let venv_dir = self.cache_dir.join("python_venv"); + self.ensure_python_venv(&venv_dir).await?; + + let python_script = self.create_conversion_script(model_info).await?; + + tracing::info!("Running model conversion script in virtual environment..."); + + // Use the virtual environment's Python + let python_exe = if cfg!(windows) { + venv_dir.join("Scripts").join("python.exe") + } else { + venv_dir.join("bin").join("python") + }; + + let output = Command::new(&python_exe) + .arg(&python_script) + .env("CACHE_DIR", &self.cache_dir) + .env("MODEL_NAME", &model_info.hf_model_name) + .output() + .map_err(|e| anyhow!("Failed to execute Python conversion script: {}", e))?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + let stdout = String::from_utf8_lossy(&output.stdout); + return Err(anyhow!( + "Model conversion failed:\nStdout: {}\nStderr: {}", + stdout, + stderr + )); + } + + // Clean up the temporary script + let _ = fs::remove_file(&python_script).await; + + Ok(()) + } + + async fn ensure_python_venv(&self, venv_dir: &std::path::Path) -> anyhow::Result<()> { + // Check if virtual environment already exists and has required packages + let python_exe = if cfg!(windows) { + venv_dir.join("Scripts").join("python.exe") + } else { + venv_dir.join("bin").join("python") + }; + + if python_exe.exists() { + // Check if required packages are installed + let output = Command::new(&python_exe) + .args(&[ + "-c", + "import torch, transformers, onnx, tokenizers; print('OK')", + ]) + .output(); + + if let Ok(output) = output { + if output.status.success() && String::from_utf8_lossy(&output.stdout).trim() == "OK" + { + tracing::info!( + "Python virtual environment already set up with required packages" + ); + return Ok(()); + } + } + } + + tracing::info!("Setting up Python virtual environment..."); + + // Create virtual environment + fs::create_dir_all(venv_dir).await?; + + let output = Command::new("python3") + .args(&[ + "-m", + "venv", + venv_dir + .to_str() + .ok_or_else(|| anyhow!("Invalid venv directory path"))?, + ]) + .output() + .map_err(|e| anyhow!("Failed to create Python virtual environment: {}", e))?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(anyhow!("Failed to create virtual environment: {}", stderr)); + } + + tracing::info!("Installing required Python packages..."); + + // Install required packages + let pip_exe = if cfg!(windows) { + venv_dir.join("Scripts").join("pip.exe") + } else { + venv_dir.join("bin").join("pip") + }; + + let packages = ["torch", "transformers", "onnx", "tokenizers"]; + + for package in &packages { + tracing::info!("Installing {}...", package); + let output = Command::new(&pip_exe) + .args(&["install", package]) + .output() + .map_err(|e| anyhow!("Failed to install {}: {}", package, e))?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(anyhow!("Failed to install {}: {}", package, stderr)); + } + } + + tracing::info!("Python virtual environment setup complete"); + Ok(()) + } + + async fn create_conversion_script(&self, model_info: &ModelInfo) -> anyhow::Result { + let script_content = format!( + r#"#!/usr/bin/env python3 +""" +Runtime model conversion script for Goose security models +""" + +import os +import sys + +def install_packages(): + """Install required packages""" + import subprocess + packages = ["torch", "transformers", "onnx", "tokenizers"] + for package in packages: + print(f"πŸ“¦ Installing {{package}}...") + try: + subprocess.check_call([sys.executable, "-m", "pip", "install", package], + stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + except subprocess.CalledProcessError as e: + print(f"❌ Failed to install {{package}}: {{e}}") + return False + return True + +def check_and_install_packages(): + """Check if packages are available, install if needed""" + try: + import torch + import transformers + import onnx + import tokenizers + print("βœ… All required packages are available") + return True + except ImportError as e: + print(f"❌ Missing packages: {{e}}") + print("πŸ“¦ Installing required packages...") + if install_packages(): + # Try importing again after installation + try: + import torch + import transformers + import onnx + import tokenizers + print("βœ… Successfully installed and imported all packages") + return True + except ImportError as e2: + print(f"❌ Still missing packages after installation: {{e2}}") + return False + else: + return False + +# Check and install packages first +if not check_and_install_packages(): + print("❌ Failed to install required packages") + sys.exit(1) + +# Now import everything we need +from transformers import AutoTokenizer, AutoModelForSequenceClassification +import torch + +def convert_model_to_onnx(model_name: str, output_dir: str): + """Convert a Hugging Face model to ONNX format""" + print(f"Converting {{model_name}} to ONNX...") + + # Create output directory + from pathlib import Path + Path(output_dir).mkdir(parents=True, exist_ok=True) + + try: + # Handle authentication for gated models + hf_token = os.getenv('HUGGINGFACE_TOKEN') or os.getenv('HF_TOKEN') + auth_kwargs = {{}} + if hf_token: + auth_kwargs['token'] = hf_token + print(f" Using HF token for authentication") + + # Load model and tokenizer + print(f" Loading tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(model_name, **auth_kwargs) + + print(f" Loading model...") + model = AutoModelForSequenceClassification.from_pretrained(model_name, **auth_kwargs) + model.eval() + + # Create dummy input + dummy_text = "This is a test input for ONNX conversion" + inputs = tokenizer(dummy_text, return_tensors="pt", padding=True, truncation=True, max_length=512) + + # Export to ONNX + model_filename = model_name.replace("/", "_") + ".onnx" + model_path = os.path.join(output_dir, model_filename) + + print(f" Exporting to ONNX...") + torch.onnx.export( + model, + (inputs['input_ids'], inputs['attention_mask']), + model_path, + export_params=True, + opset_version=14, + do_constant_folding=True, + input_names=['input_ids', 'attention_mask'], + output_names=['logits'], + dynamic_axes={{ + 'input_ids': {{0: 'batch_size', 1: 'sequence'}}, + 'attention_mask': {{0: 'batch_size', 1: 'sequence'}}, + 'logits': {{0: 'batch_size'}} + }} + ) + + # Save tokenizer with model-specific filename + tokenizer_filename = model_name.replace("/", "_") + "_tokenizer.json" + tokenizer_path = os.path.join(output_dir, tokenizer_filename) + + # First save to temp directory to get the tokenizer.json file + temp_dir = os.path.join(output_dir, "temp_tokenizer") + tokenizer.save_pretrained(temp_dir, legacy_format=False) + + # Copy the tokenizer.json file to the expected location with model-specific name + import shutil + temp_tokenizer_json = os.path.join(temp_dir, "tokenizer.json") + if os.path.exists(temp_tokenizer_json): + shutil.copy2(temp_tokenizer_json, tokenizer_path) + # Clean up temp directory + shutil.rmtree(temp_dir) + else: + print(f" Warning: tokenizer.json not found in {{temp_dir}}") + # Fallback: save the entire tokenizer directory + tokenizer_dir = os.path.join(output_dir, model_name.replace("/", "_") + "_tokenizer") + tokenizer.save_pretrained(tokenizer_dir, legacy_format=False) + print(f" Saved tokenizer to directory: {{tokenizer_dir}}") + + print(f"βœ… Successfully converted {{model_name}}") + print(f" Model: {{model_path}}") + print(f" Tokenizer: {{tokenizer_path}}") + return True + + except Exception as e: + print(f"❌ Failed to convert {{model_name}}: {{e}}") + if "gated repo" in str(e).lower() or "access" in str(e).lower(): + print(f" This might be a gated model. Make sure you:") + print(f" 1. Have access to {{model_name}} on Hugging Face") + print(f" 2. Set your HF token: export HUGGINGFACE_TOKEN='your_token'") + print(f" 3. Get a token from: https://huggingface.co/settings/tokens") + import traceback + traceback.print_exc() + return False + +def main(): + model_name = os.getenv('MODEL_NAME') + cache_dir = os.getenv('CACHE_DIR') + + if not model_name or not cache_dir: + print("Error: MODEL_NAME and CACHE_DIR environment variables must be set") + sys.exit(1) + + success = convert_model_to_onnx(model_name, cache_dir) + if not success: + sys.exit(1) + +if __name__ == "__main__": + main() +"# + ); + + let script_path = self.cache_dir.join(format!( + "convert_model_{}.py", + model_info.hf_model_name.replace("/", "_").replace("-", "_") + )); + fs::write(&script_path, script_content).await?; + + // Make the script executable + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let mut perms = fs::metadata(&script_path).await?.permissions(); + perms.set_mode(0o755); + fs::set_permissions(&script_path, perms).await?; + } + + Ok(script_path) + } + + pub fn get_cache_dir(&self) -> &Path { + &self.cache_dir + } +} + +#[derive(Debug, Clone)] +pub struct ModelInfo { + pub hf_model_name: String, + pub onnx_filename: String, + pub tokenizer_filename: String, +} + +impl ModelInfo { + pub fn from_config_model(model_name: &str) -> Self { + // Keep the original model name format for filenames to match what model_downloader creates + let safe_filename = model_name.replace("/", "_"); + Self { + hf_model_name: model_name.to_string(), + onnx_filename: format!("{}.onnx", safe_filename), + tokenizer_filename: format!("{}_tokenizer.json", safe_filename), + } + } +} + +// Global downloader instance +static GLOBAL_DOWNLOADER: OnceCell = OnceCell::const_new(); + +pub async fn get_global_downloader() -> anyhow::Result<&'static ModelDownloader> { + GLOBAL_DOWNLOADER + .get_or_try_init(|| async { ModelDownloader::new() }) + .await +} diff --git a/crates/goose/src/security/scanner.rs b/crates/goose/src/security/scanner.rs new file mode 100644 index 000000000000..822be606968b --- /dev/null +++ b/crates/goose/src/security/scanner.rs @@ -0,0 +1,768 @@ +use crate::conversation::message::Message; +use anyhow::{anyhow, Result}; +use mcp_core::tool::ToolCall; +use std::path::PathBuf; +use std::sync::Arc; +use tokio::sync::OnceCell; + +use crate::security::model_downloader::{get_global_downloader, ModelInfo}; + +// ML inference backends +use ort::{session::builder::GraphOptimizationLevel, session::Session}; +use tokenizers::Tokenizer; + +#[derive(Debug, Clone)] +pub struct ScanResult { + pub is_malicious: bool, + pub confidence: f32, + pub explanation: String, +} + +/// Trait for different ML inference backends +#[async_trait::async_trait] +pub trait PromptInjectionModel: Send + Sync { + async fn predict(&self, text: &str) -> Result<(f32, String)>; + fn model_name(&self) -> &str; +} + +/// ONNX Runtime implementation +pub struct OnnxPromptInjectionModel { + session: Arc>, + tokenizer: Arc, + model_name: String, +} + +impl OnnxPromptInjectionModel { + pub async fn new( + model_path: PathBuf, + tokenizer_path: PathBuf, + model_name: String, + ) -> Result { + tracing::info!("πŸ”’ Starting ONNX model initialization..."); + + // Initialize ONNX Runtime session + tracing::info!("πŸ”’ Creating ONNX session from: {:?}", model_path); + let session = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .commit_from_file(&model_path)?; + + tracing::info!("πŸ”’ ONNX session created successfully"); + + // Load tokenizer + tracing::info!("πŸ”’ Loading tokenizer from: {:?}", tokenizer_path); + let tokenizer = Tokenizer::from_file(&tokenizer_path) + .map_err(|e| anyhow!("Failed to load tokenizer: {}", e))?; + + tracing::info!("πŸ”’ Tokenizer loaded successfully"); + + Ok(Self { + session: Arc::new(std::sync::Mutex::new(session)), + tokenizer: Arc::new(tokenizer), + model_name, + }) + } +} + +#[async_trait::async_trait] +impl PromptInjectionModel for OnnxPromptInjectionModel { + async fn predict(&self, text: &str) -> Result<(f32, String)> { + tracing::info!("πŸ”’ ONNX predict() called with text length: {}", text.len()); + tracing::info!( + "πŸ”’ ONNX predict() received text: '{}'", + text.chars().take(200).collect::() + ); + + // Tokenize the input text + tracing::debug!("πŸ”’ Tokenizing input text..."); + let encoding = self + .tokenizer + .encode(text, true) + .map_err(|e| anyhow!("Tokenization failed: {}", e))?; + + let input_ids = encoding.get_ids(); + let attention_mask = encoding.get_attention_mask(); + + tracing::debug!( + "πŸ”’ Tokenization complete. Sequence length: {}", + input_ids.len() + ); + + // Convert to the format expected by ONNX (batch_size=1) + let input_ids: Vec = input_ids.iter().map(|&id| id as i64).collect(); + let attention_mask: Vec = attention_mask.iter().map(|&mask| mask as i64).collect(); + + let seq_len = input_ids.len(); + + tracing::debug!("πŸ”’ Creating ONNX tensors..."); + // Create ONNX tensors + let input_ids_tensor = + ort::value::Tensor::from_array(([1, seq_len], input_ids.into_boxed_slice()))?; + let attention_mask_tensor = + ort::value::Tensor::from_array(([1, seq_len], attention_mask.into_boxed_slice()))?; + + tracing::debug!("πŸ”’ Running ONNX inference..."); + // Run inference and extract the logits immediately + let (logit_0, logit_1) = { + let mut session = self + .session + .lock() + .map_err(|e| anyhow!("Failed to lock session: {}", e))?; + tracing::debug!("πŸ”’ Session locked, running inference..."); + let outputs = session.run(ort::inputs![ + "input_ids" => input_ids_tensor, + "attention_mask" => attention_mask_tensor + ])?; + + tracing::debug!("πŸ”’ Inference complete, extracting logits..."); + // Extract logits from output immediately while we have the lock + let logits = outputs["logits"].try_extract_tensor::()?; + let logits_slice = logits.1; + + // Extract the values we need + let logit_0 = logits_slice[0]; // Non-injection class + let logit_1 = logits_slice[1]; // Injection class + + tracing::debug!("πŸ”’ Logits extracted: [{:.3}, {:.3}]", logit_0, logit_1); + + (logit_0, logit_1) + }; + + // Apply softmax to get probabilities + let exp_0 = logit_0.exp(); + let exp_1 = logit_1.exp(); + let sum_exp = exp_0 + exp_1; + + let prob_injection = exp_1 / sum_exp; + + let explanation = format!( + "ONNX model '{}': Injection probability = {:.3} (logits: [{:.3}, {:.3}])", + self.model_name, prob_injection, logit_0, logit_1 + ); + + tracing::info!( + "πŸ”’ ONNX prediction complete: confidence={:.3}, explanation={}", + prob_injection, + explanation + ); + + tracing::debug!( + model = %self.model_name, + text_length = text.len(), + seq_length = seq_len, + logit_0 = logit_0, + logit_1 = logit_1, + prob_injection = prob_injection, + "ONNX inference completed" + ); + + Ok((prob_injection, explanation)) + } + + fn model_name(&self) -> &str { + &self.model_name + } +} + +type ModelCache = Arc>>>; + +/// Global model cache with reload capability +static MODEL_CACHE: OnceCell = OnceCell::const_new(); + +/// Initialize the global model +async fn initialize_model() -> Result>> { + tracing::info!("πŸ”’ Attempting to initialize ONNX security model..."); + + // Try to load the ONNX model + match get_global_downloader().await { + Ok(downloader) => { + let model_info = PromptInjectionScanner::get_model_info_from_config(); + match downloader.ensure_model_available(&model_info).await { + Ok((model_path, tokenizer_path)) => { + tracing::info!("πŸ”’ Loading ONNX model from: {:?}", model_path); + match OnnxPromptInjectionModel::new( + model_path, + tokenizer_path, + model_info.hf_model_name.clone(), + ) + .await + { + Ok(model) => { + tracing::info!("πŸ”’ βœ… ONNX security model loaded successfully"); + return Ok(Some(Arc::new(model))); + } + Err(e) => { + tracing::warn!("πŸ”’ Failed to initialize ONNX model: {}", e); + } + } + } + Err(e) => { + tracing::warn!("πŸ”’ Failed to ensure model available: {}", e); + } + } + } + Err(e) => { + tracing::warn!("πŸ”’ Failed to get model downloader: {}", e); + } + } + + tracing::info!("πŸ”’ ONNX model not available, will use pattern-based scanning"); + Ok(None) +} + +/// Get or initialize the global model +async fn get_model() -> Option> { + let cache = MODEL_CACHE + .get_or_init(|| async { Arc::new(tokio::sync::RwLock::new(None)) }) + .await; + + // Check if model is already loaded in memory + let read_guard = cache.read().await; + if let Some(model) = read_guard.as_ref() { + tracing::debug!("πŸ”’ Model found in memory cache, using cached instance"); + return Some(model.clone()); + } + drop(read_guard); + + // Model not loaded in memory, try to initialize from disk + tracing::info!("πŸ”’ Model not loaded in memory, loading from disk cache..."); + let mut write_guard = cache.write().await; + + // Double-check in case another task loaded it while we were waiting + if let Some(model) = write_guard.as_ref() { + tracing::debug!("πŸ”’ Model was loaded by another task while waiting, using that instance"); + return Some(model.clone()); + } + + // Load the model from disk + match initialize_model().await { + Ok(Some(model)) => { + tracing::info!("πŸ”’ βœ… Model successfully loaded into memory cache"); + *write_guard = Some(model.clone()); + Some(model) + } + Ok(None) => { + tracing::info!("πŸ”’ No model available, using pattern-based fallback"); + None + } + Err(e) => { + tracing::warn!("πŸ”’ Failed to initialize model: {}", e); + None + } + } +} + +/// Check if model is available without triggering download/initialization +pub async fn get_model_if_available() -> Option> { + let cache = MODEL_CACHE + .get_or_init(|| async { Arc::new(tokio::sync::RwLock::new(None)) }) + .await; + + // Only check if model is already loaded in memory - don't trigger loading + let read_guard = cache.read().await; + read_guard.as_ref().cloned() +} + +/// Simple prompt injection scanner +/// Uses the existing model_downloader infrastructure +pub struct PromptInjectionScanner { + enabled: bool, +} + +impl PromptInjectionScanner { + pub fn new() -> Self { + println!("πŸ”’ PromptInjectionScanner::new() called"); + + // Check if models are available, trigger download if needed + let scanner = Self { + enabled: Self::check_and_prepare_models(), + }; + + println!("πŸ”’ Scanner enabled: {}", scanner.enabled); + + scanner + } + + /// Check if models are available and trigger download if needed + fn check_and_prepare_models() -> bool { + // Check if models are already cached + let model_info = Self::get_model_info_from_config(); + + // Check if model files exist in cache + if let Some(cache_dir) = dirs::cache_dir() { + let security_models_dir = cache_dir.join("goose").join("security_models"); + let model_path = security_models_dir.join(&model_info.onnx_filename); + let tokenizer_path = security_models_dir.join(&model_info.tokenizer_filename); + + if model_path.exists() && tokenizer_path.exists() { + tracing::info!( + "πŸ”’ Security model files found on disk - loading model into memory now..." + ); + + // Load model into memory immediately at startup + tokio::spawn(async move { + tracing::info!("πŸ”’ Pre-loading security model at startup..."); + if let Some(_model) = get_model().await { + tracing::info!( + "πŸ”’ βœ… Security model pre-loaded successfully - ready for scanning" + ); + } else { + tracing::warn!("πŸ”’ Failed to pre-load security model"); + } + }); + + return true; + } + } + + // Models not cached - we need to download them + tracing::info!("πŸ”’ Security model files not found on disk"); + tracing::info!( + "πŸ”’ Models will be downloaded on first security scan (this may cause a delay)" + ); + + // For now, return true to enable security scanning + // The models will be downloaded lazily on first scan + // TODO: Consider blocking startup to download models synchronously + true + } + + + + /// Get model information from config file + pub fn get_model_info_from_config() -> ModelInfo { + use crate::config::Config; + let config = Config::global(); + + // Try to get model from config + let security_config = config.get_param::("security").ok(); + + let model_name = security_config + .as_ref() + .and_then(|security| security.get("models")?.as_array()?.first()) + .and_then(|model| model.get("model")?.as_str()) + .map(|s| s.to_string()) + .unwrap_or_else(|| { + tracing::warn!( + "πŸ”’ No security model configured, security scanning will be disabled" + ); + // Return a placeholder that won't work, forcing pattern-only mode + "no-model-configured".to_string() + }); + + tracing::info!("πŸ”’ Using security model from config: {}", model_name); + + // Create ModelInfo from config + ModelInfo::from_config_model(&model_name) + } + + /// Two-step security analysis: scan tool call first, then analyze context if suspicious + pub async fn analyze_tool_call_with_context( + &self, + tool_call: &ToolCall, + messages: &[Message], + ) -> Result { + // Step 1: Scan the tool call itself for suspicious patterns + let tool_call_result = self.scan_tool_call_only(tool_call).await?; + + if !tool_call_result.is_malicious { + // Tool call looks safe, no need for context analysis + tracing::debug!( + tool_name = %tool_call.name, + confidence = tool_call_result.confidence, + "βœ… Tool call passed initial security scan" + ); + return Ok(tool_call_result); + } + + // Step 2: Tool call looks suspicious, analyze conversation context + tracing::info!( + tool_name = %tool_call.name, + confidence = tool_call_result.confidence, + "πŸ” Tool call flagged as suspicious, analyzing conversation context" + ); + + let user_messages_result = self.scan_user_messages_only(messages).await?; + + // Decision logic: combine both results + let final_result = + self.make_final_security_decision(&tool_call_result, &user_messages_result, tool_call); + + tracing::info!( + tool_name = %tool_call.name, + tool_confidence = tool_call_result.confidence, + conversation_confidence = user_messages_result.confidence, + final_malicious = final_result.is_malicious, + final_confidence = final_result.confidence, + "πŸ”’ Two-step security analysis complete" + ); + + Ok(final_result) + } + + /// Step 1: Scan only the tool call for suspicious patterns + async fn scan_tool_call_only(&self, tool_call: &ToolCall) -> Result { + // Debug: Log the raw tool call arguments first + tracing::info!("πŸ”’ Raw tool call arguments: {:?}", tool_call.arguments); + + // Create text representation of the tool call for analysis + let arguments_json = + serde_json::to_string_pretty(&tool_call.arguments).unwrap_or_else(|e| { + tracing::warn!("πŸ”’ Failed to serialize tool arguments: {}", e); + format!("{{\"error\": \"Failed to serialize arguments: {}\"}}", e) + }); + + let tool_text = format!("Tool: {}\nArguments: {}", tool_call.name, arguments_json); + + tracing::info!( + "πŸ”’ Complete tool text being analyzed (length: {}): '{}'", + tool_text.len(), + tool_text + ); + + self.scan_with_prompt_injection_model(&tool_text).await + } + + /// Step 2: Scan only the user messages (conversation history) for prompt injection + async fn scan_user_messages_only(&self, messages: &[Message]) -> Result { + // Extract only user messages from recent conversation history + let user_messages: Vec = messages + .iter() + .rev() + .take(5) // Take last 5 messages for context + .rev() + .filter_map(|msg| { + // Only analyze user messages, not assistant responses + if matches!(msg.role, rmcp::model::Role::User) { + msg.content.first()?.as_text().map(|text| text.to_string()) + } else { + None + } + }) + .collect(); + + if user_messages.is_empty() { + return Ok(ScanResult { + is_malicious: false, + confidence: 0.0, + explanation: "No user messages found in conversation history".to_string(), + }); + } + + let user_context = user_messages.join("\n\n"); + self.scan_with_prompt_injection_model(&user_context).await + } + + /// Make final security decision based on both tool call and user message analysis + fn make_final_security_decision( + &self, + tool_call_result: &ScanResult, + user_messages_result: &ScanResult, + _tool_call: &ToolCall, + ) -> ScanResult { + // Simple decision logic using config threshold: + // 1. If user messages contain prompt injection, tool call is likely malicious + // 2. Otherwise, use the tool call confidence as-is + // 3. Let the config threshold determine if user should be asked + + let (confidence, explanation) = if user_messages_result.is_malicious { + // User messages contain prompt injection - tool call is likely malicious + let combined_confidence = + (tool_call_result.confidence + user_messages_result.confidence) / 2.0; + let explanation = "Tool appears to be the result of a prompt injection attack.".to_string(); + (combined_confidence, explanation) + } else { + // Use tool call confidence as-is, let config threshold decide + let explanation = if tool_call_result.confidence > 0.0 { + format!("Tool flagged with confidence: {:.2}", tool_call_result.confidence) + } else { + "Tool appears safe".to_string() + }; + (tool_call_result.confidence, explanation) + }; + + // Get threshold from config to determine if malicious + let config_threshold = self.get_threshold_from_config(); + let is_malicious = confidence > config_threshold; + + ScanResult { + is_malicious, + confidence, + explanation, + } + } + + + + /// Scan system prompt for persistent injection attacks + pub async fn scan_system_prompt(&self, system_prompt: &str) -> Result { + tracing::info!( + "πŸ”’ Scanning system prompt for persistent injection attacks (length: {})", + system_prompt.len() + ); + + // Use the ML model to scan the system prompt - this is what we have the model for! + self.scan_with_prompt_injection_model(system_prompt).await + } + + /// Model-agnostic prompt injection scanning - public for recipe scanning + pub async fn scan_with_prompt_injection_model(&self, text: &str) -> Result { + tracing::info!( + "πŸ”’ Starting scan_with_prompt_injection_model for text (length: {}): '{}'", + text.len(), + text.chars().take(100).collect::() + ); + + // Always run pattern-based scanning first + let pattern_result = self.scan_with_patterns(text).await?; + + // Try to get the ML model for additional scanning + tracing::info!("πŸ”’ Attempting to get ML model..."); + if let Some(model) = get_model().await { + tracing::info!("πŸ”’ ML model retrieved successfully, calling predict..."); + tracing::info!( + "πŸ”’ About to call model.predict() with text length: {}", + text.len() + ); + match model.predict(text).await { + Ok((ml_confidence, ml_explanation)) => { + tracing::info!("πŸ”’ ML model predict returned successfully"); + // Get threshold from config + let threshold = self.get_threshold_from_config(); + let ml_is_malicious = ml_confidence > threshold; + + tracing::info!( + "πŸ”’ ML model prediction: confidence={:.3}, threshold={:.3}, malicious={}", + ml_confidence, + threshold, + ml_is_malicious + ); + + // Combine ML and pattern results + let combined_result = self.combine_scan_results( + &pattern_result, + ml_confidence, + &ml_explanation, + ml_is_malicious, + ); + + tracing::info!( + "πŸ”’ Combined scan result: ML confidence={:.3}, Pattern confidence={:.3}, Final confidence={:.3}, Final malicious={}", + ml_confidence, pattern_result.confidence, combined_result.confidence, combined_result.is_malicious + ); + + return Ok(combined_result); + } + Err(e) => { + tracing::warn!("πŸ”’ ML model prediction failed: {}", e); + // Fall through to pattern-only result + } + } + } else { + tracing::info!("πŸ”’ No ML model available, using pattern-based scanning only"); + } + + tracing::info!("πŸ”’ Using pattern-based scan result only"); + Ok(pattern_result) + } + + /// Combine ML model and pattern matching results + fn combine_scan_results( + &self, + pattern_result: &ScanResult, + ml_confidence: f32, + _ml_explanation: &str, + ml_is_malicious: bool, + ) -> ScanResult { + // Take the higher confidence score + let final_confidence = pattern_result.confidence.max(ml_confidence); + + // Mark as malicious if either method detects it + let final_is_malicious = pattern_result.is_malicious || ml_is_malicious; + + // Simplified explanation - just show what detected the threat + let combined_explanation = if pattern_result.is_malicious && ml_is_malicious { + "Detected by both pattern analysis and ML model".to_string() + } else if pattern_result.is_malicious { + format!( + "Detected by pattern analysis: {}", + pattern_result + .explanation + .replace("Pattern-based detection: ", "") + ) + } else if ml_is_malicious { + "Detected by machine learning model".to_string() + } else { + "No threats detected".to_string() + }; + + ScanResult { + is_malicious: final_is_malicious, + confidence: final_confidence, + explanation: combined_explanation, + } + } + + /// Get threshold from config + pub fn get_threshold_from_config(&self) -> f32 { + use crate::config::Config; + let config = Config::global(); + + // Get security config and extract threshold + if let Ok(security_value) = config.get_param::("security") { + if let Some(models_array) = security_value.get("models").and_then(|m| m.as_array()) { + if let Some(first_model) = models_array.first() { + if let Some(threshold) = first_model.get("threshold").and_then(|t| t.as_f64()) { + return threshold as f32; + } + } + } + } + + 0.7 // Default threshold + } + + /// Fallback pattern-based scanning + async fn scan_with_patterns(&self, text: &str) -> Result { + let text_lower = text.to_lowercase(); + + // Command injection patterns - detect potentially dangerous commands + let dangerous_patterns = [ + // File system operations + "rm -rf /", + "rm -rf /*", + "rm -rf ~", + "rm -rf $home", + "rmdir /", + "del /s /q", + "format c:", + // System manipulation + "shutdown", + "reboot", + "halt", + "poweroff", + "kill -9", + "killall", + // Network/data exfiltration and remote execution + "bash <(curl", + "sh <(curl", + "bash <(wget", + "sh <(wget", + "curl http", + "wget http", + "nc -l", + "netcat", + "ssh ", + "scp ", + "rsync", + // Process manipulation + "sudo ", + "su -", + "chmod 777", + "chown root", + // Command chaining that could hide malicious intent + "; rm ", + "&& rm ", + "| rm ", + "; curl ", + "&& curl ", + "| curl ", + "; wget ", + "&& wget ", + // Suspicious file operations + "rm -f /", + "rm -rf .", + "rm -rf ..", + "> /dev/", + "dd if=", + "mkfs", + // Potential data theft + "cat /etc/passwd", + "cat /etc/shadow", + "/etc/hosts", + "~/.ssh/", + "id_rsa", + // Obfuscation attempts + "base64 -d", + "echo | sh", + "eval ", + "exec ", + ]; + + let mut detected_patterns = Vec::new(); + let mut max_risk_score: f32 = 0.0; + + for pattern in &dangerous_patterns { + if text_lower.contains(pattern) { + detected_patterns.push(pattern.to_string()); + + // Assign risk scores based on severity + let risk_score = match *pattern { + // Critical - system destruction + "rm -rf /" | "rm -rf /*" | "format c:" | "mkfs" => 0.95, + "rm -rf ~" | "rm -rf $home" => 0.90, + + // Critical - remote code execution patterns + "bash <(curl" | "sh <(curl" | "bash <(wget" | "sh <(wget" => 0.95, + + // High - system control + "shutdown" | "reboot" | "halt" | "poweroff" => 0.85, + "sudo " | "su -" | "chmod 777" | "chown root" => 0.80, + + // Medium-High - network/data access + "curl http" | "wget http" | "ssh " | "scp " => 0.75, + "cat /etc/passwd" | "cat /etc/shadow" | "~/.ssh/" => 0.85, + + // Medium - suspicious operations + "; rm " | "&& rm " | "| rm " => 0.70, + "kill -9" | "killall" => 0.65, + + // Lower - potentially legitimate but suspicious + "base64 -d" | "eval " | "exec " => 0.60, + + _ => 0.50, + }; + + max_risk_score = max_risk_score.max(risk_score); + } + } + + if !detected_patterns.is_empty() { + // Use config threshold for pattern-based detection too + let config_threshold = self.get_threshold_from_config(); + let is_malicious = max_risk_score > config_threshold; + let explanation = format!( + "Pattern-based detection: Found {} suspicious command pattern(s): [{}]. Risk score: {:.2}", + detected_patterns.len(), + detected_patterns.join(", "), + max_risk_score + ); + + tracing::info!( + "πŸ”’ Pattern-based scan detected {} suspicious patterns with max risk score: {:.2}", + detected_patterns.len(), + max_risk_score + ); + + Ok(ScanResult { + is_malicious, + confidence: max_risk_score, + explanation, + }) + } else { + Ok(ScanResult { + is_malicious: false, + confidence: 0.0, + explanation: "Pattern-based scan: No suspicious command patterns detected" + .to_string(), + }) + } + } + + +} + +impl Default for PromptInjectionScanner { + fn default() -> Self { + Self::new() + } +} diff --git a/ui/desktop/src/components/GooseMessage.tsx b/ui/desktop/src/components/GooseMessage.tsx index ca19513f1c7f..eeaa995db4d3 100644 --- a/ui/desktop/src/components/GooseMessage.tsx +++ b/ui/desktop/src/components/GooseMessage.tsx @@ -223,6 +223,7 @@ export default function GooseMessage({ isClicked={messageIndex < messageHistoryIndex} toolConfirmationId={toolConfirmationContent.id} toolName={toolConfirmationContent.toolName} + prompt={toolConfirmationContent.prompt} /> )} diff --git a/ui/desktop/src/components/ToolCallConfirmation.tsx b/ui/desktop/src/components/ToolCallConfirmation.tsx index 764a460402a9..6ab32476aedf 100644 --- a/ui/desktop/src/components/ToolCallConfirmation.tsx +++ b/ui/desktop/src/components/ToolCallConfirmation.tsx @@ -25,6 +25,7 @@ interface ToolConfirmationProps { isClicked: boolean; toolConfirmationId: string; toolName: string; + prompt?: string; // Security warning or custom prompt } export default function ToolConfirmation({ @@ -32,6 +33,7 @@ export default function ToolConfirmation({ isClicked, toolConfirmationId, toolName, + prompt, }: ToolConfirmationProps) { // Check if we have a stored state for this tool confirmation const storedState = toolConfirmationState.get(toolConfirmationId); @@ -121,7 +123,7 @@ export default function ToolConfirmation({ ) : ( <>
- Goose would like to call the above tool. Allow? + {prompt || 'Goose would like to call the above tool. Allow?'}
{clicked ? (
@@ -188,9 +190,12 @@ export default function ToolConfirmation({
) : (
- + {/* Hide "Always Allow" for security warnings to prevent bypassing future security checks */} + {!prompt?.includes('SECURITY WARNING') && ( + + )}