Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions codex-rs/codex-api/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ pub struct ResponsesApiRequest<'a> {
pub struct ResponseCreateWsRequest {
pub model: String,
pub instructions: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub previous_response_id: Option<String>,
pub input: Vec<ResponseItem>,
pub tools: Vec<Value>,
pub tool_choice: String,
Expand Down
8 changes: 7 additions & 1 deletion codex-rs/core/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@
"responses_websockets": {
"type": "boolean"
},
"responses_websockets_v2": {
"type": "boolean"
},
"runtime_metrics": {
"type": "boolean"
},
Expand Down Expand Up @@ -1279,6 +1282,9 @@
"responses_websockets": {
"type": "boolean"
},
"responses_websockets_v2": {
"type": "boolean"
},
"runtime_metrics": {
"type": "boolean"
},
Expand Down Expand Up @@ -1623,4 +1629,4 @@
},
"title": "ConfigToml",
"type": "object"
}
}
126 changes: 112 additions & 14 deletions codex-rs/core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ use reqwest::StatusCode;
use serde_json::Value;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::sync::oneshot::error::TryRecvError;
use tokio::task::JoinHandle;
use tokio_tungstenite::tungstenite::Error;
use tokio_tungstenite::tungstenite::Message;
Expand Down Expand Up @@ -117,6 +119,7 @@ pub const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
pub const X_CODEX_TURN_METADATA_HEADER: &str = "x-codex-turn-metadata";
pub const X_RESPONSESAPI_INCLUDE_TIMING_METRICS_HEADER: &str =
"x-responsesapi-include-timing-metrics";
const RESPONSES_WEBSOCKETS_V2_BETA_HEADER_VALUE: &str = "responses_websockets=2026-02-06";

