-
Notifications
You must be signed in to change notification settings - Fork 2.8k
feat: Prompt injection detection #4021
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cc30637
fe0c392
d641ea7
d6b1793
9177ea8
4ae158f
e6aa6db
15bc55f
af39972
b09d04e
d7dfee9
9b28c60
2847b80
8d9c190
aeee29d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,6 +62,7 @@ use crate::agents::todo_tools::{ | |
| todo_read_tool, todo_write_tool, TODO_READ_TOOL_NAME, TODO_WRITE_TOOL_NAME, | ||
| }; | ||
| use crate::conversation::message::{Message, ToolRequest}; | ||
| use crate::security::SecurityManager; | ||
|
|
||
| const DEFAULT_MAX_TURNS: u32 = 1000; | ||
|
|
||
|
|
@@ -102,6 +103,7 @@ pub struct Agent { | |
| pub(super) tool_route_manager: ToolRouteManager, | ||
| pub(super) scheduler_service: Mutex<Option<Arc<dyn SchedulerTrait>>>, | ||
| pub(super) retry_manager: RetryManager, | ||
| pub(super) security_manager: SecurityManager, | ||
| pub(super) todo_list: Arc<Mutex<String>>, | ||
| } | ||
|
|
||
|
|
@@ -188,6 +190,7 @@ impl Agent { | |
| tool_route_manager: ToolRouteManager::new(), | ||
| scheduler_service: Mutex::new(None), | ||
| retry_manager, | ||
| security_manager: SecurityManager::new(), | ||
| todo_list: Arc::new(Mutex::new(String::new())), | ||
| } | ||
| } | ||
|
|
@@ -1078,6 +1081,29 @@ impl Agent { | |
| ); | ||
| } | ||
| } else { | ||
| // Check if we need to show model download status before security scanning | ||
| if let Some(download_message) = self.security_manager.check_model_download_status().await { | ||
| yield AgentEvent::Message(Message::assistant().with_text(download_message)); | ||
| } | ||
|
|
||
| // SECURITY FIX: Scan tools for prompt injection BEFORE permission checking | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is probably fine for testing, but we should refactor this so we have a generic way to check whether we want to run a tool that says yes/no/ask the user with prompt and then take the least permissive version of that, i.e. if the counter says no, don't even run the security thing if that makes sense |
||
| // This ensures security results can override auto-mode approvals | ||
| let initial_permission_result = PermissionCheckResult { | ||
| approved: remaining_requests.clone(), | ||
| needs_approval: vec![], | ||
| denied: vec![], | ||
| }; | ||
|
|
||
| println!("🔍 DEBUG: About to call security manager with {} total tools", remaining_requests.len()); | ||
| let security_results = self.security_manager | ||
| .filter_malicious_tool_calls(messages.messages(), &initial_permission_result, Some(&system_prompt)) | ||
| .await | ||
| .unwrap_or_else(|e| { | ||
| tracing::warn!("Security scanning failed: {}", e); | ||
| vec![] | ||
| }); | ||
|
|
||
| // Now run permission checking with security context | ||
| let mut permission_manager = PermissionManager::default(); | ||
| let (permission_check_result, enable_extension_request_ids) = | ||
| check_tool_permissions( | ||
|
|
@@ -1089,21 +1115,45 @@ impl Agent { | |
| self.provider().await?, | ||
| ).await; | ||
|
|
||
| // Apply security results to override permission decisions | ||
| let final_permission_result = self.apply_security_results_to_permissions( | ||
| permission_check_result, | ||
| &security_results | ||
| ).await; | ||
|
|
||
| println!("🔍 DEBUG: After security integration - {} approved, {} need approval, {} denied", | ||
| final_permission_result.approved.len(), | ||
| final_permission_result.needs_approval.len(), | ||
| final_permission_result.denied.len()); | ||
|
|
||
| let mut tool_futures = self.handle_approved_and_denied_tools( | ||
| &permission_check_result, | ||
| &final_permission_result, | ||
| message_tool_response.clone(), | ||
| cancel_token.clone() | ||
| ).await?; | ||
|
|
||
| let tool_futures_arc = Arc::new(Mutex::new(tool_futures)); | ||
|
|
||
| // Process tools requiring approval | ||
| let mut tool_approval_stream = self.handle_approval_tool_requests( | ||
| &permission_check_result.needs_approval, | ||
| // Process tools requiring approval (including security-flagged tools) | ||
| // Create a mapping of security results for tools that need approval | ||
| let mut security_results_for_approval: Vec<Option<&crate::security::SecurityResult>> = Vec::new(); | ||
| for _approval_request in &final_permission_result.needs_approval { | ||
| // Find the corresponding security result for this tool request | ||
| let security_result = security_results.iter().find(|result| { | ||
| // Match by checking if this tool was flagged as malicious | ||
| // This is a simplified matching - ideally we'd have better tool request tracking | ||
| result.is_malicious | ||
| }); | ||
| security_results_for_approval.push(security_result); | ||
| } | ||
|
|
||
| let mut tool_approval_stream = self.handle_approval_tool_requests_with_security( | ||
| &final_permission_result.needs_approval, | ||
| tool_futures_arc.clone(), | ||
| &mut permission_manager, | ||
| message_tool_response.clone(), | ||
| cancel_token.clone(), | ||
| Some(&security_results_for_approval), | ||
| ); | ||
|
|
||
| while let Some(msg) = tool_approval_stream.try_next().await? { | ||
|
|
@@ -1230,6 +1280,98 @@ impl Agent { | |
| } | ||
| } | ||
|
|
||
| /// Apply security scan results to permission check results | ||
| /// This integrates security scanning with the existing tool approval system | ||
| async fn apply_security_results_to_permissions( | ||
| &self, | ||
| mut permission_result: PermissionCheckResult, | ||
| security_results: &[crate::security::SecurityResult], | ||
| ) -> PermissionCheckResult { | ||
| if security_results.is_empty() { | ||
| return permission_result; | ||
| } | ||
|
|
||
| // Create a map of tool requests by ID for easy lookup | ||
| let mut all_requests: std::collections::HashMap<String, ToolRequest> = | ||
| std::collections::HashMap::new(); | ||
|
|
||
| // Collect all tool requests | ||
| for req in &permission_result.approved { | ||
| all_requests.insert(req.id.clone(), req.clone()); | ||
| } | ||
| for req in &permission_result.needs_approval { | ||
| all_requests.insert(req.id.clone(), req.clone()); | ||
| } | ||
| for req in &permission_result.denied { | ||
| all_requests.insert(req.id.clone(), req.clone()); | ||
| } | ||
|
|
||
| // Collect the combined requests first to avoid borrowing issues | ||
| let combined_requests: Vec<ToolRequest> = permission_result | ||
| .approved | ||
| .iter() | ||
| .chain(permission_result.needs_approval.iter()) | ||
| .cloned() | ||
| .collect(); | ||
|
|
||
| // Process security results | ||
| for (i, security_result) in security_results.iter().enumerate() { | ||
| if !security_result.is_malicious { | ||
| continue; | ||
| } | ||
|
|
||
| // Find the corresponding tool request by index | ||
| if let Some(tool_request) = combined_requests.get(i) { | ||
| let request_id = &tool_request.id; | ||
|
|
||
| tracing::warn!( | ||
| tool_request_id = %request_id, | ||
| confidence = security_result.confidence, | ||
| explanation = %security_result.explanation, | ||
| finding_id = %security_result.finding_id, | ||
| "🔒 Security threat detected - modifying tool approval status" | ||
| ); | ||
|
|
||
| // Remove from approved if present | ||
| permission_result | ||
| .approved | ||
| .retain(|req| req.id != *request_id); | ||
|
|
||
| if security_result.should_ask_user { | ||
| // Move to needs_approval with security context | ||
| if let Some(request) = all_requests.get(request_id) { | ||
| // Only add if not already in needs_approval | ||
| if !permission_result | ||
| .needs_approval | ||
| .iter() | ||
| .any(|req| req.id == *request_id) | ||
| { | ||
| permission_result.needs_approval.push(request.clone()); | ||
| } | ||
| } | ||
| } else { | ||
| // High confidence threat - move to denied | ||
| permission_result | ||
| .needs_approval | ||
| .retain(|req| req.id != *request_id); | ||
|
|
||
| if let Some(request) = all_requests.get(request_id) { | ||
| // Only add if not already in denied | ||
| if !permission_result | ||
| .denied | ||
| .iter() | ||
| .any(|req| req.id == *request_id) | ||
| { | ||
| permission_result.denied.push(request.clone()); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| permission_result | ||
| } | ||
|
|
||
| /// Extend the system prompt with one line of additional instruction | ||
| pub async fn extend_system_prompt(&self, instruction: String) { | ||
| let mut prompt_manager = self.prompt_manager.lock().await; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -54,21 +54,99 @@ impl Agent { | |
| permission_manager: &'a mut PermissionManager, | ||
| message_tool_response: Arc<Mutex<Message>>, | ||
| cancellation_token: Option<CancellationToken>, | ||
| ) -> BoxStream<'a, anyhow::Result<Message>> { | ||
| self.handle_approval_tool_requests_with_security( | ||
| tool_requests, | ||
| tool_futures, | ||
| permission_manager, | ||
| message_tool_response, | ||
| cancellation_token, | ||
| None, // No security context by default | ||
| ) | ||
| } | ||
|
|
||
| pub(crate) fn handle_approval_tool_requests_with_security<'a>( | ||
| &'a self, | ||
| tool_requests: &'a [ToolRequest], | ||
| tool_futures: Arc<Mutex<Vec<(String, ToolStream)>>>, | ||
| permission_manager: &'a mut PermissionManager, | ||
| message_tool_response: Arc<Mutex<Message>>, | ||
| cancellation_token: Option<CancellationToken>, | ||
| security_results: Option<&'a [Option<&'a crate::security::SecurityResult>]>, | ||
| ) -> BoxStream<'a, anyhow::Result<Message>> { | ||
| try_stream! { | ||
| for request in tool_requests { | ||
| for (i, request) in tool_requests.iter().enumerate() { | ||
| if let Ok(tool_call) = request.tool_call.clone() { | ||
| // Check if this tool has security concerns | ||
| // Match by index since security results are provided in the same order as tool requests | ||
| let security_context = security_results | ||
| .and_then(|results| results.get(i)) | ||
| .and_then(|result| *result) | ||
| .filter(|result| result.is_malicious); | ||
|
|
||
| let confirmation_prompt = if let Some(security_result) = security_context { | ||
| format!( | ||
| "🚨 SECURITY WARNING: This tool call has been flagged as potentially malicious.\n\ | ||
| Finding ID: {}\n\ | ||
| Confidence: {:.1}%\n\ | ||
| Reason: {}\n\n\ | ||
| Goose would still like to call the above tool. \n\ | ||
| Please review carefully. Allow? (y/n):", | ||
| security_result.finding_id, | ||
| security_result.confidence * 100.0, | ||
| security_result.explanation | ||
| ) | ||
| } else { | ||
| "Goose would like to call the above tool. Allow? (y/n):".to_string() | ||
| }; | ||
|
|
||
| let confirmation = Message::user().with_tool_confirmation_request( | ||
| request.id.clone(), | ||
| tool_call.name.clone(), | ||
| tool_call.arguments.clone(), | ||
| Some("Goose would like to call the above tool. Allow? (y/n):".to_string()), | ||
| Some(confirmation_prompt), | ||
| ); | ||
| yield confirmation; | ||
|
|
||
| let mut rx = self.confirmation_rx.lock().await; | ||
| while let Some((req_id, confirmation)) = rx.recv().await { | ||
| if req_id == request.id { | ||
| // Log user decision, especially for security-flagged tools | ||
| if let Some(security_result) = security_context { | ||
| match confirmation.permission { | ||
| Permission::AllowOnce | Permission::AlwaysAllow => { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what does allows allow mean in this context?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It means the request gets processed further.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, I mean, AlwaysAllow doesn't make that much sense in this context but I guess you do want to catch all the arms of the match even though the client shouldn't return that here. so maybe unreachable is better? |
||
| tracing::warn!( | ||
| tool_name = %tool_call.name, | ||
| request_id = %request.id, | ||
| permission = ?confirmation.permission, | ||
| security_confidence = %format!("{:.1}%", security_result.confidence * 100.0), | ||
| security_reason = %security_result.explanation, | ||
| finding_id = %security_result.finding_id, | ||
| "🔒 USER APPROVED security-flagged tool despite warning" | ||
| ); | ||
| } | ||
| _ => { | ||
| tracing::info!( | ||
| tool_name = %tool_call.name, | ||
| request_id = %request.id, | ||
| permission = ?confirmation.permission, | ||
| security_confidence = %format!("{:.1}%", security_result.confidence * 100.0), | ||
| security_reason = %security_result.explanation, | ||
| finding_id = %security_result.finding_id, | ||
| "🔒 USER DENIED security-flagged tool" | ||
| ); | ||
| } | ||
| } | ||
| } else { | ||
| // Log regular tool decisions at debug level | ||
| tracing::debug!( | ||
| tool_name = %tool_call.name, | ||
| request_id = %request.id, | ||
| permission = ?confirmation.permission, | ||
| "🔒 User decision for tool execution" | ||
| ); | ||
| } | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should be able to simplify this logging block I think and reduce the code duplication |
||
|
|
||
| if confirmation.permission == Permission::AllowOnce || confirmation.permission == Permission::AlwaysAllow { | ||
| let (req_id, tool_result) = self.dispatch_tool_call(tool_call.clone(), request.id.clone(), cancellation_token.clone()).await; | ||
| let mut futures = tool_futures.lock().await; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how much does this add to the executable? do we need the huggingface tokenizer? can we not get this done using tiktoken?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies, didn't realise you'd left feedback - having a look now. Thanks for the feedback 🙏