diff --git a/crates/goose/src/security/scanner.rs b/crates/goose/src/security/scanner.rs index d0bc3d9ff8d4..2cb7612b23c5 100644 --- a/crates/goose/src/security/scanner.rs +++ b/crates/goose/src/security/scanner.rs @@ -153,18 +153,29 @@ impl PromptInjectionScanner { threshold ); - let final_result = - self.select_result_with_context_awareness(tool_result, context_result, threshold); + let final_confidence = + self.combine_confidences(tool_result.confidence, context_result.confidence); tracing::info!( - "Security analysis complete: final_confidence={:.3}, malicious={}", - final_result.confidence, - final_result.confidence >= threshold + tool_confidence = %tool_result.confidence, + context_confidence = %context_result.confidence, + final_confidence = %final_confidence, + has_ml = tool_result.ml_confidence.is_some(), + has_patterns = !tool_result.pattern_matches.is_empty(), + threshold = %threshold, + malicious = final_confidence >= threshold, + "Security analysis complete" ); + let final_result = DetailedScanResult { + confidence: final_confidence, + pattern_matches: tool_result.pattern_matches, + ml_confidence: tool_result.ml_confidence, + }; + Ok(ScanResult { - is_malicious: final_result.confidence >= threshold, - confidence: final_result.confidence, + is_malicious: final_confidence >= threshold, + confidence: final_confidence, explanation: self.build_explanation(&final_result, threshold, &tool_content), }) } @@ -228,33 +239,23 @@ impl PromptInjectionScanner { }) } - fn select_result_with_context_awareness( - &self, - tool_result: DetailedScanResult, - context_result: DetailedScanResult, - threshold: f32, - ) -> DetailedScanResult { - let context_is_safe = context_result - .ml_confidence - .is_some_and(|conf| conf < threshold); - - let tool_has_only_non_critical = !tool_result.pattern_matches.is_empty() - && tool_result - .pattern_matches - .iter() - .all(|m| m.threat.risk_level != crate::security::patterns::RiskLevel::Critical); - - if context_is_safe && tool_has_only_non_critical { - DetailedScanResult { - confidence: 0.0, - pattern_matches: Vec::new(), - ml_confidence: context_result.ml_confidence, - } - } else if tool_result.confidence >= context_result.confidence { - tool_result - } else { - context_result + fn combine_confidences(&self, tool_confidence: f32, context_confidence: f32) -> f32 { + // If tool is safe, context is not taken into account + if tool_confidence < 0.3 { + return tool_confidence; + } + + if context_confidence < 0.3 { + return tool_confidence * 0.9; } + + if tool_confidence > 0.8 && context_confidence > 0.8 { + let max_conf = tool_confidence.max(context_confidence); + return (max_conf * 1.05).min(1.0); + } + + // Default: weighted average (tool is primary signal) + tool_confidence * 0.8 + context_confidence * 0.2 } async fn scan_with_classifier(