Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 7 additions & 54 deletions codex-rs/core/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::OnceLock;
use std::sync::RwLock;

use crate::api_bridge::CoreAuthProvider;
use crate::api_bridge::auth_provider_from_auth;
use crate::api_bridge::map_api_error;
use crate::auth::UnauthorizedRecovery;
use crate::turn_metadata::build_turn_metadata_header;
use codex_api::CompactClient as ApiCompactClient;
use codex_api::CompactionInput as ApiCompactionInput;
use codex_api::Prompt as ApiPrompt;
Expand Down Expand Up @@ -73,12 +70,6 @@ pub const WEB_SEARCH_ELIGIBLE_HEADER: &str = "x-oai-web-search-eligible";
pub const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
pub const X_CODEX_TURN_METADATA_HEADER: &str = "x-codex-turn-metadata";

#[derive(Debug, Default)]
struct TurnMetadataCache {
cwd: Option<PathBuf>,
header: Option<HeaderValue>,
}

#[derive(Debug)]
struct ModelClientState {
config: Arc<Config>,
Expand All @@ -91,7 +82,6 @@ struct ModelClientState {
summary: ReasoningSummaryConfig,
session_source: SessionSource,
transport_manager: TransportManager,
turn_metadata_cache: Arc<RwLock<TurnMetadataCache>>,
}

#[derive(Debug, Clone)]
Expand All @@ -104,6 +94,7 @@ pub struct ModelClientSession {
connection: Option<ApiWebSocketConnection>,
websocket_last_items: Vec<ResponseItem>,
transport_manager: TransportManager,
turn_metadata_header: Option<String>,
/// Turn state for sticky routing.
///
/// This is an `OnceLock` that stores the turn state value received from the server
Expand Down Expand Up @@ -143,53 +134,20 @@ impl ModelClient {
summary,
session_source,
transport_manager,
turn_metadata_cache: Arc::new(RwLock::new(TurnMetadataCache::default())),
}),
}
}

pub fn new_session(&self, turn_metadata_cwd: Option<PathBuf>) -> ModelClientSession {
self.prewarm_turn_metadata_header(turn_metadata_cwd);
pub fn new_session(&self, turn_metadata_header: Option<String>) -> ModelClientSession {
ModelClientSession {
state: Arc::clone(&self.state),
connection: None,
websocket_last_items: Vec::new(),
transport_manager: self.state.transport_manager.clone(),
turn_metadata_header,
turn_state: Arc::new(OnceLock::new()),
}
}

/// Refresh turn metadata in the background and update a cached header that request
/// builders can read without blocking.
fn prewarm_turn_metadata_header(&self, turn_metadata_cwd: Option<PathBuf>) {
let turn_metadata_cwd =
turn_metadata_cwd.map(|cwd| std::fs::canonicalize(&cwd).unwrap_or(cwd));

if let Ok(mut cache) = self.state.turn_metadata_cache.write()
&& cache.cwd != turn_metadata_cwd
{
cache.cwd = turn_metadata_cwd.clone();
cache.header = None;
}

let Some(cwd) = turn_metadata_cwd else {
return;
};
let turn_metadata_cache = Arc::clone(&self.state.turn_metadata_cache);
if let Ok(handle) = tokio::runtime::Handle::try_current() {
let _task = handle.spawn(async move {
let header = build_turn_metadata_header(cwd.as_path())
.await
.and_then(|value| HeaderValue::from_str(value.as_str()).ok());

if let Ok(mut cache) = turn_metadata_cache.write()
&& cache.cwd.as_ref() == Some(&cwd)
{
cache.header = header;
}
});
}
}
}

