Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
349 changes: 338 additions & 11 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions crates/goose/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ unicode-normalization = "0.1"
lancedb = "0.13"
arrow = "52.2"

# ML inference backends for security scanning
ort = "2.0.0-rc.10" # ONNX Runtime - use latest RC
tokenizers = { version = "0.20.4", default-features = false, features = ["onig"] } # HuggingFace tokenizers

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how much does this add to the executable? do we need the huggingface tokenizer? can we not get this done using tiktoken?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies, didn't realise you'd left feedback - having a look now. Thanks for the feedback 🙏

[target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "0.3", features = ["wincred"] }

Expand Down
150 changes: 146 additions & 4 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ use crate::agents::todo_tools::{
todo_read_tool, todo_write_tool, TODO_READ_TOOL_NAME, TODO_WRITE_TOOL_NAME,
};
use crate::conversation::message::{Message, ToolRequest};
use crate::security::SecurityManager;

const DEFAULT_MAX_TURNS: u32 = 1000;

Expand Down Expand Up @@ -102,6 +103,7 @@ pub struct Agent {
pub(super) tool_route_manager: ToolRouteManager,
pub(super) scheduler_service: Mutex<Option<Arc<dyn SchedulerTrait>>>,
pub(super) retry_manager: RetryManager,
pub(super) security_manager: SecurityManager,
pub(super) todo_list: Arc<Mutex<String>>,
}

