Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion codex-rs/core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::WireApi;
use crate::tools::spec::create_tools_json_for_chat_completions_api;
use crate::tools::spec::create_tools_json_for_responses_api;
use crate::transport_manager::TransportManager;

pub const WEB_SEARCH_ELIGIBLE_HEADER: &str = "x-oai-web-search-eligible";
pub const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
Expand All @@ -80,6 +81,7 @@ struct ModelClientState {
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
session_source: SessionSource,
transport_manager: TransportManager,
}

#[derive(Debug, Clone)]
Expand All @@ -91,6 +93,7 @@ pub struct ModelClientSession {
state: Arc<ModelClientState>,
connection: Option<ApiWebSocketConnection>,
websocket_last_items: Vec<ResponseItem>,
transport_manager: TransportManager,
/// Turn state for sticky routing.
///
/// This is an `OnceLock` that stores the turn state value received from the server
Expand All @@ -116,6 +119,7 @@ impl ModelClient {
summary: ReasoningSummaryConfig,
conversation_id: ThreadId,
session_source: SessionSource,
transport_manager: TransportManager,
) -> Self {
Self {
state: Arc::new(ModelClientState {
Expand All @@ -128,6 +132,7 @@ impl ModelClient {
effort,
summary,
session_source,
transport_manager,
}),
}
}
Expand All @@ -137,6 +142,7 @@ impl ModelClient {
state: Arc::clone(&self.state),
connection: None,
websocket_last_items: Vec::new(),
transport_manager: self.state.transport_manager.clone(),
turn_state: Arc::new(OnceLock::new()),
}
}
Expand Down Expand Up @@ -171,6 +177,10 @@ impl ModelClient {
self.state.session_source.clone()
}

pub(crate) fn transport_manager(&self) -> TransportManager {
self.state.transport_manager.clone()
}

/// Returns the currently configured model slug.
pub fn get_model(&self) -> String {
self.state.model_info.slug.clone()
Expand Down Expand Up @@ -250,7 +260,10 @@ impl ModelClientSession {
/// For Chat providers, the underlying stream is optionally aggregated
/// based on the `show_raw_agent_reasoning` flag in the config.
pub async fn stream(&mut self, prompt: &Prompt) -> Result<ResponseStream> {
match self.state.provider.wire_api {
let wire_api = self
.transport_manager
.effective_wire_api(self.state.provider.wire_api);
match wire_api {
WireApi::Responses => self.stream_responses_api(prompt).await,
WireApi::ResponsesWebsocket => self.stream_responses_websocket(prompt).await,
WireApi::Chat => {
Expand All @@ -271,6 +284,24 @@ impl ModelClientSession {
}
}

pub(crate) fn try_switch_fallback_transport(&mut self) -> bool {
let activated = self
.transport_manager
.activate_http_fallback(self.state.provider.wire_api);
if activated {
warn!("falling back to HTTP");
self.state.otel_manager.counter(
"codex.transport.fallback_to_http",
1,
&[("from_wire_api", "responses_websocket")],
);

self.connection = None;
self.websocket_last_items.clear();
}
activated
}

fn build_responses_request(&self, prompt: &Prompt) -> Result<ApiPrompt> {
let instructions = prompt.base_instructions.text.clone();
let tools_json: Vec<Value> = create_tools_json_for_responses_api(&prompt.tools)?;
Expand Down
25 changes: 24 additions & 1 deletion codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use crate::stream_events_utils::HandleOutputCtx;
use crate::stream_events_utils::handle_non_tool_response_item;
use crate::stream_events_utils::handle_output_item_done;
use crate::terminal;
use crate::transport_manager::TransportManager;
use crate::truncate::TruncationPolicy;
use crate::user_notification::UserNotifier;
use crate::util::error_or_panic;
Expand Down Expand Up @@ -611,6 +612,7 @@ impl Session {
model_info: ModelInfo,
conversation_id: ThreadId,
sub_id: String,
transport_manager: TransportManager,
) -> TurnContext {
let otel_manager = otel_manager.clone().with_model(
session_configuration.collaboration_mode.model(),
Expand All @@ -627,6 +629,7 @@ impl Session {
session_configuration.model_reasoning_summary,
conversation_id,
session_configuration.session_source.clone(),
transport_manager,
);

let tools_config = ToolsConfig::new(&ToolsConfigParams {
Expand Down Expand Up @@ -856,6 +859,7 @@ impl Session {
skills_manager,
agent_control,
state_db: state_db_ctx.clone(),
transport_manager: TransportManager::new(),
};

let sess = Arc::new(Session {
Expand Down Expand Up @@ -1175,6 +1179,7 @@ impl Session {
model_info,
self.conversation_id,
sub_id,
self.services.transport_manager.clone(),
);
if let Some(final_schema) = final_output_json_schema {
turn_context.final_output_json_schema = final_schema;
Expand Down Expand Up @@ -3029,6 +3034,7 @@ async fn spawn_review_thread(
per_turn_config.model_reasoning_summary,
sess.conversation_id,
parent_turn_context.client.get_session_source(),
parent_turn_context.client.transport_manager(),
);

let review_turn_context = TurnContext {
Expand Down Expand Up @@ -3445,7 +3451,9 @@ async fn run_sampling_request(
)
.await
{
Ok(output) => return Ok(output),
Ok(output) => {
return Ok(output);
}
Err(CodexErr::ContextWindowExceeded) => {
sess.set_total_tokens_full(&turn_context).await;
return Err(CodexErr::ContextWindowExceeded);
Expand All @@ -3466,6 +3474,17 @@ async fn run_sampling_request(

// Use the configured provider-specific stream retry budget.
let max_retries = turn_context.client.get_provider().stream_max_retries();
if retries >= max_retries && client_session.try_switch_fallback_transport() {
sess.send_event(
&turn_context,
EventMsg::Warning(WarningEvent {
message: format!("Falling back from WebSockets to HTTPS transport. {err:#}"),
}),
)
.await;
retries = 0;
continue;
}
if retries < max_retries {
retries += 1;
let delay = match &err {
Expand Down Expand Up @@ -4599,6 +4618,7 @@ mod tests {
skills_manager,
agent_control,
state_db: None,
transport_manager: TransportManager::new(),
};

let turn_context = Session::make_turn_context(
Expand All @@ -4610,6 +4630,7 @@ mod tests {
model_info,
conversation_id,
"turn_id".to_string(),
services.transport_manager.clone(),
);

let session = Session {
Expand Down Expand Up @@ -4711,6 +4732,7 @@ mod tests {
skills_manager,
agent_control,
state_db: None,
transport_manager: TransportManager::new(),
};

let turn_context = Arc::new(Session::make_turn_context(
Expand All @@ -4722,6 +4744,7 @@ mod tests {
model_info,
conversation_id,
"turn_id".to_string(),
services.transport_manager.clone(),
));

let session = Arc::new(Session {
Expand Down
2 changes: 2 additions & 0 deletions codex-rs/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub mod landlock;
pub mod mcp;
mod mcp_connection_manager;
pub mod models_manager;
mod transport_manager;
pub use mcp_connection_manager::MCP_SANDBOX_STATE_CAPABILITY;
pub use mcp_connection_manager::MCP_SANDBOX_STATE_METHOD;
pub use mcp_connection_manager::SandboxState;
Expand Down Expand Up @@ -112,6 +113,7 @@ pub use rollout::list::parse_cursor;
pub use rollout::list::read_head_for_summary;
pub use rollout::list::read_session_meta_line;
pub use rollout::rollout_date_parts;
pub use transport_manager::TransportManager;
mod function_tool;
mod state;
mod tasks;
Expand Down
2 changes: 2 additions & 0 deletions codex-rs/core/src/state/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::models_manager::manager::ModelsManager;
use crate::skills::SkillsManager;
use crate::state_db::StateDbHandle;
use crate::tools::sandboxing::ApprovalStore;
use crate::transport_manager::TransportManager;
use crate::unified_exec::UnifiedExecProcessManager;
use crate::user_notification::UserNotifier;
use codex_otel::OtelManager;
Expand All @@ -32,4 +33,5 @@ pub(crate) struct SessionServices {
pub(crate) skills_manager: Arc<SkillsManager>,
pub(crate) agent_control: AgentControl,
pub(crate) state_db: Option<StateDbHandle>,
pub(crate) transport_manager: TransportManager,
}
3 changes: 3 additions & 0 deletions codex-rs/core/src/tools/handlers/collab.rs
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,7 @@ mod tests {
turn.client.get_reasoning_summary(),
session.conversation_id,
session_source,
session.services.transport_manager.clone(),
);

let invocation = invocation(
Expand Down Expand Up @@ -1221,6 +1222,7 @@ mod tests {
let mut base_config = (*turn.client.config()).clone();
base_config.user_instructions = Some("base-user".to_string());
turn.user_instructions = Some("resolved-user".to_string());
let transport_manager = turn.client.transport_manager();
turn.client = ModelClient::new(
Arc::new(base_config.clone()),
Some(session.services.auth_manager.clone()),
Expand All @@ -1231,6 +1233,7 @@ mod tests {
turn.client.get_reasoning_summary(),
session.conversation_id,
session_source,
transport_manager,
);
let base_instructions = BaseInstructions {
text: "base".to_string(),
Expand Down
31 changes: 31 additions & 0 deletions codex-rs/core/src/transport_manager.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;

use crate::model_provider_info::WireApi;

#[derive(Clone, Debug, Default)]
pub struct TransportManager {
fallback_to_http: Arc<AtomicBool>,
}

impl TransportManager {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not in love with this single field class that has to be dragged everywhere but we need a bit that is persisted throughout the session

pub fn new() -> Self {
Self::default()
}

pub fn effective_wire_api(&self, provider_wire_api: WireApi) -> WireApi {
if self.fallback_to_http.load(Ordering::Relaxed)
&& provider_wire_api == WireApi::ResponsesWebsocket
{
WireApi::Responses
} else {
provider_wire_api
}
}

pub fn activate_http_fallback(&self, provider_wire_api: WireApi) -> bool {
provider_wire_api == WireApi::ResponsesWebsocket
&& !self.fallback_to_http.swap(true, Ordering::Relaxed)
}
}
2 changes: 2 additions & 0 deletions codex-rs/core/tests/chat_completions_payload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use codex_core::ModelClient;
use codex_core::ModelProviderInfo;
use codex_core::Prompt;
use codex_core::ResponseItem;
use codex_core::TransportManager;
use codex_core::WireApi;
use codex_core::models_manager::manager::ModelsManager;
use codex_otel::OtelManager;
Expand Down Expand Up @@ -98,6 +99,7 @@ async fn run_request(input: Vec<ResponseItem>) -> Value {
summary,
conversation_id,
SessionSource::Exec,
TransportManager::new(),
)
.new_session();

Expand Down
2 changes: 2 additions & 0 deletions codex-rs/core/tests/chat_completions_sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use codex_core::ModelProviderInfo;
use codex_core::Prompt;
use codex_core::ResponseEvent;
use codex_core::ResponseItem;
use codex_core::TransportManager;
use codex_core::WireApi;
use codex_core::models_manager::manager::ModelsManager;
use codex_otel::OtelManager;
Expand Down Expand Up @@ -99,6 +100,7 @@ async fn run_stream_with_bytes(sse_body: &[u8]) -> Vec<ResponseEvent> {
summary,
conversation_id,
SessionSource::Exec,
TransportManager::new(),
)
.new_session();

Expand Down
4 changes: 4 additions & 0 deletions codex-rs/core/tests/responses_headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use codex_core::ModelProviderInfo;
use codex_core::Prompt;
use codex_core::ResponseEvent;
use codex_core::ResponseItem;
use codex_core::TransportManager;
use codex_core::WEB_SEARCH_ELIGIBLE_HEADER;
use codex_core::WireApi;
use codex_core::models_manager::manager::ModelsManager;
Expand Down Expand Up @@ -94,6 +95,7 @@ async fn responses_stream_includes_subagent_header_on_review() {
summary,
conversation_id,
session_source,
TransportManager::new(),
)
.new_session();

Expand Down Expand Up @@ -191,6 +193,7 @@ async fn responses_stream_includes_subagent_header_on_other() {
summary,
conversation_id,
session_source,
TransportManager::new(),
)
.new_session();

Expand Down Expand Up @@ -346,6 +349,7 @@ async fn responses_respects_model_info_overrides_from_config() {
summary,
conversation_id,
session_source,
TransportManager::new(),
)
.new_session();

Expand Down
2 changes: 2 additions & 0 deletions codex-rs/core/tests/suite/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use codex_core::Prompt;
use codex_core::ResponseEvent;
use codex_core::ResponseItem;
use codex_core::ThreadManager;
use codex_core::TransportManager;
use codex_core::WireApi;
use codex_core::auth::AuthCredentialsStoreMode;
use codex_core::built_in_model_providers;
Expand Down Expand Up @@ -1186,6 +1187,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
summary,
conversation_id,
SessionSource::Exec,
TransportManager::new(),
)
.new_session();

Expand Down
2 changes: 2 additions & 0 deletions codex-rs/core/tests/suite/client_websockets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use codex_core::ModelProviderInfo;
use codex_core::Prompt;
use codex_core::ResponseEvent;
use codex_core::ResponseItem;
use codex_core::TransportManager;
use codex_core::WireApi;
use codex_core::models_manager::manager::ModelsManager;
use codex_core::protocol::SessionSource;
Expand Down Expand Up @@ -228,6 +229,7 @@ async fn websocket_harness(server: &WebSocketTestServer) -> WebsocketTestHarness
ReasoningSummary::Auto,
conversation_id,
SessionSource::Exec,
TransportManager::new(),
);

WebsocketTestHarness {
Expand Down
1 change: 1 addition & 0 deletions codex-rs/core/tests/suite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,4 @@ mod user_notification;
mod user_shell_cmd;
mod view_image;
mod web_search_cached;
mod websocket_fallback;
Loading
Loading