impl ModelClient {
Expand Down Expand Up @@ -298,14 +256,6 @@ impl ModelClient {
}

impl ModelClientSession {
fn turn_metadata_header(&self) -> Option<HeaderValue> {
self.state
.turn_metadata_cache
.try_read()
.ok()
.and_then(|cache| cache.header.clone())
}

/// Streams a single model turn using the configured Responses transport.
pub async fn stream(&mut self, prompt: &Prompt) -> Result<ResponseStream> {
let wire_api = self.state.provider.wire_api;
Expand Down Expand Up @@ -362,7 +312,10 @@ impl ModelClientSession {
prompt: &Prompt,
compression: Compression,
) -> ApiResponsesOptions {
let turn_metadata_header = self.turn_metadata_header();
let turn_metadata_header = self
.turn_metadata_header
.as_deref()
.and_then(|value| HeaderValue::from_str(value).ok());
let model_info = &self.state.model_info;

let default_reasoning_effort = model_info.default_reasoning_level;
Expand Down
51 changes: 46 additions & 5 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use crate::stream_events_utils::last_assistant_message_from_item;
use crate::terminal;
use crate::transport_manager::TransportManager;
use crate::truncate::TruncationPolicy;
use crate::turn_metadata::build_turn_metadata_header;
use crate::user_notification::UserNotifier;
use crate::util::error_or_panic;
use async_channel::Receiver;
Expand Down Expand Up @@ -80,6 +81,7 @@ use rmcp::model::RequestId;
use serde_json;
use serde_json::Value;
use tokio::sync::Mutex;
use tokio::sync::OnceCell;
use tokio::sync::RwLock;
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
Expand All @@ -90,6 +92,7 @@ use tracing::field;
use tracing::info;
use tracing::info_span;
use tracing::instrument;
use tracing::trace;
use tracing::trace_span;
use tracing::warn;

Expand Down Expand Up @@ -501,6 +504,7 @@ pub(crate) struct TurnContext {
pub(crate) tool_call_gate: Arc<ReadinessFlag>,
pub(crate) truncation_policy: TruncationPolicy,
pub(crate) dynamic_tools: Vec<DynamicToolSpec>,
turn_metadata_header: OnceCell<Option<String>>,
}
impl TurnContext {
pub(crate) fn resolve_path(&self, path: Option<String>) -> PathBuf {
Expand All @@ -514,6 +518,38 @@ impl TurnContext {
.as_deref()
.unwrap_or(compact::SUMMARIZATION_PROMPT)
}

async fn build_turn_metadata_header(&self) -> Option<String> {
self.turn_metadata_header
.get_or_init(|| async { build_turn_metadata_header(self.cwd.as_path()).await })
.await
.clone()
}

pub async fn resolve_turn_metadata_header(&self) -> Option<String> {
const TURN_METADATA_HEADER_TIMEOUT_MS: u64 = 250;
match tokio::time::timeout(
std::time::Duration::from_millis(TURN_METADATA_HEADER_TIMEOUT_MS),
self.build_turn_metadata_header(),
)
.await
{
Ok(header) => header,
Err(_) => {
warn!("timed out after 250ms while building turn metadata header");
self.turn_metadata_header.get().cloned().flatten()
}
}
}

pub fn spawn_turn_metadata_header_task(self: &Arc<Self>) {
let context = Arc::clone(self);
tokio::spawn(async move {
trace!("Spawning turn metadata calculation task");
context.build_turn_metadata_header().await;
trace!("Turn metadata calculation task completed");
});
}
}

#[derive(Clone)]
Expand Down Expand Up @@ -682,10 +718,11 @@ impl Session {
web_search_mode: per_turn_config.web_search_mode,
});

let cwd = session_configuration.cwd.clone();
TurnContext {
sub_id,
client,
cwd: session_configuration.cwd.clone(),
cwd,
developer_instructions: session_configuration.developer_instructions.clone(),
compact_prompt: session_configuration.compact_prompt.clone(),
user_instructions: session_configuration.user_instructions.clone(),
Expand All @@ -702,6 +739,7 @@ impl Session {
tool_call_gate: Arc::new(ReadinessFlag::new()),
truncation_policy: model_info.truncation_policy.into(),
dynamic_tools: session_configuration.dynamic_tools.clone(),
turn_metadata_header: OnceCell::new(),
}
}

Expand Down Expand Up @@ -1246,10 +1284,13 @@ impl Session {
sub_id,
self.services.transport_manager.clone(),
);

if let Some(final_schema) = final_output_json_schema {
turn_context.final_output_json_schema = final_schema;
}
Arc::new(turn_context)
let turn_context = Arc::new(turn_context);
turn_context.spawn_turn_metadata_header_task();
turn_context
}

pub(crate) async fn new_default_turn(&self) -> Arc<TurnContext> {
Expand Down Expand Up @@ -3274,6 +3315,7 @@ async fn spawn_review_thread(
tool_call_gate: Arc::new(ReadinessFlag::new()),
dynamic_tools: parent_turn_context.dynamic_tools.clone(),
truncation_policy: model_info.truncation_policy.into(),
turn_metadata_header: parent_turn_context.turn_metadata_header.clone(),
};

// Seed the child task with the review prompt as the initial user message.
Expand Down Expand Up @@ -3478,9 +3520,8 @@ pub(crate) async fn run_turn(
// many turns, from the perspective of the user, it is a single turn.
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));

let mut client_session = turn_context
.client
.new_session(Some(turn_context.cwd.clone()));
let turn_metadata_header = turn_context.resolve_turn_metadata_header().await;
let mut client_session = turn_context.client.new_session(turn_metadata_header);

loop {
// Note that pending_input would be something like a message the user
Expand Down
5 changes: 2 additions & 3 deletions codex-rs/core/src/compact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,8 @@ async fn drain_to_completed(
turn_context: &TurnContext,
prompt: &Prompt,
) -> CodexResult<()> {
let mut client_session = turn_context
.client
.new_session(Some(turn_context.cwd.clone()));
let turn_metadata_header = turn_context.resolve_turn_metadata_header().await;
let mut client_session = turn_context.client.new_session(turn_metadata_header);
let mut stream = client_session.stream(prompt).await?;
loop {
let maybe_event = stream.next().await;
Expand Down
1 change: 1 addition & 0 deletions codex-rs/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ pub use exec_policy::check_execpolicy_for_warnings;
pub use exec_policy::load_exec_policy;
pub use safety::get_platform_sandbox;
pub use tools::spec::parse_tool_input_schema;
pub use turn_metadata::build_turn_metadata_header;
// Re-export the protocol types from the standalone `codex-protocol` crate so existing
// `codex_core::protocol::...` references continue to work across the workspace.
pub use codex_protocol::protocol;
Expand Down
2 changes: 1 addition & 1 deletion codex-rs/core/src/turn_metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct TurnMetadata {
workspaces: BTreeMap<String, TurnMetadataWorkspace>,
}

pub(crate) async fn build_turn_metadata_header(cwd: &Path) -> Option<String> {
pub async fn build_turn_metadata_header(cwd: &Path) -> Option<String> {
let repo_root = get_git_repo_root(cwd)?;

let (latest_git_commit_hash, associated_remote_urls) = tokio::join!(
Expand Down
Loading
Loading