Expand Down Expand Up @@ -188,6 +190,7 @@ impl Agent {
tool_route_manager: ToolRouteManager::new(),
scheduler_service: Mutex::new(None),
retry_manager,
security_manager: SecurityManager::new(),
todo_list: Arc::new(Mutex::new(String::new())),
}
}
Expand Down Expand Up @@ -1078,6 +1081,29 @@ impl Agent {
);
}
} else {
// Check if we need to show model download status before security scanning
if let Some(download_message) = self.security_manager.check_model_download_status().await {
yield AgentEvent::Message(Message::assistant().with_text(download_message));
}

// SECURITY FIX: Scan tools for prompt injection BEFORE permission checking
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is probably fine for testing, but we should refactor this so we have a generic way to check whether we want to run a tool that says yes/no/ask the user with prompt and then take the least permissive version of that, i.e. if the counter says no, don't even run the security thing if that makes sense

// This ensures security results can override auto-mode approvals
let initial_permission_result = PermissionCheckResult {
approved: remaining_requests.clone(),
needs_approval: vec![],
denied: vec![],
};

println!("🔍 DEBUG: About to call security manager with {} total tools", remaining_requests.len());
let security_results = self.security_manager
.filter_malicious_tool_calls(messages.messages(), &initial_permission_result, Some(&system_prompt))
.await
.unwrap_or_else(|e| {
tracing::warn!("Security scanning failed: {}", e);
vec![]
});

// Now run permission checking with security context
let mut permission_manager = PermissionManager::default();
let (permission_check_result, enable_extension_request_ids) =
check_tool_permissions(
Expand All @@ -1089,21 +1115,45 @@ impl Agent {
self.provider().await?,
).await;

// Apply security results to override permission decisions
let final_permission_result = self.apply_security_results_to_permissions(
permission_check_result,
&security_results
).await;

println!("🔍 DEBUG: After security integration - {} approved, {} need approval, {} denied",
final_permission_result.approved.len(),
final_permission_result.needs_approval.len(),
final_permission_result.denied.len());

let mut tool_futures = self.handle_approved_and_denied_tools(
&permission_check_result,
&final_permission_result,
message_tool_response.clone(),
cancel_token.clone()
).await?;

let tool_futures_arc = Arc::new(Mutex::new(tool_futures));

// Process tools requiring approval
let mut tool_approval_stream = self.handle_approval_tool_requests(
&permission_check_result.needs_approval,
// Process tools requiring approval (including security-flagged tools)
// Create a mapping of security results for tools that need approval
let mut security_results_for_approval: Vec<Option<&crate::security::SecurityResult>> = Vec::new();
for _approval_request in &final_permission_result.needs_approval {
// Find the corresponding security result for this tool request
let security_result = security_results.iter().find(|result| {
// Match by checking if this tool was flagged as malicious
// This is a simplified matching - ideally we'd have better tool request tracking
result.is_malicious
});
security_results_for_approval.push(security_result);
}

let mut tool_approval_stream = self.handle_approval_tool_requests_with_security(
&final_permission_result.needs_approval,
tool_futures_arc.clone(),
&mut permission_manager,
message_tool_response.clone(),
cancel_token.clone(),
Some(&security_results_for_approval),
);

while let Some(msg) = tool_approval_stream.try_next().await? {
Expand Down Expand Up @@ -1230,6 +1280,98 @@ impl Agent {
}
}

/// Apply security scan results to permission check results
/// This integrates security scanning with the existing tool approval system
async fn apply_security_results_to_permissions(
&self,
mut permission_result: PermissionCheckResult,
security_results: &[crate::security::SecurityResult],
) -> PermissionCheckResult {
if security_results.is_empty() {
return permission_result;
}

// Create a map of tool requests by ID for easy lookup
let mut all_requests: std::collections::HashMap<String, ToolRequest> =
std::collections::HashMap::new();

// Collect all tool requests
for req in &permission_result.approved {
all_requests.insert(req.id.clone(), req.clone());
}
for req in &permission_result.needs_approval {
all_requests.insert(req.id.clone(), req.clone());
}
for req in &permission_result.denied {
all_requests.insert(req.id.clone(), req.clone());
}

// Collect the combined requests first to avoid borrowing issues
let combined_requests: Vec<ToolRequest> = permission_result
.approved
.iter()
.chain(permission_result.needs_approval.iter())
.cloned()
.collect();

// Process security results
for (i, security_result) in security_results.iter().enumerate() {
if !security_result.is_malicious {
continue;
}

// Find the corresponding tool request by index
if let Some(tool_request) = combined_requests.get(i) {
let request_id = &tool_request.id;

tracing::warn!(
tool_request_id = %request_id,
confidence = security_result.confidence,
explanation = %security_result.explanation,
finding_id = %security_result.finding_id,
"🔒 Security threat detected - modifying tool approval status"
);

// Remove from approved if present
permission_result
.approved
.retain(|req| req.id != *request_id);

if security_result.should_ask_user {
// Move to needs_approval with security context
if let Some(request) = all_requests.get(request_id) {
// Only add if not already in needs_approval
if !permission_result
.needs_approval
.iter()
.any(|req| req.id == *request_id)
{
permission_result.needs_approval.push(request.clone());
}
}
} else {
// High confidence threat - move to denied
permission_result
.needs_approval
.retain(|req| req.id != *request_id);

if let Some(request) = all_requests.get(request_id) {
// Only add if not already in denied
if !permission_result
.denied
.iter()
.any(|req| req.id == *request_id)
{
permission_result.denied.push(request.clone());
}
}
}
}
}

permission_result
}

/// Extend the system prompt with one line of additional instruction
pub async fn extend_system_prompt(&self, instruction: String) {
let mut prompt_manager = self.prompt_manager.lock().await;
Expand Down
82 changes: 80 additions & 2 deletions crates/goose/src/agents/tool_execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,21 +54,99 @@ impl Agent {
permission_manager: &'a mut PermissionManager,
message_tool_response: Arc<Mutex<Message>>,
cancellation_token: Option<CancellationToken>,
) -> BoxStream<'a, anyhow::Result<Message>> {
self.handle_approval_tool_requests_with_security(
tool_requests,
tool_futures,
permission_manager,
message_tool_response,
cancellation_token,
None, // No security context by default
)
}

pub(crate) fn handle_approval_tool_requests_with_security<'a>(
&'a self,
tool_requests: &'a [ToolRequest],
tool_futures: Arc<Mutex<Vec<(String, ToolStream)>>>,
permission_manager: &'a mut PermissionManager,
message_tool_response: Arc<Mutex<Message>>,
cancellation_token: Option<CancellationToken>,
security_results: Option<&'a [Option<&'a crate::security::SecurityResult>]>,
) -> BoxStream<'a, anyhow::Result<Message>> {
try_stream! {
for request in tool_requests {
for (i, request) in tool_requests.iter().enumerate() {
if let Ok(tool_call) = request.tool_call.clone() {
// Check if this tool has security concerns
// Match by index since security results are provided in the same order as tool requests
let security_context = security_results
.and_then(|results| results.get(i))
.and_then(|result| *result)
.filter(|result| result.is_malicious);

let confirmation_prompt = if let Some(security_result) = security_context {
format!(
"🚨 SECURITY WARNING: This tool call has been flagged as potentially malicious.\n\
Finding ID: {}\n\
Confidence: {:.1}%\n\
Reason: {}\n\n\
Goose would still like to call the above tool. \n\
Please review carefully. Allow? (y/n):",
security_result.finding_id,
security_result.confidence * 100.0,
security_result.explanation
)
} else {
"Goose would like to call the above tool. Allow? (y/n):".to_string()
};

let confirmation = Message::user().with_tool_confirmation_request(
request.id.clone(),
tool_call.name.clone(),
tool_call.arguments.clone(),
Some("Goose would like to call the above tool. Allow? (y/n):".to_string()),
Some(confirmation_prompt),
);
yield confirmation;

let mut rx = self.confirmation_rx.lock().await;
while let Some((req_id, confirmation)) = rx.recv().await {
if req_id == request.id {
// Log user decision, especially for security-flagged tools
if let Some(security_result) = security_context {
match confirmation.permission {
Permission::AllowOnce | Permission::AlwaysAllow => {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does allows allow mean in this context?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It means the request gets processed further.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I mean, AlwaysAllow doesn't make that much sense in this context but I guess you do want to catch all the arms of the match even though the client shouldn't return that here. so maybe unreachable is better?

tracing::warn!(
tool_name = %tool_call.name,
request_id = %request.id,
permission = ?confirmation.permission,
security_confidence = %format!("{:.1}%", security_result.confidence * 100.0),
security_reason = %security_result.explanation,
finding_id = %security_result.finding_id,
"🔒 USER APPROVED security-flagged tool despite warning"
);
}
_ => {
tracing::info!(
tool_name = %tool_call.name,
request_id = %request.id,
permission = ?confirmation.permission,
security_confidence = %format!("{:.1}%", security_result.confidence * 100.0),
security_reason = %security_result.explanation,
finding_id = %security_result.finding_id,
"🔒 USER DENIED security-flagged tool"
);
}
}
} else {
// Log regular tool decisions at debug level
tracing::debug!(
tool_name = %tool_call.name,
request_id = %request.id,
permission = ?confirmation.permission,
"🔒 User decision for tool execution"
);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should be able to simplify this logging block I think and reduce the code duplication


if confirmation.permission == Permission::AllowOnce || confirmation.permission == Permission::AlwaysAllow {
let (req_id, tool_result) = self.dispatch_tool_call(tool_call.clone(), request.id.clone(), cancellation_token.clone()).await;
let mut futures = tool_futures.lock().await;
Expand Down
1 change: 1 addition & 0 deletions crates/goose/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub mod recipe_deeplink;
pub mod scheduler;
pub mod scheduler_factory;
pub mod scheduler_trait;
pub mod security;
pub mod session;
pub mod temporal_scheduler;
pub mod token_counter;
Expand Down
Loading