Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 24 additions & 22 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2511,17 +2511,17 @@ fn errors_to_info(errors: &[SkillError]) -> Vec<SkillErrorInfo> {
.collect()
}

/// Takes a user message as input and runs a loop where, at each turn, the model
/// Takes a user message as input and runs a loop where, at each sampling request, the model
/// replies with either:
///
/// - requested function calls
/// - an assistant message
///
/// While it is possible for the model to return multiple of these items in a
/// single turn, in practice, we generally one item per turn:
/// single sampling request, in practice, we generally one item per sampling request:
///
/// - If the model requests a function call, we execute it and send the output
/// back to the model in the next turn.
/// back to the model in the next sampling request.
/// - If the model sends only an assistant message, we record it in the
/// conversation history and consider the turn complete.
///
Expand Down Expand Up @@ -2594,35 +2594,35 @@ pub(crate) async fn run_turn(
.collect::<Vec<ResponseItem>>();

// Construct the input that we will send to the model.
let turn_input: Vec<ResponseItem> = {
let sampling_request_input: Vec<ResponseItem> = {
sess.record_conversation_items(&turn_context, &pending_input)
.await;
sess.clone_history().await.for_prompt()
};

let turn_input_messages = turn_input
let sampling_request_input_messages = sampling_request_input
.iter()
.filter_map(|item| match parse_turn_item(item) {
Some(TurnItem::UserMessage(user_message)) => Some(user_message),
_ => None,
})
.map(|user_message| user_message.message())
.collect::<Vec<String>>();
match run_model_turn(
match run_sampling_request(
Arc::clone(&sess),
Arc::clone(&turn_context),
Arc::clone(&turn_diff_tracker),
&mut client_session,
turn_input,
sampling_request_input,
cancellation_token.child_token(),
)
.await
{
Ok(turn_output) => {
let TurnRunResult {
Ok(sampling_request_output) => {
let SamplingRequestResult {
needs_follow_up,
last_agent_message: turn_last_agent_message,
} = turn_output;
last_agent_message: sampling_request_last_agent_message,
} = sampling_request_output;
let total_usage_tokens = sess.get_total_token_usage().await;
let token_limit_reached = total_usage_tokens >= auto_compact_limit;

Expand All @@ -2633,13 +2633,13 @@ pub(crate) async fn run_turn(
}

if !needs_follow_up {
last_agent_message = turn_last_agent_message;
last_agent_message = sampling_request_last_agent_message;
sess.notifier()
.notify(&UserNotification::AgentTurnComplete {
thread_id: sess.conversation_id.to_string(),
turn_id: turn_context.sub_id.clone(),
cwd: turn_context.cwd.display().to_string(),
input_messages: turn_input_messages,
input_messages: sampling_request_input_messages,
last_assistant_message: last_agent_message.clone(),
});
break;
Expand Down Expand Up @@ -2695,14 +2695,14 @@ async fn run_auto_compact(sess: &Arc<Session>, turn_context: &Arc<TurnContext>)
cwd = %turn_context.cwd.display()
)
)]
async fn run_model_turn(
async fn run_sampling_request(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
turn_diff_tracker: SharedTurnDiffTracker,
client_session: &mut ModelClientSession,
input: Vec<ResponseItem>,
cancellation_token: CancellationToken,
) -> CodexResult<TurnRunResult> {
) -> CodexResult<SamplingRequestResult> {
let mcp_tools = sess
.services
.mcp_connection_manager
Expand Down Expand Up @@ -2736,7 +2736,7 @@ async fn run_model_turn(

let mut retries = 0;
loop {
let err = match try_run_turn(
let err = match try_run_sampling_request(
Arc::clone(&router),
Arc::clone(&sess),
Arc::clone(&turn_context),
Expand Down Expand Up @@ -2776,7 +2776,9 @@ async fn run_model_turn(
}
_ => backoff(retries),
};
warn!("stream disconnected - retrying turn ({retries}/{max_retries} in {delay:?})...",);
warn!(
"stream disconnected - retrying sampling request ({retries}/{max_retries} in {delay:?})...",
);

// Surface retry information to any UI/front‑end so the
// user understands what is happening instead of staring
Expand All @@ -2796,7 +2798,7 @@ async fn run_model_turn(
}

#[derive(Debug)]
struct TurnRunResult {
struct SamplingRequestResult {
needs_follow_up: bool,
last_agent_message: Option<String>,
}
Expand Down Expand Up @@ -2828,15 +2830,15 @@ async fn drain_in_flight(
model = %turn_context.client.get_model()
)
)]
async fn try_run_turn(
async fn try_run_sampling_request(
router: Arc<ToolRouter>,
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
client_session: &mut ModelClientSession,
turn_diff_tracker: SharedTurnDiffTracker,
prompt: &Prompt,
cancellation_token: CancellationToken,
) -> CodexResult<TurnRunResult> {
) -> CodexResult<SamplingRequestResult> {
let rollout_item = RolloutItem::TurnContext(TurnContextItem {
cwd: turn_context.cwd.clone(),
approval_policy: turn_context.approval_policy,
Expand Down Expand Up @@ -2880,7 +2882,7 @@ async fn try_run_turn(
let mut active_item: Option<TurnItem> = None;
let mut should_emit_turn_diff = false;
let receiving_span = trace_span!("receiving_stream");
let outcome: CodexResult<TurnRunResult> = loop {
let outcome: CodexResult<SamplingRequestResult> = loop {
let handle_responses = trace_span!(
parent: &receiving_span,
"handle_responses",
Expand Down Expand Up @@ -2966,7 +2968,7 @@ async fn try_run_turn(

needs_follow_up |= sess.has_pending_input().await;

break Ok(TurnRunResult {
break Ok(SamplingRequestResult {
needs_follow_up,
last_agent_message,
});
Expand Down
Loading