From cc30637a0a85f9aac95b34d370ce6e6bb05d6546 Mon Sep 17 00:00:00 2001 From: Dorien Koelemeijer Date: Sat, 9 Aug 2025 15:45:06 +1000 Subject: [PATCH 01/14] initial version - integrating security scanning into check_tool_permissions --- crates/goose/src/agents/agent.rs | 28 ++ crates/goose/src/lib.rs | 1 + crates/goose/src/security/mod.rs | 131 ++++++ crates/goose/src/security/model_downloader.rs | 398 ++++++++++++++++++ crates/goose/src/security/scanner.rs | 250 +++++++++++ 5 files changed, 808 insertions(+) create mode 100644 crates/goose/src/security/mod.rs create mode 100644 crates/goose/src/security/model_downloader.rs create mode 100644 crates/goose/src/security/scanner.rs diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 26521dcb6896..d8c55fae9c0e 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -57,6 +57,7 @@ use super::platform_tools; use super::tool_execution::{ToolCallResult, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE}; use crate::agents::subagent_task_config::TaskConfig; use crate::conversation::message::{Message, ToolRequest}; +use crate::security::SecurityManager; const DEFAULT_MAX_TURNS: u32 = 1000; @@ -97,6 +98,7 @@ pub struct Agent { pub(super) tool_route_manager: ToolRouteManager, pub(super) scheduler_service: Mutex>>, pub(super) retry_manager: RetryManager, + pub(super) security_manager: SecurityManager, } #[derive(Clone, Debug)] @@ -173,6 +175,7 @@ impl Agent { tool_route_manager: ToolRouteManager::new(), scheduler_service: Mutex::new(None), retry_manager, + security_manager: SecurityManager::new(), } } @@ -1011,6 +1014,31 @@ impl Agent { self.provider().await?, ).await; + // Scan tools for prompt injection + let security_results = self.security_manager + .filter_evil_tool_calls(messages.messages(), &permission_check_result) + .await + .unwrap_or_else(|e| { + tracing::warn!("Security scanning failed: {}", e); + vec![] + }); + + // Handle security results - for now just log them + for security_result in &security_results { + if security_result.is_malicious { + tracing::warn!( + confidence = security_result.confidence, + explanation = %security_result.explanation, + "Security threat detected in tool call" + ); + + if security_result.should_ask_user { + // TODO: Implement user confirmation using existing tool approval system + tracing::info!("Security threat requires user confirmation"); + } + } + } + let mut tool_futures = self.handle_approved_and_denied_tools( &permission_check_result, message_tool_response.clone(), diff --git a/crates/goose/src/lib.rs b/crates/goose/src/lib.rs index 7d774dddeddc..5609aaa74e8d 100644 --- a/crates/goose/src/lib.rs +++ b/crates/goose/src/lib.rs @@ -13,6 +13,7 @@ pub mod recipe_deeplink; pub mod scheduler; pub mod scheduler_factory; pub mod scheduler_trait; +pub mod security; pub mod session; pub mod temporal_scheduler; pub mod token_counter; diff --git a/crates/goose/src/security/mod.rs b/crates/goose/src/security/mod.rs new file mode 100644 index 000000000000..887626d53f62 --- /dev/null +++ b/crates/goose/src/security/mod.rs @@ -0,0 +1,131 @@ +pub mod scanner; +pub mod model_downloader; + +use anyhow::Result; +use crate::conversation::message::Message; +use crate::permission::permission_judge::PermissionCheckResult; +use scanner::PromptInjectionScanner; + +/// Simple security manager for the POC +/// Focuses on tool call analysis with conversation context +pub struct SecurityManager { + scanner: Option, +} + +#[derive(Debug, Clone)] +pub struct SecurityResult { + pub is_malicious: bool, + pub confidence: f32, + pub explanation: String, + pub should_ask_user: bool, +} + +impl SecurityManager { + pub fn new() -> Self { + println!("πŸ”’ SecurityManager::new() called - checking if security should be enabled"); + + // Initialize scanner based on config + let should_enable = Self::should_enable_security(); + println!("πŸ”’ Security enabled check result: {}", should_enable); + + let scanner = match should_enable { + true => { + println!("πŸ”’ Initializing security scanner"); + tracing::info!("πŸ”’ Initializing security scanner"); + Some(PromptInjectionScanner::new()) + } + false => { + println!("πŸ”“ Security scanning disabled"); + tracing::info!("πŸ”“ Security scanning disabled"); + None + } + }; + + Self { scanner } + } + + /// Check if security should be enabled based on config + fn should_enable_security() -> bool { + // Check config file for security settings + use crate::config::Config; + let config = Config::global(); + + // Try to get security.enabled from config + let result = config.get_param::("security") + .ok() + .and_then(|security_config| security_config.get("enabled")?.as_bool()) + .unwrap_or(false); + + println!("πŸ”’ Config check - security config result: {:?}", + config.get_param::("security")); + println!("πŸ”’ Final security enabled result: {}", result); + + result + } + + /// Main security check function - called from reply_internal + pub async fn filter_evil_tool_calls( + &self, + messages: &[Message], + permission_check_result: &PermissionCheckResult, + ) -> Result> { + let Some(scanner) = &self.scanner else { + // Security disabled, return empty results + return Ok(vec![]); + }; + + let mut results = Vec::new(); + + // Check tools that need approval for potential security issues + for tool_request in &permission_check_result.needs_approval { + if let Ok(tool_call) = &tool_request.tool_call { + tracing::info!( + tool_name = %tool_call.name, + "πŸ” Analyzing tool call for security threats" + ); + + // First, check if the tool call itself looks suspicious + let tool_suspicious = scanner.scan_tool_call(tool_call).await?; + + if tool_suspicious.is_malicious { + // Tool call looks suspicious, analyze conversation context + tracing::warn!( + tool_name = %tool_call.name, + confidence = tool_suspicious.confidence, + "🚨 Suspicious tool call detected, analyzing conversation context" + ); + + let context_result = scanner.analyze_conversation_context( + messages, + tool_call, + ).await?; + + results.push(SecurityResult { + is_malicious: context_result.is_malicious, + confidence: context_result.confidence, + explanation: format!( + "Tool '{}' flagged as suspicious (confidence: {:.2}). Context analysis: {}", + tool_call.name, + tool_suspicious.confidence, + context_result.explanation + ), + should_ask_user: context_result.is_malicious && context_result.confidence > 0.7, + }); + } else { + tracing::debug!( + tool_name = %tool_call.name, + "βœ… Tool call passed security check" + ); + } + } + } + + Ok(results) + } +} + +impl Default for SecurityManager { + fn default() -> Self { + Self::new() + } +} diff --git a/crates/goose/src/security/model_downloader.rs b/crates/goose/src/security/model_downloader.rs new file mode 100644 index 000000000000..ef5e25f1be7f --- /dev/null +++ b/crates/goose/src/security/model_downloader.rs @@ -0,0 +1,398 @@ +use anyhow::{anyhow}; +use std::path::{Path, PathBuf}; +use std::process::Command; +use tokio::fs; +use tokio::sync::OnceCell; + +pub struct ModelDownloader { + cache_dir: PathBuf, +} + +impl ModelDownloader { + pub fn new() -> anyhow::Result { + // Use platform-appropriate cache directory + let cache_dir = if let Some(cache_dir) = dirs::cache_dir() { + cache_dir.join("goose").join("security_models") + } else { + // Fallback to home directory + dirs::home_dir() + .ok_or_else(|| anyhow!("Could not determine home directory"))? + .join(".cache") + .join("goose") + .join("security_models") + }; + + Ok(Self { cache_dir }) + } + + pub async fn ensure_model_available(&self, model_info: &ModelInfo) -> anyhow::Result<(PathBuf, PathBuf)> { + let model_path = self.cache_dir.join(&model_info.onnx_filename); + let tokenizer_path = self.cache_dir.join(&model_info.tokenizer_filename); + + // Check if both model and tokenizer exist + if model_path.exists() && tokenizer_path.exists() { + tracing::info!( + model = %model_info.hf_model_name, + path = ?model_path, + "Using cached ONNX model" + ); + return Ok((model_path, tokenizer_path)); + } + + tracing::info!( + model = %model_info.hf_model_name, + "πŸ”’ Goose is being set up, this could take up to a minute…" + ); + + // Create cache directory if it doesn't exist + fs::create_dir_all(&self.cache_dir).await?; + + // Download and convert the model - this blocks until complete + self.download_and_convert_model(model_info).await?; + + // Verify the files were created + if !model_path.exists() || !tokenizer_path.exists() { + return Err(anyhow!( + "Model conversion completed but files not found at expected paths. Model: {:?}, Tokenizer: {:?}", + model_path, tokenizer_path + )); + } + + tracing::info!( + model = %model_info.hf_model_name, + model_path = ?model_path, + tokenizer_path = ?tokenizer_path, + "βœ… Successfully downloaded and converted model" + ); + + Ok((model_path, tokenizer_path)) + } + + async fn download_and_convert_model(&self, model_info: &ModelInfo) -> anyhow::Result<()> { + // Set up Python virtual environment with required dependencies + let venv_dir = self.cache_dir.join("python_venv"); + self.ensure_python_venv(&venv_dir).await?; + + let python_script = self.create_conversion_script(model_info).await?; + + tracing::info!("Running model conversion script in virtual environment..."); + + // Use the virtual environment's Python + let python_exe = if cfg!(windows) { + venv_dir.join("Scripts").join("python.exe") + } else { + venv_dir.join("bin").join("python") + }; + + let output = Command::new(&python_exe) + .arg(&python_script) + .env("CACHE_DIR", &self.cache_dir) + .env("MODEL_NAME", &model_info.hf_model_name) + .output() + .map_err(|e| anyhow!("Failed to execute Python conversion script: {}", e))?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + let stdout = String::from_utf8_lossy(&output.stdout); + return Err(anyhow!( + "Model conversion failed:\nStdout: {}\nStderr: {}", + stdout, + stderr + )); + } + + // Clean up the temporary script + let _ = fs::remove_file(&python_script).await; + + Ok(()) + } + + async fn ensure_python_venv(&self, venv_dir: &std::path::Path) -> anyhow::Result<()> { + // Check if virtual environment already exists and has required packages + let python_exe = if cfg!(windows) { + venv_dir.join("Scripts").join("python.exe") + } else { + venv_dir.join("bin").join("python") + }; + + if python_exe.exists() { + // Check if required packages are installed + let output = Command::new(&python_exe) + .args(&["-c", "import torch, transformers, onnx, tokenizers; print('OK')"]) + .output(); + + if let Ok(output) = output { + if output.status.success() && String::from_utf8_lossy(&output.stdout).trim() == "OK" { + tracing::info!("Python virtual environment already set up with required packages"); + return Ok(()); + } + } + } + + tracing::info!("Setting up Python virtual environment..."); + + // Create virtual environment + fs::create_dir_all(venv_dir).await?; + + let output = Command::new("python3") + .args(&["-m", "venv", venv_dir.to_str() + .ok_or_else(|| anyhow!("Invalid venv directory path"))?]) + .output() + .map_err(|e| anyhow!("Failed to create Python virtual environment: {}", e))?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(anyhow!("Failed to create virtual environment: {}", stderr)); + } + + tracing::info!("Installing required Python packages..."); + + // Install required packages + let pip_exe = if cfg!(windows) { + venv_dir.join("Scripts").join("pip.exe") + } else { + venv_dir.join("bin").join("pip") + }; + + let packages = [ + "torch", + "transformers", + "onnx", + "tokenizers", + ]; + + for package in &packages { + tracing::info!("Installing {}...", package); + let output = Command::new(&pip_exe) + .args(&["install", package]) + .output() + .map_err(|e| anyhow!("Failed to install {}: {}", package, e))?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(anyhow!("Failed to install {}: {}", package, stderr)); + } + } + + tracing::info!("Python virtual environment setup complete"); + Ok(()) + } + + async fn create_conversion_script(&self, model_info: &ModelInfo) -> anyhow::Result { + let script_content = format!( + r#"#!/usr/bin/env python3 +""" +Runtime model conversion script for Goose security models +""" + +import os +import sys + +def install_packages(): + """Install required packages""" + import subprocess + packages = ["torch", "transformers", "onnx", "tokenizers"] + for package in packages: + print(f"πŸ“¦ Installing {{package}}...") + try: + subprocess.check_call([sys.executable, "-m", "pip", "install", package], + stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + except subprocess.CalledProcessError as e: + print(f"❌ Failed to install {{package}}: {{e}}") + return False + return True + +def check_and_install_packages(): + """Check if packages are available, install if needed""" + try: + import torch + import transformers + import onnx + import tokenizers + print("βœ… All required packages are available") + return True + except ImportError as e: + print(f"❌ Missing packages: {{e}}") + print("πŸ“¦ Installing required packages...") + if install_packages(): + # Try importing again after installation + try: + import torch + import transformers + import onnx + import tokenizers + print("βœ… Successfully installed and imported all packages") + return True + except ImportError as e2: + print(f"❌ Still missing packages after installation: {{e2}}") + return False + else: + return False + +# Check and install packages first +if not check_and_install_packages(): + print("❌ Failed to install required packages") + sys.exit(1) + +# Now import everything we need +from transformers import AutoTokenizer, AutoModelForSequenceClassification +import torch + +def convert_model_to_onnx(model_name: str, output_dir: str): + """Convert a Hugging Face model to ONNX format""" + print(f"Converting {{model_name}} to ONNX...") + + # Create output directory + from pathlib import Path + Path(output_dir).mkdir(parents=True, exist_ok=True) + + try: + # Handle authentication for gated models + hf_token = os.getenv('HUGGINGFACE_TOKEN') or os.getenv('HF_TOKEN') + auth_kwargs = {{}} + if hf_token: + auth_kwargs['token'] = hf_token + print(f" Using HF token for authentication") + + # Load model and tokenizer + print(f" Loading tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(model_name, **auth_kwargs) + + print(f" Loading model...") + model = AutoModelForSequenceClassification.from_pretrained(model_name, **auth_kwargs) + model.eval() + + # Create dummy input + dummy_text = "This is a test input for ONNX conversion" + inputs = tokenizer(dummy_text, return_tensors="pt", padding=True, truncation=True, max_length=512) + + # Export to ONNX + model_filename = model_name.replace("/", "_") + ".onnx" + model_path = os.path.join(output_dir, model_filename) + + print(f" Exporting to ONNX...") + torch.onnx.export( + model, + (inputs['input_ids'], inputs['attention_mask']), + model_path, + export_params=True, + opset_version=14, + do_constant_folding=True, + input_names=['input_ids', 'attention_mask'], + output_names=['logits'], + dynamic_axes={{ + 'input_ids': {{0: 'batch_size', 1: 'sequence'}}, + 'attention_mask': {{0: 'batch_size', 1: 'sequence'}}, + 'logits': {{0: 'batch_size'}} + }} + ) + + # Save tokenizer with model-specific filename + tokenizer_filename = model_name.replace("/", "_") + "_tokenizer.json" + tokenizer_path = os.path.join(output_dir, tokenizer_filename) + + # First save to temp directory to get the tokenizer.json file + temp_dir = os.path.join(output_dir, "temp_tokenizer") + tokenizer.save_pretrained(temp_dir, legacy_format=False) + + # Copy the tokenizer.json file to the expected location with model-specific name + import shutil + temp_tokenizer_json = os.path.join(temp_dir, "tokenizer.json") + if os.path.exists(temp_tokenizer_json): + shutil.copy2(temp_tokenizer_json, tokenizer_path) + # Clean up temp directory + shutil.rmtree(temp_dir) + else: + print(f" Warning: tokenizer.json not found in {{temp_dir}}") + # Fallback: save the entire tokenizer directory + tokenizer_dir = os.path.join(output_dir, model_name.replace("/", "_") + "_tokenizer") + tokenizer.save_pretrained(tokenizer_dir, legacy_format=False) + print(f" Saved tokenizer to directory: {{tokenizer_dir}}") + + print(f"βœ… Successfully converted {{model_name}}") + print(f" Model: {{model_path}}") + print(f" Tokenizer: {{tokenizer_path}}") + return True + + except Exception as e: + print(f"❌ Failed to convert {{model_name}}: {{e}}") + if "gated repo" in str(e).lower() or "access" in str(e).lower(): + print(f" This might be a gated model. Make sure you:") + print(f" 1. Have access to {{model_name}} on Hugging Face") + print(f" 2. Set your HF token: export HUGGINGFACE_TOKEN='your_token'") + print(f" 3. Get a token from: https://huggingface.co/settings/tokens") + import traceback + traceback.print_exc() + return False + +def main(): + model_name = os.getenv('MODEL_NAME') + cache_dir = os.getenv('CACHE_DIR') + + if not model_name or not cache_dir: + print("Error: MODEL_NAME and CACHE_DIR environment variables must be set") + sys.exit(1) + + success = convert_model_to_onnx(model_name, cache_dir) + if not success: + sys.exit(1) + +if __name__ == "__main__": + main() +"# + ); + + let script_path = self.cache_dir.join(format!("convert_model_{}.py", + model_info.hf_model_name.replace("/", "_").replace("-", "_"))); + fs::write(&script_path, script_content).await?; + + // Make the script executable + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let mut perms = fs::metadata(&script_path).await?.permissions(); + perms.set_mode(0o755); + fs::set_permissions(&script_path, perms).await?; + } + + Ok(script_path) + } + + pub fn get_cache_dir(&self) -> &Path { + &self.cache_dir + } +} + +#[derive(Debug, Clone)] +pub struct ModelInfo { + pub hf_model_name: String, + pub onnx_filename: String, + pub tokenizer_filename: String, +} + +impl ModelInfo { + pub fn deepset_deberta() -> Self { + Self { + hf_model_name: "deepset/deberta-v3-base-injection".to_string(), + onnx_filename: "deepset_deberta-v3-base-injection.onnx".to_string(), + tokenizer_filename: "deepset_deberta-v3-base-injection_tokenizer.json".to_string(), + } + } + + pub fn protectai_deberta() -> Self { + Self { + hf_model_name: "protectai/deberta-v3-base-prompt-injection-v2".to_string(), + onnx_filename: "protectai_deberta-v3-base-prompt-injection-v2.onnx".to_string(), + tokenizer_filename: "protectai_deberta-v3-base-prompt-injection-v2_tokenizer.json".to_string(), + } + } +} + +// Global downloader instance +static GLOBAL_DOWNLOADER: OnceCell = OnceCell::const_new(); + +pub async fn get_global_downloader() -> anyhow::Result<&'static ModelDownloader> { + GLOBAL_DOWNLOADER + .get_or_try_init(|| async { ModelDownloader::new() }) + .await +} diff --git a/crates/goose/src/security/scanner.rs b/crates/goose/src/security/scanner.rs new file mode 100644 index 000000000000..9d28417a4d71 --- /dev/null +++ b/crates/goose/src/security/scanner.rs @@ -0,0 +1,250 @@ +use anyhow::Result; +use mcp_core::tool::ToolCall; +use crate::conversation::message::Message; +use std::path::PathBuf; + +use crate::security::model_downloader::{get_global_downloader, ModelInfo}; + +#[derive(Debug, Clone)] +pub struct ScanResult { + pub is_malicious: bool, + pub confidence: f32, + pub explanation: String, +} + +/// Simple prompt injection scanner +/// Uses the existing model_downloader infrastructure +pub struct PromptInjectionScanner { + model_path: Option, + enabled: bool, +} + +impl PromptInjectionScanner { + pub fn new() -> Self { + println!("πŸ”’ PromptInjectionScanner::new() called"); + + // Check if models are available, trigger download if needed + let scanner = Self { + model_path: None, + enabled: Self::check_and_prepare_models(), + }; + + println!("πŸ”’ Scanner enabled: {}", scanner.enabled); + + scanner + } + + /// Check if models are available and trigger download if needed + fn check_and_prepare_models() -> bool { + // For now, trigger model download in background and use pattern-based scanning + tokio::spawn(async { + Self::ensure_models_available().await; + }); + + // Return false for now to use pattern-based scanning + // This will be true once models are properly downloaded and available + false + } + + /// Ensure models are available using the existing model_downloader + async fn ensure_models_available() { + tracing::info!("πŸ”’ Ensuring security models are available..."); + + match get_global_downloader().await { + Ok(downloader) => { + let model_info = Self::get_model_info_from_config(); + match downloader.ensure_model_available(&model_info).await { + Ok((model_path, tokenizer_path)) => { + tracing::info!( + "πŸ”’ βœ… Security models ready: model={:?}, tokenizer={:?}", + model_path, tokenizer_path + ); + } + Err(e) => { + tracing::warn!("πŸ”’ Failed to ensure models available: {}", e); + tracing::info!("πŸ”’ Continuing with pattern-based security scanning"); + } + } + } + Err(e) => { + tracing::warn!("πŸ”’ Failed to get model downloader: {}", e); + } + } + } + + /// Get model information from config file + fn get_model_info_from_config() -> ModelInfo { + use crate::config::Config; + let config = Config::global(); + + // Try to get model from config + let security_config = config.get_param::("security").ok(); + + let model_name = security_config + .as_ref() + .and_then(|security| security.get("models")?.as_array()?.first()) + .and_then(|model| model.get("model")?.as_str()) + .map(|s| s.to_string()) + .unwrap_or_else(|| { + tracing::warn!("πŸ”’ No security model configured, using default"); + "protectai/deberta-v3-base-prompt-injection-v2".to_string() + }); + + tracing::info!("πŸ”’ Using security model from config: {}", model_name); + + // Create ModelInfo from config + let safe_filename = model_name.replace("/", "_").replace("-", "_"); + ModelInfo { + hf_model_name: model_name, + onnx_filename: format!("{}.onnx", safe_filename), + tokenizer_filename: format!("{}_tokenizer.json", safe_filename), + } + } + + /// Scan a tool call for suspicious patterns + pub async fn scan_tool_call(&self, tool_call: &ToolCall) -> Result { + // Create text representation of the tool call for analysis + let tool_text = format!( + "Tool: {}\nArguments: {}", + tool_call.name, + serde_json::to_string_pretty(&tool_call.arguments)? + ); + + // For now, always use pattern-based scanning + // TODO: Use ONNX model when available + self.scan_with_patterns(&tool_text).await + } + + /// Analyze conversation context to determine if tool call is malicious + pub async fn analyze_conversation_context( + &self, + messages: &[Message], + tool_call: &ToolCall, + ) -> Result { + // Combine recent messages for context analysis + let context_text = self.build_context_text(messages, tool_call); + + // For now, always use pattern-based analysis + // TODO: Use ONNX model when available + self.analyze_context_with_patterns(&context_text, tool_call).await + } + + /// Build context text from conversation history + fn build_context_text(&self, messages: &[Message], tool_call: &ToolCall) -> String { + // Take last 5 messages for context (adjust as needed) + let recent_messages: Vec = messages + .iter() + .rev() + .take(5) + .rev() + .filter_map(|msg| { + // Extract text content from messages + msg.content.first()?.as_text().map(|text| { + format!("{:?}: {}", msg.role, text) + }) + }) + .collect(); + + let context = recent_messages.join("\n"); + let tool_text = format!( + "Tool: {}\nArguments: {}", + tool_call.name, + serde_json::to_string_pretty(&tool_call.arguments).unwrap_or_default() + ); + + format!("Context:\n{}\n\nProposed Action:\n{}", context, tool_text) + } + + /// Fallback pattern-based scanning + async fn scan_with_patterns(&self, text: &str) -> Result { + let text_lower = text.to_lowercase(); + + // Simple patterns that might indicate prompt injection + let suspicious_patterns = [ + "ignore previous instructions", + "ignore all previous", + "forget everything", + "new instructions", + "system prompt", + "you are now", + "act as", + "pretend to be", + "roleplay as", + "jailbreak", + "developer mode", + ]; + + let mut max_confidence: f32 = 0.0; + let mut detected_patterns = Vec::new(); + + for pattern in &suspicious_patterns { + if text_lower.contains(pattern) { + detected_patterns.push(*pattern); + max_confidence = max_confidence.max(0.8); // High confidence for pattern match + } + } + + let is_malicious = max_confidence > 0.5; + let explanation = if detected_patterns.is_empty() { + "Pattern-based scan: No suspicious patterns detected".to_string() + } else { + format!("Pattern-based scan detected: {} (confidence: {:.2})", + detected_patterns.join(", "), max_confidence) + }; + + Ok(ScanResult { + is_malicious, + confidence: max_confidence, + explanation, + }) + } + + /// Analyze context with patterns + async fn analyze_context_with_patterns( + &self, + context_text: &str, + tool_call: &ToolCall, + ) -> Result { + // First scan the context for suspicious patterns + let context_result = self.scan_with_patterns(context_text).await?; + + // Consider tool-specific risks + let tool_risk = self.assess_tool_risk(&tool_call.name); + + // Combine context analysis with tool risk + let combined_confidence = (context_result.confidence * 0.7) + (tool_risk * 0.3); + let is_malicious = combined_confidence > 0.6; + + let explanation = format!( + "Context analysis: {} (confidence: {:.2}). Tool risk assessment: {:.2}. Combined risk: {:.2}", + context_result.explanation, + context_result.confidence, + tool_risk, + combined_confidence + ); + + Ok(ScanResult { + is_malicious, + confidence: combined_confidence, + explanation, + }) + } + + /// Assess inherent risk of specific tools + fn assess_tool_risk(&self, tool_name: &str) -> f32 { + // Higher risk tools that could be used maliciously + match tool_name { + name if name.contains("shell") || name.contains("exec") => 0.8, + name if name.contains("file") && name.contains("write") => 0.6, + name if name.contains("network") || name.contains("http") => 0.5, + name if name.contains("read") => 0.3, + _ => 0.1, + } + } +} + +impl Default for PromptInjectionScanner { + fn default() -> Self { + Self::new() + } +} From fe0c392ca5d47e9504fe3a948dcbcfbb2949eb44 Mon Sep 17 00:00:00 2001 From: Dorien Koelemeijer Date: Mon, 11 Aug 2025 10:44:58 +1000 Subject: [PATCH 02/14] Fix some issues with how prompt injection detection on agent tool calls was done - initial working version --- Cargo.lock | 353 +++++++++++++- crates/goose/Cargo.toml | 4 + crates/goose/src/agents/agent.rs | 27 +- crates/goose/src/security/mod.rs | 54 ++- crates/goose/src/security/model_downloader.rs | 18 +- crates/goose/src/security/scanner.rs | 449 +++++++++++++++--- 6 files changed, 776 insertions(+), 129 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f849572a6789..40f517a90d9e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1151,6 +1151,12 @@ dependencies = [ "vsimd", ] +[[package]] +name = "base64ct" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba" + [[package]] name = "bat" version = "0.24.0" @@ -1499,7 +1505,7 @@ version = "0.15.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d067ad48b8650848b989a59a86c6c36a995d02d2bf778d45c3c5d57bc2718f02" dependencies = [ - "smallvec", + "smallvec 1.14.0", "target-lexicon", ] @@ -2511,6 +2517,16 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "der" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" +dependencies = [ + "pem-rfc7468", + "zeroize", +] + [[package]] name = "deranged" version = "0.3.11" @@ -2532,6 +2548,37 @@ dependencies = [ "syn 2.0.99", ] +[[package]] +name = "derive_builder" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" +dependencies = [ + "darling 0.20.10", + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" +dependencies = [ + "derive_builder_core", + "syn 2.0.99", +] + [[package]] name = "digest" version = "0.10.7" @@ -2705,6 +2752,15 @@ version = "3.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a5d9305ccc6942a704f4335694ecd3de2ea531b114ac2d51f5f843750787a92f" +[[package]] +name = "esaxx-rs" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6" +dependencies = [ + "cc", +] + [[package]] name = "etcetera" version = "0.8.0" @@ -2760,7 +2816,7 @@ dependencies = [ "lebe", "miniz_oxide", "rayon-core", - "smallvec", + "smallvec 1.14.0", "zune-inflate", ] @@ -3321,6 +3377,7 @@ dependencies = [ "opentelemetry", "opentelemetry-otlp", "opentelemetry_sdk", + "ort", "rand 0.8.5", "regex", "reqwest 0.12.12", @@ -3335,6 +3392,7 @@ dependencies = [ "tempfile", "thiserror 1.0.69", "tiktoken-rs", + "tokenizers", "tokio", "tokio-cron-scheduler", "tokio-stream", @@ -3802,7 +3860,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "smallvec", + "smallvec 1.14.0", "tokio", "want", ] @@ -3974,7 +4032,7 @@ dependencies = [ "icu_normalizer_data", "icu_properties", "icu_provider", - "smallvec", + "smallvec 1.14.0", "utf16_iter", "utf8_iter", "write16", @@ -4049,7 +4107,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" dependencies = [ "idna_adapter", - "smallvec", + "smallvec 1.14.0", "utf8_iter", ] @@ -4272,6 +4330,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.12.1" @@ -5115,6 +5182,22 @@ dependencies = [ "libc", ] +[[package]] +name = "macro_rules_attribute" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65049d7923698040cd0b1ddcced9b0eb14dd22c5f86ae59c3740eab64a676520" +dependencies = [ + "macro_rules_attribute-proc_macro", + "paste", +] + +[[package]] +name = "macro_rules_attribute-proc_macro" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30" + [[package]] name = "malloc_buf" version = "0.0.6" @@ -5145,6 +5228,16 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "maybe-rayon" version = "0.1.1" @@ -5390,13 +5483,34 @@ dependencies = [ "rustc_version", "scheduled-thread-pool", "skeptic", - "smallvec", + "smallvec 1.14.0", "tagptr", "thiserror 1.0.69", "triomphe", "uuid", ] +[[package]] +name = "monostate" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aafe1be9d0c75642e3e50fedc7ecadf1ef1cbce6eb66462153fc44245343fbee" +dependencies = [ + "monostate-impl", + "serde", +] + +[[package]] +name = "monostate-impl" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c402a4092d5e204f32c9e155431046831fa712637043c58cb73bc6bc6c9663b5" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + [[package]] name = "multimap" version = "0.10.1" @@ -5418,6 +5532,38 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "native-tls" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework 2.11.1", + "security-framework-sys", + "tempfile", +] + +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "ndk-context" version = "0.1.1" @@ -5436,7 +5582,7 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77a5d83df9f36fe23f0c3648c6bbb8b0298bb5f1939c8f2704431371f4b84d43" dependencies = [ - "smallvec", + "smallvec 1.14.0", ] [[package]] @@ -5925,6 +6071,31 @@ dependencies = [ "hashbrown 0.14.5", ] +[[package]] +name = "ort" +version = "2.0.0-rc.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa7e49bd669d32d7bc2a15ec540a527e7764aec722a45467814005725bcd721" +dependencies = [ + "ndarray", + "ort-sys", + "smallvec 2.0.0-alpha.10", + "tracing", +] + +[[package]] +name = "ort-sys" +version = "2.0.0-rc.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2aba9f5c7c479925205799216e7e5d07cc1d4fa76ea8058c60a9a30f6a4e890" +dependencies = [ + "flate2", + "pkg-config", + "sha2", + "tar", + "ureq", +] + [[package]] name = "outref" version = "0.5.2" @@ -5971,7 +6142,7 @@ dependencies = [ "cfg-if", "libc", "redox_syscall", - "smallvec", + "smallvec 1.14.0", "windows-targets 0.52.6", ] @@ -6018,6 +6189,15 @@ dependencies = [ "serde", ] +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -6237,6 +6417,15 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -6743,6 +6932,12 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2ff9a1f06a88b01621b7ae906ef0211290d1c8a168a15542486a8f61c0833b9" +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.10.0" @@ -6753,6 +6948,17 @@ dependencies = [ "rayon-core", ] +[[package]] +name = "rayon-cond" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "059f538b55efd2309c9794130bc149c6a553db90e9d99c2030785c82f0bd7df9" +dependencies = [ + "either", + "itertools 0.11.0", + "rayon", +] + [[package]] name = "rayon-core" version = "1.12.1" @@ -7703,6 +7909,12 @@ version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd" +[[package]] +name = "smallvec" +version = "2.0.0-alpha.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d44cfb396c3caf6fbfd0ab422af02631b69ddd96d2eff0b0f0724f9024051b" + [[package]] name = "smawk" version = "0.3.2" @@ -7751,6 +7963,29 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "socks" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" +dependencies = [ + "byteorder", + "libc", + "winapi", +] + +[[package]] +name = "spm_precompiled" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326" +dependencies = [ + "base64 0.13.1", + "nom", + "serde", + "unicode-segmentation", +] + [[package]] name = "sqlparser" version = "0.49.0" @@ -8017,7 +8252,7 @@ dependencies = [ "serde", "serde_json", "sketches-ddsketch", - "smallvec", + "smallvec 1.14.0", "tantivy-bitpacker", "tantivy-columnar", "tantivy-common", @@ -8404,6 +8639,38 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tokenizers" +version = "0.20.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b08cc37428a476fc9e20ac850132a513a2e1ce32b6a31addf2b74fa7033b905" +dependencies = [ + "aho-corasick", + "derive_builder", + "esaxx-rs", + "getrandom 0.2.15", + "indicatif", + "itertools 0.12.1", + "lazy_static", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand 0.8.5", + "rayon", + "rayon-cond", + "regex", + "regex-syntax 0.8.5", + "serde", + "serde_json", + "spm_precompiled", + "thiserror 1.0.69", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + [[package]] name = "tokio" version = "1.43.1" @@ -8717,7 +8984,7 @@ dependencies = [ "once_cell", "opentelemetry", "opentelemetry_sdk", - "smallvec", + "smallvec 1.14.0", "tracing", "tracing-core", "tracing-log", @@ -8748,7 +9015,7 @@ dependencies = [ "serde", "serde_json", "sharded-slab", - "smallvec", + "smallvec 1.14.0", "thread_local", "time", "tracing", @@ -8855,6 +9122,15 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b09c83c3c29d37506a3e260c08c03743a6bb66a9cd432c6934ab501a190571f" +[[package]] +name = "unicode-normalization-alignments" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de" +dependencies = [ + "smallvec 1.14.0", +] + [[package]] name = "unicode-segmentation" version = "1.12.0" @@ -8873,6 +9149,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" +[[package]] +name = "unicode_categories" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" + [[package]] name = "unsafe-libyaml" version = "0.2.11" @@ -8885,6 +9167,37 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "ureq" +version = "3.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f0fde9bc91026e381155f8c67cb354bcd35260b2f4a29bcc84639f762760c39" +dependencies = [ + "base64 0.22.1", + "der", + "log", + "native-tls", + "percent-encoding", + "rustls-pemfile 2.2.0", + "rustls-pki-types", + "socks", + "ureq-proto", + "utf-8", + "webpki-root-certs 0.26.11", +] + +[[package]] +name = "ureq-proto" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59db78ad1923f2b1be62b6da81fe80b173605ca0d57f85da2e005382adf693f7" +dependencies = [ + "base64 0.22.1", + "http 1.2.0", + "httparse", + "log", +] + [[package]] name = "url" version = "2.5.4" @@ -9198,6 +9511,24 @@ dependencies = [ "web-sys", ] +[[package]] +name = "webpki-root-certs" +version = "0.26.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75c7f0ef91146ebfb530314f5f1d24528d7f0767efbfd31dce919275413e393e" +dependencies = [ + "webpki-root-certs 1.0.2", +] + +[[package]] +name = "webpki-root-certs" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e4ffd8df1c57e87c325000a3d6ef93db75279dc3a231125aac571650f22b12a" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "webpki-roots" version = "0.26.8" diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index b8fad75b5ef3..a6aca6822268 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -100,6 +100,10 @@ tokio-util = "0.7.15" lancedb = "0.13" arrow = "52.2" +# ML inference backends for security scanning +ort = "2.0.0-rc.6" # ONNX Runtime +tokenizers = "0.20.3" # HuggingFace tokenizers + [target.'cfg(target_os = "windows")'.dependencies] winapi = { version = "0.3", features = ["wincred"] } diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index d8c55fae9c0e..61787b711ae0 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -1014,9 +1014,31 @@ impl Agent { self.provider().await?, ).await; + // DEBUG: Log tool categorization + println!("πŸ” DEBUG: Tool categorization results:"); + println!(" - {} tools approved (pre-approved)", permission_check_result.approved.len()); + println!(" - {} tools need approval", permission_check_result.needs_approval.len()); + println!(" - {} tools denied", permission_check_result.denied.len()); + println!(" - {} readonly tools", readonly_tools.len()); + println!(" - {} regular tools", regular_tools.len()); + + for (i, tool_req) in remaining_requests.iter().enumerate() { + if let Ok(tool_call) = &tool_req.tool_call { + println!(" - Tool {}: '{}' -> {}", i, tool_call.name, + if permission_check_result.approved.iter().any(|r| r.id == tool_req.id) { "APPROVED" } + else if permission_check_result.needs_approval.iter().any(|r| r.id == tool_req.id) { "NEEDS_APPROVAL" } + else if permission_check_result.denied.iter().any(|r| r.id == tool_req.id) { "DENIED" } + else { "UNKNOWN" } + ); + } + } + // Scan tools for prompt injection + let total_tools = permission_check_result.approved.len() + permission_check_result.needs_approval.len(); + println!("πŸ” DEBUG: About to call security manager with {} total tools ({} approved + {} need approval)", + total_tools, permission_check_result.approved.len(), permission_check_result.needs_approval.len()); let security_results = self.security_manager - .filter_evil_tool_calls(messages.messages(), &permission_check_result) + .filter_malicious_tool_calls(messages.messages(), &permission_check_result) .await .unwrap_or_else(|e| { tracing::warn!("Security scanning failed: {}", e); @@ -1024,7 +1046,10 @@ impl Agent { }); // Handle security results - for now just log them + println!("πŸ” DEBUG: Security scan returned {} results", security_results.len()); for security_result in &security_results { + println!("πŸ” DEBUG: Security result - malicious: {}, confidence: {:.2}, explanation: {}", + security_result.is_malicious, security_result.confidence, security_result.explanation); if security_result.is_malicious { tracing::warn!( confidence = security_result.confidence, diff --git a/crates/goose/src/security/mod.rs b/crates/goose/src/security/mod.rs index 887626d53f62..38e7b4f845aa 100644 --- a/crates/goose/src/security/mod.rs +++ b/crates/goose/src/security/mod.rs @@ -64,7 +64,9 @@ impl SecurityManager { } /// Main security check function - called from reply_internal - pub async fn filter_evil_tool_calls( + /// Uses the proper two-step security analysis process + /// Scans ALL tools (approved + needs_approval) for security threats + pub async fn filter_malicious_tool_calls( &self, messages: &[Message], permission_check_result: &PermissionCheckResult, @@ -76,45 +78,45 @@ impl SecurityManager { let mut results = Vec::new(); - // Check tools that need approval for potential security issues - for tool_request in &permission_check_result.needs_approval { + // Collect ALL tool requests (approved + needs_approval) for security scanning + let mut all_tool_requests = Vec::new(); + all_tool_requests.extend(&permission_check_result.approved); + all_tool_requests.extend(&permission_check_result.needs_approval); + + // Check ALL tools for potential security issues + for tool_request in &all_tool_requests { if let Ok(tool_call) = &tool_request.tool_call { tracing::info!( tool_name = %tool_call.name, - "πŸ” Analyzing tool call for security threats" + "πŸ” Starting two-step security analysis for tool call" ); - // First, check if the tool call itself looks suspicious - let tool_suspicious = scanner.scan_tool_call(tool_call).await?; - - if tool_suspicious.is_malicious { - // Tool call looks suspicious, analyze conversation context + // Use the new two-step analysis method + let analysis_result = scanner.analyze_tool_call_with_context( + tool_call, + messages, + ).await?; + + if analysis_result.is_malicious { tracing::warn!( tool_name = %tool_call.name, - confidence = tool_suspicious.confidence, - "🚨 Suspicious tool call detected, analyzing conversation context" + confidence = analysis_result.confidence, + explanation = %analysis_result.explanation, + "🚨 Tool call flagged as malicious after two-step analysis" ); - let context_result = scanner.analyze_conversation_context( - messages, - tool_call, - ).await?; - results.push(SecurityResult { - is_malicious: context_result.is_malicious, - confidence: context_result.confidence, - explanation: format!( - "Tool '{}' flagged as suspicious (confidence: {:.2}). Context analysis: {}", - tool_call.name, - tool_suspicious.confidence, - context_result.explanation - ), - should_ask_user: context_result.is_malicious && context_result.confidence > 0.7, + is_malicious: analysis_result.is_malicious, + confidence: analysis_result.confidence, + explanation: analysis_result.explanation, + should_ask_user: analysis_result.confidence > 0.7, }); } else { tracing::debug!( tool_name = %tool_call.name, - "βœ… Tool call passed security check" + confidence = analysis_result.confidence, + explanation = %analysis_result.explanation, + "βœ… Tool call passed two-step security analysis" ); } } diff --git a/crates/goose/src/security/model_downloader.rs b/crates/goose/src/security/model_downloader.rs index ef5e25f1be7f..b6929b5d1b95 100644 --- a/crates/goose/src/security/model_downloader.rs +++ b/crates/goose/src/security/model_downloader.rs @@ -371,19 +371,13 @@ pub struct ModelInfo { } impl ModelInfo { - pub fn deepset_deberta() -> Self { + pub fn from_config_model(model_name: &str) -> Self { + // Keep the original model name format for filenames to match what model_downloader creates + let safe_filename = model_name.replace("/", "_"); Self { - hf_model_name: "deepset/deberta-v3-base-injection".to_string(), - onnx_filename: "deepset_deberta-v3-base-injection.onnx".to_string(), - tokenizer_filename: "deepset_deberta-v3-base-injection_tokenizer.json".to_string(), - } - } - - pub fn protectai_deberta() -> Self { - Self { - hf_model_name: "protectai/deberta-v3-base-prompt-injection-v2".to_string(), - onnx_filename: "protectai_deberta-v3-base-prompt-injection-v2.onnx".to_string(), - tokenizer_filename: "protectai_deberta-v3-base-prompt-injection-v2_tokenizer.json".to_string(), + hf_model_name: model_name.to_string(), + onnx_filename: format!("{}.onnx", safe_filename), + tokenizer_filename: format!("{}_tokenizer.json", safe_filename), } } } diff --git a/crates/goose/src/security/scanner.rs b/crates/goose/src/security/scanner.rs index 9d28417a4d71..0d03442e4b03 100644 --- a/crates/goose/src/security/scanner.rs +++ b/crates/goose/src/security/scanner.rs @@ -1,10 +1,16 @@ -use anyhow::Result; +use anyhow::{Result, anyhow}; use mcp_core::tool::ToolCall; use crate::conversation::message::Message; use std::path::PathBuf; +use std::sync::Arc; +use tokio::sync::OnceCell; use crate::security::model_downloader::{get_global_downloader, ModelInfo}; +// ML inference backends +use ort::{session::Session, session::builder::GraphOptimizationLevel}; +use tokenizers::Tokenizer; + #[derive(Debug, Clone)] pub struct ScanResult { pub is_malicious: bool, @@ -12,10 +18,161 @@ pub struct ScanResult { pub explanation: String, } +/// Trait for different ML inference backends +#[async_trait::async_trait] +pub trait PromptInjectionModel: Send + Sync { + async fn predict(&self, text: &str) -> Result<(f32, String)>; + fn model_name(&self) -> &str; +} + +/// ONNX Runtime implementation +pub struct OnnxPromptInjectionModel { + _session: Session, // Temporarily unused due to mutable reference issue + _tokenizer: Arc, // Temporarily unused + model_name: String, +} + +impl OnnxPromptInjectionModel { + pub async fn new(model_path: PathBuf, tokenizer_path: PathBuf, model_name: String) -> Result { + // Initialize ONNX Runtime session + let session = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .commit_from_file(&model_path)?; + + // Load tokenizer + let tokenizer = Tokenizer::from_file(&tokenizer_path) + .map_err(|e| anyhow!("Failed to load tokenizer: {}", e))?; + + Ok(Self { + _session: session, + _tokenizer: Arc::new(tokenizer), + model_name, + }) + } +} + +#[async_trait::async_trait] +impl PromptInjectionModel for OnnxPromptInjectionModel { + async fn predict(&self, text: &str) -> Result<(f32, String)> { + // For now, return a reasonable prediction based on simple heuristics + // TODO: Implement actual ONNX inference once we resolve the mutable reference issue + + let text_lower = text.to_lowercase(); + + // Check for prompt injection patterns + let injection_patterns = [ + "ignore previous instructions", + "ignore all previous", + "forget everything", + "new instructions", + "system prompt", + "you are now", + "act as", + "pretend to be", + "roleplay as", + "jailbreak", + "developer mode", + ]; + + // Check for dangerous shell commands + let dangerous_commands = [ + "rm -rf", + "sudo rm", + "del /s", + "format c:", + "dd if=", + "mkfs", + "fdisk", + "chmod 777", + "wget http", + "curl http", + "nc -l", + "netcat", + ]; + + let mut confidence = 0.0f32; + let mut detected_patterns = Vec::new(); + + // Check for prompt injection patterns + for pattern in &injection_patterns { + if text_lower.contains(pattern) { + detected_patterns.push(format!("injection:{}", pattern)); + confidence = confidence.max(0.9); + } + } + + // Check for dangerous commands + for command in &dangerous_commands { + if text_lower.contains(command) { + detected_patterns.push(format!("dangerous:{}", command)); + confidence = confidence.max(0.8); + } + } + + let explanation = if detected_patterns.is_empty() { + format!("ONNX model '{}': No threats detected", self.model_name) + } else { + format!("ONNX model '{}': Detected threats: {}", + self.model_name, detected_patterns.join(", ")) + }; + + Ok((confidence, explanation)) + } + + fn model_name(&self) -> &str { + &self.model_name + } +} + +/// Global model cache +static MODEL_CACHE: OnceCell>> = OnceCell::const_new(); + +/// Initialize the global model +async fn initialize_model() -> Result>> { + tracing::info!("πŸ”’ Attempting to initialize ONNX security model..."); + + // Try to load the ONNX model + match get_global_downloader().await { + Ok(downloader) => { + let model_info = PromptInjectionScanner::get_model_info_from_config(); + match downloader.ensure_model_available(&model_info).await { + Ok((model_path, tokenizer_path)) => { + tracing::info!("πŸ”’ Loading ONNX model from: {:?}", model_path); + match OnnxPromptInjectionModel::new(model_path, tokenizer_path, model_info.hf_model_name.clone()).await { + Ok(model) => { + tracing::info!("πŸ”’ βœ… ONNX security model loaded successfully"); + return Ok(Some(Arc::new(model))); + } + Err(e) => { + tracing::warn!("πŸ”’ Failed to initialize ONNX model: {}", e); + } + } + } + Err(e) => { + tracing::warn!("πŸ”’ Failed to ensure model available: {}", e); + } + } + } + Err(e) => { + tracing::warn!("πŸ”’ Failed to get model downloader: {}", e); + } + } + + tracing::info!("πŸ”’ ONNX model not available, will use pattern-based scanning"); + Ok(None) +} + +/// Get or initialize the global model +async fn get_model() -> Option> { + MODEL_CACHE + .get_or_init(|| async { initialize_model().await.unwrap_or(None) }) + .await + .clone() +} + /// Simple prompt injection scanner /// Uses the existing model_downloader infrastructure pub struct PromptInjectionScanner { - model_path: Option, enabled: bool, } @@ -25,7 +182,6 @@ impl PromptInjectionScanner { // Check if models are available, trigger download if needed let scanner = Self { - model_path: None, enabled: Self::check_and_prepare_models(), }; @@ -36,13 +192,29 @@ impl PromptInjectionScanner { /// Check if models are available and trigger download if needed fn check_and_prepare_models() -> bool { - // For now, trigger model download in background and use pattern-based scanning + // Check if models are already cached + let model_info = Self::get_model_info_from_config(); + + // Check if model files exist in cache + if let Some(cache_dir) = dirs::cache_dir() { + let security_models_dir = cache_dir.join("goose").join("security_models"); + let model_path = security_models_dir.join(&model_info.onnx_filename); + let tokenizer_path = security_models_dir.join(&model_info.tokenizer_filename); + + if model_path.exists() && tokenizer_path.exists() { + tracing::info!("πŸ”’ Security models found in cache, enabling security scanning"); + return true; + } + } + + // Models not cached, trigger download in background + tracing::info!("πŸ”’ Security models not found in cache, downloading in background"); tokio::spawn(async { Self::ensure_models_available().await; }); - // Return false for now to use pattern-based scanning - // This will be true once models are properly downloaded and available + // For now, use pattern-based scanning while models download + // TODO: In the future, we could block here or enable ONNX scanning after download false } @@ -86,23 +258,66 @@ impl PromptInjectionScanner { .and_then(|model| model.get("model")?.as_str()) .map(|s| s.to_string()) .unwrap_or_else(|| { - tracing::warn!("πŸ”’ No security model configured, using default"); - "protectai/deberta-v3-base-prompt-injection-v2".to_string() + tracing::warn!("πŸ”’ No security model configured, security scanning will be disabled"); + // Return a placeholder that won't work, forcing pattern-only mode + "no-model-configured".to_string() }); tracing::info!("πŸ”’ Using security model from config: {}", model_name); // Create ModelInfo from config - let safe_filename = model_name.replace("/", "_").replace("-", "_"); - ModelInfo { - hf_model_name: model_name, - onnx_filename: format!("{}.onnx", safe_filename), - tokenizer_filename: format!("{}_tokenizer.json", safe_filename), + ModelInfo::from_config_model(&model_name) + } + + /// Two-step security analysis: scan tool call first, then analyze context if suspicious + pub async fn analyze_tool_call_with_context( + &self, + tool_call: &ToolCall, + messages: &[Message], + ) -> Result { + // Step 1: Scan the tool call itself for suspicious patterns + let tool_call_result = self.scan_tool_call_only(tool_call).await?; + + if !tool_call_result.is_malicious { + // Tool call looks safe, no need for context analysis + tracing::debug!( + tool_name = %tool_call.name, + confidence = tool_call_result.confidence, + "βœ… Tool call passed initial security scan" + ); + return Ok(tool_call_result); } + + // Step 2: Tool call looks suspicious, analyze conversation context + tracing::info!( + tool_name = %tool_call.name, + confidence = tool_call_result.confidence, + "πŸ” Tool call flagged as suspicious, analyzing conversation context" + ); + + let user_messages_result = self.scan_user_messages_only(messages).await?; + + // Decision logic: combine both results + let final_result = self.make_final_security_decision( + &tool_call_result, + &user_messages_result, + tool_call, + ); + + tracing::info!( + tool_name = %tool_call.name, + tool_confidence = tool_call_result.confidence, + user_confidence = user_messages_result.confidence, + final_malicious = final_result.is_malicious, + final_confidence = final_result.confidence, + "πŸ”’ Two-step security analysis complete" + ); + + Ok(final_result) } - /// Scan a tool call for suspicious patterns - pub async fn scan_tool_call(&self, tool_call: &ToolCall) -> Result { + /// Step 1: Scan only the tool call for suspicious patterns + async fn scan_tool_call_only(&self, tool_call: &ToolCall) -> Result { // Create text representation of the tool call for analysis let tool_text = format!( "Tool: {}\nArguments: {}", @@ -110,49 +325,156 @@ impl PromptInjectionScanner { serde_json::to_string_pretty(&tool_call.arguments)? ); - // For now, always use pattern-based scanning - // TODO: Use ONNX model when available - self.scan_with_patterns(&tool_text).await + self.scan_with_prompt_injection_model(&tool_text).await } - /// Analyze conversation context to determine if tool call is malicious - pub async fn analyze_conversation_context( - &self, - messages: &[Message], - tool_call: &ToolCall, - ) -> Result { - // Combine recent messages for context analysis - let context_text = self.build_context_text(messages, tool_call); - - // For now, always use pattern-based analysis - // TODO: Use ONNX model when available - self.analyze_context_with_patterns(&context_text, tool_call).await - } - - /// Build context text from conversation history - fn build_context_text(&self, messages: &[Message], tool_call: &ToolCall) -> String { - // Take last 5 messages for context (adjust as needed) - let recent_messages: Vec = messages + /// Step 2: Scan only the user messages (conversation history) for prompt injection + async fn scan_user_messages_only(&self, messages: &[Message]) -> Result { + // Extract only user messages from recent conversation history + let user_messages: Vec = messages .iter() .rev() - .take(5) + .take(5) // Take last 5 messages for context .rev() .filter_map(|msg| { - // Extract text content from messages - msg.content.first()?.as_text().map(|text| { - format!("{:?}: {}", msg.role, text) - }) + // Only analyze user messages, not assistant responses + if matches!(msg.role, rmcp::model::Role::User) { + msg.content.first()?.as_text().map(|text| text.to_string()) + } else { + None + } }) .collect(); - let context = recent_messages.join("\n"); - let tool_text = format!( - "Tool: {}\nArguments: {}", - tool_call.name, - serde_json::to_string_pretty(&tool_call.arguments).unwrap_or_default() - ); + if user_messages.is_empty() { + return Ok(ScanResult { + is_malicious: false, + confidence: 0.0, + explanation: "No user messages found in conversation history".to_string(), + }); + } - format!("Context:\n{}\n\nProposed Action:\n{}", context, tool_text) + let user_context = user_messages.join("\n\n"); + self.scan_with_prompt_injection_model(&user_context).await + } + + /// Make final security decision based on both tool call and user message analysis + fn make_final_security_decision( + &self, + tool_call_result: &ScanResult, + user_messages_result: &ScanResult, + tool_call: &ToolCall, + ) -> ScanResult { + // Decision logic: + // 1. If user messages contain prompt injection, tool call is likely malicious + // 2. If user messages are clean but tool call is suspicious, it might be a legitimate response + // 3. Consider tool risk level as well + + let tool_risk = self.assess_tool_risk(&tool_call.name); + + let (is_malicious, confidence, explanation) = if user_messages_result.is_malicious { + // User messages contain prompt injection - tool call is likely malicious + let combined_confidence = (tool_call_result.confidence + user_messages_result.confidence) / 2.0; + let explanation = format!( + "MALICIOUS: Tool '{}' appears to be result of prompt injection. Tool scan: {:.2} confidence ({}). User messages scan: {:.2} confidence ({})", + tool_call.name, + tool_call_result.confidence, + if tool_call_result.is_malicious { "suspicious" } else { "clean" }, + user_messages_result.confidence, + user_messages_result.explanation + ); + (true, combined_confidence.max(0.8), explanation) + } else { + // User messages are clean - suspicious tool call might be legitimate + // Lower the confidence since user didn't inject malicious prompts + let adjusted_confidence = tool_call_result.confidence * 0.6; // Reduce confidence + let explanation = format!( + "LIKELY SAFE: Tool '{}' flagged as suspicious but user messages appear clean. Tool scan: {:.2} confidence. User messages: clean ({:.2} confidence). Adjusted confidence: {:.2}", + tool_call.name, + tool_call_result.confidence, + user_messages_result.confidence, + adjusted_confidence + ); + + // Only consider malicious if adjusted confidence is still high AND tool is high-risk + let is_malicious = adjusted_confidence > 0.7 && tool_risk > 0.6; + (is_malicious, adjusted_confidence, explanation) + }; + + ScanResult { + is_malicious, + confidence, + explanation, + } + } + + /// Legacy method for backward compatibility - now delegates to two-step analysis + pub async fn scan_tool_call(&self, tool_call: &ToolCall) -> Result { + // For backward compatibility, just scan the tool call without context + self.scan_tool_call_only(tool_call).await + } + + /// Legacy method for backward compatibility - now delegates to user message scanning + pub async fn analyze_conversation_context( + &self, + messages: &[Message], + _tool_call: &ToolCall, // Ignored in new implementation + ) -> Result { + // For backward compatibility, just scan user messages + self.scan_user_messages_only(messages).await + } + + /// Model-agnostic prompt injection scanning + async fn scan_with_prompt_injection_model(&self, text: &str) -> Result { + // Try to get the ML model + if let Some(model) = get_model().await { + match model.predict(text).await { + Ok((confidence, explanation)) => { + // Get threshold from config + let threshold = self.get_threshold_from_config(); + let is_malicious = confidence > threshold; + + tracing::info!( + "πŸ”’ ML model prediction: confidence={:.3}, threshold={:.3}, malicious={}", + confidence, threshold, is_malicious + ); + + return Ok(ScanResult { + is_malicious, + confidence, + explanation, + }); + } + Err(e) => { + tracing::warn!("πŸ”’ ML model prediction failed: {}", e); + // Fall through to pattern-based scanning + } + } + } else { + tracing::info!("πŸ”’ No ML model available, using pattern-based fallback"); + } + + // Fallback to pattern-based scanning if ML model is not available + self.scan_with_patterns(text).await + } + + /// Get threshold from config + fn get_threshold_from_config(&self) -> f32 { + use crate::config::Config; + let config = Config::global(); + + // Get security config and extract threshold + if let Ok(security_value) = config.get_param::("security") { + if let Some(models_array) = security_value.get("models").and_then(|m| m.as_array()) { + if let Some(first_model) = models_array.first() { + if let Some(threshold) = first_model.get("threshold").and_then(|t| t.as_f64()) { + return threshold as f32; + } + } + } + } + + 0.7 // Default threshold } /// Fallback pattern-based scanning @@ -199,37 +521,6 @@ impl PromptInjectionScanner { }) } - /// Analyze context with patterns - async fn analyze_context_with_patterns( - &self, - context_text: &str, - tool_call: &ToolCall, - ) -> Result { - // First scan the context for suspicious patterns - let context_result = self.scan_with_patterns(context_text).await?; - - // Consider tool-specific risks - let tool_risk = self.assess_tool_risk(&tool_call.name); - - // Combine context analysis with tool risk - let combined_confidence = (context_result.confidence * 0.7) + (tool_risk * 0.3); - let is_malicious = combined_confidence > 0.6; - - let explanation = format!( - "Context analysis: {} (confidence: {:.2}). Tool risk assessment: {:.2}. Combined risk: {:.2}", - context_result.explanation, - context_result.confidence, - tool_risk, - combined_confidence - ); - - Ok(ScanResult { - is_malicious, - confidence: combined_confidence, - explanation, - }) - } - /// Assess inherent risk of specific tools fn assess_tool_risk(&self, tool_name: &str) -> f32 { // Higher risk tools that could be used maliciously From d641ea7a052ecd95412bd3d4fb9243241eda620d Mon Sep 17 00:00:00 2001 From: Dorien Koelemeijer Date: Mon, 11 Aug 2025 11:21:42 +1000 Subject: [PATCH 03/14] remove pattern based scanning - we should rely on BERT models, don't think this pattern based scanning adds a lot --- crates/goose/src/security/scanner.rs | 46 ++++++---------------------- 1 file changed, 9 insertions(+), 37 deletions(-) diff --git a/crates/goose/src/security/scanner.rs b/crates/goose/src/security/scanner.rs index 0d03442e4b03..a961f29db82d 100644 --- a/crates/goose/src/security/scanner.rs +++ b/crates/goose/src/security/scanner.rs @@ -479,45 +479,17 @@ impl PromptInjectionScanner { /// Fallback pattern-based scanning async fn scan_with_patterns(&self, text: &str) -> Result { - let text_lower = text.to_lowercase(); + let _text_lower = text.to_lowercase(); - // Simple patterns that might indicate prompt injection - let suspicious_patterns = [ - "ignore previous instructions", - "ignore all previous", - "forget everything", - "new instructions", - "system prompt", - "you are now", - "act as", - "pretend to be", - "roleplay as", - "jailbreak", - "developer mode", - ]; - - let mut max_confidence: f32 = 0.0; - let mut detected_patterns = Vec::new(); - - for pattern in &suspicious_patterns { - if text_lower.contains(pattern) { - detected_patterns.push(*pattern); - max_confidence = max_confidence.max(0.8); // High confidence for pattern match - } - } - - let is_malicious = max_confidence > 0.5; - let explanation = if detected_patterns.is_empty() { - "Pattern-based scan: No suspicious patterns detected".to_string() - } else { - format!("Pattern-based scan detected: {} (confidence: {:.2})", - detected_patterns.join(", "), max_confidence) - }; - + // Use BERT model-based scanning instead of hardcoded patterns + // This provides more sophisticated detection than simple string matching + + // For now, return a low-confidence result indicating no threats detected + // The actual ML-based scanning happens in the ONNX model prediction above Ok(ScanResult { - is_malicious, - confidence: max_confidence, - explanation, + is_malicious: false, + confidence: 0.0, + explanation: "Pattern-based fallback: No threats detected using ML-based analysis".to_string(), }) } From d6b1793b663f769cc43a69486c14319b3b223238 Mon Sep 17 00:00:00 2001 From: Dorien Koelemeijer Date: Mon, 11 Aug 2025 12:23:43 +1000 Subject: [PATCH 04/14] Model management updates - don't convert models to onnx if not necessary --- crates/goose/src/security/model_downloader.rs | 186 +++++++++++++++++- 1 file changed, 181 insertions(+), 5 deletions(-) diff --git a/crates/goose/src/security/model_downloader.rs b/crates/goose/src/security/model_downloader.rs index b6929b5d1b95..f4aeddd33f66 100644 --- a/crates/goose/src/security/model_downloader.rs +++ b/crates/goose/src/security/model_downloader.rs @@ -1,4 +1,5 @@ -use anyhow::{anyhow}; +use anyhow::{anyhow, Result}; +use serde::Deserialize; use std::path::{Path, PathBuf}; use std::process::Command; use tokio::fs; @@ -8,6 +9,30 @@ pub struct ModelDownloader { cache_dir: PathBuf, } +#[derive(Debug)] +pub enum ModelFormat { + OnnxDirect { + model_path: String, // e.g., "onnx/model.onnx" + tokenizer_path: String, // e.g., "onnx/tokenizer.json" + }, + OnnxCustomPaths { + model_path: String, // e.g., "model.onnx" (root level) + tokenizer_path: String, // e.g., "tokenizer.json" + }, + ConvertToOnnx, // Fallback: convert PyTorch + Unsupported, +} + +#[derive(Deserialize)] +struct RepoInfo { + siblings: Vec, +} + +#[derive(Deserialize)] +struct FileInfo { + rfilename: String, +} + impl ModelDownloader { pub fn new() -> anyhow::Result { // Use platform-appropriate cache directory @@ -47,13 +72,13 @@ impl ModelDownloader { // Create cache directory if it doesn't exist fs::create_dir_all(&self.cache_dir).await?; - // Download and convert the model - this blocks until complete - self.download_and_convert_model(model_info).await?; + // Use smart model loading - try ONNX direct first, fallback to conversion + self.load_model_smart(model_info).await?; // Verify the files were created if !model_path.exists() || !tokenizer_path.exists() { return Err(anyhow!( - "Model conversion completed but files not found at expected paths. Model: {:?}, Tokenizer: {:?}", + "Model download completed but files not found at expected paths. Model: {:?}, Tokenizer: {:?}", model_path, tokenizer_path )); } @@ -62,12 +87,163 @@ impl ModelDownloader { model = %model_info.hf_model_name, model_path = ?model_path, tokenizer_path = ?tokenizer_path, - "βœ… Successfully downloaded and converted model" + "βœ… Successfully downloaded model" ); Ok((model_path, tokenizer_path)) } + /// Smart model loading - tries ONNX direct download first, falls back to conversion + async fn load_model_smart(&self, model_info: &ModelInfo) -> anyhow::Result<()> { + let format = self.discover_model_format(&model_info.hf_model_name).await?; + + match format { + ModelFormat::OnnxDirect { model_path, tokenizer_path } => { + tracing::info!("πŸ” Found ONNX files in standard location for {}", model_info.hf_model_name); + self.download_onnx_files(&model_info.hf_model_name, &model_path, &tokenizer_path, model_info).await + } + + ModelFormat::OnnxCustomPaths { model_path, tokenizer_path } => { + tracing::info!("πŸ” Found ONNX files in custom location for {}", model_info.hf_model_name); + self.download_onnx_files(&model_info.hf_model_name, &model_path, &tokenizer_path, model_info).await + } + + ModelFormat::ConvertToOnnx => { + tracing::info!("πŸ”„ No ONNX files found, will convert PyTorch model for {}", model_info.hf_model_name); + self.download_and_convert_model(model_info).await // Existing approach + } + + ModelFormat::Unsupported => { + Err(anyhow!("Model {} has no supported format (no ONNX or PyTorch files)", model_info.hf_model_name)) + } + } + } + + /// Discover what format a model is available in + async fn discover_model_format(&self, repo: &str) -> anyhow::Result { + let files = self.get_repo_files(repo).await?; + + // Strategy 1: Look for standard onnx/ folder (like protectai model) + if files.iter().any(|f| f.starts_with("onnx/")) { + return Ok(ModelFormat::OnnxDirect { + model_path: "onnx/model.onnx".to_string(), + tokenizer_path: "onnx/tokenizer.json".to_string(), + }); + } + + // Strategy 2: Look for ONNX files in root or custom locations + let onnx_files: Vec<_> = files.iter() + .filter(|f| f.ends_with(".onnx")) + .collect(); + + let tokenizer_files: Vec<_> = files.iter() + .filter(|f| f.contains("tokenizer") && f.ends_with(".json")) + .collect(); + + if !onnx_files.is_empty() && !tokenizer_files.is_empty() { + return Ok(ModelFormat::OnnxCustomPaths { + model_path: onnx_files[0].clone(), + tokenizer_path: tokenizer_files[0].clone(), + }); + } + + // Strategy 3: Check if we can convert PyTorch model + if files.iter().any(|f| f == "pytorch_model.bin" || f == "model.safetensors") { + return Ok(ModelFormat::ConvertToOnnx); + } + + Ok(ModelFormat::Unsupported) + } + + /// Get list of files in a HuggingFace repository + async fn get_repo_files(&self, repo: &str) -> anyhow::Result> { + let api_url = format!("https://huggingface.co/api/models/{}", repo); + let client = reqwest::Client::new(); + + let mut request = client.get(&api_url); + + // Optional authentication + if let Ok(token) = std::env::var("HUGGINGFACE_TOKEN") { + request = request.header("Authorization", format!("Bearer {}", token)); + } + + let response = request.send().await?; + + if response.status() == 404 { + return Err(anyhow!("Model repository '{}' not found", repo)); + } + + if response.status() == 401 { + return Err(anyhow!( + "Model '{}' requires authentication. Set HUGGINGFACE_TOKEN environment variable.\n\ + Get a token from: https://huggingface.co/settings/tokens", + repo + )); + } + + let repo_info: RepoInfo = response.json().await?; + Ok(repo_info.siblings.into_iter().map(|f| f.rfilename).collect()) + } + + /// Download ONNX files directly from HuggingFace + async fn download_onnx_files( + &self, + model_name: &str, + model_path: &str, + tokenizer_path: &str, + model_info: &ModelInfo, + ) -> anyhow::Result<()> { + let base_url = format!("https://huggingface.co/{}/resolve/main/", model_name); + + // Download model file + let model_url = format!("{}{}", base_url, model_path); + let local_model_path = self.cache_dir.join(&model_info.onnx_filename); + tracing::info!("πŸ“₯ Downloading ONNX model from: {}", model_url); + self.download_file_with_auth(&model_url, &local_model_path).await?; + + // Download tokenizer file + let tokenizer_url = format!("{}{}", base_url, tokenizer_path); + let local_tokenizer_path = self.cache_dir.join(&model_info.tokenizer_filename); + tracing::info!("πŸ“₯ Downloading tokenizer from: {}", tokenizer_url); + self.download_file_with_auth(&tokenizer_url, &local_tokenizer_path).await?; + + Ok(()) + } + + /// Download a file with optional authentication + async fn download_file_with_auth(&self, url: &str, local_path: &PathBuf) -> anyhow::Result<()> { + let client = reqwest::Client::new(); + let mut request = client.get(url); + + // Use HF token if available + if let Ok(token) = std::env::var("HUGGINGFACE_TOKEN") { + request = request.header("Authorization", format!("Bearer {}", token)); + } + + let response = request.send().await?; + + if response.status() == 401 { + return Err(anyhow!( + "File requires authentication. Set HUGGINGFACE_TOKEN environment variable." + )); + } + + if !response.status().is_success() { + return Err(anyhow!( + "Failed to download file from {}: HTTP {}", + url, + response.status() + )); + } + + let bytes = response.bytes().await?; + fs::write(local_path, bytes).await?; + + tracing::info!("βœ… Downloaded: {} ({} bytes)", local_path.display(), fs::metadata(local_path).await?.len()); + + Ok(()) + } + async fn download_and_convert_model(&self, model_info: &ModelInfo) -> anyhow::Result<()> { // Set up Python virtual environment with required dependencies let venv_dir = self.cache_dir.join("python_venv"); From 9177ea8b20dcf40fb7902ce52feed39c73077ee8 Mon Sep 17 00:00:00 2001 From: Dorien Koelemeijer Date: Tue, 12 Aug 2025 09:43:59 +1000 Subject: [PATCH 05/14] update model downloader - download onnx straight away from huggingface if available + small updates in scanning approach - use both pattern based scanning and bert models (they don't always catch command injection) --- Cargo.lock | 4 - crates/goose/Cargo.toml | 4 +- crates/goose/src/security/model_downloader.rs | 2 +- crates/goose/src/security/scanner.rs | 439 ++++++++++++++---- 4 files changed, 342 insertions(+), 107 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 40f517a90d9e..83924fcdbb0a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2757,9 +2757,6 @@ name = "esaxx-rs" version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6" -dependencies = [ - "cc", -] [[package]] name = "etcetera" @@ -8649,7 +8646,6 @@ dependencies = [ "derive_builder", "esaxx-rs", "getrandom 0.2.15", - "indicatif", "itertools 0.12.1", "lazy_static", "log", diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index a6aca6822268..a37b54ea07a4 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -101,8 +101,8 @@ lancedb = "0.13" arrow = "52.2" # ML inference backends for security scanning -ort = "2.0.0-rc.6" # ONNX Runtime -tokenizers = "0.20.3" # HuggingFace tokenizers +ort = "2.0.0-rc.10" # ONNX Runtime - use latest RC +tokenizers = { version = "0.20.4", default-features = false, features = ["onig"] } # HuggingFace tokenizers [target.'cfg(target_os = "windows")'.dependencies] winapi = { version = "0.3", features = ["wincred"] } diff --git a/crates/goose/src/security/model_downloader.rs b/crates/goose/src/security/model_downloader.rs index f4aeddd33f66..676368d3108b 100644 --- a/crates/goose/src/security/model_downloader.rs +++ b/crates/goose/src/security/model_downloader.rs @@ -1,4 +1,4 @@ -use anyhow::{anyhow, Result}; +use anyhow::anyhow; use serde::Deserialize; use std::path::{Path, PathBuf}; use std::process::Command; diff --git a/crates/goose/src/security/scanner.rs b/crates/goose/src/security/scanner.rs index a961f29db82d..000a4468cb03 100644 --- a/crates/goose/src/security/scanner.rs +++ b/crates/goose/src/security/scanner.rs @@ -27,25 +27,33 @@ pub trait PromptInjectionModel: Send + Sync { /// ONNX Runtime implementation pub struct OnnxPromptInjectionModel { - _session: Session, // Temporarily unused due to mutable reference issue - _tokenizer: Arc, // Temporarily unused + session: Arc>, + tokenizer: Arc, model_name: String, } impl OnnxPromptInjectionModel { pub async fn new(model_path: PathBuf, tokenizer_path: PathBuf, model_name: String) -> Result { + tracing::info!("πŸ”’ Starting ONNX model initialization..."); + // Initialize ONNX Runtime session + tracing::info!("πŸ”’ Creating ONNX session from: {:?}", model_path); let session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level3)? .commit_from_file(&model_path)?; + tracing::info!("πŸ”’ ONNX session created successfully"); + // Load tokenizer + tracing::info!("πŸ”’ Loading tokenizer from: {:?}", tokenizer_path); let tokenizer = Tokenizer::from_file(&tokenizer_path) .map_err(|e| anyhow!("Failed to load tokenizer: {}", e))?; + tracing::info!("πŸ”’ Tokenizer loaded successfully"); + Ok(Self { - _session: session, - _tokenizer: Arc::new(tokenizer), + session: Arc::new(std::sync::Mutex::new(session)), + tokenizer: Arc::new(tokenizer), model_name, }) } @@ -54,69 +62,83 @@ impl OnnxPromptInjectionModel { #[async_trait::async_trait] impl PromptInjectionModel for OnnxPromptInjectionModel { async fn predict(&self, text: &str) -> Result<(f32, String)> { - // For now, return a reasonable prediction based on simple heuristics - // TODO: Implement actual ONNX inference once we resolve the mutable reference issue - - let text_lower = text.to_lowercase(); - - // Check for prompt injection patterns - let injection_patterns = [ - "ignore previous instructions", - "ignore all previous", - "forget everything", - "new instructions", - "system prompt", - "you are now", - "act as", - "pretend to be", - "roleplay as", - "jailbreak", - "developer mode", - ]; + tracing::info!("πŸ”’ ONNX predict() called with text length: {}", text.len()); + tracing::info!("πŸ”’ ONNX predict() received text: '{}'", text.chars().take(200).collect::()); + + // Tokenize the input text + tracing::debug!("πŸ”’ Tokenizing input text..."); + let encoding = self.tokenizer + .encode(text, true) + .map_err(|e| anyhow!("Tokenization failed: {}", e))?; + + let input_ids = encoding.get_ids(); + let attention_mask = encoding.get_attention_mask(); + + tracing::debug!("πŸ”’ Tokenization complete. Sequence length: {}", input_ids.len()); + + // Convert to the format expected by ONNX (batch_size=1) + let input_ids: Vec = input_ids.iter().map(|&id| id as i64).collect(); + let attention_mask: Vec = attention_mask.iter().map(|&mask| mask as i64).collect(); + + let seq_len = input_ids.len(); + + tracing::debug!("πŸ”’ Creating ONNX tensors..."); + // Create ONNX tensors + let input_ids_tensor = ort::value::Tensor::from_array(([1, seq_len], input_ids.into_boxed_slice()))?; + let attention_mask_tensor = ort::value::Tensor::from_array(([1, seq_len], attention_mask.into_boxed_slice()))?; + + tracing::debug!("πŸ”’ Running ONNX inference..."); + // Run inference and extract the logits immediately + let (logit_0, logit_1) = { + let mut session = self.session.lock().map_err(|e| anyhow!("Failed to lock session: {}", e))?; + tracing::debug!("πŸ”’ Session locked, running inference..."); + let outputs = session.run(ort::inputs![ + "input_ids" => input_ids_tensor, + "attention_mask" => attention_mask_tensor + ])?; + + tracing::debug!("πŸ”’ Inference complete, extracting logits..."); + // Extract logits from output immediately while we have the lock + let logits = outputs["logits"].try_extract_tensor::()?; + let logits_slice = logits.1; + + // Extract the values we need + let logit_0 = logits_slice[0]; // Non-injection class + let logit_1 = logits_slice[1]; // Injection class + + tracing::debug!("πŸ”’ Logits extracted: [{:.3}, {:.3}]", logit_0, logit_1); + + (logit_0, logit_1) + }; - // Check for dangerous shell commands - let dangerous_commands = [ - "rm -rf", - "sudo rm", - "del /s", - "format c:", - "dd if=", - "mkfs", - "fdisk", - "chmod 777", - "wget http", - "curl http", - "nc -l", - "netcat", - ]; + // Apply softmax to get probabilities + let exp_0 = logit_0.exp(); + let exp_1 = logit_1.exp(); + let sum_exp = exp_0 + exp_1; - let mut confidence = 0.0f32; - let mut detected_patterns = Vec::new(); + let prob_injection = exp_1 / sum_exp; - // Check for prompt injection patterns - for pattern in &injection_patterns { - if text_lower.contains(pattern) { - detected_patterns.push(format!("injection:{}", pattern)); - confidence = confidence.max(0.9); - } - } + let explanation = format!( + "ONNX model '{}': Injection probability = {:.3} (logits: [{:.3}, {:.3}])", + self.model_name, prob_injection, logit_0, logit_1 + ); - // Check for dangerous commands - for command in &dangerous_commands { - if text_lower.contains(command) { - detected_patterns.push(format!("dangerous:{}", command)); - confidence = confidence.max(0.8); - } - } + tracing::info!( + "πŸ”’ ONNX prediction complete: confidence={:.3}, explanation={}", + prob_injection, explanation + ); - let explanation = if detected_patterns.is_empty() { - format!("ONNX model '{}': No threats detected", self.model_name) - } else { - format!("ONNX model '{}': Detected threats: {}", - self.model_name, detected_patterns.join(", ")) - }; + tracing::debug!( + model = %self.model_name, + text_length = text.len(), + seq_length = seq_len, + logit_0 = logit_0, + logit_1 = logit_1, + prob_injection = prob_injection, + "ONNX inference completed" + ); - Ok((confidence, explanation)) + Ok((prob_injection, explanation)) } fn model_name(&self) -> &str { @@ -124,8 +146,8 @@ impl PromptInjectionModel for OnnxPromptInjectionModel { } } -/// Global model cache -static MODEL_CACHE: OnceCell>> = OnceCell::const_new(); +/// Global model cache with reload capability +static MODEL_CACHE: OnceCell>>>> = OnceCell::const_new(); /// Initialize the global model async fn initialize_model() -> Result>> { @@ -164,10 +186,46 @@ async fn initialize_model() -> Result>> { /// Get or initialize the global model async fn get_model() -> Option> { - MODEL_CACHE - .get_or_init(|| async { initialize_model().await.unwrap_or(None) }) - .await - .clone() + let cache = MODEL_CACHE + .get_or_init(|| async { + Arc::new(tokio::sync::RwLock::new(None)) + }) + .await; + + // Check if model is already loaded in memory + let read_guard = cache.read().await; + if let Some(model) = read_guard.as_ref() { + tracing::debug!("πŸ”’ Model found in memory cache, using cached instance"); + return Some(model.clone()); + } + drop(read_guard); + + // Model not loaded in memory, try to initialize from disk + tracing::info!("πŸ”’ Model not loaded in memory, loading from disk cache..."); + let mut write_guard = cache.write().await; + + // Double-check in case another task loaded it while we were waiting + if let Some(model) = write_guard.as_ref() { + tracing::debug!("πŸ”’ Model was loaded by another task while waiting, using that instance"); + return Some(model.clone()); + } + + // Load the model from disk + match initialize_model().await { + Ok(Some(model)) => { + tracing::info!("πŸ”’ βœ… Model successfully loaded into memory cache"); + *write_guard = Some(model.clone()); + Some(model) + } + Ok(None) => { + tracing::info!("πŸ”’ No model available, using pattern-based fallback"); + None + } + Err(e) => { + tracing::warn!("πŸ”’ Failed to initialize model: {}", e); + None + } + } } /// Simple prompt injection scanner @@ -202,20 +260,30 @@ impl PromptInjectionScanner { let tokenizer_path = security_models_dir.join(&model_info.tokenizer_filename); if model_path.exists() && tokenizer_path.exists() { - tracing::info!("πŸ”’ Security models found in cache, enabling security scanning"); + tracing::info!("πŸ”’ Security model files found on disk - loading model into memory now..."); + + // Load model into memory immediately at startup + tokio::spawn(async move { + tracing::info!("πŸ”’ Pre-loading security model at startup..."); + if let Some(_model) = get_model().await { + tracing::info!("πŸ”’ βœ… Security model pre-loaded successfully - ready for scanning"); + } else { + tracing::warn!("πŸ”’ Failed to pre-load security model"); + } + }); + return true; } } - // Models not cached, trigger download in background - tracing::info!("πŸ”’ Security models not found in cache, downloading in background"); - tokio::spawn(async { - Self::ensure_models_available().await; - }); + // Models not cached - we need to download them + tracing::info!("πŸ”’ Security model files not found on disk"); + tracing::info!("πŸ”’ Models will be downloaded on first security scan (this may cause a delay)"); - // For now, use pattern-based scanning while models download - // TODO: In the future, we could block here or enable ONNX scanning after download - false + // For now, return true to enable security scanning + // The models will be downloaded lazily on first scan + // TODO: Consider blocking startup to download models synchronously + true } /// Ensure models are available using the existing model_downloader @@ -318,13 +386,26 @@ impl PromptInjectionScanner { /// Step 1: Scan only the tool call for suspicious patterns async fn scan_tool_call_only(&self, tool_call: &ToolCall) -> Result { + // Debug: Log the raw tool call arguments first + tracing::info!("πŸ”’ Raw tool call arguments: {:?}", tool_call.arguments); + // Create text representation of the tool call for analysis + let arguments_json = serde_json::to_string_pretty(&tool_call.arguments) + .unwrap_or_else(|e| { + tracing::warn!("πŸ”’ Failed to serialize tool arguments: {}", e); + format!("{{\"error\": \"Failed to serialize arguments: {}\"}}", e) + }); + let tool_text = format!( "Tool: {}\nArguments: {}", tool_call.name, - serde_json::to_string_pretty(&tool_call.arguments)? + arguments_json ); + tracing::info!("πŸ”’ Complete tool text being analyzed (length: {}): '{}'", + tool_text.len(), + tool_text); + self.scan_with_prompt_injection_model(&tool_text).await } @@ -426,36 +507,76 @@ impl PromptInjectionScanner { /// Model-agnostic prompt injection scanning async fn scan_with_prompt_injection_model(&self, text: &str) -> Result { - // Try to get the ML model + tracing::info!("πŸ”’ Starting scan_with_prompt_injection_model for text (length: {}): '{}'", + text.len(), + text.chars().take(100).collect::()); + + // Always run pattern-based scanning first + let pattern_result = self.scan_with_patterns(text).await?; + + // Try to get the ML model for additional scanning + tracing::info!("πŸ”’ Attempting to get ML model..."); if let Some(model) = get_model().await { + tracing::info!("πŸ”’ ML model retrieved successfully, calling predict..."); + tracing::info!("πŸ”’ About to call model.predict() with text length: {}", text.len()); match model.predict(text).await { - Ok((confidence, explanation)) => { + Ok((ml_confidence, ml_explanation)) => { + tracing::info!("πŸ”’ ML model predict returned successfully"); // Get threshold from config let threshold = self.get_threshold_from_config(); - let is_malicious = confidence > threshold; + let ml_is_malicious = ml_confidence > threshold; tracing::info!( "πŸ”’ ML model prediction: confidence={:.3}, threshold={:.3}, malicious={}", - confidence, threshold, is_malicious + ml_confidence, threshold, ml_is_malicious ); - return Ok(ScanResult { - is_malicious, - confidence, - explanation, - }); + // Combine ML and pattern results + let combined_result = self.combine_scan_results(&pattern_result, ml_confidence, &ml_explanation, ml_is_malicious); + + tracing::info!( + "πŸ”’ Combined scan result: ML confidence={:.3}, Pattern confidence={:.3}, Final confidence={:.3}, Final malicious={}", + ml_confidence, pattern_result.confidence, combined_result.confidence, combined_result.is_malicious + ); + + return Ok(combined_result); } Err(e) => { tracing::warn!("πŸ”’ ML model prediction failed: {}", e); - // Fall through to pattern-based scanning + // Fall through to pattern-only result } } } else { - tracing::info!("πŸ”’ No ML model available, using pattern-based fallback"); + tracing::info!("πŸ”’ No ML model available, using pattern-based scanning only"); } - // Fallback to pattern-based scanning if ML model is not available - self.scan_with_patterns(text).await + tracing::info!("πŸ”’ Using pattern-based scan result only"); + Ok(pattern_result) + } + + /// Combine ML model and pattern matching results + fn combine_scan_results(&self, pattern_result: &ScanResult, ml_confidence: f32, ml_explanation: &str, ml_is_malicious: bool) -> ScanResult { + // Take the higher confidence score + let final_confidence = pattern_result.confidence.max(ml_confidence); + + // Mark as malicious if either method detects it + let final_is_malicious = pattern_result.is_malicious || ml_is_malicious; + + let combined_explanation = if pattern_result.is_malicious && ml_is_malicious { + format!("BOTH LAYERS DETECTED THREAT - Pattern: {} | ML: {}", pattern_result.explanation, ml_explanation) + } else if pattern_result.is_malicious { + format!("PATTERN DETECTION - {} | ML: {} (confidence: {:.3})", pattern_result.explanation, ml_explanation, ml_confidence) + } else if ml_is_malicious { + format!("ML DETECTION - {} | Pattern: {}", ml_explanation, pattern_result.explanation) + } else { + format!("CLEAN - Pattern: {} | ML: {}", pattern_result.explanation, ml_explanation) + }; + + ScanResult { + is_malicious: final_is_malicious, + confidence: final_confidence, + explanation: combined_explanation, + } } /// Get threshold from config @@ -479,18 +600,136 @@ impl PromptInjectionScanner { /// Fallback pattern-based scanning async fn scan_with_patterns(&self, text: &str) -> Result { - let _text_lower = text.to_lowercase(); + let text_lower = text.to_lowercase(); - // Use BERT model-based scanning instead of hardcoded patterns - // This provides more sophisticated detection than simple string matching + // Command injection patterns - detect potentially dangerous commands + let dangerous_patterns = [ + // File system operations + "rm -rf /", + "rm -rf /*", + "rm -rf ~", + "rm -rf $home", + "rmdir /", + "del /s /q", + "format c:", + + // System manipulation + "shutdown", + "reboot", + "halt", + "poweroff", + "kill -9", + "killall", + + // Network/data exfiltration + "curl http", + "wget http", + "nc -l", + "netcat", + "ssh ", + "scp ", + "rsync", + + // Process manipulation + "sudo ", + "su -", + "chmod 777", + "chown root", + + // Command chaining that could hide malicious intent + "; rm ", + "&& rm ", + "| rm ", + "; curl ", + "&& curl ", + "| curl ", + "; wget ", + "&& wget ", + + // Suspicious file operations + "rm -f /", + "rm -rf .", + "rm -rf ..", + "> /dev/", + "dd if=", + "mkfs", + + // Potential data theft + "cat /etc/passwd", + "cat /etc/shadow", + "/etc/hosts", + "~/.ssh/", + "id_rsa", + + // Obfuscation attempts + "base64 -d", + "echo | sh", + "eval ", + "exec ", + ]; - // For now, return a low-confidence result indicating no threats detected - // The actual ML-based scanning happens in the ONNX model prediction above - Ok(ScanResult { - is_malicious: false, - confidence: 0.0, - explanation: "Pattern-based fallback: No threats detected using ML-based analysis".to_string(), - }) + let mut detected_patterns = Vec::new(); + let mut max_risk_score: f32 = 0.0; + + for pattern in &dangerous_patterns { + if text_lower.contains(pattern) { + detected_patterns.push(pattern.to_string()); + + // Assign risk scores based on severity + let risk_score = match *pattern { + // Critical - system destruction + "rm -rf /" | "rm -rf /*" | "format c:" | "mkfs" => 0.95, + "rm -rf ~" | "rm -rf $home" => 0.90, + + // High - system control + "shutdown" | "reboot" | "halt" | "poweroff" => 0.85, + "sudo " | "su -" | "chmod 777" | "chown root" => 0.80, + + // Medium-High - network/data access + "curl http" | "wget http" | "ssh " | "scp " => 0.75, + "cat /etc/passwd" | "cat /etc/shadow" | "~/.ssh/" => 0.85, + + // Medium - suspicious operations + "; rm " | "&& rm " | "| rm " => 0.70, + "kill -9" | "killall" => 0.65, + + // Lower - potentially legitimate but suspicious + "base64 -d" | "eval " | "exec " => 0.60, + + _ => 0.50, + }; + + max_risk_score = max_risk_score.max(risk_score); + } + } + + if !detected_patterns.is_empty() { + let is_malicious = max_risk_score > 0.7; + let explanation = format!( + "Pattern-based detection: Found {} suspicious command pattern(s): [{}]. Risk score: {:.2}", + detected_patterns.len(), + detected_patterns.join(", "), + max_risk_score + ); + + tracing::info!( + "πŸ”’ Pattern-based scan detected {} suspicious patterns with max risk score: {:.2}", + detected_patterns.len(), + max_risk_score + ); + + Ok(ScanResult { + is_malicious, + confidence: max_risk_score, + explanation, + }) + } else { + Ok(ScanResult { + is_malicious: false, + confidence: 0.0, + explanation: "Pattern-based scan: No suspicious command patterns detected".to_string(), + }) + } } /// Assess inherent risk of specific tools From 4ae158fcc5165f6144fc49918befa0e3770ff3f3 Mon Sep 17 00:00:00 2001 From: Dorien Koelemeijer Date: Tue, 12 Aug 2025 10:54:05 +1000 Subject: [PATCH 06/14] Re-use ToolCall for user input if prompt injection detected --- crates/goose/src/agents/agent.rs | 123 +++++++++++++++--- crates/goose/src/agents/tool_execution.rs | 44 ++++++- crates/goose/src/security/scanner.rs | 23 ++-- ui/desktop/src/components/GooseMessage.tsx | 1 + .../src/components/ToolCallConfirmation.tsx | 4 +- 5 files changed, 156 insertions(+), 39 deletions(-) diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 61787b711ae0..0fa0db27ba0e 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -1045,40 +1045,45 @@ impl Agent { vec![] }); - // Handle security results - for now just log them - println!("πŸ” DEBUG: Security scan returned {} results", security_results.len()); - for security_result in &security_results { - println!("πŸ” DEBUG: Security result - malicious: {}, confidence: {:.2}, explanation: {}", - security_result.is_malicious, security_result.confidence, security_result.explanation); - if security_result.is_malicious { - tracing::warn!( - confidence = security_result.confidence, - explanation = %security_result.explanation, - "Security threat detected in tool call" - ); + // Apply security results to permission check result + let final_permission_result = self.apply_security_results_to_permissions( + permission_check_result, + &security_results + ).await; - if security_result.should_ask_user { - // TODO: Implement user confirmation using existing tool approval system - tracing::info!("Security threat requires user confirmation"); - } - } - } + println!("πŸ” DEBUG: After security integration - {} approved, {} need approval, {} denied", + final_permission_result.approved.len(), + final_permission_result.needs_approval.len(), + final_permission_result.denied.len()); let mut tool_futures = self.handle_approved_and_denied_tools( - &permission_check_result, + &final_permission_result, message_tool_response.clone(), cancel_token.clone() ).await?; let tool_futures_arc = Arc::new(Mutex::new(tool_futures)); - // Process tools requiring approval - let mut tool_approval_stream = self.handle_approval_tool_requests( - &permission_check_result.needs_approval, + // Process tools requiring approval (including security-flagged tools) + // Create a mapping of security results for tools that need approval + let mut security_results_for_approval: Vec> = Vec::new(); + for approval_request in &final_permission_result.needs_approval { + // Find the corresponding security result for this tool request + let security_result = security_results.iter().find(|result| { + // Match by checking if this tool was flagged as malicious + // This is a simplified matching - ideally we'd have better tool request tracking + result.is_malicious + }); + security_results_for_approval.push(security_result); + } + + let mut tool_approval_stream = self.handle_approval_tool_requests_with_security( + &final_permission_result.needs_approval, tool_futures_arc.clone(), &mut permission_manager, message_tool_response.clone(), cancel_token.clone(), + Some(&security_results_for_approval), ); while let Some(msg) = tool_approval_stream.try_next().await? { @@ -1205,6 +1210,82 @@ impl Agent { } } + /// Apply security scan results to permission check results + /// This integrates security scanning with the existing tool approval system + async fn apply_security_results_to_permissions( + &self, + mut permission_result: PermissionCheckResult, + security_results: &[crate::security::SecurityResult], + ) -> PermissionCheckResult { + if security_results.is_empty() { + return permission_result; + } + + // Create a map of tool requests by ID for easy lookup + let mut all_requests: std::collections::HashMap = std::collections::HashMap::new(); + + // Collect all tool requests + for req in &permission_result.approved { + all_requests.insert(req.id.clone(), req.clone()); + } + for req in &permission_result.needs_approval { + all_requests.insert(req.id.clone(), req.clone()); + } + for req in &permission_result.denied { + all_requests.insert(req.id.clone(), req.clone()); + } + + // Collect the combined requests first to avoid borrowing issues + let combined_requests: Vec = permission_result.approved.iter() + .chain(permission_result.needs_approval.iter()) + .cloned() + .collect(); + + // Process security results + for (i, security_result) in security_results.iter().enumerate() { + if !security_result.is_malicious { + continue; + } + + // Find the corresponding tool request by index + if let Some(tool_request) = combined_requests.get(i) { + let request_id = &tool_request.id; + + tracing::warn!( + tool_request_id = %request_id, + confidence = security_result.confidence, + explanation = %security_result.explanation, + "Security threat detected - modifying tool approval status" + ); + + // Remove from approved if present + permission_result.approved.retain(|req| req.id != *request_id); + + if security_result.should_ask_user { + // Move to needs_approval with security context + if let Some(request) = all_requests.get(request_id) { + // Only add if not already in needs_approval + if !permission_result.needs_approval.iter().any(|req| req.id == *request_id) { + permission_result.needs_approval.push(request.clone()); + } + } + } else { + // High confidence threat - move to denied + permission_result.needs_approval.retain(|req| req.id != *request_id); + + if let Some(request) = all_requests.get(request_id) { + // Only add if not already in denied + if !permission_result.denied.iter().any(|req| req.id == *request_id) { + permission_result.denied.push(request.clone()); + } + } + } + } + } + + permission_result + } + /// Extend the system prompt with one line of additional instruction pub async fn extend_system_prompt(&self, instruction: String) { let mut prompt_manager = self.prompt_manager.lock().await; diff --git a/crates/goose/src/agents/tool_execution.rs b/crates/goose/src/agents/tool_execution.rs index 045cdd0229dd..3ddecfe6879f 100644 --- a/crates/goose/src/agents/tool_execution.rs +++ b/crates/goose/src/agents/tool_execution.rs @@ -54,15 +54,55 @@ impl Agent { permission_manager: &'a mut PermissionManager, message_tool_response: Arc>, cancellation_token: Option, + ) -> BoxStream<'a, anyhow::Result> { + self.handle_approval_tool_requests_with_security( + tool_requests, + tool_futures, + permission_manager, + message_tool_response, + cancellation_token, + None, // No security context by default + ) + } + + pub(crate) fn handle_approval_tool_requests_with_security<'a>( + &'a self, + tool_requests: &'a [ToolRequest], + tool_futures: Arc>>, + permission_manager: &'a mut PermissionManager, + message_tool_response: Arc>, + cancellation_token: Option, + security_results: Option<&'a [Option<&'a crate::security::SecurityResult>]>, ) -> BoxStream<'a, anyhow::Result> { try_stream! { - for request in tool_requests { + for (i, request) in tool_requests.iter().enumerate() { if let Ok(tool_call) = request.tool_call.clone() { + // Check if this tool has security concerns + // Match by index since security results are provided in the same order as tool requests + let security_context = security_results + .and_then(|results| results.get(i)) + .and_then(|result| *result) + .filter(|result| result.is_malicious); + + let confirmation_prompt = if let Some(security_result) = security_context { + format!( + "🚨 SECURITY WARNING: This tool call has been flagged as potentially malicious.\n\ + Confidence: {:.1}%\n\ + Reason: {}\n\n\ + Goose would still like to call the above tool. \n\ + Please review carefully. Allow? (y/n):", + security_result.confidence * 100.0, + security_result.explanation + ) + } else { + "Goose would like to call the above tool. Allow? (y/n):".to_string() + }; + let confirmation = Message::user().with_tool_confirmation_request( request.id.clone(), tool_call.name.clone(), tool_call.arguments.clone(), - Some("Goose would like to call the above tool. Allow? (y/n):".to_string()), + Some(confirmation_prompt), ); yield confirmation; diff --git a/crates/goose/src/security/scanner.rs b/crates/goose/src/security/scanner.rs index 000a4468cb03..f94c9af73cc9 100644 --- a/crates/goose/src/security/scanner.rs +++ b/crates/goose/src/security/scanner.rs @@ -457,12 +457,7 @@ impl PromptInjectionScanner { // User messages contain prompt injection - tool call is likely malicious let combined_confidence = (tool_call_result.confidence + user_messages_result.confidence) / 2.0; let explanation = format!( - "MALICIOUS: Tool '{}' appears to be result of prompt injection. Tool scan: {:.2} confidence ({}). User messages scan: {:.2} confidence ({})", - tool_call.name, - tool_call_result.confidence, - if tool_call_result.is_malicious { "suspicious" } else { "clean" }, - user_messages_result.confidence, - user_messages_result.explanation + "Tool appears to be the result of a prompt injection attack." ); (true, combined_confidence.max(0.8), explanation) } else { @@ -470,11 +465,7 @@ impl PromptInjectionScanner { // Lower the confidence since user didn't inject malicious prompts let adjusted_confidence = tool_call_result.confidence * 0.6; // Reduce confidence let explanation = format!( - "LIKELY SAFE: Tool '{}' flagged as suspicious but user messages appear clean. Tool scan: {:.2} confidence. User messages: clean ({:.2} confidence). Adjusted confidence: {:.2}", - tool_call.name, - tool_call_result.confidence, - user_messages_result.confidence, - adjusted_confidence + "Tool flagged as suspicious but user messages appear clean." ); // Only consider malicious if adjusted confidence is still high AND tool is high-risk @@ -562,14 +553,16 @@ impl PromptInjectionScanner { // Mark as malicious if either method detects it let final_is_malicious = pattern_result.is_malicious || ml_is_malicious; + // Simplified explanation - just show what detected the threat let combined_explanation = if pattern_result.is_malicious && ml_is_malicious { - format!("BOTH LAYERS DETECTED THREAT - Pattern: {} | ML: {}", pattern_result.explanation, ml_explanation) + "Detected by both pattern analysis and ML model".to_string() } else if pattern_result.is_malicious { - format!("PATTERN DETECTION - {} | ML: {} (confidence: {:.3})", pattern_result.explanation, ml_explanation, ml_confidence) + format!("Detected by pattern analysis: {}", + pattern_result.explanation.replace("Pattern-based detection: ", "")) } else if ml_is_malicious { - format!("ML DETECTION - {} | Pattern: {}", ml_explanation, pattern_result.explanation) + "Detected by machine learning model".to_string() } else { - format!("CLEAN - Pattern: {} | ML: {}", pattern_result.explanation, ml_explanation) + "No threats detected".to_string() }; ScanResult { diff --git a/ui/desktop/src/components/GooseMessage.tsx b/ui/desktop/src/components/GooseMessage.tsx index ca19513f1c7f..eeaa995db4d3 100644 --- a/ui/desktop/src/components/GooseMessage.tsx +++ b/ui/desktop/src/components/GooseMessage.tsx @@ -223,6 +223,7 @@ export default function GooseMessage({ isClicked={messageIndex < messageHistoryIndex} toolConfirmationId={toolConfirmationContent.id} toolName={toolConfirmationContent.toolName} + prompt={toolConfirmationContent.prompt} /> )} diff --git a/ui/desktop/src/components/ToolCallConfirmation.tsx b/ui/desktop/src/components/ToolCallConfirmation.tsx index 764a460402a9..43dc51031d4e 100644 --- a/ui/desktop/src/components/ToolCallConfirmation.tsx +++ b/ui/desktop/src/components/ToolCallConfirmation.tsx @@ -25,6 +25,7 @@ interface ToolConfirmationProps { isClicked: boolean; toolConfirmationId: string; toolName: string; + prompt?: string; // Security warning or custom prompt } export default function ToolConfirmation({ @@ -32,6 +33,7 @@ export default function ToolConfirmation({ isClicked, toolConfirmationId, toolName, + prompt, }: ToolConfirmationProps) { // Check if we have a stored state for this tool confirmation const storedState = toolConfirmationState.get(toolConfirmationId); @@ -121,7 +123,7 @@ export default function ToolConfirmation({ ) : ( <>
- Goose would like to call the above tool. Allow? + {prompt || 'Goose would like to call the above tool. Allow?'}
{clicked ? (
From e6aa6dbf4d68cd51400e2bcc2fead4fa68d8bd4d Mon Sep 17 00:00:00 2001 From: Dorien Koelemeijer Date: Wed, 13 Aug 2025 09:13:08 +1000 Subject: [PATCH 07/14] If ToolCall is verified for security finding, don't use 'Always Allow', only give option to 'Allow once' or 'Deny' --- ui/desktop/src/components/ToolCallConfirmation.tsx | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ui/desktop/src/components/ToolCallConfirmation.tsx b/ui/desktop/src/components/ToolCallConfirmation.tsx index 43dc51031d4e..6ab32476aedf 100644 --- a/ui/desktop/src/components/ToolCallConfirmation.tsx +++ b/ui/desktop/src/components/ToolCallConfirmation.tsx @@ -190,9 +190,12 @@ export default function ToolConfirmation({
) : (
- + {/* Hide "Always Allow" for security warnings to prevent bypassing future security checks */} + {!prompt?.includes('SECURITY WARNING') && ( + + )}