/// Session-scoped state shared by all [`ModelClient`] clones.
///
Expand All @@ -129,6 +132,7 @@ struct ModelClientState {
session_source: SessionSource,
model_verbosity: Option<VerbosityConfig>,
enable_responses_websockets: bool,
enable_responses_websockets_v2: bool,
enable_request_compression: bool,
include_timing_metrics: bool,
beta_features_header: Option<String>,
Expand Down Expand Up @@ -238,6 +242,8 @@ pub struct ModelClientSession {
client: ModelClient,
connection: Option<ApiWebSocketConnection>,
websocket_last_items: Vec<ResponseItem>,
websocket_last_response_id: Option<String>,
websocket_last_response_id_rx: Option<oneshot::Receiver<String>>,
/// Turn state for sticky routing.
///
/// This is an `OnceLock` that stores the turn state value received from the server
Expand All @@ -264,6 +270,7 @@ impl ModelClient {
session_source: SessionSource,
model_verbosity: Option<VerbosityConfig>,
enable_responses_websockets: bool,
enable_responses_websockets_v2: bool,
enable_request_compression: bool,
include_timing_metrics: bool,
beta_features_header: Option<String>,
Expand All @@ -276,6 +283,7 @@ impl ModelClient {
session_source,
model_verbosity,
enable_responses_websockets,
enable_responses_websockets_v2,
enable_request_compression,
include_timing_metrics,
beta_features_header,
Expand All @@ -295,6 +303,8 @@ impl ModelClient {
client: self.clone(),
connection: None,
websocket_last_items: Vec::new(),
websocket_last_response_id: None,
websocket_last_response_id_rx: None,
turn_state: Arc::new(OnceLock::new()),
}
}
Expand Down Expand Up @@ -479,6 +489,10 @@ impl ModelClient {
self.state.provider.supports_websockets && self.state.enable_responses_websockets
}

fn responses_websockets_v2_enabled(&self) -> bool {
self.state.enable_responses_websockets_v2
}

/// Returns whether websocket transport has been permanently disabled for this session.
///
/// Once set by fallback activation, subsequent turns must stay on HTTP transport.
Expand Down Expand Up @@ -544,9 +558,14 @@ impl ModelClient {
headers.extend(build_conversation_headers(Some(
self.state.conversation_id.to_string(),
)));
let responses_websockets_beta_header = if self.responses_websockets_v2_enabled() {
RESPONSES_WEBSOCKETS_V2_BETA_HEADER_VALUE
} else {
OPENAI_BETA_RESPONSES_WEBSOCKETS
};
headers.insert(
OPENAI_BETA_HEADER,
HeaderValue::from_static(OPENAI_BETA_RESPONSES_WEBSOCKETS),
HeaderValue::from_static(responses_websockets_beta_header),
);
if self.state.include_timing_metrics {
headers.insert(
Expand Down Expand Up @@ -789,18 +808,37 @@ impl ModelClientSession {
}
}

fn prepare_websocket_request(
fn refresh_websocket_last_response_id(&mut self) {
if let Some(mut receiver) = self.websocket_last_response_id_rx.take() {
match receiver.try_recv() {
Ok(response_id) if !response_id.is_empty() => {
self.websocket_last_response_id = Some(response_id);
}
Ok(_) | Err(TryRecvError::Closed) => {
self.websocket_last_response_id = None;
}
Err(TryRecvError::Empty) => {
self.websocket_last_response_id_rx = Some(receiver);
}
}
}
}

fn websocket_previous_response_id(&mut self) -> Option<String> {
self.refresh_websocket_last_response_id();
self.websocket_last_response_id
.clone()
.filter(|id| !id.is_empty())
}

fn prepare_websocket_create_request(
&self,
model_slug: &str,
api_prompt: &ApiPrompt,
options: &ApiResponsesOptions,
input: Vec<ResponseItem>,
previous_response_id: Option<String>,
) -> ResponsesWsRequest {
if let Some(append_items) = self.get_incremental_items(&api_prompt.input) {
return ResponsesWsRequest::ResponseAppend(ResponseAppendWsRequest {
input: append_items,
});
}

let ApiResponsesOptions {
reasoning,
include,
Expand All @@ -814,7 +852,8 @@ impl ModelClientSession {
let payload = ResponseCreateWsRequest {
model: model_slug.to_string(),
instructions: api_prompt.instructions.clone(),
input: api_prompt.input.clone(),
previous_response_id,
input,
tools: api_prompt.tools.clone(),
tool_choice: "auto".to_string(),
parallel_tool_calls: api_prompt.parallel_tool_calls,
Expand All @@ -829,6 +868,43 @@ impl ModelClientSession {
ResponsesWsRequest::ResponseCreate(payload)
}

fn prepare_websocket_request(
&mut self,
model_slug: &str,
api_prompt: &ApiPrompt,
options: &ApiResponsesOptions,
) -> ResponsesWsRequest {
let responses_websockets_v2_enabled = self.client.responses_websockets_v2_enabled();
let incremental_items = self.get_incremental_items(&api_prompt.input);
if let Some(append_items) = incremental_items {
if responses_websockets_v2_enabled
&& let Some(previous_response_id) = self.websocket_previous_response_id()
{
return self.prepare_websocket_create_request(
model_slug,
api_prompt,
options,
append_items,
Some(previous_response_id),
);
}

if !responses_websockets_v2_enabled {
return ResponsesWsRequest::ResponseAppend(ResponseAppendWsRequest {
input: append_items,
});
}
}

self.prepare_websocket_create_request(
model_slug,
api_prompt,
options,
api_prompt.input.clone(),
None,
)
}

/// Returns a websocket connection for this turn, reusing preconnect when possible.
///
/// This method first tries to adopt the session-level preconnect slot, then falls back to a
Expand Down Expand Up @@ -863,6 +939,9 @@ impl ModelClientSession {

if needs_new {
self.client.clear_preconnect();
self.websocket_last_items.clear();
self.websocket_last_response_id = None;
self.websocket_last_response_id_rx = None;
let turn_state = options
.turn_state
.clone()
Expand Down Expand Up @@ -1023,9 +1102,8 @@ impl ModelClientSession {
turn_metadata_header,
compression,
);
let request = self.prepare_websocket_request(&model_info.slug, &api_prompt, &options);

let connection = match self
match self
.websocket_connection(
otel_manager,
client_setup.api_provider,
Expand All @@ -1035,21 +1113,41 @@ impl ModelClientSession {
)
.await
{
Ok(connection) => connection,
Ok(_) => {}
Err(ApiError::Transport(
unauthorized_transport @ TransportError::Http { status, .. },
)) if status == StatusCode::UNAUTHORIZED => {
handle_unauthorized(unauthorized_transport, &mut auth_recovery).await?;
continue;
}
Err(err) => return Err(map_api_error(err)),
};
}

let request = self.prepare_websocket_request(&model_info.slug, &api_prompt, &options);

let stream_result = connection
let stream_result = self
.connection
.as_ref()
.ok_or_else(|| {
map_api_error(ApiError::Stream(
"websocket connection is unavailable".to_string(),
))
})?
.stream_request(request)
.await
.map_err(map_api_error)?;
self.websocket_last_items = api_prompt.input.clone();
let (last_response_id_sender, last_response_id_receiver) = oneshot::channel();
self.websocket_last_response_id_rx = Some(last_response_id_receiver);
let mut last_response_id_sender = Some(last_response_id_sender);
let stream_result = stream_result.inspect(move |event| {
if let Ok(ResponseEvent::Completed { response_id, .. }) = event
&& !response_id.is_empty()
&& let Some(sender) = last_response_id_sender.take()
{
let _ = sender.send(response_id.clone());
}
});

return Ok(map_response_stream(stream_result, otel_manager.clone()));
}
Expand Down
12 changes: 9 additions & 3 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,9 @@ impl Session {
session_configuration.provider.clone(),
session_configuration.session_source.clone(),
config.model_verbosity,
config.features.enabled(Feature::ResponsesWebsockets),
config.features.enabled(Feature::ResponsesWebsockets)
|| config.features.enabled(Feature::ResponsesWebsocketsV2),
config.features.enabled(Feature::ResponsesWebsocketsV2),
config.features.enabled(Feature::EnableRequestCompression),
config.features.enabled(Feature::RuntimeMetrics),
Self::build_model_client_beta_features_header(config.as_ref()),
Expand Down Expand Up @@ -5866,7 +5868,9 @@ mod tests {
session_configuration.provider.clone(),
session_configuration.session_source.clone(),
config.model_verbosity,
config.features.enabled(Feature::ResponsesWebsockets),
config.features.enabled(Feature::ResponsesWebsockets)
|| config.features.enabled(Feature::ResponsesWebsocketsV2),
config.features.enabled(Feature::ResponsesWebsocketsV2),
config.features.enabled(Feature::EnableRequestCompression),
config.features.enabled(Feature::RuntimeMetrics),
Session::build_model_client_beta_features_header(config.as_ref()),
Expand Down Expand Up @@ -5996,7 +6000,9 @@ mod tests {
session_configuration.provider.clone(),
session_configuration.session_source.clone(),
config.model_verbosity,
config.features.enabled(Feature::ResponsesWebsockets),
config.features.enabled(Feature::ResponsesWebsockets)
|| config.features.enabled(Feature::ResponsesWebsocketsV2),
config.features.enabled(Feature::ResponsesWebsocketsV2),
config.features.enabled(Feature::EnableRequestCompression),
config.features.enabled(Feature::RuntimeMetrics),
Session::build_model_client_beta_features_header(config.as_ref()),
Expand Down
36 changes: 19 additions & 17 deletions codex-rs/core/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2685,25 +2685,27 @@ profile = "project"
}

#[test]
fn responses_websockets_feature_does_not_change_wire_api() -> std::io::Result<()> {
let codex_home = TempDir::new()?;
let mut entries = BTreeMap::new();
entries.insert("responses_websockets".to_string(), true);
let cfg = ConfigToml {
features: Some(crate::features::FeaturesToml { entries }),
..Default::default()
};
fn responses_websocket_features_do_not_change_wire_api() -> std::io::Result<()> {
for feature_key in ["responses_websockets", "responses_websockets_v2"] {
let codex_home = TempDir::new()?;
let mut entries = BTreeMap::new();
entries.insert(feature_key.to_string(), true);
let cfg = ConfigToml {
features: Some(crate::features::FeaturesToml { entries }),
..Default::default()
};

let config = Config::load_from_base_config_with_overrides(
cfg,
ConfigOverrides::default(),
codex_home.path().to_path_buf(),
)?;
let config = Config::load_from_base_config_with_overrides(
cfg,
ConfigOverrides::default(),
codex_home.path().to_path_buf(),
)?;

assert_eq!(
config.model_provider.wire_api,
crate::model_provider_info::WireApi::Responses
);
assert_eq!(
config.model_provider.wire_api,
crate::model_provider_info::WireApi::Responses
);
}

Ok(())
}
Expand Down
8 changes: 8 additions & 0 deletions codex-rs/core/src/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ pub enum Feature {
Personality,
/// Use the Responses API WebSocket transport for OpenAI by default.
ResponsesWebsockets,
/// Enable Responses API websocket v2 mode.
ResponsesWebsocketsV2,
}

impl Feature {
Expand Down Expand Up @@ -569,6 +571,12 @@ pub const FEATURES: &[FeatureSpec] = &[
stage: Stage::UnderDevelopment,
default_enabled: false,
},
FeatureSpec {
id: Feature::ResponsesWebsocketsV2,
key: "responses_websockets_v2",
stage: Stage::UnderDevelopment,
default_enabled: false,
},
];

/// Push a warning event if any under-development features are enabled.
Expand Down
Loading
Loading