Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1686,6 +1686,17 @@ impl Session {
.map(|task| Arc::clone(&task.turn_context))
}

async fn active_turn_context_and_cancellation_token(
&self,
) -> Option<(Arc<TurnContext>, CancellationToken)> {
let active = self.active_turn.lock().await;
let (_, task) = active.as_ref()?.tasks.first()?;
Some((
Arc::clone(&task.turn_context),
task.cancellation_token.child_token(),
))
}

pub(crate) async fn record_execpolicy_amendment_message(
&self,
sub_id: &str,
Expand Down Expand Up @@ -2716,7 +2727,9 @@ mod handlers {
use crate::tasks::CompactTask;
use crate::tasks::RegularTask;
use crate::tasks::UndoTask;
use crate::tasks::UserShellCommandMode;
use crate::tasks::UserShellCommandTask;
use crate::tasks::execute_user_shell_command;
use codex_protocol::custom_prompts::CustomPrompt;
use codex_protocol::protocol::CodexErrorInfo;
use codex_protocol::protocol::ErrorEvent;
Expand Down Expand Up @@ -2863,6 +2876,23 @@ mod handlers {
command: String,
previous_context: &mut Option<Arc<TurnContext>>,
) {
if let Some((turn_context, cancellation_token)) =
sess.active_turn_context_and_cancellation_token().await
{
let session = Arc::clone(sess);
tokio::spawn(async move {
execute_user_shell_command(
session,
turn_context,
command,
cancellation_token,
UserShellCommandMode::ActiveTurnAuxiliary,
)
.await;
});
return;
}

let turn_context = sess.new_default_turn_with_sub_id(sub_id).await;
sess.spawn_task(
Arc::clone(&turn_context),
Expand Down Expand Up @@ -4807,6 +4837,7 @@ mod tests {
use codex_app_server_protocol::AuthMode;
use codex_protocol::models::BaseInstructions;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseInputItem;
use codex_protocol::models::ResponseItem;
use std::path::Path;
use std::time::Duration;
Expand Down Expand Up @@ -6044,6 +6075,50 @@ mod tests {
assert!(rx.try_recv().is_err());
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn task_finish_persists_leftover_pending_input() {
let (sess, tc, _rx) = make_session_and_context_with_rx().await;
let input = vec![UserInput::Text {
text: "hello".to_string(),
text_elements: Vec::new(),
}];
sess.spawn_task(
Arc::clone(&tc),
input,
NeverEndingTask {
kind: TaskKind::Regular,
listen_to_cancellation_token: false,
},
)
.await;

sess.inject_response_items(vec![ResponseInputItem::Message {
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "late pending input".to_string(),
}],
}])
.await
.expect("inject pending input into active turn");

sess.on_task_finished(Arc::clone(&tc), None).await;

let history = sess.clone_history().await;
let expected = ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "late pending input".to_string(),
}],
end_turn: None,
phase: None,
};
assert!(
history.raw_items().iter().any(|item| item == &expected),
"expected pending input to be persisted into history on turn completion"
);
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn abort_review_task_emits_exited_then_aborted_and_records_history() {
let (sess, tc, rx) = make_session_and_context_with_rx().await;
Expand Down
25 changes: 20 additions & 5 deletions codex-rs/core/src/tasks/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use crate::state::ActiveTurn;
use crate::state::RunningTask;
use crate::state::TaskKind;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseInputItem;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::RolloutItem;
use codex_protocol::user_input::UserInput;
Expand All @@ -40,7 +41,9 @@ pub(crate) use ghost_snapshot::GhostSnapshotTask;
pub(crate) use regular::RegularTask;
pub(crate) use review::ReviewTask;
pub(crate) use undo::UndoTask;
pub(crate) use user_shell::UserShellCommandMode;
pub(crate) use user_shell::UserShellCommandTask;
pub(crate) use user_shell::execute_user_shell_command;

const GRACEFULL_INTERRUPTION_TIMEOUT_MS: u64 = 100;
const TURN_ABORTED_INTERRUPTED_GUIDANCE: &str = "The user interrupted the previous turn on purpose. If any tools/commands were aborted, they may have partially executed; verify current state before retrying.";
Expand Down Expand Up @@ -187,15 +190,27 @@ impl Session {
last_agent_message: Option<String>,
) {
let mut active = self.active_turn.lock().await;
let should_close_processes = if let Some(at) = active.as_mut()
let mut pending_input = Vec::<ResponseInputItem>::new();
let mut should_close_processes = false;
if let Some(at) = active.as_mut()
&& at.remove_task(&turn_context.sub_id)
{
let mut ts = at.turn_state.lock().await;
pending_input = ts.take_pending_input();
should_close_processes = true;
}
if should_close_processes {
*active = None;
true
} else {
false
};
}
drop(active);
if !pending_input.is_empty() {
let pending_response_items = pending_input
.into_iter()
.map(ResponseItem::from)
.collect::<Vec<_>>();
self.record_conversation_items(turn_context.as_ref(), &pending_response_items)
.await;
}
if should_close_processes {
self.close_unified_exec_processes().await;
}
Expand Down
Loading
Loading