diff --git a/codex-rs/core/src/mcp_tool_call.rs b/codex-rs/core/src/mcp_tool_call.rs index 75248f34cc1..33694a3e72f 100644 --- a/codex-rs/core/src/mcp_tool_call.rs +++ b/codex-rs/core/src/mcp_tool_call.rs @@ -14,12 +14,14 @@ use codex_protocol::mcp::CallToolResult; use codex_protocol::models::FunctionCallOutputPayload; use codex_protocol::models::ResponseInputItem; use codex_protocol::protocol::AskForApproval; +use codex_protocol::protocol::ReviewDecision; use codex_protocol::protocol::SandboxPolicy; use codex_protocol::request_user_input::RequestUserInputArgs; use codex_protocol::request_user_input::RequestUserInputQuestion; use codex_protocol::request_user_input::RequestUserInputQuestionOption; use codex_protocol::request_user_input::RequestUserInputResponse; use rmcp::model::ToolAnnotations; +use serde::Serialize; use std::sync::Arc; /// Handles the specified tool call dispatches the appropriate @@ -64,7 +66,7 @@ pub(crate) async fn handle_mcp_tool_call( .await { let result = match decision { - McpToolApprovalDecision::Accept => { + McpToolApprovalDecision::Accept | McpToolApprovalDecision::AcceptAndRemember => { let tool_call_begin_event = EventMsg::McpToolCallBegin(McpToolCallBeginEvent { call_id: call_id.clone(), invocation: invocation.clone(), @@ -167,21 +169,31 @@ async fn notify_mcp_tool_call_event(sess: &Session, turn_context: &TurnContext, #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum McpToolApprovalDecision { Accept, + AcceptAndRemember, Decline, Cancel, } struct McpToolApprovalMetadata { annotations: ToolAnnotations, + connector_id: Option, connector_name: Option, tool_title: Option, } const MCP_TOOL_APPROVAL_QUESTION_ID_PREFIX: &str = "mcp_tool_call_approval"; -const MCP_TOOL_APPROVAL_ACCEPT: &str = "Accept"; -const MCP_TOOL_APPROVAL_DECLINE: &str = "Decline"; +const MCP_TOOL_APPROVAL_ACCEPT: &str = "Approve Once"; +const MCP_TOOL_APPROVAL_ACCEPT_AND_REMEMBER: &str = "Approve this Session"; +const MCP_TOOL_APPROVAL_DECLINE: &str = "Deny"; const MCP_TOOL_APPROVAL_CANCEL: &str = "Cancel"; +#[derive(Debug, Serialize)] +struct McpToolApprovalKey { + server: String, + connector_id: String, + tool_name: String, +} + async fn maybe_request_mcp_tool_approval( sess: &Session, turn_context: &TurnContext, @@ -200,6 +212,19 @@ async fn maybe_request_mcp_tool_approval( if !requires_mcp_tool_approval(&metadata.annotations) { return None; } + let approval_key = metadata + .connector_id + .as_deref() + .map(|connector_id| McpToolApprovalKey { + server: server.to_string(), + connector_id: connector_id.to_string(), + tool_name: tool_name.to_string(), + }); + if let Some(key) = approval_key.as_ref() + && mcp_tool_approval_is_remembered(sess, key).await + { + return Some(McpToolApprovalDecision::Accept); + } let question_id = format!("{MCP_TOOL_APPROVAL_QUESTION_ID_PREFIX}_{call_id}"); let question = build_mcp_tool_approval_question( @@ -208,6 +233,7 @@ async fn maybe_request_mcp_tool_approval( metadata.tool_title.as_deref(), metadata.connector_name.as_deref(), &metadata.annotations, + approval_key.is_some(), ); let args = RequestUserInputArgs { questions: vec![question], @@ -215,7 +241,13 @@ async fn maybe_request_mcp_tool_approval( let response = sess .request_user_input(turn_context, call_id.to_string(), args) .await; - Some(parse_mcp_tool_approval_response(response, &question_id)) + let decision = parse_mcp_tool_approval_response(response, &question_id); + if matches!(decision, McpToolApprovalDecision::AcceptAndRemember) + && let Some(key) = approval_key + { + remember_mcp_tool_approval(sess, key).await; + } + Some(decision) } fn is_full_access_mode(turn_context: &TurnContext) -> bool { @@ -246,6 +278,7 @@ async fn lookup_mcp_tool_metadata( .annotations .map(|annotations| McpToolApprovalMetadata { annotations, + connector_id: tool_info.connector_id, connector_name: tool_info.connector_name, tool_title: tool_info.tool.title, }) @@ -261,6 +294,7 @@ fn build_mcp_tool_approval_question( tool_title: Option<&str>, connector_name: Option<&str>, annotations: &ToolAnnotations, + allow_remember_option: bool, ) -> RequestUserInputQuestion { let destructive = annotations.destructive_hint == Some(true); let open_world = annotations.open_world_hint == Some(true); @@ -279,26 +313,34 @@ fn build_mcp_tool_approval_question( "{app_label} wants to run the tool \"{tool_label}\", which {reason}. Allow this action?" ); + let mut options = vec![RequestUserInputQuestionOption { + label: MCP_TOOL_APPROVAL_ACCEPT.to_string(), + description: "Run the tool and continue.".to_string(), + }]; + if allow_remember_option { + options.push(RequestUserInputQuestionOption { + label: MCP_TOOL_APPROVAL_ACCEPT_AND_REMEMBER.to_string(), + description: "Run the tool and remember this choice for this session.".to_string(), + }); + } + options.extend([ + RequestUserInputQuestionOption { + label: MCP_TOOL_APPROVAL_DECLINE.to_string(), + description: "Decline this tool call and continue.".to_string(), + }, + RequestUserInputQuestionOption { + label: MCP_TOOL_APPROVAL_CANCEL.to_string(), + description: "Cancel this tool call".to_string(), + }, + ]); + RequestUserInputQuestion { id: question_id, header: "Approve app tool call?".to_string(), question, is_other: false, is_secret: false, - options: Some(vec![ - RequestUserInputQuestionOption { - label: MCP_TOOL_APPROVAL_ACCEPT.to_string(), - description: "Run the tool and continue.".to_string(), - }, - RequestUserInputQuestionOption { - label: MCP_TOOL_APPROVAL_DECLINE.to_string(), - description: "Decline this tool call and continue.".to_string(), - }, - RequestUserInputQuestionOption { - label: MCP_TOOL_APPROVAL_CANCEL.to_string(), - description: "Cancel this tool call".to_string(), - }, - ]), + options: Some(options), } } @@ -317,6 +359,11 @@ fn parse_mcp_tool_approval_response( return McpToolApprovalDecision::Cancel; }; if answers + .iter() + .any(|answer| answer == MCP_TOOL_APPROVAL_ACCEPT_AND_REMEMBER) + { + McpToolApprovalDecision::AcceptAndRemember + } else if answers .iter() .any(|answer| answer == MCP_TOOL_APPROVAL_ACCEPT) { @@ -331,6 +378,16 @@ fn parse_mcp_tool_approval_response( } } +async fn mcp_tool_approval_is_remembered(sess: &Session, key: &McpToolApprovalKey) -> bool { + let store = sess.services.tool_approvals.lock().await; + matches!(store.get(key), Some(ReviewDecision::ApprovedForSession)) +} + +async fn remember_mcp_tool_approval(sess: &Session, key: McpToolApprovalKey) { + let mut store = sess.services.tool_approvals.lock().await; + store.put(key, ReviewDecision::ApprovedForSession); +} + fn requires_mcp_tool_approval(annotations: &ToolAnnotations) -> bool { annotations.read_only_hint == Some(false) && (annotations.destructive_hint == Some(true) || annotations.open_world_hint == Some(true))