diff --git a/codex-rs/core/src/tools/handlers/mod.rs b/codex-rs/core/src/tools/handlers/mod.rs index a00c6eba517..dda4760bd79 100644 --- a/codex-rs/core/src/tools/handlers/mod.rs +++ b/codex-rs/core/src/tools/handlers/mod.rs @@ -27,6 +27,7 @@ pub use mcp_resource::McpResourceHandler; pub use plan::PlanHandler; pub use read_file::ReadFileHandler; pub use request_user_input::RequestUserInputHandler; +pub(crate) use request_user_input::request_user_input_tool_description; pub use shell::ShellCommandHandler; pub use shell::ShellHandler; pub use test_sync::TestSyncHandler; diff --git a/codex-rs/core/src/tools/handlers/request_user_input.rs b/codex-rs/core/src/tools/handlers/request_user_input.rs index 0164d974669..6d014755b56 100644 --- a/codex-rs/core/src/tools/handlers/request_user_input.rs +++ b/codex-rs/core/src/tools/handlers/request_user_input.rs @@ -10,6 +10,56 @@ use crate::tools::registry::ToolKind; use codex_protocol::config_types::ModeKind; use codex_protocol::request_user_input::RequestUserInputArgs; +const REQUEST_USER_INPUT_ALLOWED_MODES: [ModeKind; 1] = [ModeKind::Plan]; + +fn request_user_input_mode_name(mode: ModeKind) -> &'static str { + match mode { + ModeKind::Plan => "Plan", + ModeKind::Default => "Default", + ModeKind::Execute => "Execute", + ModeKind::PairProgramming => "Pair Programming", + } +} + +fn format_allowed_modes() -> String { + let mut mode_names = Vec::with_capacity(REQUEST_USER_INPUT_ALLOWED_MODES.len()); + for mode in REQUEST_USER_INPUT_ALLOWED_MODES { + let name = request_user_input_mode_name(mode); + if !mode_names.contains(&name) { + mode_names.push(name); + } + } + + match mode_names.as_slice() { + [] => "no modes".to_string(), + [mode] => format!("{mode} mode"), + [first, second] => format!("{first} or {second} mode"), + [..] => format!("modes: {}", mode_names.join(",")), + } +} + +fn request_user_input_is_available_in_mode(mode: ModeKind) -> bool { + REQUEST_USER_INPUT_ALLOWED_MODES.contains(&mode) +} + +pub(crate) fn request_user_input_unavailable_message(mode: ModeKind) -> Option { + if request_user_input_is_available_in_mode(mode) { + None + } else { + let mode_name = request_user_input_mode_name(mode); + Some(format!( + "request_user_input is unavailable in {mode_name} mode" + )) + } +} + +pub(crate) fn request_user_input_tool_description() -> String { + let allowed_modes = format_allowed_modes(); + format!( + "Request user input for one to three short questions and wait for the response. This tool is only available in {allowed_modes}." + ) +} + pub struct RequestUserInputHandler; #[async_trait] @@ -37,14 +87,8 @@ impl ToolHandler for RequestUserInputHandler { }; let mode = session.collaboration_mode().await.mode; - if !matches!(mode, ModeKind::Plan | ModeKind::PairProgramming) { - let mode_name = match mode { - ModeKind::Default | ModeKind::Execute => "Default", - ModeKind::Plan | ModeKind::PairProgramming => unreachable!(), - }; - return Err(FunctionCallError::RespondToModel(format!( - "request_user_input is unavailable in {mode_name} mode" - ))); + if let Some(message) = request_user_input_unavailable_message(mode) { + return Err(FunctionCallError::RespondToModel(message)); } let mut args: RequestUserInputArgs = parse_arguments(&arguments)?; @@ -82,3 +126,54 @@ impl ToolHandler for RequestUserInputHandler { }) } } + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + + #[test] + fn request_user_input_mode_availability_is_plan_only() { + assert_eq!( + request_user_input_is_available_in_mode(ModeKind::Plan), + true + ); + assert_eq!( + request_user_input_is_available_in_mode(ModeKind::Default), + false + ); + assert_eq!( + request_user_input_is_available_in_mode(ModeKind::Execute), + false + ); + assert_eq!( + request_user_input_is_available_in_mode(ModeKind::PairProgramming), + false + ); + } + + #[test] + fn request_user_input_unavailable_messages_use_default_name_for_default_modes() { + assert_eq!(request_user_input_unavailable_message(ModeKind::Plan), None); + assert_eq!( + request_user_input_unavailable_message(ModeKind::Default), + Some("request_user_input is unavailable in Default mode".to_string()) + ); + assert_eq!( + request_user_input_unavailable_message(ModeKind::Execute), + Some("request_user_input is unavailable in Execute mode".to_string()) + ); + assert_eq!( + request_user_input_unavailable_message(ModeKind::PairProgramming), + Some("request_user_input is unavailable in Pair Programming mode".to_string()) + ); + } + + #[test] + fn request_user_input_tool_description_mentions_plan_only() { + assert_eq!( + request_user_input_tool_description(), + "Request user input for one to three short questions and wait for the response. This tool is only available in Plan mode.".to_string() + ); + } +} diff --git a/codex-rs/core/src/tools/spec.rs b/codex-rs/core/src/tools/spec.rs index c0fd1d02213..8851a157a20 100644 --- a/codex-rs/core/src/tools/spec.rs +++ b/codex-rs/core/src/tools/spec.rs @@ -9,6 +9,7 @@ use crate::tools::handlers::apply_patch::create_apply_patch_json_tool; use crate::tools::handlers::collab::DEFAULT_WAIT_TIMEOUT_MS; use crate::tools::handlers::collab::MAX_WAIT_TIMEOUT_MS; use crate::tools::handlers::collab::MIN_WAIT_TIMEOUT_MS; +use crate::tools::handlers::request_user_input_tool_description; use crate::tools::registry::ToolRegistryBuilder; use codex_protocol::config_types::WebSearchMode; use codex_protocol::dynamic_tools::DynamicToolSpec; @@ -623,9 +624,7 @@ fn create_request_user_input_tool() -> ToolSpec { ToolSpec::Function(ResponsesApiTool { name: "request_user_input".to_string(), - description: - "Request user input for one to three short questions and wait for the response." - .to_string(), + description: request_user_input_tool_description(), strict: false, parameters: JsonSchema::Object { properties, diff --git a/codex-rs/core/templates/collaboration_mode/default.md b/codex-rs/core/templates/collaboration_mode/default.md index 90b80453fa1..c8154d10d99 100644 --- a/codex-rs/core/templates/collaboration_mode/default.md +++ b/codex-rs/core/templates/collaboration_mode/default.md @@ -1 +1,9 @@ -you are now in default mode. +# Collaboration Mode: Default + +You are now in Default mode. Any previous instructions for other modes (e.g. Plan mode) are no longer active. + +## request_user_input availability + +The `request_user_input` tool is unavailable in Default mode. If you call it while in Default mode, it will return an error. + +If a decision is necessary and cannot be discovered from local context, ask the user directly. However, in Default mode you should strongly prefer executing the user's request rather than stopping to ask questions. diff --git a/codex-rs/core/tests/suite/request_user_input.rs b/codex-rs/core/tests/suite/request_user_input.rs index b5e0396d8a1..6d8b8cb0354 100644 --- a/codex-rs/core/tests/suite/request_user_input.rs +++ b/codex-rs/core/tests/suite/request_user_input.rs @@ -74,11 +74,6 @@ async fn request_user_input_round_trip_resolves_pending() -> anyhow::Result<()> request_user_input_round_trip_for_mode(ModeKind::Plan).await } -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn request_user_input_round_trip_works_in_pair_mode() -> anyhow::Result<()> { - request_user_input_round_trip_for_mode(ModeKind::PairProgramming).await -} - async fn request_user_input_round_trip_for_mode(mode: ModeKind) -> anyhow::Result<()> { skip_if_no_network!(Ok(())); @@ -216,7 +211,7 @@ where .build(&server) .await?; - let mode_slug = mode_name.to_lowercase(); + let mode_slug = mode_name.to_lowercase().replace(' ', "-"); let call_id = format!("user-input-{mode_slug}-call"); let request_args = json!({ "questions": [{ @@ -283,7 +278,7 @@ where #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn request_user_input_rejected_in_execute_mode_alias() -> anyhow::Result<()> { - assert_request_user_input_rejected("Default", |model| CollaborationMode { + assert_request_user_input_rejected("Execute", |model| CollaborationMode { mode: ModeKind::Execute, settings: Settings { model, @@ -306,3 +301,16 @@ async fn request_user_input_rejected_in_default_mode() -> anyhow::Result<()> { }) .await } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn request_user_input_rejected_in_pair_mode_alias() -> anyhow::Result<()> { + assert_request_user_input_rejected("Pair Programming", |model| CollaborationMode { + mode: ModeKind::PairProgramming, + settings: Settings { + model, + reasoning_effort: None, + developer_instructions: None, + }, + }) + .await +}