diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index f415c3604a8..98fce871fd9 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -65,6 +65,7 @@ use crate::model_provider_info::ModelProviderInfo; use crate::model_provider_info::WireApi; use crate::tools::spec::create_tools_json_for_chat_completions_api; use crate::tools::spec::create_tools_json_for_responses_api; +use crate::transport_manager::TransportManager; pub const WEB_SEARCH_ELIGIBLE_HEADER: &str = "x-oai-web-search-eligible"; pub const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state"; @@ -80,6 +81,7 @@ struct ModelClientState { effort: Option, summary: ReasoningSummaryConfig, session_source: SessionSource, + transport_manager: TransportManager, } #[derive(Debug, Clone)] @@ -91,6 +93,7 @@ pub struct ModelClientSession { state: Arc, connection: Option, websocket_last_items: Vec, + transport_manager: TransportManager, /// Turn state for sticky routing. /// /// This is an `OnceLock` that stores the turn state value received from the server @@ -116,6 +119,7 @@ impl ModelClient { summary: ReasoningSummaryConfig, conversation_id: ThreadId, session_source: SessionSource, + transport_manager: TransportManager, ) -> Self { Self { state: Arc::new(ModelClientState { @@ -128,6 +132,7 @@ impl ModelClient { effort, summary, session_source, + transport_manager, }), } } @@ -137,6 +142,7 @@ impl ModelClient { state: Arc::clone(&self.state), connection: None, websocket_last_items: Vec::new(), + transport_manager: self.state.transport_manager.clone(), turn_state: Arc::new(OnceLock::new()), } } @@ -171,6 +177,10 @@ impl ModelClient { self.state.session_source.clone() } + pub(crate) fn transport_manager(&self) -> TransportManager { + self.state.transport_manager.clone() + } + /// Returns the currently configured model slug. pub fn get_model(&self) -> String { self.state.model_info.slug.clone() @@ -250,7 +260,10 @@ impl ModelClientSession { /// For Chat providers, the underlying stream is optionally aggregated /// based on the `show_raw_agent_reasoning` flag in the config. pub async fn stream(&mut self, prompt: &Prompt) -> Result { - match self.state.provider.wire_api { + let wire_api = self + .transport_manager + .effective_wire_api(self.state.provider.wire_api); + match wire_api { WireApi::Responses => self.stream_responses_api(prompt).await, WireApi::ResponsesWebsocket => self.stream_responses_websocket(prompt).await, WireApi::Chat => { @@ -271,6 +284,24 @@ impl ModelClientSession { } } + pub(crate) fn try_switch_fallback_transport(&mut self) -> bool { + let activated = self + .transport_manager + .activate_http_fallback(self.state.provider.wire_api); + if activated { + warn!("falling back to HTTP"); + self.state.otel_manager.counter( + "codex.transport.fallback_to_http", + 1, + &[("from_wire_api", "responses_websocket")], + ); + + self.connection = None; + self.websocket_last_items.clear(); + } + activated + } + fn build_responses_request(&self, prompt: &Prompt) -> Result { let instructions = prompt.base_instructions.text.clone(); let tools_json: Vec = create_tools_json_for_responses_api(&prompt.tools)?; diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index c2518052723..a8e010d796c 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -30,6 +30,7 @@ use crate::stream_events_utils::HandleOutputCtx; use crate::stream_events_utils::handle_non_tool_response_item; use crate::stream_events_utils::handle_output_item_done; use crate::terminal; +use crate::transport_manager::TransportManager; use crate::truncate::TruncationPolicy; use crate::user_notification::UserNotifier; use crate::util::error_or_panic; @@ -611,6 +612,7 @@ impl Session { model_info: ModelInfo, conversation_id: ThreadId, sub_id: String, + transport_manager: TransportManager, ) -> TurnContext { let otel_manager = otel_manager.clone().with_model( session_configuration.collaboration_mode.model(), @@ -627,6 +629,7 @@ impl Session { session_configuration.model_reasoning_summary, conversation_id, session_configuration.session_source.clone(), + transport_manager, ); let tools_config = ToolsConfig::new(&ToolsConfigParams { @@ -856,6 +859,7 @@ impl Session { skills_manager, agent_control, state_db: state_db_ctx.clone(), + transport_manager: TransportManager::new(), }; let sess = Arc::new(Session { @@ -1175,6 +1179,7 @@ impl Session { model_info, self.conversation_id, sub_id, + self.services.transport_manager.clone(), ); if let Some(final_schema) = final_output_json_schema { turn_context.final_output_json_schema = final_schema; @@ -3029,6 +3034,7 @@ async fn spawn_review_thread( per_turn_config.model_reasoning_summary, sess.conversation_id, parent_turn_context.client.get_session_source(), + parent_turn_context.client.transport_manager(), ); let review_turn_context = TurnContext { @@ -3445,7 +3451,9 @@ async fn run_sampling_request( ) .await { - Ok(output) => return Ok(output), + Ok(output) => { + return Ok(output); + } Err(CodexErr::ContextWindowExceeded) => { sess.set_total_tokens_full(&turn_context).await; return Err(CodexErr::ContextWindowExceeded); @@ -3466,6 +3474,17 @@ async fn run_sampling_request( // Use the configured provider-specific stream retry budget. let max_retries = turn_context.client.get_provider().stream_max_retries(); + if retries >= max_retries && client_session.try_switch_fallback_transport() { + sess.send_event( + &turn_context, + EventMsg::Warning(WarningEvent { + message: format!("Falling back from WebSockets to HTTPS transport. {err:#}"), + }), + ) + .await; + retries = 0; + continue; + } if retries < max_retries { retries += 1; let delay = match &err { @@ -4599,6 +4618,7 @@ mod tests { skills_manager, agent_control, state_db: None, + transport_manager: TransportManager::new(), }; let turn_context = Session::make_turn_context( @@ -4610,6 +4630,7 @@ mod tests { model_info, conversation_id, "turn_id".to_string(), + services.transport_manager.clone(), ); let session = Session { @@ -4711,6 +4732,7 @@ mod tests { skills_manager, agent_control, state_db: None, + transport_manager: TransportManager::new(), }; let turn_context = Arc::new(Session::make_turn_context( @@ -4722,6 +4744,7 @@ mod tests { model_info, conversation_id, "turn_id".to_string(), + services.transport_manager.clone(), )); let session = Arc::new(Session { diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index eaf25d14e8e..454fb22b329 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -38,6 +38,7 @@ pub mod landlock; pub mod mcp; mod mcp_connection_manager; pub mod models_manager; +mod transport_manager; pub use mcp_connection_manager::MCP_SANDBOX_STATE_CAPABILITY; pub use mcp_connection_manager::MCP_SANDBOX_STATE_METHOD; pub use mcp_connection_manager::SandboxState; @@ -112,6 +113,7 @@ pub use rollout::list::parse_cursor; pub use rollout::list::read_head_for_summary; pub use rollout::list::read_session_meta_line; pub use rollout::rollout_date_parts; +pub use transport_manager::TransportManager; mod function_tool; mod state; mod tasks; diff --git a/codex-rs/core/src/state/service.rs b/codex-rs/core/src/state/service.rs index 6559ec2e182..e036b29b7d7 100644 --- a/codex-rs/core/src/state/service.rs +++ b/codex-rs/core/src/state/service.rs @@ -9,6 +9,7 @@ use crate::models_manager::manager::ModelsManager; use crate::skills::SkillsManager; use crate::state_db::StateDbHandle; use crate::tools::sandboxing::ApprovalStore; +use crate::transport_manager::TransportManager; use crate::unified_exec::UnifiedExecProcessManager; use crate::user_notification::UserNotifier; use codex_otel::OtelManager; @@ -32,4 +33,5 @@ pub(crate) struct SessionServices { pub(crate) skills_manager: Arc, pub(crate) agent_control: AgentControl, pub(crate) state_db: Option, + pub(crate) transport_manager: TransportManager, } diff --git a/codex-rs/core/src/tools/handlers/collab.rs b/codex-rs/core/src/tools/handlers/collab.rs index b1666949f61..61ccc1932e9 100644 --- a/codex-rs/core/src/tools/handlers/collab.rs +++ b/codex-rs/core/src/tools/handlers/collab.rs @@ -781,6 +781,7 @@ mod tests { turn.client.get_reasoning_summary(), session.conversation_id, session_source, + session.services.transport_manager.clone(), ); let invocation = invocation( @@ -1221,6 +1222,7 @@ mod tests { let mut base_config = (*turn.client.config()).clone(); base_config.user_instructions = Some("base-user".to_string()); turn.user_instructions = Some("resolved-user".to_string()); + let transport_manager = turn.client.transport_manager(); turn.client = ModelClient::new( Arc::new(base_config.clone()), Some(session.services.auth_manager.clone()), @@ -1231,6 +1233,7 @@ mod tests { turn.client.get_reasoning_summary(), session.conversation_id, session_source, + transport_manager, ); let base_instructions = BaseInstructions { text: "base".to_string(), diff --git a/codex-rs/core/src/transport_manager.rs b/codex-rs/core/src/transport_manager.rs new file mode 100644 index 00000000000..f1fe68faf4a --- /dev/null +++ b/codex-rs/core/src/transport_manager.rs @@ -0,0 +1,31 @@ +use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; + +use crate::model_provider_info::WireApi; + +#[derive(Clone, Debug, Default)] +pub struct TransportManager { + fallback_to_http: Arc, +} + +impl TransportManager { + pub fn new() -> Self { + Self::default() + } + + pub fn effective_wire_api(&self, provider_wire_api: WireApi) -> WireApi { + if self.fallback_to_http.load(Ordering::Relaxed) + && provider_wire_api == WireApi::ResponsesWebsocket + { + WireApi::Responses + } else { + provider_wire_api + } + } + + pub fn activate_http_fallback(&self, provider_wire_api: WireApi) -> bool { + provider_wire_api == WireApi::ResponsesWebsocket + && !self.fallback_to_http.swap(true, Ordering::Relaxed) + } +} diff --git a/codex-rs/core/tests/chat_completions_payload.rs b/codex-rs/core/tests/chat_completions_payload.rs index cdb92fe672c..5e8c895a3fb 100644 --- a/codex-rs/core/tests/chat_completions_payload.rs +++ b/codex-rs/core/tests/chat_completions_payload.rs @@ -11,6 +11,7 @@ use codex_core::ModelClient; use codex_core::ModelProviderInfo; use codex_core::Prompt; use codex_core::ResponseItem; +use codex_core::TransportManager; use codex_core::WireApi; use codex_core::models_manager::manager::ModelsManager; use codex_otel::OtelManager; @@ -98,6 +99,7 @@ async fn run_request(input: Vec) -> Value { summary, conversation_id, SessionSource::Exec, + TransportManager::new(), ) .new_session(); diff --git a/codex-rs/core/tests/chat_completions_sse.rs b/codex-rs/core/tests/chat_completions_sse.rs index 05ef476a057..84d0ede0855 100644 --- a/codex-rs/core/tests/chat_completions_sse.rs +++ b/codex-rs/core/tests/chat_completions_sse.rs @@ -10,6 +10,7 @@ use codex_core::ModelProviderInfo; use codex_core::Prompt; use codex_core::ResponseEvent; use codex_core::ResponseItem; +use codex_core::TransportManager; use codex_core::WireApi; use codex_core::models_manager::manager::ModelsManager; use codex_otel::OtelManager; @@ -99,6 +100,7 @@ async fn run_stream_with_bytes(sse_body: &[u8]) -> Vec { summary, conversation_id, SessionSource::Exec, + TransportManager::new(), ) .new_session(); diff --git a/codex-rs/core/tests/responses_headers.rs b/codex-rs/core/tests/responses_headers.rs index 22d9fa8b796..b3c0165f46b 100644 --- a/codex-rs/core/tests/responses_headers.rs +++ b/codex-rs/core/tests/responses_headers.rs @@ -9,6 +9,7 @@ use codex_core::ModelProviderInfo; use codex_core::Prompt; use codex_core::ResponseEvent; use codex_core::ResponseItem; +use codex_core::TransportManager; use codex_core::WEB_SEARCH_ELIGIBLE_HEADER; use codex_core::WireApi; use codex_core::models_manager::manager::ModelsManager; @@ -94,6 +95,7 @@ async fn responses_stream_includes_subagent_header_on_review() { summary, conversation_id, session_source, + TransportManager::new(), ) .new_session(); @@ -191,6 +193,7 @@ async fn responses_stream_includes_subagent_header_on_other() { summary, conversation_id, session_source, + TransportManager::new(), ) .new_session(); @@ -346,6 +349,7 @@ async fn responses_respects_model_info_overrides_from_config() { summary, conversation_id, session_source, + TransportManager::new(), ) .new_session(); diff --git a/codex-rs/core/tests/suite/client.rs b/codex-rs/core/tests/suite/client.rs index bdc2f311c01..8b8adf06abb 100644 --- a/codex-rs/core/tests/suite/client.rs +++ b/codex-rs/core/tests/suite/client.rs @@ -11,6 +11,7 @@ use codex_core::Prompt; use codex_core::ResponseEvent; use codex_core::ResponseItem; use codex_core::ThreadManager; +use codex_core::TransportManager; use codex_core::WireApi; use codex_core::auth::AuthCredentialsStoreMode; use codex_core::built_in_model_providers; @@ -1186,6 +1187,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() { summary, conversation_id, SessionSource::Exec, + TransportManager::new(), ) .new_session(); diff --git a/codex-rs/core/tests/suite/client_websockets.rs b/codex-rs/core/tests/suite/client_websockets.rs index d55d71d0d88..60229e01d6e 100644 --- a/codex-rs/core/tests/suite/client_websockets.rs +++ b/codex-rs/core/tests/suite/client_websockets.rs @@ -8,6 +8,7 @@ use codex_core::ModelProviderInfo; use codex_core::Prompt; use codex_core::ResponseEvent; use codex_core::ResponseItem; +use codex_core::TransportManager; use codex_core::WireApi; use codex_core::models_manager::manager::ModelsManager; use codex_core::protocol::SessionSource; @@ -228,6 +229,7 @@ async fn websocket_harness(server: &WebSocketTestServer) -> WebsocketTestHarness ReasoningSummary::Auto, conversation_id, SessionSource::Exec, + TransportManager::new(), ); WebsocketTestHarness { diff --git a/codex-rs/core/tests/suite/mod.rs b/codex-rs/core/tests/suite/mod.rs index e054070b344..0744624a928 100644 --- a/codex-rs/core/tests/suite/mod.rs +++ b/codex-rs/core/tests/suite/mod.rs @@ -80,3 +80,4 @@ mod user_notification; mod user_shell_cmd; mod view_image; mod web_search_cached; +mod websocket_fallback; diff --git a/codex-rs/core/tests/suite/websocket_fallback.rs b/codex-rs/core/tests/suite/websocket_fallback.rs new file mode 100644 index 00000000000..d52236a8a2d --- /dev/null +++ b/codex-rs/core/tests/suite/websocket_fallback.rs @@ -0,0 +1,98 @@ +use anyhow::Result; +use codex_core::WireApi; +use core_test_support::responses; +use core_test_support::responses::ev_completed; +use core_test_support::responses::ev_response_created; +use core_test_support::responses::mount_sse_once; +use core_test_support::responses::mount_sse_sequence; +use core_test_support::responses::sse; +use core_test_support::skip_if_no_network; +use core_test_support::test_codex::test_codex; +use pretty_assertions::assert_eq; +use wiremock::http::Method; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn websocket_fallback_switches_to_http_after_retries_exhausted() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + let response_mock = mount_sse_once( + &server, + sse(vec![ev_response_created("resp-1"), ev_completed("resp-1")]), + ) + .await; + + let mut builder = test_codex().with_config({ + let base_url = format!("{}/v1", server.uri()); + move |config| { + config.model_provider.base_url = Some(base_url); + config.model_provider.wire_api = WireApi::ResponsesWebsocket; + config.model_provider.stream_max_retries = Some(0); + config.model_provider.request_max_retries = Some(0); + } + }); + let test = builder.build(&server).await?; + + test.submit_turn("hello").await?; + + let requests = server.received_requests().await.unwrap_or_default(); + let websocket_attempts = requests + .iter() + .filter(|req| req.method == Method::GET && req.url.path().ends_with("/responses")) + .count(); + let http_attempts = requests + .iter() + .filter(|req| req.method == Method::POST && req.url.path().ends_with("/responses")) + .count(); + + assert_eq!(websocket_attempts, 1); + assert_eq!(http_attempts, 1); + assert_eq!(response_mock.requests().len(), 1); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn websocket_fallback_is_sticky_across_turns() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + let response_mock = mount_sse_sequence( + &server, + vec![ + sse(vec![ev_response_created("resp-1"), ev_completed("resp-1")]), + sse(vec![ev_response_created("resp-2"), ev_completed("resp-2")]), + ], + ) + .await; + + let mut builder = test_codex().with_config({ + let base_url = format!("{}/v1", server.uri()); + move |config| { + config.model_provider.base_url = Some(base_url); + config.model_provider.wire_api = WireApi::ResponsesWebsocket; + config.model_provider.stream_max_retries = Some(0); + config.model_provider.request_max_retries = Some(0); + } + }); + let test = builder.build(&server).await?; + + test.submit_turn("first").await?; + test.submit_turn("second").await?; + + let requests = server.received_requests().await.unwrap_or_default(); + let websocket_attempts = requests + .iter() + .filter(|req| req.method == Method::GET && req.url.path().ends_with("/responses")) + .count(); + let http_attempts = requests + .iter() + .filter(|req| req.method == Method::POST && req.url.path().ends_with("/responses")) + .count(); + + assert_eq!(websocket_attempts, 1); + assert_eq!(http_attempts, 2); + assert_eq!(response_mock.requests().len(), 2); + + Ok(()) +}