Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 75 additions & 18 deletions codex-rs/core/src/mcp_tool_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ use codex_protocol::mcp::CallToolResult;
use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::ResponseInputItem;
use codex_protocol::protocol::AskForApproval;
use codex_protocol::protocol::ReviewDecision;
use codex_protocol::protocol::SandboxPolicy;
use codex_protocol::request_user_input::RequestUserInputArgs;
use codex_protocol::request_user_input::RequestUserInputQuestion;
use codex_protocol::request_user_input::RequestUserInputQuestionOption;
use codex_protocol::request_user_input::RequestUserInputResponse;
use rmcp::model::ToolAnnotations;
use serde::Serialize;
use std::sync::Arc;

/// Handles the specified tool call dispatches the appropriate
Expand Down Expand Up @@ -64,7 +66,7 @@ pub(crate) async fn handle_mcp_tool_call(
.await
{
let result = match decision {
McpToolApprovalDecision::Accept => {
McpToolApprovalDecision::Accept | McpToolApprovalDecision::AcceptAndRemember => {
let tool_call_begin_event = EventMsg::McpToolCallBegin(McpToolCallBeginEvent {
call_id: call_id.clone(),
invocation: invocation.clone(),
Expand Down Expand Up @@ -167,21 +169,31 @@ async fn notify_mcp_tool_call_event(sess: &Session, turn_context: &TurnContext,
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum McpToolApprovalDecision {
Accept,
AcceptAndRemember,
Decline,
Cancel,
}

struct McpToolApprovalMetadata {
annotations: ToolAnnotations,
connector_id: Option<String>,
connector_name: Option<String>,
tool_title: Option<String>,
}

const MCP_TOOL_APPROVAL_QUESTION_ID_PREFIX: &str = "mcp_tool_call_approval";
const MCP_TOOL_APPROVAL_ACCEPT: &str = "Accept";
const MCP_TOOL_APPROVAL_DECLINE: &str = "Decline";
const MCP_TOOL_APPROVAL_ACCEPT: &str = "Approve Once";
const MCP_TOOL_APPROVAL_ACCEPT_AND_REMEMBER: &str = "Approve this Session";
const MCP_TOOL_APPROVAL_DECLINE: &str = "Deny";
const MCP_TOOL_APPROVAL_CANCEL: &str = "Cancel";

#[derive(Debug, Serialize)]
struct McpToolApprovalKey {
server: String,
connector_id: String,
tool_name: String,
}

async fn maybe_request_mcp_tool_approval(
sess: &Session,
turn_context: &TurnContext,
Expand All @@ -200,6 +212,19 @@ async fn maybe_request_mcp_tool_approval(
if !requires_mcp_tool_approval(&metadata.annotations) {
return None;
}
let approval_key = metadata
.connector_id
.as_deref()
.map(|connector_id| McpToolApprovalKey {
server: server.to_string(),
connector_id: connector_id.to_string(),
tool_name: tool_name.to_string(),
});
if let Some(key) = approval_key.as_ref()
&& mcp_tool_approval_is_remembered(sess, key).await
{
return Some(McpToolApprovalDecision::Accept);
}

let question_id = format!("{MCP_TOOL_APPROVAL_QUESTION_ID_PREFIX}_{call_id}");
let question = build_mcp_tool_approval_question(
Expand All @@ -208,14 +233,21 @@ async fn maybe_request_mcp_tool_approval(
metadata.tool_title.as_deref(),
metadata.connector_name.as_deref(),
&metadata.annotations,
approval_key.is_some(),
);
let args = RequestUserInputArgs {
questions: vec![question],
};
let response = sess
.request_user_input(turn_context, call_id.to_string(), args)
.await;
Some(parse_mcp_tool_approval_response(response, &question_id))
let decision = parse_mcp_tool_approval_response(response, &question_id);
if matches!(decision, McpToolApprovalDecision::AcceptAndRemember)
&& let Some(key) = approval_key
{
remember_mcp_tool_approval(sess, key).await;
}
Some(decision)
}

fn is_full_access_mode(turn_context: &TurnContext) -> bool {
Expand Down Expand Up @@ -246,6 +278,7 @@ async fn lookup_mcp_tool_metadata(
.annotations
.map(|annotations| McpToolApprovalMetadata {
annotations,
connector_id: tool_info.connector_id,
connector_name: tool_info.connector_name,
tool_title: tool_info.tool.title,
})
Expand All @@ -261,6 +294,7 @@ fn build_mcp_tool_approval_question(
tool_title: Option<&str>,
connector_name: Option<&str>,
annotations: &ToolAnnotations,
allow_remember_option: bool,
) -> RequestUserInputQuestion {
let destructive = annotations.destructive_hint == Some(true);
let open_world = annotations.open_world_hint == Some(true);
Expand All @@ -279,26 +313,34 @@ fn build_mcp_tool_approval_question(
"{app_label} wants to run the tool \"{tool_label}\", which {reason}. Allow this action?"
);

let mut options = vec![RequestUserInputQuestionOption {
label: MCP_TOOL_APPROVAL_ACCEPT.to_string(),
description: "Run the tool and continue.".to_string(),
}];
if allow_remember_option {
options.push(RequestUserInputQuestionOption {
label: MCP_TOOL_APPROVAL_ACCEPT_AND_REMEMBER.to_string(),
description: "Run the tool and remember this choice for this session.".to_string(),
});
}
options.extend([
RequestUserInputQuestionOption {
label: MCP_TOOL_APPROVAL_DECLINE.to_string(),
description: "Decline this tool call and continue.".to_string(),
},
RequestUserInputQuestionOption {
label: MCP_TOOL_APPROVAL_CANCEL.to_string(),
description: "Cancel this tool call".to_string(),
},
]);

RequestUserInputQuestion {
id: question_id,
header: "Approve app tool call?".to_string(),
question,
is_other: false,
is_secret: false,
options: Some(vec![
RequestUserInputQuestionOption {
label: MCP_TOOL_APPROVAL_ACCEPT.to_string(),
description: "Run the tool and continue.".to_string(),
},
RequestUserInputQuestionOption {
label: MCP_TOOL_APPROVAL_DECLINE.to_string(),
description: "Decline this tool call and continue.".to_string(),
},
RequestUserInputQuestionOption {
label: MCP_TOOL_APPROVAL_CANCEL.to_string(),
description: "Cancel this tool call".to_string(),
},
]),
options: Some(options),
}
}

Expand All @@ -317,6 +359,11 @@ fn parse_mcp_tool_approval_response(
return McpToolApprovalDecision::Cancel;
};
if answers
.iter()
.any(|answer| answer == MCP_TOOL_APPROVAL_ACCEPT_AND_REMEMBER)
{
McpToolApprovalDecision::AcceptAndRemember
} else if answers
.iter()
.any(|answer| answer == MCP_TOOL_APPROVAL_ACCEPT)
{
Expand All @@ -331,6 +378,16 @@ fn parse_mcp_tool_approval_response(
}
}

async fn mcp_tool_approval_is_remembered(sess: &Session, key: &McpToolApprovalKey) -> bool {
let store = sess.services.tool_approvals.lock().await;
matches!(store.get(key), Some(ReviewDecision::ApprovedForSession))
}

async fn remember_mcp_tool_approval(sess: &Session, key: McpToolApprovalKey) {
let mut store = sess.services.tool_approvals.lock().await;
store.put(key, ReviewDecision::ApprovedForSession);
}

fn requires_mcp_tool_approval(annotations: &ToolAnnotations) -> bool {
annotations.read_only_hint == Some(false)
&& (annotations.destructive_hint == Some(true) || annotations.open_world_hint == Some(true))
Expand Down
Loading