diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 8567c9286aa..dc7d5902d65 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -1686,6 +1686,17 @@ impl Session { .map(|task| Arc::clone(&task.turn_context)) } + async fn active_turn_context_and_cancellation_token( + &self, + ) -> Option<(Arc, CancellationToken)> { + let active = self.active_turn.lock().await; + let (_, task) = active.as_ref()?.tasks.first()?; + Some(( + Arc::clone(&task.turn_context), + task.cancellation_token.child_token(), + )) + } + pub(crate) async fn record_execpolicy_amendment_message( &self, sub_id: &str, @@ -2716,7 +2727,9 @@ mod handlers { use crate::tasks::CompactTask; use crate::tasks::RegularTask; use crate::tasks::UndoTask; + use crate::tasks::UserShellCommandMode; use crate::tasks::UserShellCommandTask; + use crate::tasks::execute_user_shell_command; use codex_protocol::custom_prompts::CustomPrompt; use codex_protocol::protocol::CodexErrorInfo; use codex_protocol::protocol::ErrorEvent; @@ -2863,6 +2876,23 @@ mod handlers { command: String, previous_context: &mut Option>, ) { + if let Some((turn_context, cancellation_token)) = + sess.active_turn_context_and_cancellation_token().await + { + let session = Arc::clone(sess); + tokio::spawn(async move { + execute_user_shell_command( + session, + turn_context, + command, + cancellation_token, + UserShellCommandMode::ActiveTurnAuxiliary, + ) + .await; + }); + return; + } + let turn_context = sess.new_default_turn_with_sub_id(sub_id).await; sess.spawn_task( Arc::clone(&turn_context), @@ -4807,6 +4837,7 @@ mod tests { use codex_app_server_protocol::AuthMode; use codex_protocol::models::BaseInstructions; use codex_protocol::models::ContentItem; + use codex_protocol::models::ResponseInputItem; use codex_protocol::models::ResponseItem; use std::path::Path; use std::time::Duration; @@ -6044,6 +6075,50 @@ mod tests { assert!(rx.try_recv().is_err()); } + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn task_finish_persists_leftover_pending_input() { + let (sess, tc, _rx) = make_session_and_context_with_rx().await; + let input = vec![UserInput::Text { + text: "hello".to_string(), + text_elements: Vec::new(), + }]; + sess.spawn_task( + Arc::clone(&tc), + input, + NeverEndingTask { + kind: TaskKind::Regular, + listen_to_cancellation_token: false, + }, + ) + .await; + + sess.inject_response_items(vec![ResponseInputItem::Message { + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "late pending input".to_string(), + }], + }]) + .await + .expect("inject pending input into active turn"); + + sess.on_task_finished(Arc::clone(&tc), None).await; + + let history = sess.clone_history().await; + let expected = ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "late pending input".to_string(), + }], + end_turn: None, + phase: None, + }; + assert!( + history.raw_items().iter().any(|item| item == &expected), + "expected pending input to be persisted into history on turn completion" + ); + } + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn abort_review_task_emits_exited_then_aborted_and_records_history() { let (sess, tc, rx) = make_session_and_context_with_rx().await; diff --git a/codex-rs/core/src/tasks/mod.rs b/codex-rs/core/src/tasks/mod.rs index 0ab6e2f49b5..0821ba7bc23 100644 --- a/codex-rs/core/src/tasks/mod.rs +++ b/codex-rs/core/src/tasks/mod.rs @@ -31,6 +31,7 @@ use crate::state::ActiveTurn; use crate::state::RunningTask; use crate::state::TaskKind; use codex_protocol::models::ContentItem; +use codex_protocol::models::ResponseInputItem; use codex_protocol::models::ResponseItem; use codex_protocol::protocol::RolloutItem; use codex_protocol::user_input::UserInput; @@ -40,7 +41,9 @@ pub(crate) use ghost_snapshot::GhostSnapshotTask; pub(crate) use regular::RegularTask; pub(crate) use review::ReviewTask; pub(crate) use undo::UndoTask; +pub(crate) use user_shell::UserShellCommandMode; pub(crate) use user_shell::UserShellCommandTask; +pub(crate) use user_shell::execute_user_shell_command; const GRACEFULL_INTERRUPTION_TIMEOUT_MS: u64 = 100; const TURN_ABORTED_INTERRUPTED_GUIDANCE: &str = "The user interrupted the previous turn on purpose. If any tools/commands were aborted, they may have partially executed; verify current state before retrying."; @@ -187,15 +190,27 @@ impl Session { last_agent_message: Option, ) { let mut active = self.active_turn.lock().await; - let should_close_processes = if let Some(at) = active.as_mut() + let mut pending_input = Vec::::new(); + let mut should_close_processes = false; + if let Some(at) = active.as_mut() && at.remove_task(&turn_context.sub_id) { + let mut ts = at.turn_state.lock().await; + pending_input = ts.take_pending_input(); + should_close_processes = true; + } + if should_close_processes { *active = None; - true - } else { - false - }; + } drop(active); + if !pending_input.is_empty() { + let pending_response_items = pending_input + .into_iter() + .map(ResponseItem::from) + .collect::>(); + self.record_conversation_items(turn_context.as_ref(), &pending_response_items) + .await; + } if should_close_processes { self.close_unified_exec_processes().await; } diff --git a/codex-rs/core/src/tasks/user_shell.rs b/codex-rs/core/src/tasks/user_shell.rs index 6c37cb150e8..d626057d924 100644 --- a/codex-rs/core/src/tasks/user_shell.rs +++ b/codex-rs/core/src/tasks/user_shell.rs @@ -32,9 +32,22 @@ use crate::user_shell_command::user_shell_command_record_item; use super::SessionTask; use super::SessionTaskContext; +use crate::codex::Session; +use codex_protocol::models::ResponseInputItem; +use codex_protocol::models::ResponseItem; const USER_SHELL_TIMEOUT_MS: u64 = 60 * 60 * 1000; // 1 hour +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum UserShellCommandMode { + /// Executes as an independent turn lifecycle (emits TurnStarted/TurnComplete + /// via task lifecycle plumbing). + StandaloneTurn, + /// Executes while another turn is already active. This mode must not emit a + /// second TurnStarted/TurnComplete pair for the same active turn. + ActiveTurnAuxiliary, +} + #[derive(Clone)] pub(crate) struct UserShellCommandTask { command: String, @@ -59,198 +72,246 @@ impl SessionTask for UserShellCommandTask { _input: Vec, cancellation_token: CancellationToken, ) -> Option { - let _ = session - .session - .services - .otel_manager - .counter("codex.task.user_shell", 1, &[]); + execute_user_shell_command( + session.clone_session(), + turn_context, + self.command.clone(), + cancellation_token, + UserShellCommandMode::StandaloneTurn, + ) + .await; + None + } +} + +pub(crate) async fn execute_user_shell_command( + session: Arc, + turn_context: Arc, + command: String, + cancellation_token: CancellationToken, + mode: UserShellCommandMode, +) { + session + .services + .otel_manager + .counter("codex.task.user_shell", 1, &[]); + if mode == UserShellCommandMode::StandaloneTurn { + // Auxiliary mode runs within an existing active turn. That turn already + // emitted TurnStarted, so emitting another TurnStarted here would create + // duplicate turn lifecycle events and confuse clients. let event = EventMsg::TurnStarted(TurnStartedEvent { model_context_window: turn_context.model_context_window(), collaboration_mode_kind: turn_context.collaboration_mode.mode, }); - let session = session.clone_session(); session.send_event(turn_context.as_ref(), event).await; + } - // Execute the user's script under their default shell when known; this - // allows commands that use shell features (pipes, &&, redirects, etc.). - // We do not source rc files or otherwise reformat the script. - let use_login_shell = true; - let session_shell = session.user_shell(); - let display_command = session_shell.derive_exec_args(&self.command, use_login_shell); - let exec_command = - maybe_wrap_shell_lc_with_snapshot(&display_command, session_shell.as_ref()); + // Execute the user's script under their default shell when known; this + // allows commands that use shell features (pipes, &&, redirects, etc.). + // We do not source rc files or otherwise reformat the script. + let use_login_shell = true; + let session_shell = session.user_shell(); + let display_command = session_shell.derive_exec_args(&command, use_login_shell); + let exec_command = maybe_wrap_shell_lc_with_snapshot(&display_command, session_shell.as_ref()); - let call_id = Uuid::new_v4().to_string(); - let raw_command = self.command.clone(); - let cwd = turn_context.cwd.clone(); + let call_id = Uuid::new_v4().to_string(); + let raw_command = command; + let cwd = turn_context.cwd.clone(); - let parsed_cmd = parse_command(&display_command); - session - .send_event( + let parsed_cmd = parse_command(&display_command); + session + .send_event( + turn_context.as_ref(), + EventMsg::ExecCommandBegin(ExecCommandBeginEvent { + call_id: call_id.clone(), + process_id: None, + turn_id: turn_context.sub_id.clone(), + command: display_command.clone(), + cwd: cwd.clone(), + parsed_cmd: parsed_cmd.clone(), + source: ExecCommandSource::UserShell, + interaction_input: None, + }), + ) + .await; + + let exec_env = ExecEnv { + command: exec_command.clone(), + cwd: cwd.clone(), + env: create_env( + &turn_context.shell_environment_policy, + Some(session.conversation_id), + ), + // TODO(zhao-oai): Now that we have ExecExpiration::Cancellation, we + // should use that instead of an "arbitrarily large" timeout here. + expiration: USER_SHELL_TIMEOUT_MS.into(), + sandbox: SandboxType::None, + windows_sandbox_level: turn_context.windows_sandbox_level, + sandbox_permissions: SandboxPermissions::UseDefault, + justification: None, + arg0: None, + }; + + let stdout_stream = Some(StdoutStream { + sub_id: turn_context.sub_id.clone(), + call_id: call_id.clone(), + tx_event: session.get_tx_event(), + }); + + let sandbox_policy = SandboxPolicy::DangerFullAccess; + let exec_result = execute_exec_env(exec_env, &sandbox_policy, stdout_stream) + .or_cancel(&cancellation_token) + .await; + + match exec_result { + Err(CancelErr::Cancelled) => { + let aborted_message = "command aborted by user".to_string(); + let exec_output = ExecToolCallOutput { + exit_code: -1, + stdout: StreamOutput::new(String::new()), + stderr: StreamOutput::new(aborted_message.clone()), + aggregated_output: StreamOutput::new(aborted_message.clone()), + duration: Duration::ZERO, + timed_out: false, + }; + persist_user_shell_output( + &session, turn_context.as_ref(), - EventMsg::ExecCommandBegin(ExecCommandBeginEvent { - call_id: call_id.clone(), - process_id: None, - turn_id: turn_context.sub_id.clone(), - command: display_command.clone(), - cwd: cwd.clone(), - parsed_cmd: parsed_cmd.clone(), - source: ExecCommandSource::UserShell, - interaction_input: None, - }), + &raw_command, + &exec_output, + mode, ) .await; + session + .send_event( + turn_context.as_ref(), + EventMsg::ExecCommandEnd(ExecCommandEndEvent { + call_id, + process_id: None, + turn_id: turn_context.sub_id.clone(), + command: display_command.clone(), + cwd: cwd.clone(), + parsed_cmd: parsed_cmd.clone(), + source: ExecCommandSource::UserShell, + interaction_input: None, + stdout: String::new(), + stderr: aborted_message.clone(), + aggregated_output: aborted_message.clone(), + exit_code: -1, + duration: Duration::ZERO, + formatted_output: aborted_message, + }), + ) + .await; + } + Ok(Ok(output)) => { + session + .send_event( + turn_context.as_ref(), + EventMsg::ExecCommandEnd(ExecCommandEndEvent { + call_id: call_id.clone(), + process_id: None, + turn_id: turn_context.sub_id.clone(), + command: display_command.clone(), + cwd: cwd.clone(), + parsed_cmd: parsed_cmd.clone(), + source: ExecCommandSource::UserShell, + interaction_input: None, + stdout: output.stdout.text.clone(), + stderr: output.stderr.text.clone(), + aggregated_output: output.aggregated_output.text.clone(), + exit_code: output.exit_code, + duration: output.duration, + formatted_output: format_exec_output_str( + &output, + turn_context.truncation_policy, + ), + }), + ) + .await; - let exec_env = ExecEnv { - command: exec_command.clone(), - cwd: cwd.clone(), - env: create_env( - &turn_context.shell_environment_policy, - Some(session.conversation_id), - ), - // TODO(zhao-oai): Now that we have ExecExpiration::Cancellation, we - // should use that instead of an "arbitrarily large" timeout here. - expiration: USER_SHELL_TIMEOUT_MS.into(), - sandbox: SandboxType::None, - windows_sandbox_level: turn_context.windows_sandbox_level, - sandbox_permissions: SandboxPermissions::UseDefault, - justification: None, - arg0: None, - }; + persist_user_shell_output(&session, turn_context.as_ref(), &raw_command, &output, mode) + .await; + } + Ok(Err(err)) => { + error!("user shell command failed: {err:?}"); + let message = format!("execution error: {err:?}"); + let exec_output = ExecToolCallOutput { + exit_code: -1, + stdout: StreamOutput::new(String::new()), + stderr: StreamOutput::new(message.clone()), + aggregated_output: StreamOutput::new(message.clone()), + duration: Duration::ZERO, + timed_out: false, + }; + session + .send_event( + turn_context.as_ref(), + EventMsg::ExecCommandEnd(ExecCommandEndEvent { + call_id, + process_id: None, + turn_id: turn_context.sub_id.clone(), + command: display_command, + cwd, + parsed_cmd, + source: ExecCommandSource::UserShell, + interaction_input: None, + stdout: exec_output.stdout.text.clone(), + stderr: exec_output.stderr.text.clone(), + aggregated_output: exec_output.aggregated_output.text.clone(), + exit_code: exec_output.exit_code, + duration: exec_output.duration, + formatted_output: format_exec_output_str( + &exec_output, + turn_context.truncation_policy, + ), + }), + ) + .await; + persist_user_shell_output( + &session, + turn_context.as_ref(), + &raw_command, + &exec_output, + mode, + ) + .await; + } + } +} - let stdout_stream = Some(StdoutStream { - sub_id: turn_context.sub_id.clone(), - call_id: call_id.clone(), - tx_event: session.get_tx_event(), - }); +async fn persist_user_shell_output( + session: &Session, + turn_context: &TurnContext, + raw_command: &str, + exec_output: &ExecToolCallOutput, + mode: UserShellCommandMode, +) { + let output_item = user_shell_command_record_item(raw_command, exec_output, turn_context); - let sandbox_policy = SandboxPolicy::DangerFullAccess; - let exec_result = execute_exec_env(exec_env, &sandbox_policy, stdout_stream) - .or_cancel(&cancellation_token) + if mode == UserShellCommandMode::StandaloneTurn { + session + .record_conversation_items(turn_context, std::slice::from_ref(&output_item)) .await; + return; + } - match exec_result { - Err(CancelErr::Cancelled) => { - let aborted_message = "command aborted by user".to_string(); - let exec_output = ExecToolCallOutput { - exit_code: -1, - stdout: StreamOutput::new(String::new()), - stderr: StreamOutput::new(aborted_message.clone()), - aggregated_output: StreamOutput::new(aborted_message.clone()), - duration: Duration::ZERO, - timed_out: false, - }; - let output_items = [user_shell_command_record_item( - &raw_command, - &exec_output, - &turn_context, - )]; - session - .record_conversation_items(turn_context.as_ref(), &output_items) - .await; - session - .send_event( - turn_context.as_ref(), - EventMsg::ExecCommandEnd(ExecCommandEndEvent { - call_id, - process_id: None, - turn_id: turn_context.sub_id.clone(), - command: display_command.clone(), - cwd: cwd.clone(), - parsed_cmd: parsed_cmd.clone(), - source: ExecCommandSource::UserShell, - interaction_input: None, - stdout: String::new(), - stderr: aborted_message.clone(), - aggregated_output: aborted_message.clone(), - exit_code: -1, - duration: Duration::ZERO, - formatted_output: aborted_message, - }), - ) - .await; - } - Ok(Ok(output)) => { - session - .send_event( - turn_context.as_ref(), - EventMsg::ExecCommandEnd(ExecCommandEndEvent { - call_id: call_id.clone(), - process_id: None, - turn_id: turn_context.sub_id.clone(), - command: display_command.clone(), - cwd: cwd.clone(), - parsed_cmd: parsed_cmd.clone(), - source: ExecCommandSource::UserShell, - interaction_input: None, - stdout: output.stdout.text.clone(), - stderr: output.stderr.text.clone(), - aggregated_output: output.aggregated_output.text.clone(), - exit_code: output.exit_code, - duration: output.duration, - formatted_output: format_exec_output_str( - &output, - turn_context.truncation_policy, - ), - }), - ) - .await; + let response_input_item = match output_item { + ResponseItem::Message { role, content, .. } => ResponseInputItem::Message { role, content }, + _ => unreachable!("user shell command output record should always be a message"), + }; - let output_items = [user_shell_command_record_item( - &raw_command, - &output, - &turn_context, - )]; - session - .record_conversation_items(turn_context.as_ref(), &output_items) - .await; - } - Ok(Err(err)) => { - error!("user shell command failed: {err:?}"); - let message = format!("execution error: {err:?}"); - let exec_output = ExecToolCallOutput { - exit_code: -1, - stdout: StreamOutput::new(String::new()), - stderr: StreamOutput::new(message.clone()), - aggregated_output: StreamOutput::new(message.clone()), - duration: Duration::ZERO, - timed_out: false, - }; - session - .send_event( - turn_context.as_ref(), - EventMsg::ExecCommandEnd(ExecCommandEndEvent { - call_id, - process_id: None, - turn_id: turn_context.sub_id.clone(), - command: display_command, - cwd, - parsed_cmd, - source: ExecCommandSource::UserShell, - interaction_input: None, - stdout: exec_output.stdout.text.clone(), - stderr: exec_output.stderr.text.clone(), - aggregated_output: exec_output.aggregated_output.text.clone(), - exit_code: exec_output.exit_code, - duration: exec_output.duration, - formatted_output: format_exec_output_str( - &exec_output, - turn_context.truncation_policy, - ), - }), - ) - .await; - let output_items = [user_shell_command_record_item( - &raw_command, - &exec_output, - &turn_context, - )]; - session - .record_conversation_items(turn_context.as_ref(), &output_items) - .await; - } - } - None + if let Err(items) = session + .inject_response_items(vec![response_input_item]) + .await + { + let response_items = items + .into_iter() + .map(ResponseItem::from) + .collect::>(); + session + .record_conversation_items(turn_context, &response_items) + .await; } } diff --git a/codex-rs/core/tests/suite/user_shell_cmd.rs b/codex-rs/core/tests/suite/user_shell_cmd.rs index 45c91126d13..88a5caba4c2 100644 --- a/codex-rs/core/tests/suite/user_shell_cmd.rs +++ b/codex-rs/core/tests/suite/user_shell_cmd.rs @@ -1,5 +1,6 @@ use anyhow::Context; use codex_core::features::Feature; +use codex_core::protocol::AskForApproval; use codex_core::protocol::EventMsg; use codex_core::protocol::ExecCommandEndEvent; use codex_core::protocol::ExecCommandSource; @@ -7,6 +8,8 @@ use codex_core::protocol::ExecOutputStream; use codex_core::protocol::Op; use codex_core::protocol::SandboxPolicy; use codex_core::protocol::TurnAbortReason; +use codex_protocol::config_types::ReasoningSummary; +use codex_protocol::user_input::UserInput; use core_test_support::assert_regex_match; use core_test_support::responses; use core_test_support::responses::ev_assistant_message; @@ -23,6 +26,8 @@ use core_test_support::wait_for_event_match; use regex_lite::escape; use std::path::PathBuf; use tempfile::TempDir; +use tokio::time::Duration; +use tokio::time::timeout; #[tokio::test] async fn user_shell_cmd_ls_and_cat_in_temp_dir() { @@ -119,6 +124,115 @@ async fn user_shell_cmd_can_be_interrupted() { assert_eq!(ev.reason, TurnAbortReason::Interrupted); } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn user_shell_command_does_not_replace_active_turn() -> anyhow::Result<()> { + let server = start_mock_server().await; + let mut builder = test_codex().with_model("gpt-5.1"); + let fixture = builder.build(&server).await?; + + let call_id = "active-turn-shell-call"; + let args = if cfg!(windows) { + serde_json::json!({ + "command": "Start-Sleep -Seconds 2; Write-Output model-shell", + "timeout_ms": 10_000, + }) + } else { + serde_json::json!({ + "command": "sleep 2; echo model-shell", + "timeout_ms": 10_000, + }) + }; + let first = sse(vec![ + ev_response_created("resp-1"), + ev_function_call(call_id, "shell_command", &serde_json::to_string(&args)?), + ev_completed("resp-1"), + ]); + let second = sse(vec![ + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ]); + let mock = responses::mount_sse_sequence(&server, vec![first, second]).await; + + fixture + .codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: "run model shell command".to_string(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + cwd: fixture.cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: fixture.session_configured.model.clone(), + effort: None, + summary: ReasoningSummary::Auto, + collaboration_mode: None, + personality: None, + }) + .await?; + + let _ = wait_for_event_match(&fixture.codex, |ev| match ev { + EventMsg::ExecCommandBegin(event) if event.source == ExecCommandSource::Agent => { + Some(event.clone()) + } + _ => None, + }) + .await; + + #[cfg(windows)] + let user_shell_command = "Write-Output user-shell".to_string(); + #[cfg(not(windows))] + let user_shell_command = "printf user-shell".to_string(); + fixture + .codex + .submit(Op::RunUserShellCommand { + command: user_shell_command, + }) + .await?; + + let mut saw_replaced_abort = false; + let mut saw_user_shell_end = false; + let mut saw_turn_complete = false; + for _ in 0..200 { + let event = timeout(Duration::from_secs(20), fixture.codex.next_event()) + .await + .context("timed out waiting for event")? + .context("event stream ended unexpectedly")?; + match event.msg { + EventMsg::TurnAborted(ev) if ev.reason == TurnAbortReason::Replaced => { + saw_replaced_abort = true; + } + EventMsg::ExecCommandEnd(ev) if ev.source == ExecCommandSource::UserShell => { + saw_user_shell_end = true; + } + EventMsg::TurnComplete(_) => { + saw_turn_complete = true; + break; + } + _ => {} + } + } + + assert!(saw_turn_complete, "expected turn to complete"); + assert!( + saw_user_shell_end, + "expected user shell command to finish while turn was active" + ); + assert!( + !saw_replaced_abort, + "user shell command should not replace the active turn" + ); + + assert_eq!( + mock.requests().len(), + 2, + "active turn should continue and issue the follow-up model request" + ); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn user_shell_command_history_is_persisted_and_shared_with_model() -> anyhow::Result<()> { let server = responses::start_mock_server().await;